From 6b1d059eda21c1bd421f3d352786fca2cab61954 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20G=C3=B3rny?= Date: Sat, 18 Jan 2025 05:18:37 +0100 Subject: [PATCH 01/32] Support ROCM builds from source distribution, and improve error handling (#1446) * Always update both submodules to include them in sdist Always update both submodules, irrespectively of whether a CUDA or a ROCM build is being done, to ensure that the necessary files from both are present in sdist. Otherwise, attempt to perform a ROCM build from sdist fails because of missing `composable_kernel` srouces. * Include `*.py` files from composable_kernel in sdist Include the `*.py` files from `csrc` in sdist, to ensure that the `generate.py` script is present. * Replace the `os.system()` calls in `setup.py` with `subprocess.run()` * Add error checking to `subprocess.run()` calls in `setup.py` Add error checking to ensure that `setup.py` fails immediately if one of the commands fail. Otherwise, the failures result only in messages to stderr that could be missed, and could lead to more confusing errors later in the build process. * Call git in `setup.py` only when working in a git repository Call git commands in `setup.py` only when the `.git` directory is present, indicating that we are working in a git checkout. Otherwise, just assert that the needed files are there. With this, building from a source distribution no longer attempts to call git in an incorrect directory. --- MANIFEST.in | 1 + setup.py | 24 ++++++++++++++++-------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 021b4d0f7d3..d3c4b4eda1a 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,6 +3,7 @@ recursive-include csrc *.h recursive-include csrc *.cuh recursive-include csrc *.cpp recursive-include csrc *.hpp +recursive-include csrc *.py recursive-include flash_attn *.cu recursive-include flash_attn *.h diff --git a/setup.py b/setup.py index a802a7e65e4..264b0eed511 100644 --- a/setup.py +++ b/setup.py @@ -145,11 +145,19 @@ def validate_and_update_archs(archs): # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp # files included in the source distribution, in case the user compiles from source. -if IS_ROCM: - if not USE_TRITON_ROCM: - subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"]) +if os.path.isdir(".git"): + subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"], check=True) + subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"], check=True) else: - subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) + if IS_ROCM: + if not USE_TRITON_ROCM: + assert ( + os.path.exists("csrc/composable_kernel/example/ck_tile/01_fmha/generate.py") + ), "csrc/composable_kernel is missing, please use source distribution or git clone" + else: + assert ( + os.path.exists("csrc/cutlass/include/cutlass/cutlass.h") + ), "csrc/cutlass is missing, please use source distribution or git clone" if not SKIP_CUDA_BUILD and not IS_ROCM: print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) @@ -324,10 +332,10 @@ def validate_and_update_archs(archs): if not os.path.exists("./build"): os.makedirs("build") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_appendkv --output_dir build --receipt 2") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd_splitkv --output_dir build --receipt 2") - os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2") + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd", "--output_dir", "build", "--receipt", "2"], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_appendkv", "--output_dir", "build", "--receipt", "2"], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "fwd_splitkv", "--output_dir", "build", "--receipt", "2"], check=True) + subprocess.run([sys.executable, f"{ck_dir}/example/ck_tile/01_fmha/generate.py", "-d", "bwd", "--output_dir", "build", "--receipt", "2"], check=True) # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h # See https://github.com/pytorch/pytorch/pull/70650 From cd393e0ace51f8b0812b6e4f071ef2094082056a Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Wed, 29 Jan 2025 13:27:59 -0800 Subject: [PATCH 02/32] [Build] Update version of setuptools used to generate core package (#1460) --- .github/workflows/publish.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 4746c714930..5dffc0d1413 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -203,7 +203,9 @@ jobs: - name: Install dependencies run: | - pip install ninja packaging setuptools wheel twine + pip install ninja packaging wheel twine + # Install latest setuptools with support for pypi metadata 2.2 (improved compat w/ uv) + pip install setuptools==75.8.0 # We don't want to download anything CUDA-related here pip install torch --index-url https://download.pytorch.org/whl/cpu From bb135af07c362236bde418e9fe3db029d1e7ed88 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 29 Jan 2025 16:31:54 -0500 Subject: [PATCH 03/32] Don't compile for CUDA 11, compile for official pytorch 2.6.0 --- .github/workflows/publish.yml | 8 ++++---- README.md | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 5dffc0d1413..3d67cfbf6a7 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -44,8 +44,8 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001'] - cuda-version: ['11.8.0', '12.3.2'] + torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0'] + cuda-version: ['12.4.1'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) @@ -113,7 +113,7 @@ jobs: run: | pip install --upgrade pip # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools - pip install setuptools==68.0.0 + pip install setuptools==75.8.0 # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable pip install typing-extensions==4.12.2 @@ -149,7 +149,7 @@ jobs: # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 # However this still fails so I'm using a newer version of setuptools - pip install setuptools==68.0.0 + pip install setuptools==75.8.0 pip install ninja packaging wheel export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH diff --git a/README.md b/README.md index 033dba41006..9f57bd56cbb 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ flash_attn_interface.flash_attn_func() ## Installation and features **Requirements:** - CUDA toolkit or ROCm toolkit -- PyTorch 1.12 and above. +- PyTorch 2.1 and above. - `packaging` Python package (`pip install packaging`) - `ninja` Python package (`pip install ninja`) * - Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue. @@ -98,7 +98,7 @@ MAX_JOBS=4 pip install flash-attn --no-build-isolation ### NVIDIA CUDA Support **Requirements:** -- CUDA 11.7 and above. +- CUDA 12.0 and above. We recommend the [Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) From 979702c87a8713a8e0a5e9fee122b90d2ef13be5 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 29 Jan 2025 16:34:02 -0500 Subject: [PATCH 04/32] Bump to v2.7.4 --- flash_attn/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 07d16cd0f48..094b3233d2f 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.7.3" +__version__ = "2.7.4" from flash_attn.flash_attn_interface import ( flash_attn_func, From 5231d95fe13733fb534c01895f7ea88c6a6c7793 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 29 Jan 2025 16:42:56 -0500 Subject: [PATCH 05/32] Drop Pytorch 2.1 --- .github/workflows/publish.yml | 11 +++-------- README.md | 2 +- flash_attn/__init__.py | 2 +- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 3d67cfbf6a7..6f227d1abe1 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -44,7 +44,7 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0'] + torch-version: ['2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0'] cuda-version: ['12.4.1'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. @@ -53,12 +53,7 @@ jobs: cxx11_abi: ['FALSE', 'TRUE'] exclude: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix - # Pytorch < 2.2 does not support Python 3.12 - - torch-version: '2.1.2' - python-version: '3.12' # Pytorch < 2.5 does not support Python 3.13 - - torch-version: '2.1.2' - python-version: '3.13' - torch-version: '2.2.2' python-version: '3.13' - torch-version: '2.3.1' @@ -122,8 +117,8 @@ jobs: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix # This code is ugly, maybe there's a better way to do this. export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \ - maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \ + minv = {'2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \ print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ ) if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then diff --git a/README.md b/README.md index 9f57bd56cbb..aa545ceb071 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ flash_attn_interface.flash_attn_func() ## Installation and features **Requirements:** - CUDA toolkit or ROCm toolkit -- PyTorch 2.1 and above. +- PyTorch 2.2 and above. - `packaging` Python package (`pip install packaging`) - `ninja` Python package (`pip install ninja`) * - Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue. diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 094b3233d2f..db131242dd4 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.7.4" +__version__ = "2.7.4.post1" from flash_attn.flash_attn_interface import ( flash_attn_func, From 454ce31594aaf0978e394ff9a21635b6f6ce56c4 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 29 Jan 2025 18:01:58 -0500 Subject: [PATCH 06/32] [FA3] Compile with nvcc 12.8 instead of 12.3 --- README.md | 2 +- hopper/flash_fwd_launch_template.h | 3 +- hopper/setup.py | 47 +++++++++++++++++------------- 3 files changed, 29 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index aa545ceb071..c5d68536d4b 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ Currently released: Requirements: H100 / H800 GPU, CUDA >= 12.3. -For now, we highly recommend CUDA 12.3 for best performance. +We highly recommend CUDA 12.8 for best performance. To install: ```sh diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 16701f160d2..57d64d6a7b8 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -191,7 +191,8 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { // Only needed here to decide if we should use cluster static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128; - static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; + // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS) + static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { diff --git a/hopper/setup.py b/hopper/setup.py index d95be9ad409..0104819c68f 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -333,22 +333,19 @@ def open_url(url): return urllib.request.urlopen(request, timeout=300) -def download_and_copy(name, src_path, dst_path, version, url_func): +def download_and_copy(name, src_func, dst_path, version, url_func): if is_offline_build(): return flashattn_cache_path = get_flashattn_cache_path() base_dir = os.path.dirname(__file__) system = platform.system() - try: - arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()] - except KeyError: - arch = platform.machine() + arch = platform.machine() + arch = {"arm64": "aarch64"}.get(arch, arch) supported = {"Linux": "linux", "Darwin": "linux"} url = url_func(supported[system], arch, version) + src_path = src_func(supported[system], arch, version) tmp_path = os.path.join(flashattn_cache_path, "nvidia", name) # path to cache the download dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path - platform_name = "sbsa-linux" if arch == "aarch64" else "x86_64-linux" - src_path = src_path(platform_name, version) if callable(src_path) else src_path src_path = os.path.join(tmp_path, src_path) download = not os.path.exists(src_path) if download: @@ -364,11 +361,12 @@ def download_and_copy(name, src_path, dst_path, version, url_func): def nvcc_threads_args(): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" + nvcc_threads = os.getenv("NVCC_THREADS") or "2" return ["--threads", nvcc_threads] -NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"} +# NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"} +NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.8.61"} exe_extension = sysconfig.get_config_var("EXE") @@ -389,24 +387,31 @@ def nvcc_threads_args(): if bare_metal_version < Version("12.3"): raise RuntimeError("FlashAttention-3 is only supported on CUDA 12.3 and above") - if bare_metal_version != Version("12.3"): # nvcc 12.3 gives the best perf currently + if bare_metal_version != Version("12.8"): # nvcc 12.8 gives the best perf currently download_and_copy( - name="nvcc", src_path=f"bin", dst_path="bin", - version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], url_func=lambda system, arch, version: - ((lambda version_major, version_minor1, version_minor2: - f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2") - (*version.split('.')))) + name="nvcc", + # src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas{exe_extension}", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin", + dst_path="bin", + version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", + ) download_and_copy( - name="nvcc", src_path=f"nvvm/bin", dst_path="bin", - version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], url_func=lambda system, arch, version: - ((lambda version_major, version_minor1, version_minor2: - f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2") - (*version.split('.')))) + name="nvcc", + # src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas{exe_extension}", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/nvvm/bin", + dst_path="nvvm/bin", + version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", + ) base_dir = os.path.dirname(__file__) ctk_path_new = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", "bin") nvcc_path_new = os.path.join(ctk_path_new, f"nvcc{exe_extension}") # Need to append to path otherwise nvcc can't find cicc in nvvm/bin/cicc - os.environ["PATH"] = ctk_path_new + os.pathsep + os.environ["PATH"] + # nvcc 12.8 seems to hard-code looking for cicc in ../nvvm/bin/cicc + # os.environ["PATH"] = ctk_path_new + os.pathsep + os.environ["PATH"] os.environ["PYTORCH_NVCC"] = nvcc_path_new # Make nvcc executable, sometimes after the copy it loses its permissions os.chmod(nvcc_path_new, os.stat(nvcc_path_new).st_mode | stat.S_IEXEC) From 803f609aa1c2b7c0f0ddea3a0e7e9fdeaa77e071 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 29 Jan 2025 21:44:20 -0500 Subject: [PATCH 07/32] Fix comment in assert --- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index dbbf2f8f821..3af51566b01 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -68,7 +68,7 @@ struct CollectiveMainloopFwdSm90 { // Leaving this option here for reference. static constexpr bool Mma0_is_RS = false; // We can have Mma1 (P @ V) with P in smem in rmem to reduce register pressure at the cost of more smem. - static_assert(!(!Mma1_is_RS && !IntraWGOverlap), "Mma1 must be RS if IntraWGOverlap is enabled"); + static_assert(!(!Mma1_is_RS && !IntraWGOverlap), "Mma1 must be RS if IntraWGOverlap is disabled"); static_assert(!(!Mma1_is_RS && Is_FP8), "Mma1 must be RS if FP8"); static_assert(!(!Mma1_is_RS && Transpose_V), "Mma1 must be RS if Transpose_V"); From 02541ac9e8382f4d8e17f1f2ba0d7de2c792390c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 29 Jan 2025 21:48:03 -0500 Subject: [PATCH 08/32] [CE] Assert logit_scale > 0 --- flash_attn/ops/triton/cross_entropy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flash_attn/ops/triton/cross_entropy.py b/flash_attn/ops/triton/cross_entropy.py index 7b0315b9793..1b5a415b73f 100644 --- a/flash_attn/ops/triton/cross_entropy.py +++ b/flash_attn/ops/triton/cross_entropy.py @@ -166,6 +166,7 @@ def forward( if labels.dtype == torch.long and labels.data_ptr() % 16 != 0: labels = F.pad(labels, (0, 1))[..., :-1] assert labels.data_ptr() % 16 == 0 + assert logit_scale > 0.0 n_rows, n_cols = logits.shape assert labels.shape == (n_rows,) world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) From 2a204125ae71d2010bd3c9634d72a81c63967f3b Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 3 Feb 2025 00:19:25 -0500 Subject: [PATCH 09/32] Implement HeadDim_V != HeadDim_QK, support hdimQK=192, hdimV=128 --- hopper/benchmark_attn.py | 34 +- hopper/epilogue_fwd.hpp | 44 +- hopper/flash.h | 7 +- hopper/flash_api.cpp | 119 ++- hopper/flash_fwd_combine.cu | 3 + hopper/flash_fwd_combine_launch_template.h | 2 +- hopper/flash_fwd_kernel_sm90.h | 27 +- hopper/flash_fwd_launch_template.h | 20 +- hopper/generate_kernels.py | 22 +- .../flash_fwd_hdim128_bf16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim128_bf16_paged_sm80.cu | 4 +- .../flash_fwd_hdim128_bf16_paged_sm90.cu | 2 +- ...ash_fwd_hdim128_bf16_paged_softcap_sm80.cu | 4 +- ...ash_fwd_hdim128_bf16_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim128_bf16_paged_split_sm80.cu | 4 +- ...flash_fwd_hdim128_bf16_paged_split_sm90.cu | 2 +- ...d_hdim128_bf16_paged_split_softcap_sm80.cu | 4 +- ...d_hdim128_bf16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim128_bf16_sm80.cu | 4 +- .../flash_fwd_hdim128_bf16_sm90.cu | 2 +- ...h_fwd_hdim128_bf16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim128_bf16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim128_bf16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim128_bf16_split_sm80.cu | 4 +- .../flash_fwd_hdim128_bf16_split_sm90.cu | 2 +- ...ash_fwd_hdim128_bf16_split_softcap_sm80.cu | 4 +- ...ash_fwd_hdim128_bf16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim128_e4m3_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim128_e4m3_paged_sm90.cu | 2 +- ...ash_fwd_hdim128_e4m3_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim128_e4m3_paged_split_sm90.cu | 2 +- ...d_hdim128_e4m3_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim128_e4m3_sm90.cu | 2 +- ...h_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim128_e4m3_softcap_sm90.cu | 2 +- .../flash_fwd_hdim128_e4m3_split_sm90.cu | 2 +- ...ash_fwd_hdim128_e4m3_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim128_fp16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim128_fp16_paged_sm80.cu | 4 +- .../flash_fwd_hdim128_fp16_paged_sm90.cu | 2 +- ...ash_fwd_hdim128_fp16_paged_softcap_sm80.cu | 4 +- ...ash_fwd_hdim128_fp16_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim128_fp16_paged_split_sm80.cu | 4 +- ...flash_fwd_hdim128_fp16_paged_split_sm90.cu | 2 +- ...d_hdim128_fp16_paged_split_softcap_sm80.cu | 4 +- ...d_hdim128_fp16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim128_fp16_sm80.cu | 4 +- .../flash_fwd_hdim128_fp16_sm90.cu | 2 +- ...h_fwd_hdim128_fp16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim128_fp16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim128_fp16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim128_fp16_split_sm80.cu | 4 +- .../flash_fwd_hdim128_fp16_split_sm90.cu | 2 +- ...ash_fwd_hdim128_fp16_split_softcap_sm80.cu | 4 +- ...ash_fwd_hdim128_fp16_split_softcap_sm90.cu | 2 +- ...flash_fwd_hdim192_128_bf16_packgqa_sm90.cu | 9 + .../flash_fwd_hdim192_128_bf16_paged_sm90.cu | 9 + ...fwd_hdim192_128_bf16_paged_softcap_sm90.cu | 9 + ...h_fwd_hdim192_128_bf16_paged_split_sm90.cu | 9 + ...im192_128_bf16_paged_split_softcap_sm90.cu | 9 + .../flash_fwd_hdim192_128_bf16_sm90.cu | 9 + ...d_hdim192_128_bf16_softcap_packgqa_sm90.cu | 9 + ...flash_fwd_hdim192_128_bf16_softcap_sm90.cu | 9 + .../flash_fwd_hdim192_128_bf16_split_sm90.cu | 9 + ...fwd_hdim192_128_bf16_split_softcap_sm90.cu | 9 + ...flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu | 9 + .../flash_fwd_hdim192_128_e4m3_paged_sm90.cu | 9 + ...fwd_hdim192_128_e4m3_paged_softcap_sm90.cu | 9 + ...h_fwd_hdim192_128_e4m3_paged_split_sm90.cu | 9 + ...im192_128_e4m3_paged_split_softcap_sm90.cu | 9 + .../flash_fwd_hdim192_128_e4m3_sm90.cu | 9 + ...d_hdim192_128_e4m3_softcap_packgqa_sm90.cu | 9 + ...flash_fwd_hdim192_128_e4m3_softcap_sm90.cu | 9 + .../flash_fwd_hdim192_128_e4m3_split_sm90.cu | 9 + ...fwd_hdim192_128_e4m3_split_softcap_sm90.cu | 9 + ...flash_fwd_hdim192_128_fp16_packgqa_sm90.cu | 9 + .../flash_fwd_hdim192_128_fp16_paged_sm90.cu | 9 + ...fwd_hdim192_128_fp16_paged_softcap_sm90.cu | 9 + ...h_fwd_hdim192_128_fp16_paged_split_sm90.cu | 9 + ...im192_128_fp16_paged_split_softcap_sm90.cu | 9 + .../flash_fwd_hdim192_128_fp16_sm90.cu | 9 + ...d_hdim192_128_fp16_softcap_packgqa_sm90.cu | 9 + ...flash_fwd_hdim192_128_fp16_softcap_sm90.cu | 9 + .../flash_fwd_hdim192_128_fp16_split_sm90.cu | 9 + ...fwd_hdim192_128_fp16_split_softcap_sm90.cu | 9 + .../flash_fwd_hdim192_bf16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim192_bf16_paged_sm80.cu | 4 +- .../flash_fwd_hdim192_bf16_paged_sm90.cu | 2 +- ...ash_fwd_hdim192_bf16_paged_softcap_sm80.cu | 4 +- ...ash_fwd_hdim192_bf16_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim192_bf16_paged_split_sm80.cu | 4 +- ...flash_fwd_hdim192_bf16_paged_split_sm90.cu | 2 +- ...d_hdim192_bf16_paged_split_softcap_sm80.cu | 4 +- ...d_hdim192_bf16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim192_bf16_sm80.cu | 4 +- .../flash_fwd_hdim192_bf16_sm90.cu | 2 +- ...h_fwd_hdim192_bf16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim192_bf16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim192_bf16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim192_bf16_split_sm80.cu | 4 +- .../flash_fwd_hdim192_bf16_split_sm90.cu | 2 +- ...ash_fwd_hdim192_bf16_split_softcap_sm80.cu | 4 +- ...ash_fwd_hdim192_bf16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim192_e4m3_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim192_e4m3_paged_sm90.cu | 2 +- ...ash_fwd_hdim192_e4m3_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim192_e4m3_paged_split_sm90.cu | 2 +- ...d_hdim192_e4m3_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim192_e4m3_sm90.cu | 2 +- ...h_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim192_e4m3_softcap_sm90.cu | 2 +- .../flash_fwd_hdim192_e4m3_split_sm90.cu | 2 +- ...ash_fwd_hdim192_e4m3_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim192_fp16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim192_fp16_paged_sm80.cu | 4 +- .../flash_fwd_hdim192_fp16_paged_sm90.cu | 2 +- ...ash_fwd_hdim192_fp16_paged_softcap_sm80.cu | 4 +- ...ash_fwd_hdim192_fp16_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim192_fp16_paged_split_sm80.cu | 4 +- ...flash_fwd_hdim192_fp16_paged_split_sm90.cu | 2 +- ...d_hdim192_fp16_paged_split_softcap_sm80.cu | 4 +- ...d_hdim192_fp16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim192_fp16_sm80.cu | 4 +- .../flash_fwd_hdim192_fp16_sm90.cu | 2 +- ...h_fwd_hdim192_fp16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim192_fp16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim192_fp16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim192_fp16_split_sm80.cu | 4 +- .../flash_fwd_hdim192_fp16_split_sm90.cu | 2 +- ...ash_fwd_hdim192_fp16_split_softcap_sm80.cu | 4 +- ...ash_fwd_hdim192_fp16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_bf16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim256_bf16_paged_sm80.cu | 4 +- .../flash_fwd_hdim256_bf16_paged_sm90.cu | 2 +- ...ash_fwd_hdim256_bf16_paged_softcap_sm80.cu | 4 +- ...ash_fwd_hdim256_bf16_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim256_bf16_paged_split_sm80.cu | 4 +- ...flash_fwd_hdim256_bf16_paged_split_sm90.cu | 2 +- ...d_hdim256_bf16_paged_split_softcap_sm80.cu | 4 +- ...d_hdim256_bf16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_bf16_sm80.cu | 4 +- .../flash_fwd_hdim256_bf16_sm90.cu | 2 +- ...h_fwd_hdim256_bf16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim256_bf16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim256_bf16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_bf16_split_sm80.cu | 4 +- .../flash_fwd_hdim256_bf16_split_sm90.cu | 2 +- ...ash_fwd_hdim256_bf16_split_softcap_sm80.cu | 4 +- ...ash_fwd_hdim256_bf16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_e4m3_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim256_e4m3_paged_sm90.cu | 2 +- ...ash_fwd_hdim256_e4m3_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim256_e4m3_paged_split_sm90.cu | 2 +- ...d_hdim256_e4m3_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_e4m3_sm90.cu | 2 +- ...h_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim256_e4m3_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_e4m3_split_sm90.cu | 2 +- ...ash_fwd_hdim256_e4m3_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_fp16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim256_fp16_paged_sm80.cu | 4 +- .../flash_fwd_hdim256_fp16_paged_sm90.cu | 2 +- ...ash_fwd_hdim256_fp16_paged_softcap_sm80.cu | 4 +- ...ash_fwd_hdim256_fp16_paged_softcap_sm90.cu | 2 +- ...flash_fwd_hdim256_fp16_paged_split_sm80.cu | 4 +- ...flash_fwd_hdim256_fp16_paged_split_sm90.cu | 2 +- ...d_hdim256_fp16_paged_split_softcap_sm80.cu | 4 +- ...d_hdim256_fp16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_fp16_sm80.cu | 4 +- .../flash_fwd_hdim256_fp16_sm90.cu | 2 +- ...h_fwd_hdim256_fp16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim256_fp16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim256_fp16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim256_fp16_split_sm80.cu | 4 +- .../flash_fwd_hdim256_fp16_split_sm90.cu | 2 +- ...ash_fwd_hdim256_fp16_split_softcap_sm80.cu | 4 +- ...ash_fwd_hdim256_fp16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_bf16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim64_bf16_paged_sm80.cu | 4 +- .../flash_fwd_hdim64_bf16_paged_sm90.cu | 2 +- ...lash_fwd_hdim64_bf16_paged_softcap_sm80.cu | 4 +- ...lash_fwd_hdim64_bf16_paged_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_bf16_paged_split_sm80.cu | 4 +- .../flash_fwd_hdim64_bf16_paged_split_sm90.cu | 2 +- ...wd_hdim64_bf16_paged_split_softcap_sm80.cu | 4 +- ...wd_hdim64_bf16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_bf16_sm80.cu | 4 +- .../flash_fwd_hdim64_bf16_sm90.cu | 2 +- ...sh_fwd_hdim64_bf16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim64_bf16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim64_bf16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_bf16_split_sm80.cu | 4 +- .../flash_fwd_hdim64_bf16_split_sm90.cu | 2 +- ...lash_fwd_hdim64_bf16_split_softcap_sm80.cu | 4 +- ...lash_fwd_hdim64_bf16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_e4m3_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim64_e4m3_paged_sm90.cu | 2 +- ...lash_fwd_hdim64_e4m3_paged_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_e4m3_paged_split_sm90.cu | 2 +- ...wd_hdim64_e4m3_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_e4m3_sm90.cu | 2 +- ...sh_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim64_e4m3_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_e4m3_split_sm90.cu | 2 +- ...lash_fwd_hdim64_e4m3_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_fp16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim64_fp16_paged_sm80.cu | 4 +- .../flash_fwd_hdim64_fp16_paged_sm90.cu | 2 +- ...lash_fwd_hdim64_fp16_paged_softcap_sm80.cu | 4 +- ...lash_fwd_hdim64_fp16_paged_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_fp16_paged_split_sm80.cu | 4 +- .../flash_fwd_hdim64_fp16_paged_split_sm90.cu | 2 +- ...wd_hdim64_fp16_paged_split_softcap_sm80.cu | 4 +- ...wd_hdim64_fp16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_fp16_sm80.cu | 4 +- .../flash_fwd_hdim64_fp16_sm90.cu | 2 +- ...sh_fwd_hdim64_fp16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim64_fp16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim64_fp16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim64_fp16_split_sm80.cu | 4 +- .../flash_fwd_hdim64_fp16_split_sm90.cu | 2 +- ...lash_fwd_hdim64_fp16_split_softcap_sm80.cu | 4 +- ...lash_fwd_hdim64_fp16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_bf16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim96_bf16_paged_sm80.cu | 4 +- .../flash_fwd_hdim96_bf16_paged_sm90.cu | 2 +- ...lash_fwd_hdim96_bf16_paged_softcap_sm80.cu | 4 +- ...lash_fwd_hdim96_bf16_paged_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_bf16_paged_split_sm80.cu | 4 +- .../flash_fwd_hdim96_bf16_paged_split_sm90.cu | 2 +- ...wd_hdim96_bf16_paged_split_softcap_sm80.cu | 4 +- ...wd_hdim96_bf16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_bf16_sm80.cu | 4 +- .../flash_fwd_hdim96_bf16_sm90.cu | 2 +- ...sh_fwd_hdim96_bf16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim96_bf16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim96_bf16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_bf16_split_sm80.cu | 4 +- .../flash_fwd_hdim96_bf16_split_sm90.cu | 2 +- ...lash_fwd_hdim96_bf16_split_softcap_sm80.cu | 4 +- ...lash_fwd_hdim96_bf16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_e4m3_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim96_e4m3_paged_sm90.cu | 2 +- ...lash_fwd_hdim96_e4m3_paged_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_e4m3_paged_split_sm90.cu | 2 +- ...wd_hdim96_e4m3_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_e4m3_sm90.cu | 2 +- ...sh_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim96_e4m3_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_e4m3_split_sm90.cu | 2 +- ...lash_fwd_hdim96_e4m3_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_fp16_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim96_fp16_paged_sm80.cu | 4 +- .../flash_fwd_hdim96_fp16_paged_sm90.cu | 2 +- ...lash_fwd_hdim96_fp16_paged_softcap_sm80.cu | 4 +- ...lash_fwd_hdim96_fp16_paged_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_fp16_paged_split_sm80.cu | 4 +- .../flash_fwd_hdim96_fp16_paged_split_sm90.cu | 2 +- ...wd_hdim96_fp16_paged_split_softcap_sm80.cu | 4 +- ...wd_hdim96_fp16_paged_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_fp16_sm80.cu | 4 +- .../flash_fwd_hdim96_fp16_sm90.cu | 2 +- ...sh_fwd_hdim96_fp16_softcap_packgqa_sm90.cu | 2 +- .../flash_fwd_hdim96_fp16_softcap_sm80.cu | 4 +- .../flash_fwd_hdim96_fp16_softcap_sm90.cu | 2 +- .../flash_fwd_hdim96_fp16_split_sm80.cu | 4 +- .../flash_fwd_hdim96_fp16_split_sm90.cu | 2 +- ...lash_fwd_hdim96_fp16_split_softcap_sm80.cu | 4 +- ...lash_fwd_hdim96_fp16_split_softcap_sm90.cu | 2 +- .../flash_fwd_hdimall_bf16_packgqa_sm90.cu | 1 + .../flash_fwd_hdimall_bf16_paged_sm90.cu | 1 + ...ash_fwd_hdimall_bf16_paged_softcap_sm90.cu | 1 + ...flash_fwd_hdimall_bf16_paged_split_sm90.cu | 1 + ...d_hdimall_bf16_paged_split_softcap_sm90.cu | 1 + .../flash_fwd_hdimall_bf16_sm90.cu | 1 + ...h_fwd_hdimall_bf16_softcap_packgqa_sm90.cu | 1 + .../flash_fwd_hdimall_bf16_softcap_sm90.cu | 1 + .../flash_fwd_hdimall_bf16_split_sm90.cu | 1 + ...ash_fwd_hdimall_bf16_split_softcap_sm90.cu | 1 + .../flash_fwd_hdimall_e4m3_packgqa_sm90.cu | 1 + .../flash_fwd_hdimall_e4m3_paged_sm90.cu | 1 + ...ash_fwd_hdimall_e4m3_paged_softcap_sm90.cu | 1 + ...flash_fwd_hdimall_e4m3_paged_split_sm90.cu | 1 + ...d_hdimall_e4m3_paged_split_softcap_sm90.cu | 1 + .../flash_fwd_hdimall_e4m3_sm90.cu | 1 + ...h_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu | 1 + .../flash_fwd_hdimall_e4m3_softcap_sm90.cu | 1 + .../flash_fwd_hdimall_e4m3_split_sm90.cu | 1 + ...ash_fwd_hdimall_e4m3_split_softcap_sm90.cu | 1 + .../flash_fwd_hdimall_fp16_packgqa_sm90.cu | 1 + .../flash_fwd_hdimall_fp16_paged_sm90.cu | 1 + ...ash_fwd_hdimall_fp16_paged_softcap_sm90.cu | 1 + ...flash_fwd_hdimall_fp16_paged_split_sm90.cu | 1 + ...d_hdimall_fp16_paged_split_softcap_sm90.cu | 1 + .../flash_fwd_hdimall_fp16_sm90.cu | 1 + ...h_fwd_hdimall_fp16_softcap_packgqa_sm90.cu | 1 + .../flash_fwd_hdimall_fp16_softcap_sm90.cu | 1 + .../flash_fwd_hdimall_fp16_split_sm90.cu | 1 + ...ash_fwd_hdimall_fp16_split_softcap_sm90.cu | 1 + hopper/mainloop_fwd_sm80.hpp | 15 +- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 89 +- hopper/paged_kv.h | 62 +- hopper/setup.py | 4 +- hopper/test_flash_attn.py | 850 +++++++++--------- hopper/test_util.py | 9 +- hopper/tile_size.h | 6 +- 306 files changed, 1312 insertions(+), 921 deletions(-) create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index 5f7522a8ac3..e61cea9e67e 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -56,7 +56,7 @@ def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs): return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3) -def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, window_size=(-1, -1)): +def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)): if causal: avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 else: @@ -67,7 +67,7 @@ def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, window_size= col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) avg_seqlen = (col_right - col_left + 1).float().mean().item() - return batch * nheads * 2 * seqlen_q * avg_seqlen * headdim * 2 + return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) def convert_to_cudnn_type(torch_type): @@ -263,7 +263,7 @@ def run(*args, **kwargs): # for headdim in [64, 96, 128]: # for headdim in [64, 128, 256]: # for headdim in [64, 96, 128, 192, 256]: -for headdim in [128]: +for headdim in [192]: nheads = dim // headdim # headdim = 64 # batch_size = 64 @@ -272,6 +272,8 @@ def run(*args, **kwargs): # headdim = 128 nheads_kv = nheads # nheads_kv = nheads // 4 + headdim_v = headdim + # headdim_v = 128 for batch_size, seqlen in bs_seqlen_vals: num_splits = 1 @@ -285,15 +287,15 @@ def run(*args, **kwargs): # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True) - v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True) + v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]] v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_() v_fa3 = v if not V_colmajor else v_colmajor # q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) # k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) - # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) - g = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) - o = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) + # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim_v), device=device, dtype=torch.int32).to(dtype) + g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) + o = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32) a = torch.randn(batch_size, seqlen, seqlen, device=device, dtype=dtype_gen) b = torch.randn(batch_size, dim * 2, seqlen, device=device, dtype=dtype_gen).transpose(-1, -2) @@ -320,14 +322,14 @@ def run(*args, **kwargs): for causal in [False, True]: # for causal in [False]: print(f"\n### {headdim = }, {causal = }, {seqlen = } ###") - nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, causal=causal, window_size=window_size) + nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, headdim_v, causal=causal, window_size=window_size) if cudnn is not None: # if False: - if headdim <= 256 and dtype != torch.float8_e4m3fn: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0]) cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0]) # _, m0 = benchmark_forward(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2') - if dtype != torch.float8_e4m3fn: + if dtype != torch.float8_e4m3fn and headdim == headdim_v: # if False: if not varlen: m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2') @@ -343,7 +345,7 @@ def run(*args, **kwargs): repeats=repeats, verbose=False, desc='Fav2') time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = m0b.mean # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True) - if headdim <= 256 and dtype != torch.float8_e4m3fn: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: if triton_attention is not None: qt, kt, vt = [x.detach().transpose(1, 2).contiguous().requires_grad_() for x in [q, k, v]] time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark @@ -356,7 +358,7 @@ def run(*args, **kwargs): # # pytorch_profiler(triton_attention, q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), v.transpose(1, 2).contiguous(), causal, 1 / math.sqrt(headdim), backward=True) if cudnn is not None: # if False: - if headdim <= 256 and dtype != torch.float8_e4m3fn: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN') time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean @@ -380,7 +382,7 @@ def run(*args, **kwargs): # nFLOPS_matmul = nFLOPS # nFLOPS_matmul = 2 * x.shape[0] * x.shape[1] * w.shape[1] # m5 = time_fwd(torch.matmul, x, w, desc='cuBLAS') - if dtype != torch.float8_e4m3fn: + if dtype != torch.float8_e4m3fn and headdim == headdim_v: time.sleep(1) if not varlen: _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic, @@ -396,11 +398,11 @@ def run(*args, **kwargs): # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True) # benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy') - if dtype != torch.float8_e4m3fn: + if dtype != torch.float8_e4m3fn and headdim == headdim_v: # if False: print(f'Fav2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS') print(f'Fav2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS') - if headdim <= 256 and dtype != torch.float8_e4m3fn: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: if triton_attention is not None: print(f'Triton fwd: {m3.mean * 1e3:.3f}ms, {(nFLOPS / m3.mean * 1e-12):.1f} TFLOPS') # if causal: @@ -409,7 +411,7 @@ def run(*args, **kwargs): print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS') print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS') print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS') - if dtype != torch.float8_e4m3fn: + if dtype != torch.float8_e4m3fn and headdim == headdim_v: print(f'Fav3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS') # benchmark_forward(torch.square, k) # print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS') diff --git a/hopper/epilogue_fwd.hpp b/hopper/epilogue_fwd.hpp index 0f916060260..d8f2c15c977 100644 --- a/hopper/epilogue_fwd.hpp +++ b/hopper/epilogue_fwd.hpp @@ -20,11 +20,11 @@ namespace flash { using namespace cute; -template struct CollectiveEpilogueFwd { - using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_PV = TileShape_MNK_PV_; using ClusterShape = ClusterShape_; using Element = Element_; using ArchTag = ArchTag_; @@ -37,21 +37,21 @@ struct CollectiveEpilogueFwd { static_assert(ArchTag::kMinComputeCapability >= 80); static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1); - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); + static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{}); using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; // These are for storing the output tensor without TMA (e.g., for setting output to zero) static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore"); + static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times // we need to call divmod. - static constexpr int kBytePerRow = kHeadDim * sizeof(Element); + static constexpr int kBytePerRow = kHeadDimV * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); - // static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); - // static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerStore, NumEpilogueThreads); + // static constexpr int kBlockKGmem = kHeadDimV % 128 == 0 ? 128 : (kHeadDimV % 64 == 0 ? 64 : 32); + // static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDimV / kGmemElemsPerStore, NumEpilogueThreads); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore; // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0); @@ -65,15 +65,15 @@ struct CollectiveEpilogueFwd { Layout>>{})); // Val layout, 8 or 16 vals per store using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 2>(TileShape_MNK{}))); + decltype(cute::get<0>(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>()); + using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{}))); static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); - using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); + using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{}))); using SmemLayoutO = std::conditional_t= 90, SmemLayoutOTMA, SmemLayoutOSTS>; using ShapeO = cute::Shape; // (seqlen_q, d, head, batch, num_splits) @@ -109,7 +109,7 @@ struct CollectiveEpilogueFwd { GmemTiledCopyOTMA{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeO{}, StrideO{}), SmemLayoutOTMA{}, - select<0, 2>(TileShape_MNK{}), + select<0, 1>(TileShape_MNK_PV{}), _1{})), // no mcast for O std::nullptr_t >; @@ -148,7 +148,7 @@ struct CollectiveEpilogueFwd { Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O); TMA_O tma_store_O = [&]{ if constexpr (Use_TMA_O) { - return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 2>(TileShape_MNK{}), _1{}); // no mcast + return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast } else { return nullptr; } @@ -243,14 +243,14 @@ struct CollectiveEpilogueFwd { // Step 2: Write LSE from rmem -> gmem auto thread_mma = tiled_mma.get_thread_slice(thread_idx); // (MMA,MMA_M,MMA_K) - Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); + Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); static_assert(decltype(size<0, 0>(taccOcO))::value == 2); static_assert(decltype(size<0, 1>(taccOcO))::value == 2); Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M - using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>; + using PackGQAt = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx); // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); } @@ -267,7 +267,7 @@ struct CollectiveEpilogueFwd { // Step 3: Write O from smem -> gmem if constexpr (Use_TMA_O) { Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx); - Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) auto block_tma_O = params.tma_store_O.get_slice(_0{}); Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K) @@ -287,7 +287,7 @@ struct CollectiveEpilogueFwd { } } else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); - Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast(&mO(0)) - reinterpret_cast(params.ptr_O)); } if constexpr (Use_smem) { GmemTiledCopyO gmem_tiled_copy_O; @@ -305,7 +305,7 @@ struct CollectiveEpilogueFwd { } if constexpr (!PackGQA) { // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); Tensor tOpO = make_tensor(make_shape(size<2>(tOsO))); #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } @@ -361,7 +361,7 @@ struct CollectiveEpilogueFwd { int thread_idx, cute::tuple const& block_coord ) { - static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); auto [m_block, bidh, bidb, split_idx] = block_coord; flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; @@ -391,12 +391,12 @@ struct CollectiveEpilogueFwd { GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); if constexpr (!PackGQA) { Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } - Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) Tensor tOgO = gmem_thr_copy_O.partition_D(gO); Tensor tOrO = make_fragment_like(tOgO); cute::clear(tOrO); @@ -406,7 +406,7 @@ struct CollectiveEpilogueFwd { ); } else { // If PackGQA, we split the work of compute O_ptr among threads in the same row - using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>; + using PackGQAt = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; Tensor tOrO = make_tensor(make_shape(Shape<_1, Int>{}, size<1>(tOcO), size<2>(tOcO))); cute::clear(tOrO); PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); diff --git a/hopper/flash.h b/hopper/flash.h index 4559a1352e4..9f8cb1bcae1 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -65,6 +65,7 @@ struct Flash_fwd_params : public Qkv_params { int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; int total_q, total_k, total_knew; int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q + int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim // The scaling factors for the kernel. float scale_softmax; @@ -197,9 +198,9 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -template +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); -template +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 82643d9fff4..94fcf5d78f5 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -271,36 +271,48 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { if (!params.is_e4m3) { if (params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_(params, stream); } + if (params.d <= 64) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_(params, stream); } + if (params.d <= 96) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_(params, stream); } + if (params.d <= 128) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_fwd_(params, stream); } + if (params.d <= 192) { + if (params.dv <= 128 && Arch == 90) { + return run_mha_fwd_(params, stream); + } else { + return run_mha_fwd_(params, stream); + } + } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_(params, stream); } + if (params.d <= 256) { return run_mha_fwd_(params, stream); } #endif } else { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_(params, stream); } + if (params.d <= 64) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_(params, stream); } + if (params.d <= 96) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_(params, stream); } + if (params.d <= 128) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_fwd_(params, stream); } + if (params.d <= 192) { + if (params.dv <= 128 && Arch == 90) { + return run_mha_fwd_(params, stream); + } else { + return run_mha_fwd_(params, stream); + } + } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_(params, stream); } + if (params.d <= 256) { return run_mha_fwd_(params, stream); } #endif #else TORCH_CHECK(false, "This flash attention build does not support FP16."); @@ -309,19 +321,25 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } else { #ifndef FLASHATTENTION_DISABLE_FP8 #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 192) { + if (params.dv <= 128 && Arch == 90) { + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); + } else { + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream); + } + } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } #endif #else TORCH_CHECK(false, "This flash attention build does not support FP8."); @@ -339,28 +357,34 @@ void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively // so that kBlockM is smaller and we have more parallelism. if (params.is_fp32) { - if (params.d <= 64) { + if (params.dv <= 64) { run_mha_fwd_combine_(params, stream); - } else if (params.d <= 128) { + } else if (params.dv <= 128) { run_mha_fwd_combine_(params, stream); - } else { + } else if (params.dv <= 256) { run_mha_fwd_combine_(params, stream); + } else { + run_mha_fwd_combine_(params, stream); } } else if (params.is_bf16) { - if (params.d <= 64) { + if (params.dv <= 64) { run_mha_fwd_combine_(params, stream); - } else if (params.d <= 128) { + } else if (params.dv <= 128) { run_mha_fwd_combine_(params, stream); - } else { + } else if (params.dv <= 256) { run_mha_fwd_combine_(params, stream); + } else { + run_mha_fwd_combine_(params, stream); } } else { - if (params.d <= 64) { + if (params.dv <= 64) { run_mha_fwd_combine_(params, stream); - } else if (params.d <= 128) { + } else if (params.dv <= 128) { run_mha_fwd_combine_(params, stream); - } else { + } else if (params.dv <= 256) { run_mha_fwd_combine_(params, stream); + } else { + run_mha_fwd_combine_(params, stream); } } #else @@ -378,7 +402,7 @@ inline bool get_pack_gqa(Flash_fwd_params const& params) { // params.page_table must already be set if (params.h == params.h_k) { return false; } // This needs to match the kernel configs - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); #endif @@ -392,10 +416,10 @@ inline int get_num_splits(Flash_fwd_params const& params) { // params.page_table must already be set // This needs to match the kernel configs bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits // has not been set here. It's OK though because we might just underestimate kBlockN a bit - auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k); @@ -460,10 +484,10 @@ inline int round_up_headdim(int head_size) { std::vector mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. - const at::Tensor &v, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. + const at::Tensor &v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. std::optional &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional &v_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional &out_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + std::optional &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q std::optional &cu_seqlens_q_, // b+1 std::optional &cu_seqlens_k_, // b+1 std::optional &cu_seqlens_k_new_, // b+1 @@ -551,6 +575,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; int num_heads = q.size(-2); int const head_size = q.size(-1); + int const head_size_v = v.size(-1); int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1); int const num_pages = !paged_KV ? 0 : k.size(0); int const page_size = !paged_KV ? 1 : k.size(1); @@ -564,6 +589,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int const max_headdim = get_max_headdim(); TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (head_size_v != head_size) { + TORCH_CHECK(head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128, "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128]"); + TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); + } // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM // TODO: check this @@ -583,15 +612,15 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq if (!paged_KV) { if (!is_varlen_k) { CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v); } else { CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); } } else { CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); - CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); } @@ -610,6 +639,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8; TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); + TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); auto opts = q.options(); auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type; @@ -620,16 +650,19 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); if (!is_varlen_q) { - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); } else { - CHECK_SHAPE(out, total_q, num_heads, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size_v); } } else { - out = torch::empty_like(q, opts.dtype(out_type)); + out = !is_varlen_q + ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)) + : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type)); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; int const head_size_rounded = round_up_headdim(head_size); + int const head_size_v_rounded = round_up_headdim(head_size_v); int const seqlen_q_rounded = round_multiple(seqlen_q, 128); int const seqlen_k_rounded = round_multiple(seqlen_k, 128); @@ -667,6 +700,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.total_k = total_k; params.sink_token_length = sink_token_length; params.b_k = batch_size_k; + params.dv = head_size_v; + params.dv_rounded = head_size_v_rounded; if (paged_KV) { params.page_table = page_table.data_ptr(); @@ -702,10 +737,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0); if (!is_varlen_k_new) { CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); - CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); } else { CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); - CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); } params.seqlen_knew = seqlen_k_new; @@ -772,12 +807,12 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq if (params.num_splits > 1) { TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported"); if (!is_varlen_q) { - out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size}, opts.dtype(outaccum_type)); + out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type)); softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); params.oaccum_batch_stride = out_accum.stride(1); params.lseaccum_batch_stride = softmax_lse_accum.stride(1); } else { - out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size}, opts.dtype(outaccum_type)); + out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type)); softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat)); } params.is_fp32 = false; @@ -1258,7 +1293,7 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x const int seqlen = sizes[2]; const int num_heads = sizes[3]; const int head_size_og = sizes[4]; - TORCH_CHECK(head_size_og <= 256, "FlashAttention combine only supports head dimension at most 256"); + TORCH_CHECK(head_size_og <= 512, "FlashAttention combine only supports head dimension at most 512"); TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256"); CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og); @@ -1307,7 +1342,7 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x params.b = batch_size; params.h = num_heads; params.seqlen_q = seqlen; - params.d = head_size; + params.dv = head_size; params.num_splits = num_splits; params.oaccum_split_stride = out_partial_padded.stride(0); params.oaccum_row_stride = out_partial_padded.stride(2); diff --git a/hopper/flash_fwd_combine.cu b/hopper/flash_fwd_combine.cu index 5b7d9eed655..57392ee75f4 100644 --- a/hopper/flash_fwd_combine.cu +++ b/hopper/flash_fwd_combine.cu @@ -6,11 +6,14 @@ template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index 33e66c21f82..5cbed2b0c74 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -24,7 +24,7 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { typename CombineKernel::Arguments args { static_cast(params.oaccum_ptr), - {!Varlen ? params.seqlen_q : params.total_q, params.d, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial + {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial static_cast(params.softmax_lseaccum_ptr), {!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_LSE_partial diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index e5411042dc9..05ce4d0ae60 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -45,10 +45,11 @@ class FlashAttnFwdSm90 { static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O; static constexpr bool PackGQA = CollectiveMainloop::PackGQA; static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads; + static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim; using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t; // Mainloop derived types - using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; + using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV; using TiledMma0 = typename CollectiveMainloop::TiledMma0; using TiledMma1 = typename CollectiveMainloop::TiledMma1; using ArchTag = typename CollectiveMainloop::ArchTag; @@ -176,7 +177,7 @@ class FlashAttnFwdSm90 { static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup; - static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK; using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV; @@ -222,6 +223,11 @@ class FlashAttnFwdSm90 { pipeline_params_k.producer_arv_count = NumProducerThreads; } + PipelineParamsV pipeline_params_v = pipeline_params_k; + if constexpr (Use_TMA_KV && !SameHeadDim) { + pipeline_params_v.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + } + MainloopPipelineK pipeline_k = [&] { if constexpr (Use_TMA_KV) { return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{}); @@ -234,9 +240,9 @@ class FlashAttnFwdSm90 { if constexpr (!Transpose_V) { static_assert(is_same_v); if constexpr (Use_TMA_KV) { - return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_k, ClusterShape{}); + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{}); } else { - return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_k); + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v); } } else { PipelineParamsV pipeline_params_v; @@ -256,11 +262,11 @@ class FlashAttnFwdSm90 { // However, the thread role isn't used in the pipeline implementation. MainloopPipelineVt pipeline_vt = [&] { if constexpr (Use_TMA_KV) { - pipeline_params_k.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG - return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_k, ClusterShape{}); + pipeline_params_v.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_v, ClusterShape{}); } else { - pipeline_params_k.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG - return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_k); + pipeline_params_v.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_v); } }(); @@ -272,6 +278,9 @@ class FlashAttnFwdSm90 { pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0; pipeline_params_kv_new.num_consumers = NumMmaThreads; auto pipeline_k_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr); + if constexpr (!SameHeadDim) { + pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + } auto pipeline_v_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr); CollectiveMainloop collective_mainloop; @@ -357,7 +366,7 @@ class FlashAttnFwdSm90 { work_tile_info.is_valid(params.scheduler); work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { // Attention output (GEMM-II) accumulator. - Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{})); + Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 1>(TileShape_MNK_PV{})); float softmax_scale_log2 = params.mainloop.softmax_scale_log2; // If there's tanh softcap, the scaling will be done before tanh. auto block_coord = work_tile_info.get_block_coord(params.scheduler); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 57d64d6a7b8..3f4bea96ee4 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -23,7 +23,7 @@ using namespace cute; -template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { @@ -35,8 +35,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; // Can't use structured binding since it's not compatible with constexpr - static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKV, Has_softcap); - static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKV, Varlen && Split, Has_softcap, AppendKV); + static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKV, Has_softcap); + static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKV, Varlen && Split, Has_softcap, AppendKV); static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS); static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS); static constexpr bool Mma1_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); @@ -46,13 +46,14 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS); using TileShape_MNK = cute::Shape, Int, Int>; + using TileShape_MNK_PV = cute::Shape, Int, Int>; using ClusterShape = cute::Shape, _1, _1>; using CollectiveMainloop = std::conditional_t< Arch >= 90, - flash::CollectiveMainloopFwdSm90, - flash::CollectiveMainloopFwdSm80 + flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm80 >; - using CollectiveEpilogue = flash::CollectiveEpilogueFwd; + using CollectiveEpilogue = flash::CollectiveEpilogueFwd; static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; using SchedulerPersistent = std::conditional_t(params.v_ptr), + params.dv, // headdim_v v_strides, // stride_V static_cast(params.knew_ptr), {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new @@ -179,7 +181,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { CHECK_CUDA_KERNEL_LAUNCH(); } -template +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported"); static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v; @@ -189,7 +191,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { // Only needed here to decide if we should use cluster - static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128; + static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128; // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS) static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; @@ -197,7 +199,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - run_flash_fwd(params, stream); + run_flash_fwd(params, stream); }); }); }); diff --git a/hopper/generate_kernels.py b/hopper/generate_kernels.py index e741c13826f..7a5eb47d08b 100644 --- a/hopper/generate_kernels.py +++ b/hopper/generate_kernels.py @@ -38,7 +38,7 @@ KERNEL_IMPL_TEMPLATE_FWD_SM90 = """#include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM} -template void run_mha_fwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif """ @@ -46,8 +46,8 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM} -template void run_mha_fwd_<80, {DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, {DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif """ @@ -85,6 +85,7 @@ class Kernel: sm: int dtype: str head_dim: int + head_dim_v: int split: bool paged_kv: bool softcap: bool @@ -98,14 +99,15 @@ def template(self) -> str: # Always enable PackGQA for PagedKV or Split to reduce compilation packgqa = self.packgqa or self.paged_kv or self.split return KERNEL_IMPL_TEMPLATE_FWD_SM90.format( - ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, + ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], + HEAD_DIM=self.head_dim, HEAD_DIM_V=self.head_dim_v, SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(), SOFTCAP=str(self.softcap).lower(), PACKGQA=str(packgqa).lower() ) else: # Always enable PackGQA for Sm8x to reduce compilation return KERNEL_IMPL_TEMPLATE_FWD_SM8x.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, HEAD_DIM_V=self.head_dim_v, SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(), SOFTCAP=str(self.softcap).lower(), PACKGQA=str(True).lower() ) @@ -117,13 +119,13 @@ def template(self) -> str: ) else: return KERNEL_IMPL_TEMPLATE_BWD_SM8x.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, SOFTCAP=str(self.softcap).lower() ) @property def filename(self) -> str: - return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}{'_paged' if self.paged_kv else ''}{'_split' if self.split else ''}{'_softcap' if self.softcap else ''}{'_packgqa' if self.packgqa else ''}_sm{self.sm}.cu" + return f"flash_{self.direction}_hdim{self.head_dim}{f'_{self.head_dim_v}' if self.head_dim_v != self.head_dim else ''}_{self.dtype}{'_paged' if self.paged_kv else ''}{'_split' if self.split else ''}{'_softcap' if self.softcap else ''}{'_packgqa' if self.packgqa else ''}_sm{self.sm}.cu" def get_all_kernels() -> List[Kernel]: @@ -133,9 +135,11 @@ def get_all_kernels() -> List[Kernel]: if packgqa and (sm < 90 or (sm >= 90 and (paged_kv or split))): continue if sm >= 90 or dtype in DTYPE_MAP_FWD_SM8x: - yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") + if sm == 90 and head_dim == 192: + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=128, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") for dtype, head_dim, softcap, sm in itertools.product(DTYPE_MAP_BWD.keys(), HEAD_DIMENSIONS, SOFTCAP, SM): - yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction="bwd") + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction="bwd") def batch_hdim(kernels_all) -> List[KERNEL_BATCH]: diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu index 18879eff6ee..affc7a4dd96 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu index 35c0ad78fd1..7e13614bfea 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu index 7a39869a001..670041341bc 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu index fb7ba5caed5..f315fbb4545 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu index 296ec9e91fc..bde3024a4a6 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu index 8cffb6de830..2724463e621 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu index 12d564ce364..a38a1d5cf33 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu index 845b1fa5d06..284eeba1823 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu index 25fbfda38d2..0c40ddba8fe 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu index 1130ca747d1..cc89c4d5d25 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu index 502bc1d1771..3a236b712c4 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu index 537e42ba56b..8449104c5aa 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu index 2255e7949e2..b152b90bab7 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu index 086f55b3588..8cc4fed1739 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu index 54590eebbcc..1db3f1e6d80 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu index af322d1d15f..9b3e294f1b3 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu index 3e83398e7f6..07bd687fc34 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu index 3f917d26abc..5f44833b10d 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu index 87c78f28929..9f95ca29f6b 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu index e56b64c3d9a..ad97737d4f3 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu index 8202bfadde5..d77d37ec041 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu index ee7439b277c..ae05c7ce5f0 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu index 812239ef50e..bc52a9f356f 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu index 74e52315bd4..480d485d069 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu index fe0bff6a1d3..d3da5f4e665 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu index 55df1a66635..1c1c2d8207f 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu index 03a9c61e409..371d933e3e1 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu index 67ba153c605..7491148dcde 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu index 9f7bcec9ed9..d04159a62a0 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu index 7116702f3fd..28ad6c14963 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu index 04f18ac0fac..7afb267e3eb 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu index c7c7c9e69f5..69758584cb6 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu index b4ea8bc3301..3be45956bb4 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu index ec99965c928..698095dad6a 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu index d1dd9645233..16d443a9ad1 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu index 83274ca3fdb..1e8f6af71bd 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu index 80e9eb0e2c0..4ec68886112 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu index fbbc273b7e2..670b5952d9d 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu index f4f4829f331..b9778dc92e1 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu index c768a89fdfb..446e917c795 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu index 89c2db39e61..fd62a2c5435 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu index 5b87286aef4..0a397f4acf2 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu index 7506097821d..4d3c553e296 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu index d3b7b0f87b2..77621846ffe 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu index 4d8625cd6dc..7d217ac2733 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu index f6f129c550b..0b6430abc2f 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu new file mode 100644 index 00000000000..ea1e266f8d4 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu new file mode 100644 index 00000000000..2d7488fefe2 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu new file mode 100644 index 00000000000..8718571e30c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu new file mode 100644 index 00000000000..f7dfc18fc1e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..935f5a0fe60 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu new file mode 100644 index 00000000000..3f4d858ff57 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..54d720efeb3 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu new file mode 100644 index 00000000000..b9b93af4fc5 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu new file mode 100644 index 00000000000..39d9167b9f1 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu new file mode 100644 index 00000000000..0f86458012a --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu new file mode 100644 index 00000000000..bd6f4df8f69 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu new file mode 100644 index 00000000000..1824b86c64c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu new file mode 100644 index 00000000000..87dd01725a5 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu new file mode 100644 index 00000000000..6594d560123 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..d7dc84ebc1c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu new file mode 100644 index 00000000000..b9d6e54cbed --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..a8c47652ec1 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu new file mode 100644 index 00000000000..32d17c7665d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu new file mode 100644 index 00000000000..365017c256d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu new file mode 100644 index 00000000000..82cfdf040b0 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu new file mode 100644 index 00000000000..f3254936a47 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu new file mode 100644 index 00000000000..931a6dbf869 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu new file mode 100644 index 00000000000..5c8877a756d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu new file mode 100644 index 00000000000..1e230ab084b --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..03716c86237 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu new file mode 100644 index 00000000000..54c66c9552e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..e5e0ec47db1 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu new file mode 100644 index 00000000000..e4411b5db32 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu new file mode 100644 index 00000000000..157ed06dddf --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu new file mode 100644 index 00000000000..7ef5adc9e85 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu index 96243edf0ae..bf8386b8297 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu index a51a8945888..cbc6f988424 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu index 515d88a11aa..d5aa15b5c8c 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu index e5a154c18db..b8593612df3 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu index 2bd860c7758..a03514d919b 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu index 6e1d8037819..df547749e93 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu index 942685e148f..1ddb1916209 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu index d6050520e02..cefffcd2169 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu index 7ee500a80ee..3d4333b9e1f 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu index 1f9d8bfd56c..35a2abef8c9 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu index 0313ad1b2c8..99e34ac0bfb 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu index 8d87eb21f2d..ed1cf22d5c4 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu index 081bb31b12e..4527d9a2793 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu index a9b5aa0de49..41fcf800170 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu index d465545ef8f..704cbcb337e 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu index 68c57145532..e0ea082156b 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu index e1d656e5ae3..a9c00408a8b 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu index 57d1c73d85e..1497e7aa843 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu index 5104d439810..c66ea9baca1 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu index cbc61f27e76..a7e472b478b 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu index f08ba1459e1..9f090aeeda8 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu index e413758de8f..2205168a67f 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu index c8205c1605f..2a01898b560 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu index f0db959e0f3..888e241a9f1 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu index 249cae97ff7..2a6bde7a39f 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu index 14b073deb31..3d315187b2d 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu index 8152dbaa6b4..3c3d0938034 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu index d0b0df02798..4ca103566d6 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu index 24f3e128dca..16debf27799 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu index 6eabe0ee269..43c2615718e 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu index 5c780da81f7..d9d483838f1 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu index 5a943660174..70543998d94 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu index 9815dd13551..c30c7e3b8b9 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu index 66fc2cb8a6b..7ae26e69c96 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu index 2ceddd8cae6..155b5a539fd 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu index 4c64bc61c57..3e6173c31c2 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu index 6ad1a1529c1..e1e3191a202 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu index f0ee8c0159f..8272ecb76cb 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu index 4a9583196bb..74606c39373 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu index 2b65a88f06c..89a58502b37 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu index e324a932671..b13373806a1 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu index a8be65709d1..1335fad7f2b 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu index 1ad82d7edfb..18c31bdfc0e 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu index 75f53ee4f6e..18a5603cf1f 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu index 09f76526338..4e99c7db027 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu index e5299154c3d..82f8204aa66 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu index 364579e1b32..cb851a77110 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu index a5f821becd7..ae2871c1655 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu index 364bd2b3aee..ed24fbffef9 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu index 3d2e337e164..ffca9c7f8fe 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu index 310c4a5c38f..57a06bd6e66 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu index 96f5bbf3ada..ccdcf21e492 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu index 7d3131bd5bb..c2bc7787765 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu index 7715a52531d..6bba953fc69 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu index 686bdfa5c7f..25c96174c79 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu index 97fdc0094c0..f172239e5b9 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu index 25a90d3be3a..9dde6adb04b 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu index 4c91ee5bc06..2317adef8c5 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu index ef12a584c65..b9b3b74867e 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu index e4e746f9d3d..c57a5a30abb 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu index 99924af52c7..4f59a6aea92 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu index 705582b9f22..2c2de1574ac 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu index 7e969012051..0dbd062c79f 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu index 058eca375d5..bee54c702de 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu index 679066d5443..c02e6833494 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu index e4ce6f9aa15..02b50b98b8d 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu index 03eff4c6f7b..6599de63bbd 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu index 26df5e592eb..a1cdc775cbb 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu index 57de7421dfd..6d01be60f58 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu index 53974f3e61f..968bbf36f83 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu index 24e1f635638..d564a622111 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu index a2fc325dad1..cb5bccc176c 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu index 2c1f5f56f1f..146a7bc3430 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu index 7cbdff3e8a5..a195e0931c0 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu index b81bf0b99b0..045fc71bedb 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu index 88a00e91212..a31da2eddf4 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu index c28edfd8f95..7382b58a231 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu index dbcd163308e..87ca31ce902 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu index 63620ec90a4..60f4d6ebbf6 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu index d8c11ee6a0e..e0d5d318bcd 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu index 4af31d0bf9c..dec7db046bd 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu index c7a04dc47b8..7b71f435226 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu index 9bca3a1c5ee..08fc989af8b 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu index acd0fa660fa..2cc8b5b86d4 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu index a38430fb3c7..644e268469f 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu index 03bb0516f77..1ebcec8b3fe 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu index 8ea90bd417a..780ade7f6b8 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu index f9144326426..bfcffe2a39a 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu index e7e1cecd1f8..ba4ba78ad49 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu index 18b79da92c9..f04260ba4f1 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu index 1c1c9470d6d..33c78e53059 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu index 6cadc2641d5..8388420921e 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu index 4b650f53cfc..4134d7d80bb 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu index 29cb3fe18be..11e3503b0d9 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu index 2612bc9c98b..67e39bd7371 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu index 4c5fae060a8..c37844daa56 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu index c0b58521bc5..f0c40e2f89f 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu index 0a058847247..3ed9694908c 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu index b421199714b..4a16aae66c0 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu index 7f337595b56..b5b5fc26b28 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu index c4c35a18c7c..3b29be627ed 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu index 9ea549e1173..5f1c298c4c4 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu index 8ffc852e3e9..64895643d20 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu index 7143da2f79a..dd508590d66 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu index 4f7cd4f8e4c..8411b6fccbd 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu index 5a9bb142056..b5b4f40770e 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu index dc9b71a5b3d..e608da04b02 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu index 4c5440436a2..c69b78ac3b6 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu index d988a48f990..170cdb5cb8c 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu index c6ae246e7f2..ef0d1e921c1 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu index 761a625564e..6a7fc29ddda 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu index a74d7c2c3b3..faeb6c487fc 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu index 6d48fb099b1..655258d5194 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu index 0e49f26aaaa..4bd8ad8f267 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu index f780a8eb73a..657820f2854 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu index 948c8b17c85..cb0955d1a53 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu index 519783851fe..357b64e83b6 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu index d5392ef3b07..c1207925864 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu index 06086d40840..21687f8932b 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu index a15ab4c60da..4df8ed64d7b 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu index 7038c0ad726..b601195d7e0 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu index 9a805fd3e5b..ced47531898 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu index b23cb43e770..03090f73cb2 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu index c18f470fcfe..d6fe1559ca1 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu index d61b04a07e1..7b5ae4a56aa 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu index 1d33fe12e05..6c603b4dcaa 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu index 03ac4d2f84e..26d25fc1909 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu index 7b031a49031..05a0baf18b6 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu index 77dbc58123b..3a45776537f 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu index 6bae5faa535..9b80bae51f6 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu index 30f666a73fc..f6810efafb8 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu index 358e813eca8..98c018893f1 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu index f5df3f502b8..a10dfaca722 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu index f16185c3ac2..b912a81443e 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu index 796e4d63a3e..8603c396e1f 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu index 6eeb977415f..dc55dbc66aa 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu index aa1d2cd05f0..ef48844972a 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu index 5a92ebdddfb..b1c0ead6e5c 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu index 78c390e5ef0..5d76d0fff04 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu index 2b5aaff0d87..44ea823d272 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu index f0fa3ac63d1..30fe623508b 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu index 0d9407b2ce4..6eb12dc80a6 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu index 223b6783e02..b806fc9d501 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu index 2f49d5f5aae..8f0a26da03a 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu index 9661156d889..6de2819a172 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu index b5f6d7f8757..16927295b82 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu index 82b827e180a..08413072092 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu index 042dd0cc71b..7d4dcdc293b 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu index 4712aed6c3b..b4dfbf7f8b2 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu index 8295033deeb..1fa048752dc 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu index 21c43e6dbd1..e0b6a75e635 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu index d3317ad6280..e257b42f79c 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu index 86218988c2c..f97ab4733a0 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu index 7a6450373c8..cee43ef94cd 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu index 34c1a3d3f04..0442e1f94b5 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu index 96affd254c9..bc71fa9e71f 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu index 489717ff2cd..b61dd71885d 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu index 69917aa1ea4..f47e1f5cdac 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu index 3e3cc66f667..215752f1b06 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu index e5f53e49c51..207afc79242 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu index 0899aa8987b..6c38c083384 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu index 22f4cf6b14c..dc2eb35dc29 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu index d601d694d4f..f04e8bca6f3 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu index 1c5ba9b0066..2697f6910e7 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu index 8073b677a1d..e7a98b2e6ee 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu index 857be35920c..98fb39c86ee 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu index 6931ffa2792..cb938ad93b0 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu index 84facb47e70..e2dc45c79c6 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu index 878d160ff2b..64f99c05a32 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu index e5561f7d63f..3fdbbf23bac 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu index 30474d3543d..ffe202ee394 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu index 074f7232f23..42740f0228b 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu index 734abb7b0e9..829929980d0 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu index 285e7ef520d..d6a330432a4 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu index d552e45db1a..39c774e6f77 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu index 64ca02345eb..bc54be11e6c 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu index 3d8bb7c2775..a68790500d8 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu index 6fab8802c5a..3bca3065c7f 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu index 1fb30696ddb..985692b9fa1 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu index af9b88d9a3d..3c99cb6b5a0 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu index 5f9794a9873..cf77a1ae819 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu index c906649acc6..f9a46a44dd5 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu index 2d7ac26e250..9b4dbbba58a 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu index 171f28e9ce2..da5373fd13e 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu index 8b659e8321b..e8ed21cda49 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim128_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim192_bf16_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim256_bf16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu index c84d02b6d04..f7de8fa2019 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_paged_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu index 6aaf7d12f56..64e5ce4a33f 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu index 11712141419..44619cce59b 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_split_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu index 6175723086c..a059735824d 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu index 2aac1970b1b..daea288fe3a 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_sm90.cu" #include "flash_fwd_hdim128_bf16_sm90.cu" #include "flash_fwd_hdim192_bf16_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_sm90.cu" #include "flash_fwd_hdim256_bf16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu index be0c5af080b..62640192c68 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu index fd5893c59f4..79b0d52fa55 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu index bcde9c94582..333406cb439 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_split_sm90.cu" #include "flash_fwd_hdim128_bf16_split_sm90.cu" #include "flash_fwd_hdim192_bf16_split_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_split_sm90.cu" #include "flash_fwd_hdim256_bf16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu index 160eb3a18e4..b6c1fb54c4a 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu index 28819a690a3..abf0b10e46e 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_packgqa_sm90.cu" #include "flash_fwd_hdim128_e4m3_packgqa_sm90.cu" #include "flash_fwd_hdim192_e4m3_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu" #include "flash_fwd_hdim256_e4m3_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu index 933ad982719..22b310e5aba 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_paged_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_paged_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu index a934f7d9924..f9eed0732d7 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu index 8475e878ae2..b91c7f85ad7 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_paged_split_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_split_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_split_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu index dd1405b17f0..a6b215bfdfd 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu index 7e7d806c6d5..ddec44c68ca 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_sm90.cu" #include "flash_fwd_hdim128_e4m3_sm90.cu" #include "flash_fwd_hdim192_e4m3_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_sm90.cu" #include "flash_fwd_hdim256_e4m3_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu index f973a4e411d..81601b9ec21 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu index 30390838d39..ae9a362c109 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu index 0b629bd2b32..163ee761be1 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_split_sm90.cu" #include "flash_fwd_hdim128_e4m3_split_sm90.cu" #include "flash_fwd_hdim192_e4m3_split_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_split_sm90.cu" #include "flash_fwd_hdim256_e4m3_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu index 818c7fafb7a..ba2d427ddd4 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_e4m3_split_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_split_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu index 6652824d075..34d1763483a 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim128_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim192_fp16_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim256_fp16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu index 05d11e2e258..326a2ea901a 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_paged_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu index b638138eb26..a9e032a071c 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu index 3619a2175f0..d7cc300b89b 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_split_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu index 3a408ceacbd..fa4de4e298f 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu index eec11be9162..cb345586694 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_sm90.cu" #include "flash_fwd_hdim128_fp16_sm90.cu" #include "flash_fwd_hdim192_fp16_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_sm90.cu" #include "flash_fwd_hdim256_fp16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu index ca2a1e1b843..5dbd70ec5d3 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu index 8cf31a8a85f..9a97b96041b 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu index 5ee7ace63ac..5aacbf02664 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_split_sm90.cu" #include "flash_fwd_hdim128_fp16_split_sm90.cu" #include "flash_fwd_hdim192_fp16_split_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_split_sm90.cu" #include "flash_fwd_hdim256_fp16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu index 4da0ee704eb..cfaabd990ab 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu @@ -6,4 +6,5 @@ #include "flash_fwd_hdim96_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/mainloop_fwd_sm80.hpp b/hopper/mainloop_fwd_sm80.hpp index e43904518cf..2d2ba06f220 100644 --- a/hopper/mainloop_fwd_sm80.hpp +++ b/hopper/mainloop_fwd_sm80.hpp @@ -22,7 +22,7 @@ namespace flash { using namespace cute; -template struct CollectiveMainloopFwdSm80 { @@ -30,6 +30,7 @@ struct CollectiveMainloopFwdSm80 { static constexpr int kStages = Stages; static_assert(kStages > 0, "kStages must be greater than 0"); using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; @@ -177,6 +178,7 @@ struct CollectiveMainloopFwdSm80 { ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; @@ -218,6 +220,7 @@ struct CollectiveMainloopFwdSm80 { ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; @@ -272,7 +275,7 @@ struct CollectiveMainloopFwdSm80 { // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) // (assigning it to params.softmax_scale_log2). return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, - args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.stride_V, + args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, @@ -430,11 +433,11 @@ struct CollectiveMainloopFwdSm80 { } cute::cp_async_fence(); - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element, true /*KV_Same_Iter*/>; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.stride_V, + params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k ); @@ -730,11 +733,11 @@ struct CollectiveMainloopFwdSm80 { params.ptr_rotary_sin, params.stride_rotary_sin, params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary); - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.stride_V, + params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position ); diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 3af51566b01..da5f902eae1 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -27,7 +27,7 @@ namespace flash { using namespace cute; -template struct CollectiveMainloopFwdSm90 { @@ -35,6 +35,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr int kStages = Stages; using ClusterShape = ClusterShape_; using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; @@ -53,6 +54,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr bool Use_TMA_KV = !PagedKV; static_assert(Use_TMA_KV || CUTE_STATIC_V(size(ClusterShape{})) == 1, "If not using TMA for KV, ClusterShape must be 1"); static_assert(Use_TMA_KV || !V_colmajor, "If not using TMA for KV, V_colmajor is not supported"); + static constexpr bool SameHeadDim = get<2>(TileShape_MNK{}) == kHeadDimV; using SeqlenInfo_t = flash::SeqlenInfoQKNewK; static_assert(ArchTag::kMinComputeCapability >= 90); @@ -84,9 +86,9 @@ struct CollectiveMainloopFwdSm90 { std::conditional_t< !Mma1_is_RS, decltype(cute::GMMA::ss_op_selector(TileShape_MNK{})), GMMA::Major::K, MmaMajorV>()), + TileShape_MNK_PV, GMMA::Major::K, MmaMajorV>()), decltype(cute::GMMA::rs_op_selector(TileShape_MNK{})), GMMA::Major::K, MmaMajorV>()) + TileShape_MNK_PV, GMMA::Major::K, MmaMajorV>()) >{}, AtomLayoutMNK{})); @@ -107,25 +109,25 @@ struct CollectiveMainloopFwdSm90 { make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + decltype(cute::get<1>(TileShape_MNK_PV{})), decltype(cute::get<2>(TileShape_MNK_PV{}))>()); using SmemLayoutVt = decltype(tile_to_shape( SmemLayoutAtomVt{}, - make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}), + make_shape(shape<1>(TileShape_MNK_PV{}), shape<2>(TileShape_MNK_PV{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); using SmemLayoutAtomVtMma = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + decltype(cute::get<1>(TileShape_MNK_PV{})), decltype(cute::get<2>(TileShape_MNK_PV{}))>()); using SmemLayoutVtMma = decltype(tile_to_shape( SmemLayoutAtomVtMma{}, - make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}), + make_shape(shape<1>(TileShape_MNK_PV{}), shape<2>(TileShape_MNK_PV{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); // Only used if we're using cp.async to load V using SmemLayoutAtomVCpAsync = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + decltype(cute::get<1>(TileShape_MNK{})), Int>()); using SmemLayoutVCpAsync = decltype(tile_to_shape( SmemLayoutAtomVCpAsync{}, - make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + make_shape(shape<1>(TileShape_MNK{}), Int{}, Int{}))); using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); @@ -135,26 +137,26 @@ struct CollectiveMainloopFwdSm90 { // Use LDSM.T and STSM to transpose V in the case of FP8 and V being row-major. // For FP16/BF16 we don't do any transposing. - static_assert(!Transpose_V || (kHeadDim % 32 == 0 && kBlockN % 32 == 0)); - static constexpr bool kHeadDim_multiple_64 = kHeadDim % 64 == 0; - // Either kHeadDim is a multiple of 64 (in which case we use a block size of 64 x 32 for the transpose), + static_assert(!Transpose_V || (kHeadDimV % 32 == 0 && kBlockN % 32 == 0)); + static constexpr bool kHeadDimV_multiple_64 = kHeadDimV % 64 == 0; + // Either kHeadDimV is a multiple of 64 (in which case we use a block size of 64 x 32 for the transpose), // or we need kBlockN to be a multiple of 64 (in which case we use a block size of 32 x 64 for the transpose). - static_assert(!Transpose_V || (kHeadDim_multiple_64 || kBlockN % 64 == 0)); - using LDSM_thread_shape = std::conditional_t, Shape<_16, _4, _1, _2>>; - using LDSM_thread_stride = std::conditional_t, Stride<_4, _1, _0, _64>>; + static_assert(!Transpose_V || (kHeadDimV_multiple_64 || kBlockN % 64 == 0)); + using LDSM_thread_shape = std::conditional_t, Shape<_16, _4, _1, _2>>; + using LDSM_thread_stride = std::conditional_t, Stride<_4, _1, _0, _64>>; using LDSM_value_shape = Shape<_2, _2, _1, _4>; using LDSM_value_stride = Stride<_1, _2, _16, _4>; - using LDSM_divide_shape = std::conditional_t, Shape<_32, _8>>; + using LDSM_divide_shape = std::conditional_t, Shape<_32, _8>>; using S2RTiledCopyVt = decltype(make_tiled_copy( Copy_Atom{}, Layout{}, Layout{})); - using STSM_thread_shape = std::conditional_t, Shape<_8, _4, _2, _2>>; - using STSM_thread_stride = std::conditional_t, Stride<_4, _1, _32, _64>>; + using STSM_thread_shape = std::conditional_t, Shape<_8, _4, _2, _2>>; + using STSM_thread_stride = std::conditional_t, Stride<_4, _1, _32, _64>>; using STSM_value_shape = Shape<_1, _4, _2, _2>; using STSM_value_stride = Stride<_0, _1, _4, _8>; using STSM_divide_shape = Shape<_8, _16>; - // These will not permute the columns of V (the kHeadDim dimension) but incur bank conflicts + // These will not permute the columns of V (the kHeadDimV dimension) but incur bank conflicts // so a little slower (e.g. 1150 TFLOPS for hdim 256 instead of 1200 TFLOPS). // Instead we will permute the cols of V, and un-permute the cols of O in the epilogue. // using STSM_value_shape = Shape<_2, _4, _1, _2>; @@ -168,14 +170,15 @@ struct CollectiveMainloopFwdSm90 { using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); // We use CpAsync for K and V if PagedKV and AppendKV, since TMA doesn't work there + static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will // load twice from the same row. - static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element); + static constexpr int kBytePerHalfRow = kHeadDimGCD / 2 * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); @@ -221,14 +224,13 @@ struct CollectiveMainloopFwdSm90 { GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, select<1, 0, 2, 3>(StrideV{})), take<0, 2>(SmemLayoutVt{}), - select<2, 1>(TileShape_MNK{}), + select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any // Set the bytes transferred in this TMA transaction (may involve multiple issues) static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesK = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesV = static_cast(size(take<0, 2>(SmemLayoutVt{})) * cutlass::sizeof_bits_v / 8); - static_assert(TmaTransactionBytesK == TmaTransactionBytesV); using PipelineTmaAsync = std::conditional_t, typename cutlass::PipelineTmaAsync>; using MainloopPipelineK = std::conditional_t>; @@ -294,6 +296,7 @@ struct CollectiveMainloopFwdSm90 { ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; @@ -335,6 +338,7 @@ struct CollectiveMainloopFwdSm90 { ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; @@ -388,12 +392,14 @@ struct CollectiveMainloopFwdSm90 { take<0, 2>(SmemLayoutK{}), TileShape_MNK{}, ClusterShape{}); // mcast along M mode for this N load, if any - Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), select<1, 0, 2, 3>(args.shape_K), select<1, 0, 2, 3>(args.stride_V)); + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), + make_shape(args.headdim_v, get<0>(args.shape_K), get<2>(args.shape_K), get<3>(args.shape_K)), + select<1, 0, 2, 3>(args.stride_V)); TMA_V tma_load_V = make_tma_copy( GmemTiledCopyKV{}, mV, take<0, 2>(SmemLayoutVt{}), - select<2, 1>(TileShape_MNK{}), + select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any Tensor mKnew = make_tensor(make_gmem_ptr(args.ptr_K_new), args.shape_K_new, args.stride_K_new); TMA_K tma_load_K_new = make_tma_copy_B_sm90( @@ -402,12 +408,14 @@ struct CollectiveMainloopFwdSm90 { take<0, 2>(SmemLayoutK{}), TileShape_MNK{}, ClusterShape{}); // mcast along M mode for this N load, if any - Tensor mVnew = make_tensor(make_gmem_ptr(args.ptr_V_new), select<1, 0, 2, 3>(args.shape_K_new), select<1, 0, 2, 3>(args.stride_V_new)); + Tensor mVnew = make_tensor(make_gmem_ptr(args.ptr_V_new), + make_shape(args.headdim_v, get<0>(args.shape_K_new), get<2>(args.shape_K_new), get<3>(args.shape_K_new)), + select<1, 0, 2, 3>(args.stride_V_new)); TMA_V tma_load_V_new = make_tma_copy( GmemTiledCopyKV{}, cute::conditional_return(mVnew, mV), take<0, 2>(SmemLayoutVt{}), - select<2, 1>(TileShape_MNK{}), + select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); @@ -429,7 +437,7 @@ struct CollectiveMainloopFwdSm90 { // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) // (assigning it to params.softmax_scale_log2). return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, - args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.stride_V, + args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, @@ -555,12 +563,13 @@ struct CollectiveMainloopFwdSm90 { bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); - Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(select<1, 0, 2, 3>(params.shape_K))(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); + Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(shape_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) // if (cute::thread0()) { printf("Varlen = %d, params.leftpad_k = %p, leftpad_k = %d\n", Varlen, params.leftpad_k, leftpad_k); } Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k), mVt_TMA), select<2, 1>(TileShape_MNK{}), make_coord(_0{}, _)); // (K, N, _) + Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k), mVt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) @@ -573,11 +582,11 @@ struct CollectiveMainloopFwdSm90 { Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA)); // (TMA, k) Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V.partition_D(sVt)); // (TMA, PIPE) - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.stride_V, + params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k ); @@ -1210,7 +1219,7 @@ struct CollectiveMainloopFwdSm90 { Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(select<1, 0, 2, 3>(params.shape_K_new))(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); Tensor gKnew_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<2, 1>(TileShape_MNK{}), make_coord(_0{}, _)); // (K, N, _) + Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) auto block_tma_K_new = params.tma_load_K_new.get_slice(cluster_local_block_id.x); Tensor tKgKnew_TMA = group_modes<0, 3>(block_tma_K_new.partition_S(gKnew_TMA)); // (TMA, k) @@ -1306,7 +1315,7 @@ struct CollectiveMainloopFwdSm90 { int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og; Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<2, 1>(TileShape_MNK_PV{}), make_coord(_, _0{})); // (N, K_v, _) static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); @@ -1317,11 +1326,11 @@ struct CollectiveMainloopFwdSm90 { params.ptr_rotary_sin, params.stride_rotary_sin, params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary); - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.stride_V, + params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position ); @@ -1347,6 +1356,12 @@ struct CollectiveMainloopFwdSm90 { Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); #pragma unroll for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(params.shape_K); } + Tensor cV = cute::make_identity_tensor(select<2, 1>(TileShape_MNK_PV{})); // (BLK_N,BLK_K_V) -> (blk_n,blk_k_v) + Tensor tVcV = cute::conditional_return(tKcK, gmem_thr_copy_kv.partition_D(cV)); + Tensor tVpV_ = make_tensor(make_shape(size<2>(tVsV))); + #pragma unroll + for (int k = 0; k < size(tVpV_); ++k) { tVpV_(k) = get<1>(tVcV(_0{}, _0{}, k)) < params.headdim_v; } + Tensor tVpV = cute::conditional_return(tKpK, tVpV_); auto store_K = [&] (int const n_block, auto const& smem_pipe_read) { int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); @@ -1392,7 +1407,7 @@ struct CollectiveMainloopFwdSm90 { Tensor tVgV_cur = tVgV(_, _, _, n_block); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_kv, tVsV_cur, tVgV_cur, tKcK, tKpK, n_limit); + gmem_tiled_copy_kv, tVsV_cur, tVgV_cur, tVcV, tVpV, n_limit); } else { paged_kv_manager.store_V(n_block, tVsV_cur); } diff --git a/hopper/paged_kv.h b/hopper/paged_kv.h index 0f710e54935..9431f384f39 100644 --- a/hopper/paged_kv.h +++ b/hopper/paged_kv.h @@ -14,7 +14,7 @@ namespace flash { using namespace cute; -template +template struct PagedKVManager { // If KV_Same_Iter=false, then we do load_page_table(0), load_K(0), load_page_table(1), load_K(1), load_V(0), // load_page_table(2), load_K(2), load_V(1), etc. @@ -23,14 +23,17 @@ struct PagedKVManager { // LoadsPerRow_LB is the lower bound on number of loads per row in the K direction. This is useful for // rotary where we want each thread to have at least 2 loads per row. + static constexpr bool SameHeadDim = (kHeadDim == kHeadDimV); + static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); + // We use CpAsync for K and V if PagedKV, since TMA doesn't work there static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. // In the case of PackGQA, this reduces the number of times we need to call divmod. - static_assert(kHeadDim % LoadsPerRow_LB == 0, "Headdim must be a multiple of LoadsPerRow_LB"); - static constexpr int kBytePerRow = kHeadDim / LoadsPerRow_LB * sizeof(Element); + static_assert(kHeadDimGCD % LoadsPerRow_LB == 0, "Headdim and HeaddimV must be a multiple of LoadsPerRow_LB"); + static constexpr int kBytePerRow = kHeadDimGCD / LoadsPerRow_LB * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); @@ -59,6 +62,8 @@ struct PagedKVManager { using GmemThrCopyKVCpAsync = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0))); using TensortKcK = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); using TensortKpK = decltype(make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{})); + using TensortVcV = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); + using TensortVpV = decltype(make_tensor(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{})); // For PagedKV, it's expensive the calculate the pointers to K and V for each page table entry, // since those require int64_t arithmetic. We optimize by having threads split this work. @@ -66,6 +71,7 @@ struct PagedKVManager { // that each thread needs to load for the case of hdim 128 and kBlockN = 176. // So each of those 8 threads will calculate the K_ptr and V_ptr for 11 / 8 = 2 rows. // We then use __shfl_sync to broadcast the pointers to the other threads in the warp. + static_assert(CUTE_STATIC_V(size<1>(TensortKcK{})) == CUTE_STATIC_V(size<1>(TensortVcV{}))); static constexpr int kPageEntryPerThread = cute::ceil_div(size<1>(TensortKcK{}), kGmemThreadsPerRow); using TensorPageOffset = decltype(make_tensor>(Shape>{})); using TensorKVPtr = decltype(make_tensor(Shape>{})); @@ -79,15 +85,15 @@ struct PagedKVManager { TensorPageTable mPageTable; TensorKV mK_paged, mV_paged; TensortKpK tKpK; + TensortVpV tVpV; TensorPageOffset tPrPageOffset; TensorKVPtr tPrVPtr; - CUTLASS_DEVICE PagedKVManager(int const* const ptr_page_table, ShapePageTable const &shape_pagetable, StridePageTable const &stride_pagetable, Element* const ptr_K, ShapeKV const &shape_K, StrideKV const &stride_K, - Element* const ptr_V, StrideKV const &stride_V, + Element* const ptr_V, int const headdim_v, StrideKV const &stride_V, cutlass::FastDivmod const &page_size_divmod, int const bidb, int const bidh, int const thread_idx, int const seqlen_k, int const leftpad_k ) @@ -100,13 +106,19 @@ struct PagedKVManager { { mPageTable = make_tensor(make_gmem_ptr(ptr_page_table), shape_pagetable, stride_pagetable)(bidb, _); mK_paged = make_tensor(make_gmem_ptr(ptr_K), shape_K, stride_K)(_, _, bidh, _); - mV_paged = make_tensor(make_gmem_ptr(ptr_V), shape_K, stride_V)(_, _, bidh, _); + auto shape_V = make_shape(get<0>(shape_K), headdim_v, get<2>(shape_K), get<3>(shape_K)); + mV_paged = make_tensor(make_gmem_ptr(ptr_V), shape_V, stride_V)(_, _, bidh, _); tKpK = make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{}); - Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); #pragma unroll for (int k = 0; k < size<1>(tKpK); ++k) { tKpK(_0{}, k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(shape_K); } + Tensor tVpV_ = make_tensor(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{}); + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + #pragma unroll + for (int k = 0; k < size<1>(tVpV_); ++k) { tVpV_(_0{}, k) = get<1>(tVcV(_0{}, _0{}, k)) < get<1>(shape_K); } + tVpV = cute::conditional_return(tKpK, tVpV_); }; template @@ -200,27 +212,27 @@ struct PagedKVManager { // Only for index calculation, since all the indices of thread 0 are known at compile time auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); Tensor tVsV = gmem_thr_copy_kv.partition_D(sV); - Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts - Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); - Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV); - int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKcK(_0{}, _0{}, _0{})); + int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tVcV(_0{}, _0{}, _0{})); #pragma unroll for (int m = 0; m < size<1>(tVsV); ++m) { // Faster to rely on the cp.async to clear smem that are out of bound, // rather than calling cute::clear directly. // We have to be careful not to write to smem past `kBlockN` if !EvenN. // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to checked - if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKcK(_0{}, m, _0{})) < kBlockN) { - bool const should_load = !Seqlenk_mask || get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; + if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tVcV(_0{}, m, _0{})) < kBlockN) { + bool const should_load = !Seqlenk_mask || get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit; Element const* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); - Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); + Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); #pragma unroll for (int k = 0; k < size<2>(tVsV); ++k) { - int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; - cute::copy(gmem_tiled_copy_kv.with(tKpK(_0{}, k) && should_load), mV_paged_cur_copy(_, ki), tVsV(_, m, k)); + int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad; + cute::copy(gmem_tiled_copy_kv.with(tVpV(_0{}, k) && should_load), mV_paged_cur_copy(_, ki), tVsV(_, m, k)); } } } @@ -269,24 +281,24 @@ struct PagedKVManager { if constexpr (KV_Same_Iter) { compute_V_ptr(); } // Only for index calculation, since all the indices of thread 0 are known at compile time auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); - Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts - Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); - Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV); GmemTiledCopyKVStore gmem_tiled_copy_kv_store; - int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tKcK(_0{}, _0{}, _0{})); + int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tVcV(_0{}, _0{}, _0{})); #pragma unroll for (int m = 0; m < size<1>(tVrV); ++m) { - bool const should_load = get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; + bool const should_load = get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit; Element* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); - Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); + Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); if (should_load) { #pragma unroll for (int k = 0; k < size<2>(tVrV); ++k) { - int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; - if (tKpK(_0{}, k)) { + int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad; + if (tVpV(_0{}, k)) { cute::copy(gmem_tiled_copy_kv_store, tVrV(_, m, k), mV_paged_cur_copy(_, ki)); } } diff --git a/hopper/setup.py b/hopper/setup.py index 0104819c68f..db89902550b 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -182,10 +182,10 @@ def sanitize_flags(flags): # to make this work on Windows too. nvcc_gendeps = '--generate-dependencies-with-compile --dependency-output $out.d' cuda_compile_rule_sm80 = ['rule cuda_compile_sm80'] + cuda_compile_rule[1:] + [ - f' command = $nvcc {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80' + f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80' ] cuda_compile_rule_sm80_sm90 = ['rule cuda_compile_sm80_sm90'] + cuda_compile_rule[1:] + [ - f' command = $nvcc {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80_sm90' + f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80_sm90' ] cuda_compile_rule.append( f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags') diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 1fe43e21fa2..d0590b5f1e7 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -113,88 +113,89 @@ def test_flash_attn_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) - if softcap > 0.0: - # Ensure the values of qk are at least within softcap range. - q_ref = (q_ref * softcap / 4) - q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() - k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - # window_size = (-1, -1) if not local else (16, 0) - if dtype == torch.float8_e4m3fn: - q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] - else: - q_descale, k_descale, v_descale = None, None, None - q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] - if V_colmajor: - v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() - out_ref, attn_ref = attention_ref( - q_ref, - k_ref, - v_ref, - None, - None, - causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - sink_token_length=sink_token_length, - softcap=softcap - ) - out_pt, attn_pt = attention_ref( - q_ref, - k_ref, - v_ref, - None, - None, - causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - sink_token_length=sink_token_length, - softcap=softcap, - upcast=False, - reorder_ops=True, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, - ) - - # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() - # m = qk.amax(-1, keepdim=True) - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) - # exp_sum = s_tmp.sum(-1) - # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) - # lse_ref = torch.logsumexp(qk, dim=-1) - - # Numerical error if we just do any arithmetic on out_ref - fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() - rtol = 2 if softcap == 0.0 else 3 - - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] - for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - out, lse = flash_attn_func( - q, - k, - v, + for dv in [128, d] if d > 128 and d <= 192 else [d]: + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4) + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + # window_size = (-1, -1) if not local else (16, 0) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + if V_colmajor: + v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + sink_token_length=sink_token_length, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, causal=causal, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, - pack_gqa=pack_gqa, - num_splits=num_splits + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - # if not causal: - # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") - # breakpoint() - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out, lse = flash_attn_func( + q, + k, + v, + causal=causal, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + sink_token_length=sink_token_length, + softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor: g = torch.randn_like(out) @@ -320,132 +321,133 @@ def test_flash_attn_varlen_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) - if softcap > 0.0: - # Ensure the values of qk are at least within softcap range. - q_ref = (q_ref * softcap / 4).detach().requires_grad_() - q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() - k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - if dtype == torch.float8_e4m3fn: - q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] - else: - q_descale, k_descale, v_descale = None, None, None - q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] - query_padding_mask = generate_random_padding_mask( - seqlen_q, batch_size, device, mode="random", zero_lengths=False - ) - key_padding_mask = generate_random_padding_mask( - seqlen_k, batch_size, device, mode="random", zero_lengths=True - ) - - def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): - if add_unused: - another_mask = generate_random_padding_mask(max_seq_len, bs, device) - attn_mask = torch.logical_and(padding_mask, another_mask) - unused_mask = torch.logical_xor( - torch.logical_or(padding_mask, another_mask), attn_mask - ) + for dv in [128, d] if d > 128 and d <= 192 else [d]: + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] else: - attn_mask = padding_mask - unused_mask = None - return attn_mask, unused_mask - - query_padding_mask, query_unused_mask = _gen_unused_masks( - query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device - ) - key_padding_mask, key_unused_mask = _gen_unused_masks( - key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device - ) - - ( - q_unpad, - k_unpad, - v_unpad, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - q, - k, - v, - output_pad_fn, - dq_pad_fn, - dk_pad_fn, - ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False, - query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) - q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] - out_ref, attn_ref = attention_ref( - q_ref, - k_ref, - v_ref, - query_padding_mask, - key_padding_mask, - causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - softcap=softcap - ) - out_pt, attn_pt = attention_ref( - q_ref, - k_ref, - v_ref, - query_padding_mask, - key_padding_mask, - causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - softcap=softcap, - upcast=False, - reorder_ops=True, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, - ) - - - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="random", zero_lengths=True + ) - if query_unused_mask is not None: - q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask - # Numerical error if we just do any arithmetic on out_ref - fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() - rtol = 2 if softcap == 0.0 else 3 + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) - pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] - for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - out_unpad, lse = flash_attn_varlen_func( + ( q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, - seqused_q, seqused_k, + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False, + query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) + q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, causal=causal, - q_descale=q_descale, - k_descale=k_descale, v_descale=v_descale, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) - out = output_pad_fn(out_unpad) + + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + if query_unused_mask is not None: - out.masked_fill_(q_zero_masking, 0.0) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - # if not causal: - # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") - # breakpoint() + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") - # Check that FlashAttention's numerical error is at most 3x the numerical error - # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out_unpad, lse = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, seqused_k, + max_seqlen_q, + max_seqlen_k, + causal=causal, + q_descale=q_descale, + k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + softcap=softcap, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn: @@ -557,7 +559,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) -@pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("d", [192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -614,261 +617,262 @@ def test_flash_attn_kvcache( nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - if varlen_q: - query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) - output_pad_fn = lambda output_unpad: pad_input( - output_unpad, indices_q, batch_size, seqlen_q - ) - else: - query_padding_mask = None - q_unpad = q - cu_seqlens_q, max_seqlen_q = None, None - # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + for dv in [128, d] if d > 128 and d <= 192 else [d]: + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if varlen_q: + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + else: + query_padding_mask = None + q_unpad = q + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() - cu_seqlens_k_new = None - key_new_padding_mask = None - if new_kv: - k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - if varlen_q: # k & v are also varlen - key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") - k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) - v_unpad, *rest = unpad_input(v, key_new_padding_mask) + seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v else: - k_unpad, v_unpad = k, v - else: - k, v, k_unpad, v_unpad = None, None, None, None - if page_size is None: - k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - page_table = None - else: - ( - k_cache, - v_cache, - page_table, - k_cache_paged, - v_cache_paged, - num_blocks, - ) = _generate_block_kvcache( - seqlen_k, page_size, batch_size_cache, nheads_k, d, device, dtype_ref - ) - cache_seqlens = torch.randint( - 0 if new_kv else 1, - # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough - ( - (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) - if new_kv - else (seqlen_k + 1) - ), - (batch_size,), - dtype=torch.int32, - device=device, - ) - if has_leftpad: - cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) - if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) - for i in range(batch_size)]) - else: - cache_leftpad = None - if has_batch_idx: - cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ - :batch_size - ] - else: - cache_batch_idx = None - arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") - cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") - if not new_kv: - key_padding_mask = arange < cache_seqlens_expanded - else: - k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new - key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens - if has_leftpad: - key_padding_mask = torch.logical_and( - key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) - ) - # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) - if rotary_dim > 0: - angle = ( - torch.rand( - seqlen_k if page_size is None else num_blocks * page_size, - rotary_dim // 2, - device=device, + k, v, k_unpad, v_unpad = None, None, None, None + if page_size is None: + k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + page_table = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype_ref ) - * 2 - * math.pi + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, ) - cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) - sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) - if causal or local: - q_ro = apply_rotary_emb( - q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + if has_leftpad: + cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) + if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size)]) + else: + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ + :batch_size + ] + else: + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded + else: + k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=cache_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved ) else: - q_ro = rearrange( - apply_rotary_emb( - rearrange(q, "b s h d -> b 1 (s h) d"), - cos, - sin, - seqlen_offsets=cache_seqlens, - interleaved=rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=seqlen_q, + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() + v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens ) - # q_ro = q - k_ro = apply_rotary_emb( - k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + window_size=window_size, + key_leftpad=cache_leftpad, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + window_size=window_size, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None ) - else: - cos, sin = None, None - q_ro, k_ro = q, k - # k_cache[:, 64:] = -1 - k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() - v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() - if new_kv: - update_mask = torch.logical_and( - cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + out, lse, *rest = flash_attn_with_kvcache( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + k if not new_kv or not varlen_q else k_unpad, + v if not new_kv or not varlen_q else v_unpad, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + causal=causal, + window_size=window_size, + rotary_interleaved=rotary_interleaved, + num_splits=num_splits, + return_softmax_lse=True ) - k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") - v_to_update = rearrange(v, "b s ... -> (b s) ...") if varlen_q: - k_to_update = k_to_update[indices_k] - v_to_update = v_to_update[indices_k] - k_cache_ref[update_mask] = k_to_update - v_cache_ref[update_mask] = v_to_update - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - out_ref, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - query_padding_mask, - key_padding_mask, - causal=causal, - window_size=window_size, - key_leftpad=cache_leftpad, - ) - out_pt, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - query_padding_mask, - key_padding_mask, - causal=causal, - window_size=window_size, - upcast=False, - reorder_ops=True, - key_leftpad=cache_leftpad, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None - ) - q = q.to(dtype) - q_unpad = q_unpad.to(dtype) if varlen_q else None - k_cache = k_cache.to(dtype) - v_cache = v_cache.to(dtype) - k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None - v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None - k = k.to(dtype) if k is not None else None - v = v.to(dtype) if v is not None else None - k_unpad = k_unpad.to(dtype) if k_unpad is not None else None - v_unpad = v_unpad.to(dtype) if v_unpad is not None else None - cos = cos.to(dtype) if cos is not None else None - sin = sin.to(dtype) if sin is not None else None - out, lse, *rest = flash_attn_with_kvcache( - q if not varlen_q else q_unpad, - k_cache if page_size is None else k_cache_paged, - v_cache if page_size is None else v_cache_paged, - k if not new_kv or not varlen_q else k_unpad, - v if not new_kv or not varlen_q else v_unpad, - rotary_cos=cos, - rotary_sin=sin, - cache_seqlens=cache_seqlens, - cache_batch_idx=cache_batch_idx, - cache_leftpad=cache_leftpad, - page_table=page_table, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k_new, - max_seqlen_q=max_seqlen_q, - causal=causal, - window_size=window_size, - rotary_interleaved=rotary_interleaved, - num_splits=num_splits, - return_softmax_lse=True - ) - if varlen_q: - out = output_pad_fn(out) - # out = flash_attn_with_kvcache( - # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size - # ) - # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) - # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) - # m = qk.amax(-1, keepdim=True) - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) - # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) - # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) - # probs = torch.softmax(qk, dim=-1) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - # breakpoint() - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - if new_kv: - if page_size is None: - k_cache_select = ( - k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] - ) - v_cache_select = ( - v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] - ) - else: - k_cache_select = rearrange( - k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - v_cache_select = rearrange( - v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) - v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) - if dtype is not torch.float8_e4m3fn: - assert torch.equal(v_cache_select, v_cache_ref) - else: - assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # breakpoint() - # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: - if rotary_dim == 0: - assert torch.equal(k_cache_select, k_cache_ref) - else: - # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): - # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) if dtype is not torch.float8_e4m3fn: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + assert torch.equal(v_cache_select, v_cache_ref) else: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) - mult = 4 if dtype == torch.float8_e4m3fn else 2 - assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 - mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 - assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() + assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + else: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() -def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, device, dtype): +def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype): num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 k_cache_paged = torch.randn( num_blocks, page_size, nheads_k, d, device=device, dtype=dtype ) v_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, d, device=device, dtype=dtype + num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype ) page_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), @@ -990,12 +994,12 @@ def attention_combine_ref(out_partial, lse_partial): @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float32]) # @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) -@pytest.mark.parametrize("d", [64, 96, 128, 192, 256]) +@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) # @pytest.mark.parametrize("d", [128]) -@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024, 2048]) +@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) # @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) # @pytest.mark.parametrize("seqlen", [15]) -@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 155]) +@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) # @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) # @pytest.mark.parametrize("num_splits", [128]) def test_flash_attn_combine(num_splits, seqlen, d, dtype): diff --git a/hopper/test_util.py b/hopper/test_util.py index 54eb195eb36..cbf44103126 100644 --- a/hopper/test_util.py +++ b/hopper/test_util.py @@ -37,15 +37,16 @@ def generate_qkv( Arguments: q: (batch_size, seqlen_q, nheads, d) k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d_v) query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ assert not (kvpacked and qkvpacked) batch_size, seqlen_q, nheads, d = q.shape + d_v = v.shape[-1] _, seqlen_k, nheads_k, _ = k.shape assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d_v) if query_unused_mask is not None or key_unused_mask is not None: assert not kvpacked assert not qkvpacked @@ -208,7 +209,7 @@ def attention_ref( Arguments: q: (batch_size, seqlen_q, nheads, head_dim) k: (batch_size, seqlen_k, nheads, head_dim) - v: (batch_size, seqlen_k, nheads, head_dim) + v: (batch_size, seqlen_k, nheads, head_dim_v) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) @@ -221,7 +222,7 @@ def attention_ref( without changing the math. This is to estimate the numerical error from operation reordering. Output: - output: (batch_size, seqlen_q, nheads, head_dim) + output: (batch_size, seqlen_q, nheads, head_dim_v) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ if causal: diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 127f518bbb6..66ab1a7fd49 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -8,7 +8,7 @@ // Return {kBlockM, kBlockN, Mma1_is_RS, IntraWGOverlap} constexpr std::tuple tile_size_fwd_sm90( - int headdim, bool is_causal, bool is_local, int element_size=2, + int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, bool v_colmajor=false, bool paged_kv=false, bool softcap=false) { if (element_size == 2) { if (headdim <= 64) { @@ -22,7 +22,7 @@ constexpr std::tuple tile_size_fwd_sm90( // {128, 192, false, false} and {192, 128, false, true} are quite good too // 128 x 192 hits the limit of smem if Mma1_is_RS, 128 x 144 hits the limit if !Mma1_is_RS } else if (headdim <= 192) { - return {128, paged_kv || is_local ? 96 : 112, true, true}; // 128 x 112 hits the limit of smem + return {128, paged_kv || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem } else { return {128, is_local ? 64 : 80, true, true}; // 128 x 80 hits the limit of smem } @@ -43,7 +43,7 @@ constexpr std::tuple tile_size_fwd_sm90( // Return {kBlockM, kBlockN, kNWarps, kStages, Q_in_regs} constexpr std::tuple tile_size_fwd_sm8x( - bool sm86_or_89, int headdim, bool is_causal, bool is_local, int element_size=2, + bool sm86_or_89, int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, bool paged_kv=false, bool varlen_and_split=false, bool softcap=false, bool append_kv=false) { if (element_size == 2) { From 6d199aa20721fbb51340aff6ec19d70cb03063b9 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 3 Feb 2025 20:06:33 -0500 Subject: [PATCH 10/32] Fix shape_O in epilogue params when kHeadDimV != kHeadDim --- hopper/flash_fwd_launch_template.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 3f4bea96ee4..de17b39c977 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -126,7 +126,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }; typename CollectiveEpilogue::Arguments epilogue_args { static_cast(!Split ? params.o_ptr : params.oaccum_ptr), - {seqlen_q, params.d, params.h, batch_q, params.num_splits}, // shape_O + {seqlen_q, params.dv, params.h, batch_q, params.num_splits}, // shape_O {!Split ? params.o_row_stride : params.oaccum_row_stride, _1{}, !Split ? params.o_head_stride : params.oaccum_head_stride, From 86bcd0552ff5e817c23d58e3b476e1185dfd2965 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 3 Feb 2025 20:12:13 -0500 Subject: [PATCH 11/32] Remove old combine.h --- hopper/combine.h | 248 ----------------------------------------------- 1 file changed, 248 deletions(-) delete mode 100644 hopper/combine.h diff --git a/hopper/combine.h b/hopper/combine.h deleted file mode 100644 index c26f7ea5623..00000000000 --- a/hopper/combine.h +++ /dev/null @@ -1,248 +0,0 @@ - -#pragma once - -#include - -#include -#include "cutlass/layout/layout.h" -#include -#include - -#include "kernel_traits.h" -#include "utils.h" - -namespace flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SharedStorageLSE { - cute::array_aligned> smem_lse; - cute::array_aligned> smem_valid_splits; -}; - -// DONT use Kernel_traits here to avoid redundant compilation. -// template -template -__global__ void combine_attn_seqk_parallel(Params const params) { - // using Element = typename Kernel_traits::OutputType; - // using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = int64_t; // Kernel_traits::index_t - constexpr int kMaxSplits = 1 << Log_max_splits; - // constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int kNThreads = 128; //Kernel_traits::kNThreads; - - static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); - static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); - static_assert(kNThreads == 128, "We assume that each block has 128 threads"); - - // Shared memory. - // kBlockM + 1 instead of kBlockM to reduce bank conflicts. - //__shared__ __align__(16) ElementAccum sLSE[kMaxSplits][kBlockM+1]; - extern __shared__ char smem_[]; - using SharedStorage = SharedStorageLSE, Int>, Shape>>; - SharedStorage &shared_storage = - *reinterpret_cast(smem_); - Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape, Int>{}); - Tensor sValidSplits = make_tensor(make_smem_ptr(shared_storage.smem_valid_splits.data()), Shape>{}); - - // The thread and block index. - const int tidx = threadIdx.x; - const int bidx = blockIdx.x; - - const index_t lse_size = params.b * params.h * params.seqlen_q; - //if (cute::thread0()) print ("final %d %d %d %d\n", params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q); - - const index_t row_offset_lse = bidx * kBlockM; - Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), - Shape, Int>{}, - make_stride(lse_size, _1{})); - - // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile. - // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}. - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape>{}, Stride<_1>{}); - - // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}. - Layout flat_layout = make_layout(lse_size); - Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b)); - auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q); - Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride); - Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout)); - - Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), final_layout); - - constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; - - // Read the LSE values from gmem and store them in shared memory, then transpose them. - constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; - #pragma unroll - for (int l = 0; l < kNLsePerThread; ++l) { - const int row = l * kRowsPerLoadLSE + tidx / kBlockM; - const int col = tidx % kBlockM; - ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; - if (row < kMaxSplits) { sLSE(row,col) = lse; } - // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } - } - __syncthreads(); - - // Reduce along the kBlockM dimension to determine valid splits (store in SMEM) - // One thread per split. Know NumThreads = 128 >= NumMaxSplits - if (tidx < kMaxSplits) { - bool is_valid_split = false; - #pragma unroll - for (int col = 0; col < kBlockM; ++col) { - if(sLSE(tidx,col) != -INFINITY) { - is_valid_split = true; - } - } - sValidSplits(tidx) = is_valid_split; - } - __syncthreads(); - // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } - - Tensor lse_accum = make_tensor(Shape>{}); - constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); - // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits - // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, - // kBlockM rows, so each time we load we can load 128 / kBlockM rows). - // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; - // static_assert(kThreadsPerSplit <= 32); - static_assert(kRowsPerLoadTranspose <= 32); - static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); - #pragma unroll - for (int l = 0; l < kNLsePerThread; ++l) { - const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; - const int col = tidx / kRowsPerLoadTranspose; - //if (bidx == 0 && tidx < 128) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } - lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE(row,col) : -INFINITY; - - } - //return; - - // Compute the logsumexp of the LSE along the split dimension. - ElementAccum lse_max = lse_accum(0); - #pragma unroll - for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); } - MaxOp max_op; - lse_max = Allreduce::run(lse_max, max_op); - lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf - float lse_sum = expf(lse_accum(0) - lse_max); - #pragma unroll - for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } - SumOp sum_op; - lse_sum = Allreduce::run(lse_sum, sum_op); - // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise - // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. - ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; - // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } - if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { - if (params.unpadded_lse) { - const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; - if (lse_offset < lse_size) { - gLSE_unpadded(lse_offset) = lse_logsum; - } - } else { - gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; - } - } - //if (cute::thread0()) printf ("lse_logsum = %f\n", lse_logsum); - - // Store the scales exp(lse - lse_logsum) in shared memory. - #pragma unroll - for (int l = 0; l < kNLsePerThread; ++l) { - const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; - const int col = tidx / kRowsPerLoadTranspose; - if (row < params.num_splits && col < kBlockM) { sLSE(row,col) = expf(lse_accum(l) - lse_logsum); } - } - __syncthreads(); - - const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; - Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), - Shape, Int>{}, - Stride, _1>{}); - constexpr int kBlockN = kNThreads / kBlockM; - using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; - using GmemTiledCopyOaccum = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtomOaccum{}, - Layout>{})); // Val layout, 4 vals per store - GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; - auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); - Tensor tOrO = make_tensor(shape(tOgOaccum)); - Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); - clear(tOrO); - - // Predicates - Tensor cOaccum = make_identity_tensor(Shape, Int>{}); - //if (cute::thread0()) print_tensor (cOaccum); - // Repeat the partitioning with identity layouts - Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); - Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } - } - // Load Oaccum in then scale and accumulate to O - for (int split = 0; split < params.num_splits; ++split) { - // DONT copy in Oaccum if lse(split) = -inf for all kBlockM. - if(sValidSplits(split)) { - flash::copy( - gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM - ); - #pragma unroll - for (int m = 0; m < size<1>(tOrOaccum); ++m) { - int row = get<0>(tOcOaccum(0, m, 0)); - ElementAccum lse_scale = sLSE(split,row); - if (lse_scale != 0.f) { - #pragma unroll - for (int k = 0; k < size<2>(tOrOaccum); ++k) { - #pragma unroll - for (int i = 0; i < size<0>(tOrOaccum); ++i) { - tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); - //tOrO(i, m, k) += tOrOaccum(i, m, k); - } - } - } - //if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE(split, 0), sLSE(split, 1)); print_tensor(tOrOaccum); } - } - } - tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; - } - //if (cute::thread0()) { print_tensor(tOrO); } - - Tensor rO = flash::convert_type(tOrO); - // Write to gO - #pragma unroll - for (int m = 0; m < size<1>(rO); ++m) { - const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); - //if (cute::thread0()) print ("final %d %d %d %d %d\n", idx, params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q); - if (idx < params.b * params.h * params.seqlen_q) { - //print ("final2\n"); - const int batch_idx = idx / (params.h * params.seqlen_q); - const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; - // The index to the rows of Q - const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; - auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride - + head_idx * params.o_head_stride + row * params.o_row_stride; - #pragma unroll - for (int k = 0; k < size<2>(rO); ++k) { - if (Is_even_K || tOpOaccum(k)) { - const int col = get<1>(tOcOaccum(0, m, k)); - Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), - Shape(rO))::value>>{}, Stride<_1>{}); - // TODO: Should check if this is using vectorized store, but it seems pretty fast - copy(rO(_, m, k), gO); - //if (cute::thread0()) { print ("final\n"); print_tensor(gO); } - // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } - // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); - } - } - } - } -} - -} // namespace flash From e3b2400a31e1a094411102dfd474b3c582a42305 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 4 Feb 2025 01:31:38 -0500 Subject: [PATCH 12/32] Fix loading paged V when kHeadDimV != kHeadDim --- hopper/paged_kv.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/paged_kv.h b/hopper/paged_kv.h index 9431f384f39..80ee61b9a41 100644 --- a/hopper/paged_kv.h +++ b/hopper/paged_kv.h @@ -117,7 +117,7 @@ struct PagedKVManager { Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); #pragma unroll - for (int k = 0; k < size<1>(tVpV_); ++k) { tVpV_(_0{}, k) = get<1>(tVcV(_0{}, _0{}, k)) < get<1>(shape_K); } + for (int k = 0; k < size<1>(tVpV_); ++k) { tVpV_(_0{}, k) = get<1>(tVcV(_0{}, _0{}, k)) < get<1>(shape_V); } tVpV = cute::conditional_return(tKpK, tVpV_); }; From 9e07d6d3cfc3a5ab3ea134af70e3d879d855aa70 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 4 Feb 2025 02:26:10 -0500 Subject: [PATCH 13/32] Fix shape_V for storing new KV when kHeadDimV != kHeadDim --- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index da5f902eae1..0a1bf98a1e1 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -1216,7 +1216,8 @@ struct CollectiveMainloopFwdSm90 { bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; Tensor mKnew_TMA = params.tma_load_K_new.get_tma_tensor(params.shape_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); - Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(select<1, 0, 2, 3>(params.shape_K_new))(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); + auto shape_Vnew = make_shape(params.headdim_v, get<0>(params.shape_K_new), get<2>(params.shape_K_new), get<3>(params.shape_K_new)); + Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(shape_Vnew)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); Tensor gKnew_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) @@ -1311,7 +1312,8 @@ struct CollectiveMainloopFwdSm90 { bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); - Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); + Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), shape_V, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og; Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) From f0f25239bd0c5a39c0b481cc5686a835f6c746f5 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 4 Feb 2025 02:28:16 -0500 Subject: [PATCH 14/32] Implement the case of LargeHeadDimV --- hopper/epilogue_fwd.hpp | 19 ++- hopper/flash_fwd_kernel_sm90.h | 35 +++- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 200 +++++++++++++++++++++-- hopper/named_barrier.hpp | 2 + 4 files changed, 222 insertions(+), 34 deletions(-) diff --git a/hopper/epilogue_fwd.hpp b/hopper/epilogue_fwd.hpp index d8f2c15c977..1c13988ebd7 100644 --- a/hopper/epilogue_fwd.hpp +++ b/hopper/epilogue_fwd.hpp @@ -40,6 +40,8 @@ struct CollectiveEpilogueFwd { static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{}); + static constexpr bool LargeHeadDimV = kHeadDimV > 256; + using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; // These are for storing the output tensor without TMA (e.g., for setting output to zero) @@ -239,6 +241,7 @@ struct CollectiveEpilogueFwd { bool is_varlen = Varlen && params.cu_seqlens; int offset_o = seqlen_info.offset; int seqlen_o = seqlen_info.seqlen; + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); // Step 2: Write LSE from rmem -> gmem auto thread_mma = tiled_mma.get_thread_slice(thread_idx); @@ -254,14 +257,16 @@ struct CollectiveEpilogueFwd { Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx); // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); } - if constexpr (!PackGQA) { - #pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { - int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); - if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); } + if (!LargeHeadDimV || warp_group_idx == 0) { + if constexpr (!PackGQA) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); + if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); } + } + } else { + PackGQAt::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } - } else { - PackGQAt::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } // Step 3: Write O from smem -> gmem diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index 05ce4d0ae60..5e1dceb0934 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -46,11 +46,12 @@ class FlashAttnFwdSm90 { static constexpr bool PackGQA = CollectiveMainloop::PackGQA; static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads; static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim; + static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV; + static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV); using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t; // Mainloop derived types using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV; - using TiledMma0 = typename CollectiveMainloop::TiledMma0; using TiledMma1 = typename CollectiveMainloop::TiledMma1; using ArchTag = typename CollectiveMainloop::ArchTag; using ClusterShape = typename CollectiveMainloop::ClusterShape; @@ -69,8 +70,8 @@ class FlashAttnFwdSm90 { using TileSchedulerParams = typename TileScheduler::Params; static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma0{})) / cutlass::NumThreadsPerWarpGroup; - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma0{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma1{})) / cutlass::NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma1{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); @@ -217,15 +218,18 @@ class FlashAttnFwdSm90 { if constexpr (Use_TMA_KV) { pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; pipeline_params_k.is_leader = warp_group_thread_idx == 0; - pipeline_params_k.num_consumers = NumMmaThreads; + pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup; } else { - pipeline_params_k.consumer_arv_count = NumMmaThreads; + pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup; pipeline_params_k.producer_arv_count = NumProducerThreads; } PipelineParamsV pipeline_params_v = pipeline_params_k; if constexpr (Use_TMA_KV && !SameHeadDim) { pipeline_params_v.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + if constexpr (LargeHeadDimV) { pipeline_params_v.num_consumers = NumMmaThreads; } + } else { + if constexpr (LargeHeadDimV) { pipeline_params_v.consumer_arv_count = NumMmaThreads; } } MainloopPipelineK pipeline_k = [&] { @@ -378,7 +382,7 @@ class FlashAttnFwdSm90 { float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; softmax_scale_log2 *= q_descale * k_descale; } - flash::Softmax<2 * (2 * kBlockM / NumMmaThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2); + flash::Softmax softmax(softmax_scale_log2); SeqlenInfo_t seqlen_info{ bidb, @@ -404,9 +408,22 @@ class FlashAttnFwdSm90 { // if (threadIdx.x == 128) { printf("Consumer: After sync\n"); } } } - bool tile_valid = collective_mainloop.mma( - params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, - tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); + bool tile_valid; + if constexpr (!LargeHeadDimV) { + tile_valid = collective_mainloop.mma( + params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); + } else { // mma1_only might not compile if !LargeHeadDimV + if (warp_group_idx == 1) { + tile_valid = collective_mainloop.mma( + params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); + } else { + tile_valid = collective_mainloop.mma1_only( + params.mainloop, pipeline_v, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage); + } + } if (tile_valid) { // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma1, diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 0a1bf98a1e1..67f645e60e1 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -55,6 +55,7 @@ struct CollectiveMainloopFwdSm90 { static_assert(Use_TMA_KV || CUTE_STATIC_V(size(ClusterShape{})) == 1, "If not using TMA for KV, ClusterShape must be 1"); static_assert(Use_TMA_KV || !V_colmajor, "If not using TMA for KV, V_colmajor is not supported"); static constexpr bool SameHeadDim = get<2>(TileShape_MNK{}) == kHeadDimV; + static constexpr bool LargeHeadDimV = kHeadDimV > 256; using SeqlenInfo_t = flash::SeqlenInfoQKNewK; static_assert(ArchTag::kMinComputeCapability >= 90); @@ -66,6 +67,10 @@ struct CollectiveMainloopFwdSm90 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0); + static_assert(!LargeHeadDimV || kBlockM <= 64, "kBlockM must be 64 or less for large Headdim_V"); + static_assert(!LargeHeadDimV || !Mma1_is_RS, "Mma1 must be SS for large Headdim_V"); + // Register bandwidth is actually a bottleneck so we don't want Q to be in registers. // Leaving this option here for reference. static constexpr bool Mma0_is_RS = false; @@ -74,26 +79,34 @@ struct CollectiveMainloopFwdSm90 { static_assert(!(!Mma1_is_RS && Is_FP8), "Mma1 must be RS if FP8"); static_assert(!(!Mma1_is_RS && Transpose_V), "Mma1 must be RS if Transpose_V"); - using AtomLayoutMNK = Layout, _1, _1>>; + using AtomLayoutQK = Layout, _1, _1>>; using TiledMma0 = decltype(cute::make_tiled_mma( std::conditional_t< !Mma0_is_RS, decltype(cute::GMMA::ss_op_selector()), decltype(cute::GMMA::rs_op_selector()) >{}, - AtomLayoutMNK{})); + AtomLayoutQK{})); + using AtomLayoutPV = std::conditional_t< + !LargeHeadDimV, + AtomLayoutQK, + Layout, _1>> + >; + using TileShapeAtomPV = Shape, Int, Int>; using TiledMma1 = decltype(cute::make_tiled_mma( std::conditional_t< !Mma1_is_RS, decltype(cute::GMMA::ss_op_selector()), + TileShapeAtomPV, GMMA::Major::K, MmaMajorV>()), decltype(cute::GMMA::rs_op_selector()) + TileShapeAtomPV, GMMA::Major::K, MmaMajorV>()) >{}, - AtomLayoutMNK{})); + AtomLayoutPV{})); - static constexpr int NumMmaThreads = size(TiledMma0{}); + static constexpr int NumMmaThreadsMma0 = size(TiledMma0{}); + static constexpr int NumMmaThreads = size(TiledMma1{}); static constexpr int NumProducerThreads = !Transpose_V && Use_TMA_KV && Use_TMA_Q ? cutlass::NumThreadsPerWarp : cutlass::NumThreadsPerWarpGroup; + static_assert(NumMmaThreadsMma0 % cutlass::NumThreadsPerWarpGroup == 0); static_assert(NumMmaThreads % cutlass::NumThreadsPerWarpGroup == 0); static constexpr int NumMmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); @@ -133,6 +146,9 @@ struct CollectiveMainloopFwdSm90 { decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{}))); + // Only for LargeHeadDimV where WG0 sends WG1 the scales + using SmemLayoutScale = cute::Layout, Int>>; + using SmemCopyAtomP = Copy_Atom; // Use LDSM.T and STSM to transpose V in the case of FP8 and V being row-major. @@ -251,6 +267,7 @@ struct CollectiveMainloopFwdSm90 { static_assert(SmemAlignmentP >= 128, "Require at least 128B alignment"); using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; + using SmemScale_t = std::conditional_t, cute::array_aligned, 128>>; // Sometimes even with SmemP_t = cute::array, putting it in the TensorStorage struct causes // smem size to go from 227KB to 228KB and we get "invalid argument". @@ -266,8 +283,19 @@ struct CollectiveMainloopFwdSm90 { cute::array_aligned, SmemAlignmentK> smem_k; SmemP_t smem_p; }; + struct TensorStorageWithPScaleNoTranspose : cute::aligned_struct { + cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; + cute::array_aligned, SmemAlignmentQ> smem_q; + cute::array_aligned, SmemAlignmentK> smem_k; + SmemP_t smem_p; + SmemScale_t smem_scale; + }; - using TensorStorageNoTranspose = std::conditional_t; + using TensorStorageNoTranspose = std::conditional_t< + Mma1_is_RS, + TensorStorageWithoutPNoTranspose, + std::conditional_t + >; static constexpr size_t SmemAlignmentVt = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); static constexpr size_t SmemAlignmentV = cutlass::detail::alignment_for_swizzle(SmemLayoutVtMma{}); @@ -277,14 +305,16 @@ struct CollectiveMainloopFwdSm90 { cute::array_aligned, SmemAlignmentVt> smem_vt; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; + SmemScale_t smem_scale; }; using TensorStorage = std::conditional_t; // These are tuned for speed. They don't affect correctness. - static constexpr bool UseSchedulerBarrier = IntraWGOverlap + static constexpr bool UseSchedulerBarrier = (IntraWGOverlap ? (NumMmaWarpGroups >= 2) && (!Is_FP8 ? kHeadDim <= 128 : kHeadDim >= 128) - : NumMmaWarpGroups == 2; + : NumMmaWarpGroups == 2) + && !LargeHeadDimV; static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor); // Host side kernel arguments @@ -699,7 +729,7 @@ struct CollectiveMainloopFwdSm90 { if constexpr (Use_TMA_Q) { // Wait for the MMA warpgroups to signal that smem_q is ready if (SingleProducerWarp || warp_idx_in_warpgroup == 0) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsMma0 + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); } if ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()) { @@ -708,7 +738,7 @@ struct CollectiveMainloopFwdSm90 { tQgQ, tQsQ); } } else { // Load Q with cp.async - cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsMma0 + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element>; @@ -830,13 +860,19 @@ struct CollectiveMainloopFwdSm90 { CUTLASS_DEVICE void mma_init() { + int warp_group_idx = flash::canonical_warp_group_idx_nosync(); // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + if (!LargeHeadDimV || warp_group_idx == 1) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsMma0 + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + } + if (LargeHeadDimV && warp_group_idx > 1) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } if constexpr (UseSchedulerBarrier) { // We have NamedBarrier for up to 3 WGs static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); // WG1 needs the very first signal to start - if (flash::canonical_warp_group_idx_nosync() == 1) { + if (warp_group_idx == 1) { cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) /*id*/); } } @@ -883,6 +919,13 @@ struct CollectiveMainloopFwdSm90 { return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); } }(); + Tensor sScale = [&] { + if constexpr (LargeHeadDimV) { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); + } else { // won't be used, just a placeholder + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutScale{}); + } + }(); if constexpr (!Mma0_is_RS) { static_assert(stride<0>(typename TiledMma0::ALayout{}) == 0 and @@ -891,7 +934,7 @@ struct CollectiveMainloopFwdSm90 { size<0>(typename TiledMma0::BLayout{}) == cutlass::NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); } - constexpr int MmaWarpGroups = size(TiledMma0{}) / cutlass::NumThreadsPerWarpGroup; + static constexpr int MmaWarpGroups = size(TiledMma1{}) / cutlass::NumThreadsPerWarpGroup; Layout warp_group_thread_layout = make_layout(make_shape(Int{}), make_stride(Int{})); @@ -911,6 +954,21 @@ struct CollectiveMainloopFwdSm90 { Tensor tOsP = wg_mma1.partition_fragment_A(sP); Tensor tPsP = smem_thr_copy_P.partition_D(cute::as_position_independent_swizzle_tensor(sP)); + // For storing scales to smem, only used when LargeHeadDimV + auto thread_mma1 = tiled_mma1.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma1.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + auto store_scales = [&](auto& scales, int stage) { + static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row))); + #pragma unroll + for (int mi = 0; mi < size(taccOcO_row); ++mi) { + if (get<1>(taccOcO_row(_0{})) == 0) { + sScale(get<0>(taccOcO_row(mi)), stage) = scales(mi); + } + } + }; + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); @@ -947,7 +1005,7 @@ struct CollectiveMainloopFwdSm90 { } else { if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; - using Rotary_t = Rotary; + using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, params.is_rotary_interleaved, thread_idx, seqlen_q, offset_rotary); @@ -970,7 +1028,7 @@ struct CollectiveMainloopFwdSm90 { } // SMEM fence to make sure the rotated Q is visible to GMMA cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::QueryRotated) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsMma0, static_cast(FwdNamedBarriers::QueryRotated) /*id*/); } else { barrier_Q.wait(work_idx % 2); } @@ -996,6 +1054,8 @@ struct CollectiveMainloopFwdSm90 { mask.template apply(tSrS, m_block, n_block); Tensor scores_scale = softmax.template max_get_scale(tSrS); + // Don't need to store scales to send to WG1 (in the case of LargeHeadDimV) since it's 1.f + softmax.template online_softmax(tSrS); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); @@ -1003,9 +1063,15 @@ struct CollectiveMainloopFwdSm90 { convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } if constexpr (!Mma1_is_RS) { + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); cutlass::arch::fence_view_async_shared(); __syncwarp(); // Only need syncwarp since each warp is using its own P values for Mma1 + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } } --n_block; @@ -1027,17 +1093,24 @@ struct CollectiveMainloopFwdSm90 { scoremod_premask_fn(tSrS); mask_fn(tSrS, n_block); cute::copy(softmax.template max_get_scale(tSrS), scores_scale); + if constexpr (LargeHeadDimV) { store_scales(scores_scale, smem_pipe_read_v.index()); } softmax.template online_softmax(tSrS); warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read_v); // release V if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } if constexpr (!Mma1_is_RS) { cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); } if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } if constexpr (!Mma1_is_RS) { cutlass::arch::fence_view_async_shared(); __syncwarp(); + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } } }; @@ -1077,12 +1150,17 @@ struct CollectiveMainloopFwdSm90 { // } } // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsMma0 + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } consumer_wait(pipeline_v, smem_pipe_read); flash::gemm(tiled_mma1, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; cute::copy(softmax.finalize(v_descale), scores_scale); + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + store_scales(scores_scale, smem_pipe_read.index()); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read); // release V, otherwise producers will hang softmax.rescale_o(tOrO, scores_scale); @@ -1158,7 +1236,7 @@ struct CollectiveMainloopFwdSm90 { } warp_scheduler_barrier_arrive(); // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsMma0 + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; Tensor scores_scale = softmax.finalize(v_descale); softmax.rescale_o(tOrO, scores_scale); @@ -1168,6 +1246,92 @@ struct CollectiveMainloopFwdSm90 { return true; } + template + CUTLASS_DEVICE bool + mma1_only(Params const& params, + MainloopPipelineV pipeline_v, + PipelineState& smem_pipe_read, + FrgTensorO& tOrO, + Softmax& softmax, + int const thread_idx, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + SharedStorage& shared_storage + ) { + static_assert(is_rmem::value, "O tensor must be rmem resident."); + // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda + int const m_block = get<0>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); + auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier + if constexpr (Is_causal || Is_local || Varlen || Split) { + if (n_block_max <= n_block_min) { return false; } + } + + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); + Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); + Tensor sScale = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); + static constexpr int MmaWarpGroups = size(TiledMma1{}) / cutlass::NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(make_shape(Int{}), + make_stride(Int{})); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); + TiledMma1 tiled_mma1; + auto wg_mma1 = tiled_mma1.get_slice(warp_group_thread_layout(warp_group_idx)); + + // Allocate "fragments/descriptors" + Tensor tOrV = wg_mma1.partition_fragment_B(sV); + Tensor tOsP = wg_mma1.partition_fragment_A(sP); + + // For load scales to smem, pretend thread_idx is thread_idx % 128 + auto thread_mma1 = tiled_mma1.get_thread_slice(thread_idx % cutlass::NumThreadsPerWarpGroup); + Tensor taccOcO = thread_mma1.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + auto load_scales = [&](auto& scales, int stage) { + static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row))); + #pragma unroll + for (int mi = 0; mi < size(taccOcO_row); ++mi) { + scales(mi) = sScale(get<0>(taccOcO_row(mi)), stage); + } + }; + + clear(tOrO); + // tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; + + typename Softmax::TensorT scores_scale; + + int n_block = n_block_max - 1; + pipeline_v.consumer_wait(smem_pipe_read); + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + flash::gemm(tiled_mma1, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + pipeline_v.consumer_release(smem_pipe_read); // release V + --n_block; + + for (; n_block >= n_block_min; --n_block) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + load_scales(scores_scale, smem_pipe_read.index()); + softmax.rescale_o(tOrO, scores_scale); + ++smem_pipe_read; + auto barrier_token = pipeline_v.consumer_try_wait(smem_pipe_read); + pipeline_v.consumer_wait(smem_pipe_read, barrier_token); + flash::gemm(tiled_mma1, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + pipeline_v.consumer_release(smem_pipe_read); // release V + }; + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + load_scales(scores_scale, smem_pipe_read.index()); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + // if (thread_idx == 128) { print_tensor(scores_scale); } + // if (thread_idx == 128) { print_tensor(sScale); } + softmax.rescale_o(tOrO, scores_scale); + if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } + ++smem_pipe_read; + return true; + } + CUTLASS_DEVICE cute::tuple get_n_block_k_new_min_max(Params const& params, SeqlenInfo_t const& seqlen_info, int m_block, int bidb, int split_idx=0, int num_splits=1) { diff --git a/hopper/named_barrier.hpp b/hopper/named_barrier.hpp index f77ea778298..8d07f6aa2fc 100644 --- a/hopper/named_barrier.hpp +++ b/hopper/named_barrier.hpp @@ -57,6 +57,8 @@ enum class FwdNamedBarriers { WarpSchedulerWG3 = 6, AppendKV = 7, QueryRotated = 8, + PFull = 9, + PEmpty = 6, // HACK: PEmpty is only used when we don't have 3 WGs }; enum class BwdNamedBarriers { From 4c8819d8c68e8021cb82cf5b2df38c5eb5340531 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 7 Feb 2025 13:55:44 -0500 Subject: [PATCH 15/32] Rename Mma0->MmaQK, Mma1->MmaPV, use Cluster only if hdimV >= 192 --- hopper/flash_fwd_kernel_sm90.h | 16 +-- hopper/flash_fwd_launch_template.h | 6 +- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 162 +++++++++++------------ hopper/tile_size.h | 7 +- 4 files changed, 96 insertions(+), 95 deletions(-) diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index 5e1dceb0934..aad099bd31b 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -52,7 +52,7 @@ class FlashAttnFwdSm90 { // Mainloop derived types using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV; - using TiledMma1 = typename CollectiveMainloop::TiledMma1; + using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV; using ArchTag = typename CollectiveMainloop::ArchTag; using ClusterShape = typename CollectiveMainloop::ClusterShape; using MainloopArguments = typename CollectiveMainloop::Arguments; @@ -70,8 +70,8 @@ class FlashAttnFwdSm90 { using TileSchedulerParams = typename TileScheduler::Params; static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma1{})) / cutlass::NumThreadsPerWarpGroup; - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma1{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); @@ -354,7 +354,7 @@ class FlashAttnFwdSm90 { TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); // Initialize matmul objects. - TiledMma1 tiled_mma1; + TiledMmaPV tiled_mma_pv; PipelineState smem_pipe_read; PipelineState smem_pipe_read_new; @@ -370,7 +370,7 @@ class FlashAttnFwdSm90 { work_tile_info.is_valid(params.scheduler); work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { // Attention output (GEMM-II) accumulator. - Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 1>(TileShape_MNK_PV{})); + Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{})); float softmax_scale_log2 = params.mainloop.softmax_scale_log2; // If there's tanh softcap, the scaling will be done before tanh. auto block_coord = work_tile_info.get_block_coord(params.scheduler); @@ -413,20 +413,20 @@ class FlashAttnFwdSm90 { tile_valid = collective_mainloop.mma( params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); - } else { // mma1_only might not compile if !LargeHeadDimV + } else { // mma_pv might not compile if !LargeHeadDimV if (warp_group_idx == 1) { tile_valid = collective_mainloop.mma( params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); } else { - tile_valid = collective_mainloop.mma1_only( + tile_valid = collective_mainloop.mma_pv( params.mainloop, pipeline_v, smem_pipe_read, tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage); } } if (tile_valid) { // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } - collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma1, + collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv, threadIdx.x - MmaThreadOffset, block_coord); } else { // Write 0 to gO and -inf to gLSE. diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index de17b39c977..f8a98a08fe2 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -39,7 +39,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKV, Varlen && Split, Has_softcap, AppendKV); static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS); static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS); - static constexpr bool Mma1_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); + static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap); static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS); static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS); @@ -50,7 +50,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ClusterShape = cute::Shape, _1, _1>; using CollectiveMainloop = std::conditional_t< Arch >= 90, - flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm90, flash::CollectiveMainloopFwdSm80 >; using CollectiveEpilogue = flash::CollectiveEpilogueFwd; @@ -194,7 +194,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128; // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS) - static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; + static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDimV >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 67f645e60e1..c4911c359cb 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -69,20 +69,20 @@ struct CollectiveMainloopFwdSm90 { static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0); static_assert(!LargeHeadDimV || kBlockM <= 64, "kBlockM must be 64 or less for large Headdim_V"); - static_assert(!LargeHeadDimV || !Mma1_is_RS, "Mma1 must be SS for large Headdim_V"); + static_assert(!LargeHeadDimV || !MmaPV_is_RS, "MmaPV must be SS for large Headdim_V"); // Register bandwidth is actually a bottleneck so we don't want Q to be in registers. // Leaving this option here for reference. - static constexpr bool Mma0_is_RS = false; - // We can have Mma1 (P @ V) with P in smem in rmem to reduce register pressure at the cost of more smem. - static_assert(!(!Mma1_is_RS && !IntraWGOverlap), "Mma1 must be RS if IntraWGOverlap is disabled"); - static_assert(!(!Mma1_is_RS && Is_FP8), "Mma1 must be RS if FP8"); - static_assert(!(!Mma1_is_RS && Transpose_V), "Mma1 must be RS if Transpose_V"); + static constexpr bool MmaQK_is_RS = false; + // We can have MmaPV with P in smem in rmem to reduce register pressure at the cost of more smem. + static_assert(!(!MmaPV_is_RS && !IntraWGOverlap), "MmaPV must be RS if IntraWGOverlap is disabled"); + static_assert(!(!MmaPV_is_RS && Is_FP8), "MmaPV must be RS if FP8"); + static_assert(!(!MmaPV_is_RS && Transpose_V), "MmaPV must be RS if Transpose_V"); using AtomLayoutQK = Layout, _1, _1>>; - using TiledMma0 = decltype(cute::make_tiled_mma( + using TiledMmaQK = decltype(cute::make_tiled_mma( std::conditional_t< - !Mma0_is_RS, + !MmaQK_is_RS, decltype(cute::GMMA::ss_op_selector()), decltype(cute::GMMA::rs_op_selector()) >{}, @@ -93,9 +93,9 @@ struct CollectiveMainloopFwdSm90 { Layout, _1>> >; using TileShapeAtomPV = Shape, Int, Int>; - using TiledMma1 = decltype(cute::make_tiled_mma( + using TiledMmaPV = decltype(cute::make_tiled_mma( std::conditional_t< - !Mma1_is_RS, + !MmaPV_is_RS, decltype(cute::GMMA::ss_op_selector()), decltype(cute::GMMA::rs_op_selector{}, AtomLayoutPV{})); - static constexpr int NumMmaThreadsMma0 = size(TiledMma0{}); - static constexpr int NumMmaThreads = size(TiledMma1{}); + static constexpr int NumMmaThreadsQK = size(TiledMmaQK{}); + static constexpr int NumMmaThreads = size(TiledMmaPV{}); static constexpr int NumProducerThreads = !Transpose_V && Use_TMA_KV && Use_TMA_Q ? cutlass::NumThreadsPerWarp : cutlass::NumThreadsPerWarpGroup; - static_assert(NumMmaThreadsMma0 % cutlass::NumThreadsPerWarpGroup == 0); + static_assert(NumMmaThreadsQK % cutlass::NumThreadsPerWarpGroup == 0); static_assert(NumMmaThreads % cutlass::NumThreadsPerWarpGroup == 0); static constexpr int NumMmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); @@ -259,14 +259,14 @@ struct CollectiveMainloopFwdSm90 { // If PackGQA, we use cp.async (instead of TMA) to load Q, so we want smem_q to be aligned // and have sQ being position_independent_swizzle_tensor. // If !Use_TMA_KV, we use cp.async (instead of TMA) to load K & V, so we want smem_k and smem_v to be aligned. - static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !AppendKV && !Mma0_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); + static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !AppendKV && !MmaQK_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); static constexpr size_t SmemAlignmentK = Use_TMA_KV && !AppendKV ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutK{}); static constexpr size_t SmemAlignmentVtNoTranspose = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); static_assert(SmemAlignmentQ >= 128 and SmemAlignmentK >= 128 && SmemAlignmentVtNoTranspose >= 128, "Require at least 128B alignment"); static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutP{}); static_assert(SmemAlignmentP >= 128, "Require at least 128B alignment"); - using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; + using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; using SmemScale_t = std::conditional_t, cute::array_aligned, 128>>; // Sometimes even with SmemP_t = cute::array, putting it in the TensorStorage struct causes // smem size to go from 227KB to 228KB and we get "invalid argument". @@ -292,7 +292,7 @@ struct CollectiveMainloopFwdSm90 { }; using TensorStorageNoTranspose = std::conditional_t< - Mma1_is_RS, + MmaPV_is_RS, TensorStorageWithoutPNoTranspose, std::conditional_t >; @@ -729,7 +729,7 @@ struct CollectiveMainloopFwdSm90 { if constexpr (Use_TMA_Q) { // Wait for the MMA warpgroups to signal that smem_q is ready if (SingleProducerWarp || warp_idx_in_warpgroup == 0) { - cutlass::arch::NamedBarrier::sync(NumMmaThreadsMma0 + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); } if ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()) { @@ -738,7 +738,7 @@ struct CollectiveMainloopFwdSm90 { tQgQ, tQsQ); } } else { // Load Q with cp.async - cutlass::arch::NamedBarrier::sync(NumMmaThreadsMma0 + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element>; @@ -863,7 +863,7 @@ struct CollectiveMainloopFwdSm90 { int warp_group_idx = flash::canonical_warp_group_idx_nosync(); // Tell producers that smem_q is ready if (!LargeHeadDimV || warp_group_idx == 1) { - cutlass::arch::NamedBarrier::arrive(NumMmaThreadsMma0 + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); } if (LargeHeadDimV && warp_group_idx > 1) { cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); @@ -912,8 +912,8 @@ struct CollectiveMainloopFwdSm90 { Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); Tensor sP = [&] { - if constexpr (Mma1_is_RS) { - // We might not have smem_p if !Mma1_is_RS1, just use smem_q as a placeholder since we don't use it + if constexpr (MmaPV_is_RS) { + // We might not have smem_p if !MmaPV_is_RS, just use smem_q as a placeholder since we don't use it return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutP{}); } else { return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); @@ -927,36 +927,36 @@ struct CollectiveMainloopFwdSm90 { } }(); - if constexpr (!Mma0_is_RS) { - static_assert(stride<0>(typename TiledMma0::ALayout{}) == 0 and - stride<0>(typename TiledMma0::BLayout{}) == 0 and - size<0>(typename TiledMma0::ALayout{}) == cutlass::NumThreadsPerWarpGroup and - size<0>(typename TiledMma0::BLayout{}) == cutlass::NumThreadsPerWarpGroup, + if constexpr (!MmaQK_is_RS) { + static_assert(stride<0>(typename TiledMmaQK::ALayout{}) == 0 and + stride<0>(typename TiledMmaQK::BLayout{}) == 0 and + size<0>(typename TiledMmaQK::ALayout{}) == cutlass::NumThreadsPerWarpGroup and + size<0>(typename TiledMmaQK::BLayout{}) == cutlass::NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); } - static constexpr int MmaWarpGroups = size(TiledMma1{}) / cutlass::NumThreadsPerWarpGroup; + static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup; Layout warp_group_thread_layout = make_layout(make_shape(Int{}), make_stride(Int{})); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); - TiledMma0 tiled_mma0; - TiledMma1 tiled_mma1; - auto wg_mma0 = tiled_mma0.get_slice(warp_group_thread_layout(warp_group_idx)); - auto wg_mma1 = tiled_mma1.get_slice(warp_group_thread_layout(warp_group_idx)); + TiledMmaQK tiled_mma_qk; + TiledMmaPV tiled_mma_pv; + auto wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)); + auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); - auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtomP{}, tiled_mma0); + auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtomP{}, tiled_mma_qk); auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx); // Allocate "fragments/descriptors" - Tensor tSrQ = wg_mma0.partition_fragment_A(sQ); - Tensor tSrK = wg_mma0.partition_fragment_B(sK); - Tensor tOrV = wg_mma1.partition_fragment_B(sV); - Tensor tOsP = wg_mma1.partition_fragment_A(sP); + Tensor tSrQ = wg_mma_qk.partition_fragment_A(sQ); + Tensor tSrK = wg_mma_qk.partition_fragment_B(sK); + Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); + Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); Tensor tPsP = smem_thr_copy_P.partition_D(cute::as_position_independent_swizzle_tensor(sP)); // For storing scales to smem, only used when LargeHeadDimV - auto thread_mma1 = tiled_mma1.get_thread_slice(thread_idx); - Tensor taccOcO = thread_mma1.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); auto store_scales = [&](auto& scales, int stage) { @@ -976,13 +976,13 @@ struct CollectiveMainloopFwdSm90 { // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter clear(tOrO); - // tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; + // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; int const seqlen_q = seqlen_info.seqlen_q; int const seqlen_k = seqlen_info.seqlen_k; int n_block = n_block_max - 1; - flash::Mask mask( + flash::Mask mask( thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, params.qhead_per_khead_divmod ); @@ -1005,7 +1005,7 @@ struct CollectiveMainloopFwdSm90 { } else { if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; - using Rotary_t = Rotary; + using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, params.is_rotary_interleaved, thread_idx, seqlen_q, offset_rotary); @@ -1028,15 +1028,15 @@ struct CollectiveMainloopFwdSm90 { } // SMEM fence to make sure the rotated Q is visible to GMMA cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(NumMmaThreadsMma0, static_cast(FwdNamedBarriers::QueryRotated) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK, static_cast(FwdNamedBarriers::QueryRotated) /*id*/); } else { barrier_Q.wait(work_idx % 2); } } - if constexpr (Mma0_is_RS) { + if constexpr (MmaQK_is_RS) { using SmemCopyAtomQ = Copy_Atom; - auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtomQ{}, tiled_mma0); + auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtomQ{}, tiled_mma_qk); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); Tensor tSsQ_copy_view = smem_thr_copy_Q.partition_S(cute::as_position_independent_swizzle_tensor(sQ)); @@ -1045,9 +1045,9 @@ struct CollectiveMainloopFwdSm90 { // TODO: check the case where n_block_max <= n_block_min but there are sink tokens if constexpr (IntraWGOverlap) { - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); warpgroup_wait<0>(); pipeline_k.consumer_release(smem_pipe_read); scoremod_premask_fn(tSrS); @@ -1058,17 +1058,17 @@ struct CollectiveMainloopFwdSm90 { softmax.template online_softmax(tSrS); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } - Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); Tensor tOrP = make_tensor_like(tOrP_acc); convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } - if constexpr (!Mma1_is_RS) { + if constexpr (!MmaPV_is_RS) { if constexpr (LargeHeadDimV) { cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); } cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); cutlass::arch::fence_view_async_shared(); - __syncwarp(); // Only need syncwarp since each warp is using its own P values for Mma1 + __syncwarp(); // Only need syncwarp since each warp is using its own P values for MmaPV if constexpr (LargeHeadDimV) { cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); } @@ -1080,13 +1080,13 @@ struct CollectiveMainloopFwdSm90 { static constexpr bool Check_inf = decltype(check_inf_type)::value; PipelineState smem_pipe_read_v(smem_pipe_read.index(), smem_pipe_read.phase(), smem_pipe_read.count()); ++smem_pipe_read; - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_k, smem_pipe_read); } warp_scheduler_barrier_sync(); - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } - flash::gemm(tiled_mma1, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); warp_scheduler_barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read); // release K @@ -1103,9 +1103,9 @@ struct CollectiveMainloopFwdSm90 { if constexpr (LargeHeadDimV) { cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); } - if constexpr (!Mma1_is_RS) { cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); } + if constexpr (!MmaPV_is_RS) { cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); } if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - if constexpr (!Mma1_is_RS) { + if constexpr (!MmaPV_is_RS) { cutlass::arch::fence_view_async_shared(); __syncwarp(); if constexpr (LargeHeadDimV) { @@ -1150,10 +1150,10 @@ struct CollectiveMainloopFwdSm90 { // } } // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreadsMma0 + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } consumer_wait(pipeline_v, smem_pipe_read); - flash::gemm(tiled_mma1, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; cute::copy(softmax.finalize(v_descale), scores_scale); if constexpr (LargeHeadDimV) { @@ -1174,9 +1174,9 @@ struct CollectiveMainloopFwdSm90 { auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) { static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value; static constexpr bool Check_inf = decltype(check_inf_type)::value; - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); warp_scheduler_barrier_arrive(); warpgroup_wait<0>(); pipeline_k.consumer_release(smem_pipe_read); // release K @@ -1185,14 +1185,14 @@ struct CollectiveMainloopFwdSm90 { Tensor scores_scale = softmax.template max_get_scale(tSrS); softmax.template online_softmax(tSrS); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } - Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); Tensor tOrP = make_tensor_like(tOrP_acc); convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); } consumer_wait(pipeline_v, smem_pipe_read); warp_scheduler_barrier_sync(); - flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + flash::gemm(tiled_mma_pv, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); pipeline_v.consumer_release(smem_pipe_read); // release V ++smem_pipe_read; }; @@ -1236,7 +1236,7 @@ struct CollectiveMainloopFwdSm90 { } warp_scheduler_barrier_arrive(); // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreadsMma0 + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; Tensor scores_scale = softmax.finalize(v_descale); softmax.rescale_o(tOrO, scores_scale); @@ -1248,16 +1248,16 @@ struct CollectiveMainloopFwdSm90 { template CUTLASS_DEVICE bool - mma1_only(Params const& params, - MainloopPipelineV pipeline_v, - PipelineState& smem_pipe_read, - FrgTensorO& tOrO, - Softmax& softmax, - int const thread_idx, - SeqlenInfo_t const& seqlen_info, - cute::tuple block_coord, - SharedStorage& shared_storage - ) { + mma_pv(Params const& params, + MainloopPipelineV pipeline_v, + PipelineState& smem_pipe_read, + FrgTensorO& tOrO, + Softmax& softmax, + int const thread_idx, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + SharedStorage& shared_storage + ) { static_assert(is_rmem::value, "O tensor must be rmem resident."); // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda int const m_block = get<0>(block_coord); @@ -1272,21 +1272,21 @@ struct CollectiveMainloopFwdSm90 { Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); Tensor sScale = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); - static constexpr int MmaWarpGroups = size(TiledMma1{}) / cutlass::NumThreadsPerWarpGroup; + static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup; Layout warp_group_thread_layout = make_layout(make_shape(Int{}), make_stride(Int{})); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); - TiledMma1 tiled_mma1; - auto wg_mma1 = tiled_mma1.get_slice(warp_group_thread_layout(warp_group_idx)); + TiledMmaPV tiled_mma_pv; + auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); // Allocate "fragments/descriptors" - Tensor tOrV = wg_mma1.partition_fragment_B(sV); - Tensor tOsP = wg_mma1.partition_fragment_A(sP); + Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); + Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); // For load scales to smem, pretend thread_idx is thread_idx % 128 - auto thread_mma1 = tiled_mma1.get_thread_slice(thread_idx % cutlass::NumThreadsPerWarpGroup); - Tensor taccOcO = thread_mma1.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx % cutlass::NumThreadsPerWarpGroup); + Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); auto load_scales = [&](auto& scales, int stage) { @@ -1298,14 +1298,14 @@ struct CollectiveMainloopFwdSm90 { }; clear(tOrO); - // tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; + // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; typename Softmax::TensorT scores_scale; int n_block = n_block_max - 1; pipeline_v.consumer_wait(smem_pipe_read); cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); - flash::gemm(tiled_mma1, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V --n_block; @@ -1317,7 +1317,7 @@ struct CollectiveMainloopFwdSm90 { ++smem_pipe_read; auto barrier_token = pipeline_v.consumer_try_wait(smem_pipe_read); pipeline_v.consumer_wait(smem_pipe_read, barrier_token); - flash::gemm(tiled_mma1, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); pipeline_v.consumer_release(smem_pipe_read); // release V }; diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 66ab1a7fd49..997664bcbc5 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -6,13 +6,14 @@ #include -// Return {kBlockM, kBlockN, Mma1_is_RS, IntraWGOverlap} +// Return {kBlockM, kBlockN, MmaPV_is_RS, IntraWGOverlap} constexpr std::tuple tile_size_fwd_sm90( int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, bool v_colmajor=false, bool paged_kv=false, bool softcap=false) { if (element_size == 2) { if (headdim <= 64) { - return {192, 128, true, true}; + bool same_hdim = (headdim == headdim_v); // if not same hdim, we're targeting hdimv=512 + return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, true}; // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; } else if (headdim <= 96) { @@ -20,7 +21,7 @@ constexpr std::tuple tile_size_fwd_sm90( } else if (headdim <= 128) { return {128, is_causal || is_local || paged_kv ? 128 : 176, true, true}; // {128, 192, false, false} and {192, 128, false, true} are quite good too - // 128 x 192 hits the limit of smem if Mma1_is_RS, 128 x 144 hits the limit if !Mma1_is_RS + // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS } else if (headdim <= 192) { return {128, paged_kv || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem } else { From dd876913f435b3349cb15a2d83b7b5af366f0acd Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 7 Feb 2025 21:52:35 -0500 Subject: [PATCH 16/32] Pass _1 or _0 to cute::aligned_struct --- hopper/flash_fwd_kernel_sm90.h | 5 ++--- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 17 +++++++---------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index aad099bd31b..b6ab92e0b84 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -90,7 +90,7 @@ class FlashAttnFwdSm90 { static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v))); static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_; struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { + struct TensorStorage : cute::aligned_struct<128, _1> { union { struct { cute::array padding_; @@ -100,8 +100,7 @@ class FlashAttnFwdSm90 { typename CollectiveEpilogue::TensorStorage epilogue; }; } tensors; - - struct PipelineStorage : cute::aligned_struct<16> { + struct PipelineStorage : cute::aligned_struct<16, _1> { alignas(16) BarrierQ barrier_Q; alignas(16) cutlass::arch::ClusterBarrier barrier_O; alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k; diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index c4911c359cb..797c88d6441 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -92,14 +92,13 @@ struct CollectiveMainloopFwdSm90 { AtomLayoutQK, Layout, _1>> >; - using TileShapeAtomPV = Shape, Int, Int>; using TiledMmaPV = decltype(cute::make_tiled_mma( std::conditional_t< !MmaPV_is_RS, decltype(cute::GMMA::ss_op_selector()), + TileShape_MNK_PV, GMMA::Major::K, MmaMajorV>()), decltype(cute::GMMA::rs_op_selector()) + TileShape_MNK_PV, GMMA::Major::K, MmaMajorV>()) >{}, AtomLayoutPV{})); @@ -259,7 +258,7 @@ struct CollectiveMainloopFwdSm90 { // If PackGQA, we use cp.async (instead of TMA) to load Q, so we want smem_q to be aligned // and have sQ being position_independent_swizzle_tensor. // If !Use_TMA_KV, we use cp.async (instead of TMA) to load K & V, so we want smem_k and smem_v to be aligned. - static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !AppendKV && !MmaQK_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); + static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !MmaQK_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); static constexpr size_t SmemAlignmentK = Use_TMA_KV && !AppendKV ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutK{}); static constexpr size_t SmemAlignmentVtNoTranspose = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); static_assert(SmemAlignmentQ >= 128 and SmemAlignmentK >= 128 && SmemAlignmentVtNoTranspose >= 128, "Require at least 128B alignment"); @@ -271,19 +270,19 @@ struct CollectiveMainloopFwdSm90 { // Sometimes even with SmemP_t = cute::array, putting it in the TensorStorage struct causes // smem size to go from 227KB to 228KB and we get "invalid argument". - struct TensorStorageWithoutPNoTranspose : cute::aligned_struct { + struct TensorStorageWithoutPNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; }; - struct TensorStorageWithPNoTranspose : cute::aligned_struct { + struct TensorStorageWithPNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; SmemP_t smem_p; }; - struct TensorStorageWithPScaleNoTranspose : cute::aligned_struct { + struct TensorStorageWithPScaleNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; @@ -300,7 +299,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr size_t SmemAlignmentVt = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); static constexpr size_t SmemAlignmentV = cutlass::detail::alignment_for_swizzle(SmemLayoutVtMma{}); static_assert(SmemAlignmentVt >= 128 and SmemAlignmentV >= 128, "Require at least 128B alignment"); - struct TensorStorageTransposeV : cute::aligned_struct { + struct TensorStorageTransposeV : cute::aligned_struct { cute::array_aligned, SmemAlignmentV> smem_v; cute::array_aligned, SmemAlignmentVt> smem_vt; cute::array_aligned, SmemAlignmentQ> smem_q; @@ -1324,8 +1323,6 @@ struct CollectiveMainloopFwdSm90 { cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); load_scales(scores_scale, smem_pipe_read.index()); cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); - // if (thread_idx == 128) { print_tensor(scores_scale); } - // if (thread_idx == 128) { print_tensor(sScale); } softmax.rescale_o(tOrO, scores_scale); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } ++smem_pipe_read; From ed53b5fc4c3b01a6d98d747f21380e444056e042 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 7 Feb 2025 22:25:17 -0500 Subject: [PATCH 17/32] Fix compilation for FP8 when kHeadDimV != kHeadDim --- hopper/flash_api.cpp | 4 ++++ hopper/flash_fwd_kernel_sm90.h | 22 +++++++++++----------- hopper/flash_fwd_launch_template.h | 2 +- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 94fcf5d78f5..7fd8dfc3e28 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -592,6 +592,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq if (head_size_v != head_size) { TORCH_CHECK(head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128, "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128]"); TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); + if (head_size_v > 256) { + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "HeaddimV > 256 requires fp16 and bf16 data type"); + } } // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index b6ab92e0b84..aeb81977c56 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -223,12 +223,13 @@ class FlashAttnFwdSm90 { pipeline_params_k.producer_arv_count = NumProducerThreads; } - PipelineParamsV pipeline_params_v = pipeline_params_k; + static_assert(is_same_v); + PipelineParamsVt pipeline_params_vt = pipeline_params_k; if constexpr (Use_TMA_KV && !SameHeadDim) { - pipeline_params_v.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; - if constexpr (LargeHeadDimV) { pipeline_params_v.num_consumers = NumMmaThreads; } + pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; } } else { - if constexpr (LargeHeadDimV) { pipeline_params_v.consumer_arv_count = NumMmaThreads; } + if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; } } MainloopPipelineK pipeline_k = [&] { @@ -243,9 +244,9 @@ class FlashAttnFwdSm90 { if constexpr (!Transpose_V) { static_assert(is_same_v); if constexpr (Use_TMA_KV) { - return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{}); + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{}); } else { - return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v); + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt); } } else { PipelineParamsV pipeline_params_v; @@ -257,7 +258,6 @@ class FlashAttnFwdSm90 { return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v); } }(); - static_assert(is_same_v); // If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then // the producer WG will read from pipeline_vt and write to pipeline_v. // If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used. @@ -265,11 +265,11 @@ class FlashAttnFwdSm90 { // However, the thread role isn't used in the pipeline implementation. MainloopPipelineVt pipeline_vt = [&] { if constexpr (Use_TMA_KV) { - pipeline_params_v.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG - return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_v, ClusterShape{}); + pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{}); } else { - pipeline_params_v.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG - return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_v); + pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt); } }(); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index f8a98a08fe2..118ccb26b46 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -194,7 +194,7 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128; // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS) - static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDimV >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; + static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { From 4e8496a78179416ea18ae111508dfa4341dc1e37 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 7 Feb 2025 22:55:07 -0500 Subject: [PATCH 18/32] Support Qv --- hopper/flash.h | 5 + hopper/flash_api.cpp | 25 +++++ hopper/flash_attn_interface.py | 29 +++-- hopper/flash_fwd_kernel_sm90.h | 5 + hopper/flash_fwd_launch_template.h | 19 ++-- hopper/mainloop_fwd_sm80.hpp | 2 + hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 133 +++++++++++++++++++++-- hopper/test_flash_attn.py | 59 ++++++++-- hopper/test_util.py | 22 +++- 9 files changed, 260 insertions(+), 39 deletions(-) diff --git a/hopper/flash.h b/hopper/flash.h index 9f8cb1bcae1..9cce795b759 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -104,6 +104,11 @@ struct Flash_fwd_params : public Qkv_params { index_t knew_head_stride; index_t vnew_head_stride; + void *__restrict__ qv_ptr; + index_t qv_batch_stride; + index_t qv_row_stride; + index_t qv_head_stride; + // The cos and sin matrices for rotary embedding. void * __restrict__ rotary_cos_ptr; void * __restrict__ rotary_sin_ptr; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 7fd8dfc3e28..54ec78bce7c 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -487,6 +487,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq const at::Tensor &v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. std::optional &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new std::optional &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional &q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q std::optional &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q std::optional &cu_seqlens_q_, // b+1 std::optional &cu_seqlens_k_, // b+1 @@ -765,6 +766,30 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } } + if (q_v_.has_value()) { + TORCH_CHECK(false, "q_v should be None for now"); + TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "q_v is only supported for fp16 and bf16 data type"); + TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); + at::Tensor q_v = q_v_.value(); + TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); + CHECK_DEVICE(q_v); + TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); + } + params.qv_ptr = q_v.data_ptr(); + // All stride are in elements, not bytes. + params.qv_row_stride = q_v.stride(-3); + params.qv_head_stride = q_v.stride(-2); + if (!is_varlen_q) { + params.qv_batch_stride = q_v.stride(0); + } + } + if (leftpad_k_.has_value()) { auto leftpad_k = leftpad_k_.value(); TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 5f1e4899c92..adee1a0ff26 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -22,6 +22,7 @@ def _flash_attn_forward( v, k_new, v_new, + qv, out, cu_seqlens_q, cu_seqlens_k, @@ -64,6 +65,7 @@ def _flash_attn_forward( v, k_new, v_new, + qv, out, cu_seqlens_q, cu_seqlens_k, @@ -239,6 +241,7 @@ def forward( v, softmax_scale, causal, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), sink_token_length=0, @@ -249,13 +252,14 @@ def forward( sm_margin=0, ): if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward( out, softmax_lse, *rest = _flash_attn_forward( q, k, v, None, None, # k_new, v_new + qv, # qv None, # out None, None, None, # cu_seqlens_q/k/k_new None, None, # seqused_q/k @@ -311,7 +315,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @@ -330,6 +334,7 @@ def forward( max_seqlen_k, softmax_scale, causal, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), sink_token_length=0, @@ -340,13 +345,14 @@ def forward( sm_margin=0, ): if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward( out, softmax_lse, *rest = _flash_attn_forward( q, k, v, None, None, # k_new, v_new + qv, # qv None, # out cu_seqlens_q, cu_seqlens_k, @@ -411,7 +417,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( @@ -478,6 +484,7 @@ def flash_attn_func( v, softmax_scale=None, causal=False, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), sink_token_length=0, @@ -538,6 +545,7 @@ def flash_attn_func( v, softmax_scale, causal, + qv, q_descale, k_descale, v_descale, window_size, sink_token_length, @@ -561,6 +569,7 @@ def flash_attn_varlen_func( max_seqlen_k, softmax_scale=None, causal=False, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), sink_token_length=0, @@ -582,6 +591,7 @@ def flash_attn_varlen_func( max_seqlen_k, softmax_scale, causal, + qv, q_descale, k_descale, v_descale, window_size, sink_token_length, @@ -603,6 +613,7 @@ def flash_attn_with_kvcache( v_cache, k=None, v=None, + qv=None, rotary_cos=None, rotary_sin=None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, @@ -673,11 +684,12 @@ def flash_attn_with_kvcache( k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) page_block_size must be a multiple of 256. - v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no _table, - or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate k with k_cache, starting at the indices specified by cache_seqlens. - v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k. + qv [optional]: (batch_size, seqlen, nheads, headdim_v) rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. @@ -714,7 +726,7 @@ def flash_attn_with_kvcache( assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) if cache_seqlens is not None and isinstance(cache_seqlens, int): cache_seqlens = torch.full( (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device @@ -726,6 +738,7 @@ def flash_attn_with_kvcache( v_cache, k, v, + qv, None, # out cu_seqlens_q, None, # cu_seqlens_k diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index aeb81977c56..c7fec6df559 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -40,6 +40,7 @@ class FlashAttnFwdSm90 { static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8; static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V; static constexpr bool AppendKV = CollectiveMainloop::AppendKV; + static constexpr bool HasQv = CollectiveMainloop::HasQv; static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q; static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV; static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O; @@ -102,6 +103,7 @@ class FlashAttnFwdSm90 { } tensors; struct PipelineStorage : cute::aligned_struct<16, _1> { alignas(16) BarrierQ barrier_Q; + alignas(16) BarrierQ barrier_Qv; alignas(16) cutlass::arch::ClusterBarrier barrier_O; alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k; alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v; @@ -206,6 +208,9 @@ class FlashAttnFwdSm90 { if (warp_idx == 0 && lane_predicate) { shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/); + } shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/); } diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 118ccb26b46..b4f80a04e7c 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -24,7 +24,7 @@ using namespace cute; template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); @@ -50,7 +50,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ClusterShape = cute::Shape, _1, _1>; using CollectiveMainloop = std::conditional_t< Arch >= 90, - flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm90, flash::CollectiveMainloopFwdSm80 >; using CollectiveEpilogue = flash::CollectiveEpilogueFwd; @@ -101,6 +101,8 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new static_cast(params.vnew_ptr), {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new + static_cast(params.qv_ptr), + {params.qv_row_stride, _1{}, params.qv_head_stride, !is_varlen_q ? params.qv_batch_stride : 0}, // stride_Qv static_cast(params.rotary_cos_ptr), {params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter {params.rotary_dim / 2, _1{}}, // stride_rotary_cos @@ -195,11 +197,14 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS) static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; - APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { - // Only use Cluster if number of tiles along seqlen_q is even and not varlen - CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { - static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - run_flash_fwd(params, stream); + BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { + static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 and false; + APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { + // Only use Cluster if number of tiles along seqlen_q is even and not varlen + CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { + static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; + run_flash_fwd(params, stream); + }); }); }); }); diff --git a/hopper/mainloop_fwd_sm80.hpp b/hopper/mainloop_fwd_sm80.hpp index 2d2ba06f220..0fb32c7a900 100644 --- a/hopper/mainloop_fwd_sm80.hpp +++ b/hopper/mainloop_fwd_sm80.hpp @@ -185,6 +185,8 @@ struct CollectiveMainloopFwdSm80 { StrideQK const stride_K_new; Element const* const ptr_V_new; StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideQK const stride_Qv; Element const* const ptr_rotary_cos; ShapeRotary const shape_rotary; StrideRotary const stride_rotary_cos; diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 797c88d6441..1834f200c57 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -28,14 +28,15 @@ namespace flash { using namespace cute; template + bool Is_causal_, bool Is_local_, bool Has_softcap_, bool Varlen_, bool PagedKV_, bool AppendKV_, bool HasQv_, + bool MmaPV_is_RS, bool IntraWGOverlap, bool PackGQA_, bool Split_, bool V_colmajor_> struct CollectiveMainloopFwdSm90 { static constexpr int kStages = Stages; using ClusterShape = ClusterShape_; using TileShape_MNK = TileShape_MNK_; using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; + using TileShape_MNK_QV = Shape(TileShape_MNK{})), decltype(get<1>(TileShape_MNK{})), Int>; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; @@ -46,6 +47,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr bool Varlen = Varlen_; static constexpr bool PagedKV = PagedKV_; static constexpr bool AppendKV = AppendKV_; + static constexpr bool HasQv = HasQv_; static constexpr bool PackGQA = PackGQA_; static constexpr bool Split = Split_; static constexpr bool V_colmajor = V_colmajor_; @@ -70,6 +72,7 @@ struct CollectiveMainloopFwdSm90 { static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0); static_assert(!LargeHeadDimV || kBlockM <= 64, "kBlockM must be 64 or less for large Headdim_V"); static_assert(!LargeHeadDimV || !MmaPV_is_RS, "MmaPV must be SS for large Headdim_V"); + static_assert(!(HasQv && !IntraWGOverlap), "HasQv requires IntraWGOverlap"); // Register bandwidth is actually a bottleneck so we don't want Q to be in registers. // Leaving this option here for reference. @@ -101,6 +104,9 @@ struct CollectiveMainloopFwdSm90 { TileShape_MNK_PV, GMMA::Major::K, MmaMajorV>()) >{}, AtomLayoutPV{})); + using TiledMmaQV = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutQK{})); static constexpr int NumMmaThreadsQK = size(TiledMmaQK{}); static constexpr int NumMmaThreads = size(TiledMmaPV{}); @@ -134,6 +140,16 @@ struct CollectiveMainloopFwdSm90 { make_shape(shape<1>(TileShape_MNK_PV{}), shape<2>(TileShape_MNK_PV{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); + using SmemLayoutAtomQv = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); + using SmemLayoutQv = decltype(tile_to_shape(SmemLayoutAtomQv{}, select<0, 2>(TileShape_MNK_QV{}))); + using SmemLayoutAtomVMmaQV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); + using SmemLayoutVMmaQV = decltype(tile_to_shape( + SmemLayoutAtomVMmaQV{}, + make_shape(shape<1>(TileShape_MNK_QV{}), shape<2>(TileShape_MNK_QV{}), Int{}))); + static_assert(CUTE_STATIC_V(size(SmemLayoutVMmaQV{})) == size(SmemLayoutVtMma{})); + // Only used if we're using cp.async to load V using SmemLayoutAtomVCpAsync = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), Int>()); @@ -242,10 +258,19 @@ struct CollectiveMainloopFwdSm90 { select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + using TMA_Qv_ = decltype(make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + SmemLayoutQv{}, + TileShape_MNK_QV{}, + ClusterShape{})); + using TMA_Qv = std::conditional_t; + // Set the bytes transferred in this TMA transaction (may involve multiple issues) static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesK = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesV = static_cast(size(take<0, 2>(SmemLayoutVt{})) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesQv = static_cast(size(SmemLayoutQv{}) * cutlass::sizeof_bits_v / 8); using PipelineTmaAsync = std::conditional_t, typename cutlass::PipelineTmaAsync>; using MainloopPipelineK = std::conditional_t>; @@ -261,12 +286,14 @@ struct CollectiveMainloopFwdSm90 { static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !MmaQK_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); static constexpr size_t SmemAlignmentK = Use_TMA_KV && !AppendKV ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutK{}); static constexpr size_t SmemAlignmentVtNoTranspose = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); + static constexpr size_t SmemAlignmentQv = Use_TMA_Q ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQv{}); static_assert(SmemAlignmentQ >= 128 and SmemAlignmentK >= 128 && SmemAlignmentVtNoTranspose >= 128, "Require at least 128B alignment"); static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutP{}); static_assert(SmemAlignmentP >= 128, "Require at least 128B alignment"); using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; using SmemScale_t = std::conditional_t, cute::array_aligned, 128>>; + using SmemQv_t = std::conditional_t, cute::array_aligned, SmemAlignmentQv>>; // Sometimes even with SmemP_t = cute::array, putting it in the TensorStorage struct causes // smem size to go from 227KB to 228KB and we get "invalid argument". @@ -274,18 +301,21 @@ struct CollectiveMainloopFwdSm90 { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; }; struct TensorStorageWithPNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; SmemP_t smem_p; }; struct TensorStorageWithPScaleNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; SmemP_t smem_p; SmemScale_t smem_scale; }; @@ -304,6 +334,7 @@ struct CollectiveMainloopFwdSm90 { cute::array_aligned, SmemAlignmentVt> smem_vt; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; SmemScale_t smem_scale; }; @@ -332,6 +363,8 @@ struct CollectiveMainloopFwdSm90 { StrideQK const stride_K_new; Element const* const ptr_V_new; StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideQK const stride_Qv; Element const* const ptr_rotary_cos; ShapeRotary const shape_rotary; StrideRotary const stride_rotary_cos; @@ -374,6 +407,10 @@ struct CollectiveMainloopFwdSm90 { StrideQK const stride_K_new; Element const* const ptr_V_new; StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideV const stride_Qv; + ShapeQPacked const shape_Qv_packed; + StrideQPacked const stride_Qv_packed; Element const* const ptr_rotary_cos; ShapeRotary const shape_rotary; StrideRotary const stride_rotary_cos; @@ -390,6 +427,7 @@ struct CollectiveMainloopFwdSm90 { TMA_V tma_load_V; TMA_K tma_load_K_new; TMA_V tma_load_V_new; + TMA_Qv tma_load_Qv; float const softmax_scale_log2; float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; @@ -446,6 +484,20 @@ struct CollectiveMainloopFwdSm90 { take<0, 2>(SmemLayoutVt{}), select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + auto shape_Qv = make_shape(get<0>(args.shape_Q), args.headdim_v, get<2>(args.shape_Q), get<3>(args.shape_Q)); + Tensor mQv = make_tensor(make_gmem_ptr(args.ptr_Qv), shape_Qv, args.stride_Qv); + TMA_Qv tma_load_Qv = [&] { + if constexpr (HasQv) { + return make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + mQv, + SmemLayoutQv{}, + TileShape_MNK_QV{}, + ClusterShape{}); // no mcast for Qv + } else { + return nullptr; + } + }(); // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); auto const shape_Q_packed = cute::conditional_return( @@ -456,6 +508,14 @@ struct CollectiveMainloopFwdSm90 { args.stride_Q, make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q)) ); + auto const shape_Qv_packed = cute::conditional_return( + shape_Qv, + make_shape(make_shape(qhead_per_khead, get<0>(shape_Qv)), get<1>(shape_Qv), get<2>(args.shape_K), get<3>(shape_Qv)) + ); + auto const stride_Qv_packed = cute::conditional_return( + args.stride_Qv, + make_stride(make_stride(get<2>(args.stride_Qv), get<0>(args.stride_Qv)), get<1>(args.stride_Qv), get<2>(args.stride_Qv) * qhead_per_khead, get<3>(args.stride_Qv)) + ); if (get<1>(args.shape_rotary) > 0) { assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr); } @@ -468,12 +528,13 @@ struct CollectiveMainloopFwdSm90 { return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, + args.ptr_Qv, args.stride_Qv, shape_Qv_packed, stride_Qv_packed, args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable, cutlass::FastDivmod(int(get<0>(args.shape_K))), cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), - tma_load_Q, tma_load_K, tma_load_V, tma_load_K_new, tma_load_V_new, + tma_load_Q, tma_load_K, tma_load_V, tma_load_K_new, tma_load_V_new, tma_load_Qv, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, @@ -490,6 +551,9 @@ struct CollectiveMainloopFwdSm90 { static void prefetch_tma_descriptors(Params const& params) { if constexpr (Use_TMA_Q) { cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); + if constexpr (HasQv) { + cute::prefetch_tma_descriptor(params.tma_load_Qv.get_tma_descriptor()); + } } if constexpr (Use_TMA_KV) { cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); @@ -546,7 +610,11 @@ struct CollectiveMainloopFwdSm90 { int &work_idx ) { - auto [m_block, bidh, bidb, split_idx] = block_coord; + // some of these are captured in lambda so can't use structured binding + int const m_block = get<0>(block_coord); + int const bidh = get<1>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access. if constexpr (Is_causal || Is_local || Varlen || Split) { @@ -578,6 +646,7 @@ struct CollectiveMainloopFwdSm90 { return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVCpAsync{})); } }(); + Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{}); int const thread_idx = threadIdx.x % NumProducerThreads; int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; @@ -610,6 +679,19 @@ struct CollectiveMainloopFwdSm90 { auto block_tma_V = params.tma_load_V.get_slice(cluster_local_block_id.x); Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA)); // (TMA, k) Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V.partition_D(sVt)); // (TMA, PIPE) + auto [tQvgQv, tQvsQv] = [&] { + if constexpr (HasQv) { + auto shape_Qv = make_shape(get<0>(params.shape_Q), params.headdim_v, get<2>(params.shape_Q), get<3>(params.shape_Q)); + Tensor mQv = params.tma_load_Qv.get_tma_tensor(shape_Qv)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor gQv = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQv), select<0, 2>(TileShape_MNK_QV{}), make_coord(m_block, _0{})); // (M, Kv) + auto block_tma_Qv = params.tma_load_Qv.get_slice(_0{}); + Tensor tQvgQv = group_modes<0, 3>(block_tma_Qv.partition_S(gQv)); // (TMA) + Tensor tQvsQv = group_modes<0, 3>(block_tma_Qv.partition_D(sQv)); // (TMA) + return cute::make_tuple(tQvgQv, tQvsQv); + } else { + return cute::make_tuple(nullptr, nullptr); + } + }(); using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; PagedKVManager_t paged_kv_manager( @@ -735,6 +817,11 @@ struct CollectiveMainloopFwdSm90 { shared_storage.pipelines.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); copy(params.tma_load_Q.with(reinterpret_cast(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), tQgQ, tQsQ); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.arrive_and_expect_tx(TmaTransactionBytesQv); + copy(params.tma_load_Qv.with(reinterpret_cast(shared_storage.pipelines.barrier_Qv), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), + tQvgQv, tQvsQv); + } } } else { // Load Q with cp.async cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); @@ -745,6 +832,15 @@ struct CollectiveMainloopFwdSm90 { auto &barrier_Q = shared_storage.pipelines.barrier_Q; cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Q)); barrier_Q.arrive(); + if constexpr (HasQv) { + Tensor mQv = make_tensor(make_gmem_ptr(params.ptr_Qv + seqlen_info.offset_q * get<0>(params.stride_Qv)), params.shape_Qv_packed, params.stride_Qv_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor sQv_pi = cute::as_position_independent_swizzle_tensor(sQv); + using PackGQAt = flash::PackGQAManager(TileShape_MNK_QV{}), get<2>(TileShape_MNK_QV{}), NumProducerThreads, Element>; + PackGQAt::load_Q(mQv, sQv_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block); + auto &barrier_Qv = shared_storage.pipelines.barrier_Qv; + cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Qv)); + barrier_Qv.arrive(); + } } // Wait for the MMA WGs to signal that smem_v are ready and V can be copied from gmem @@ -925,6 +1021,8 @@ struct CollectiveMainloopFwdSm90 { return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutScale{}); } }(); + Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{}); + Tensor sVMmaQV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVMmaQV{}); if constexpr (!MmaQK_is_RS) { static_assert(stride<0>(typename TiledMmaQK::ALayout{}) == 0 and @@ -940,8 +1038,10 @@ struct CollectiveMainloopFwdSm90 { int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); TiledMmaQK tiled_mma_qk; TiledMmaPV tiled_mma_pv; + TiledMmaQV tiled_mma_qv; auto wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)); auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); + auto wg_mma_qv = tiled_mma_qv.get_slice(warp_group_thread_layout(warp_group_idx)); auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtomP{}, tiled_mma_qk); auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx); @@ -951,6 +1051,8 @@ struct CollectiveMainloopFwdSm90 { Tensor tSrK = wg_mma_qk.partition_fragment_B(sK); Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); + Tensor tSrQv = wg_mma_qv.partition_fragment_A(sQv); + Tensor tSrV = wg_mma_qv.partition_fragment_B(sVMmaQV); Tensor tPsP = smem_thr_copy_P.partition_D(cute::as_position_independent_swizzle_tensor(sP)); // For storing scales to smem, only used when LargeHeadDimV @@ -1049,6 +1151,11 @@ struct CollectiveMainloopFwdSm90 { flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); warpgroup_wait<0>(); pipeline_k.consumer_release(smem_pipe_read); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); + consumer_wait(pipeline_v, smem_pipe_read); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + } scoremod_premask_fn(tSrS); mask.template apply(tSrS, m_block, n_block); @@ -1084,18 +1191,28 @@ struct CollectiveMainloopFwdSm90 { warp_scheduler_barrier_sync(); flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } + if constexpr(!HasQv) { + if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } + } flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); warp_scheduler_barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read); // release K + if constexpr (HasQv) { + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + consumer_wait(pipeline_v, smem_pipe_read); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + } scoremod_premask_fn(tSrS); mask_fn(tSrS, n_block); cute::copy(softmax.template max_get_scale(tSrS), scores_scale); if constexpr (LargeHeadDimV) { store_scales(scores_scale, smem_pipe_read_v.index()); } softmax.template online_softmax(tSrS); - warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read_v); // release V + if constexpr (!HasQv) { + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + } if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } @@ -1151,7 +1268,7 @@ struct CollectiveMainloopFwdSm90 { // Tell producers that smem_q is ready cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - consumer_wait(pipeline_v, smem_pipe_read); + if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; cute::copy(softmax.finalize(v_descale), scores_scale); diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index d0590b5f1e7..6d5d8f8e2d0 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -50,6 +50,8 @@ # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -96,7 +98,7 @@ ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, mha_type, dtype + seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype ): # sink_token_length = 0 if not local else 4 sink_token_length = 0 if not local else 0 @@ -121,6 +123,10 @@ def test_flash_attn_output( q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) # window_size = (-1, -1) if not local else (16, 0) @@ -129,6 +135,7 @@ def test_flash_attn_output( else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None if V_colmajor: v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() out_ref, attn_ref = attention_ref( @@ -138,6 +145,7 @@ def test_flash_attn_output( None, None, causal=causal, + qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, sink_token_length=sink_token_length, @@ -150,6 +158,7 @@ def test_flash_attn_output( None, None, causal=causal, + qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, sink_token_length=sink_token_length, @@ -160,6 +169,8 @@ def test_flash_attn_output( ) # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() + # if qv is not None: + # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() # m = qk.amax(-1, keepdim=True) # s_tmp = torch.exp((qk - m) / math.sqrt(d)) # exp_sum = s_tmp.sum(-1) @@ -180,6 +191,7 @@ def test_flash_attn_output( k, v, causal=causal, + qv=qv, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, sink_token_length=sink_token_length, @@ -197,7 +209,7 @@ def test_flash_attn_output( # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor: + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv: g = torch.randn_like(out) do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) # import flash_attn_3_cuda @@ -249,7 +261,7 @@ def test_flash_attn_output( # breakpoint() - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor: + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv: dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) @@ -264,6 +276,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -308,7 +322,7 @@ def test_flash_attn_output( ], ) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, mha_type, dtype + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype ): device = "cuda" # set seed @@ -329,6 +343,10 @@ def test_flash_attn_varlen_output( q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) if dtype == torch.float8_e4m3fn: @@ -336,6 +354,7 @@ def test_flash_attn_varlen_output( else: q_descale, k_descale, v_descale = None, None, None q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None query_padding_mask = generate_random_padding_mask( seqlen_q, batch_size, device, mode="random", zero_lengths=False ) @@ -366,6 +385,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): q_unpad, k_unpad, v_unpad, + qv_unpad, cu_seqlens_q, cu_seqlens_k, seqused_q, @@ -375,10 +395,11 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): q, k, v, + qv, output_pad_fn, dq_pad_fn, dk_pad_fn, - ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] out_ref, attn_ref = attention_ref( @@ -388,6 +409,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): query_padding_mask, key_padding_mask, causal=causal, + qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, softcap=softcap @@ -399,6 +421,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): query_padding_mask, key_padding_mask, causal=causal, + qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, softcap=softcap, @@ -431,6 +454,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): max_seqlen_q, max_seqlen_k, causal=causal, + qv=qv_unpad, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, @@ -450,7 +474,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn: + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: g_unpad = torch.randn_like(out_unpad) do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) # import flash_attn_3_cuda @@ -518,7 +542,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # breakpoint() - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn: + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) @@ -554,7 +578,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("has_batch_idx", [False, True]) # @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("varlen_q", [False, True]) -# @pytest.mark.parametrize("varlen_q", [True]) +# @pytest.mark.parametrize("varlen_q", [False]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) @@ -572,8 +596,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): (3, 799), (64, 2048), (16, 20000), - (1, 128 * 1024), - (16, 128 * 1024), + # (1, 128 * 1024), + # (16, 128 * 1024), (128, 128), (256, 512), # To test appending KV with more than 1 block (2048, 3577), # Enough tile to test persistent scheduler @@ -617,17 +641,25 @@ def test_flash_attn_kvcache( nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - for dv in [128, d] if d > 128 and d <= 192 else [d]: + dv_vals = [128, d] if d > 128 and d <= 192 else [d] + has_qv_vals = [False] + for dv, has_qv in itertools.product(dv_vals, has_qv_vals): q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if has_qv: + qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv = None if varlen_q: query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None else: query_padding_mask = None q_unpad = q + qv_unpad = qv cu_seqlens_q, max_seqlen_q = None, None # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) @@ -755,6 +787,7 @@ def test_flash_attn_kvcache( query_padding_mask, key_padding_mask, causal=causal, + qv=qv, window_size=window_size, key_leftpad=cache_leftpad, ) @@ -765,6 +798,7 @@ def test_flash_attn_kvcache( query_padding_mask, key_padding_mask, causal=causal, + qv=qv, window_size=window_size, upcast=False, reorder_ops=True, @@ -781,6 +815,8 @@ def test_flash_attn_kvcache( v = v.to(dtype) if v is not None else None k_unpad = k_unpad.to(dtype) if k_unpad is not None else None v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None cos = cos.to(dtype) if cos is not None else None sin = sin.to(dtype) if sin is not None else None out, lse, *rest = flash_attn_with_kvcache( @@ -789,6 +825,7 @@ def test_flash_attn_kvcache( v_cache if page_size is None else v_cache_paged, k if not new_kv or not varlen_q else k_unpad, v if not new_kv or not varlen_q else v_unpad, + qv=qv if not varlen_q else qv_unpad, rotary_cos=cos, rotary_sin=sin, cache_seqlens=cache_seqlens, diff --git a/hopper/test_util.py b/hopper/test_util.py index cbf44103126..b7ea3d3b752 100644 --- a/hopper/test_util.py +++ b/hopper/test_util.py @@ -30,7 +30,7 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", def generate_qkv( - q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False, + q, k, v, query_padding_mask=None, key_padding_mask=None, qv=None, kvpacked=False, qkvpacked=False, query_unused_mask=None, key_unused_mask=None, ): """ @@ -58,6 +58,7 @@ def generate_qkv( output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None else: q_unpad = rearrange(q, "b s h d -> (b s) h d") cu_seqlens_q = torch.arange( @@ -68,6 +69,7 @@ def generate_qkv( output_pad_fn = lambda output_unpad: rearrange( output_unpad, "(b s) h d -> b s h d", b=batch_size ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None if key_padding_mask is not None: k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input( @@ -135,6 +137,7 @@ def generate_qkv( q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(), v_unpad.detach().requires_grad_(), + qv_unpad.detach() if qv is not None else None, cu_seqlens_q, cu_seqlens_k, seqused_q, @@ -144,6 +147,7 @@ def generate_qkv( q.detach().requires_grad_(), k.detach().requires_grad_(), v.detach().requires_grad_(), + qv.detach() if qv is not None else None, output_pad_fn, dq_pad_fn, dk_pad_fn, @@ -197,6 +201,7 @@ def attention_ref( dropout_p=0.0, dropout_mask=None, causal=False, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), # -1 means infinite window size sink_token_length=0, @@ -210,6 +215,7 @@ def attention_ref( q: (batch_size, seqlen_q, nheads, head_dim) k: (batch_size, seqlen_k, nheads, head_dim) v: (batch_size, seqlen_k, nheads, head_dim_v) + qv: (batch_size, seqlen_q, nheads, head_dim_v) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) @@ -230,9 +236,11 @@ def attention_ref( dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() + qv = qv.float() if qv is not None else None if q_descale is not None: - q_descale = repeat(q_descale, "b h -> b (h g)", g = q.shape[2] // k.shape[2]) - q = (q.float() * rearrange(q_descale, "b h -> b 1 h 1")).to(dtype=q.dtype) + q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g = q.shape[2] // k.shape[2]).to(dtype=q.dtype) + q = q.float() * q_descale + qv = qv.float() * q_descale if qv is not None else None if k_descale is not None: k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) if v_descale is not None: @@ -241,10 +249,14 @@ def attention_ref( k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) d = q.shape[-1] + dv = v.shape[-1] + softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) if not reorder_ops: - scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) else: - scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if qv is not None: + scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) if softcap > 0: scores = torch.tanh(scores / softcap) * softcap if key_padding_mask is not None: From 893a22ab5703ab3d61eda256f3a9a73a66b4444c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 7 Feb 2025 23:04:20 -0500 Subject: [PATCH 19/32] Test varlen_q=True by default for kvcache --- hopper/test_flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 6d5d8f8e2d0..e9cd8c9d6cb 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -578,7 +578,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("has_batch_idx", [False, True]) # @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("varlen_q", [False, True]) -# @pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("varlen_q", [True]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) From 5fab938555597b5e6150b16b190415d3420b1c67 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 8 Feb 2025 01:25:09 -0500 Subject: [PATCH 20/32] Fix num_splits heuristic being called before get_pack_gqa --- hopper/flash_api.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 54ec78bce7c..6820f93416e 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -715,8 +715,9 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.page_size = page_size; params.num_pages = num_pages; - params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); + // get_num_splits need params.pack_gqa to decide + params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; if (k_new_.has_value()) { at::Tensor k_new, v_new; From 5fc5ebf82b27adc47ffb364a3e0c654fc266a321 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 8 Feb 2025 16:21:44 -0500 Subject: [PATCH 21/32] Fix num_splits heuristic again when PackGQA --- hopper/flash_api.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 6820f93416e..402e1a6aa73 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -429,7 +429,8 @@ inline int get_num_splits(Flash_fwd_params const& params) { : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM)); int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN; int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; - return num_splits_heuristic(params.b * (!params.pack_gqa ? params.h : params.h_k) * num_m_blocks, params.num_sm, num_n_blocks, 128); + // Always enable PackGQA for Split + return num_splits_heuristic(params.b * params.h_k * num_m_blocks, params.num_sm, num_n_blocks, 128); // return num_splits_heuristic(params.b * params.h_k * num_m_blocks, params.b * params.h_k, // params.num_sm, num_n_blocks, 128, params.d_rounded); #endif @@ -715,9 +716,9 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.page_size = page_size; params.num_pages = num_pages; - params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); - // get_num_splits need params.pack_gqa to decide params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); if (k_new_.has_value()) { at::Tensor k_new, v_new; From 5378bc3204bf9a2d959f1c66fe2f9bf60d582b43 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sat, 8 Feb 2025 16:30:51 -0500 Subject: [PATCH 22/32] Tile fwd_combine kernel along headdim, don't need kBlockM > 128 --- hopper/flash.h | 2 +- hopper/flash_api.cpp | 18 +----- hopper/flash_fwd_combine.cu | 6 -- hopper/flash_fwd_combine_kernel.h | 65 +++++++++++++--------- hopper/flash_fwd_combine_launch_template.h | 23 ++++---- 5 files changed, 55 insertions(+), 59 deletions(-) diff --git a/hopper/flash.h b/hopper/flash.h index 9cce795b759..8e95f5ff75c 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -207,5 +207,5 @@ template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); -template +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 402e1a6aa73..7dad5b9c7bc 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -359,32 +359,20 @@ void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { if (params.is_fp32) { if (params.dv <= 64) { run_mha_fwd_combine_(params, stream); - } else if (params.dv <= 128) { - run_mha_fwd_combine_(params, stream); - } else if (params.dv <= 256) { - run_mha_fwd_combine_(params, stream); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream); } } else if (params.is_bf16) { if (params.dv <= 64) { run_mha_fwd_combine_(params, stream); - } else if (params.dv <= 128) { - run_mha_fwd_combine_(params, stream); - } else if (params.dv <= 256) { - run_mha_fwd_combine_(params, stream); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream); } } else { if (params.dv <= 64) { run_mha_fwd_combine_(params, stream); - } else if (params.dv <= 128) { - run_mha_fwd_combine_(params, stream); - } else if (params.dv <= 256) { - run_mha_fwd_combine_(params, stream); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream); } } #else diff --git a/hopper/flash_fwd_combine.cu b/hopper/flash_fwd_combine.cu index 57392ee75f4..a1725cf2a82 100644 --- a/hopper/flash_fwd_combine.cu +++ b/hopper/flash_fwd_combine.cu @@ -5,15 +5,9 @@ template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index aaec31e5807..20685a15656 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -40,11 +40,11 @@ class FlashAttnFwdCombine { static constexpr uint32_t MinBlocksPerMultiprocessor = 2; static constexpr int kBlockM = get<0>(TileShape_MK{}); - static constexpr int kHeadDim = get<1>(TileShape_MK{}); + static constexpr int kBlockK = get<1>(TileShape_MK{}); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); - static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static_assert(kBlockK % kGmemElemsPerLoad == 0, "kBlockK must be a multiple of kGmemElemsPerLoad"); + static constexpr int kBlockKGmem = kBlockK % 128 == 0 ? 128 : (kBlockK % 64 == 0 ? 64 : 32); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); using GmemCopyAtom = std::conditional_t< @@ -98,8 +98,8 @@ class FlashAttnFwdCombine { Stride, _1>>{})); using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape, Int>{})); - using SmemLayoutO = Layout, Int, Int>, - Stride, _1, Int>>; + using SmemLayoutO = Layout, Int, Int>, + Stride, _1, Int>>; // We want each column (kMaxSplits) to be processed by threads in the same warp. // To reduce the number of shuffles, we want as few threads on the same column as possible. @@ -194,7 +194,8 @@ class FlashAttnFwdCombine { Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{}); int const thread_idx = threadIdx.x; - int const m_block = blockIdx.x; + int const k_block = blockIdx.x; + int const m_block = blockIdx.y; int const batch = !Varlen ? 0 : blockIdx.y; int const num_splits = get<1>(params.shape_LSE_partial); flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; @@ -254,7 +255,8 @@ class FlashAttnFwdCombine { Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO); - Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)), params.shape_O_partial, params.stride_O_partial); // (seqlen, d, num_splits, head, batch) + Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)), + params.shape_O_partial, params.stride_O_partial); // (seqlen, d, num_splits, head, batch) // Precompute these values to avoid recomputing them in the loop Tensor tOmidx = make_tensor(make_shape(size<1>(tOcO))); @@ -271,7 +273,7 @@ class FlashAttnFwdCombine { tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx); tObidb[m] = 0; } - tOrOptr[m] = &mOpartial(tOmidx(m), _0{}, _0{}, tObidh(m), tObidb(m)); + tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m), tObidb(m)); if (idx >= max_idx) { tObidb[m] = -1; } @@ -280,7 +282,7 @@ class FlashAttnFwdCombine { Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); if constexpr (!(Is_even_K)) { #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial); } + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial) - k_block * kBlockK; } } Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO); @@ -358,26 +360,36 @@ class FlashAttnFwdCombine { // Store the scales exp(lse - lse_logsum) back to smem cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE); - // Step 5: store final LSE back to gmem - auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial); - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE); + // Store max_valid_split to smem #pragma unroll for (int m = 0; m < size<2>(ts2rrLSE); ++m) { - if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem + if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to smem int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); - int idx = m_block * kBlockM + mi; - if (idx < max_idx) { - int m_idx, bidh, bidb; - if constexpr (!Varlen) { - bidb = params.head_divmod.divmod(bidh, params.seqlen_divmod.divmod(m_idx, idx)); - } else { - bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); - bidb = 0; + if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; } + } + } + + // Step 5: store final LSE back to gmem + if (k_block == 0) { + auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial); + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE); + #pragma unroll + for (int m = 0; m < size<2>(ts2rrLSE); ++m) { + if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem + int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); + int idx = m_block * kBlockM + mi; + if (idx < max_idx) { + int m_idx, bidh, bidb; + if constexpr (!Varlen) { + bidb = params.head_divmod.divmod(bidh, params.seqlen_divmod.divmod(m_idx, idx)); + } else { + bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); + bidb = 0; + } + // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m)); + mLSE(m_idx, bidh, bidb) = lse_sum(m); } - // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m)); - mLSE(m_idx, bidh, bidb) = lse_sum(m); } - if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; } } } @@ -427,8 +439,9 @@ class FlashAttnFwdCombine { // Step 7: Write the final O to gmem Tensor rO = make_tensor_like(tOrO); flash::convert_type_out(tOrO, rO); - auto shape_O = select<0, 1, 3, 4>(params.shape_O_partial); - Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O)), shape_O, params.stride_O); + auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial)); + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)), + shape_O, params.stride_O); Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int>{}); GmemTiledCopy gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index 5cbed2b0c74..101f894b2d6 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -16,9 +16,9 @@ using namespace cute; -template +template void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { - using TileShape_MK = cute::Shape, Int>; + using TileShape_MK = cute::Shape, Int>; using CombineKernel = flash::FlashAttnFwdCombine; @@ -37,8 +37,9 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { }; typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); + int num_blocks_k = cute::ceil_div(params.dv, kBlockK); int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h * (!Varlen ? params.b : 1), kBlockM); - dim3 grid_m(num_blocks_m, !Varlen ? 1 : params.b); + dim3 grid_m(num_blocks_k, num_blocks_m, !Varlen ? 1 : params.b); auto kernel = cutlass::device_kernel; int smem_size = CombineKernel::SharedStorageSize; if (smem_size >= 48 * 1024) { @@ -48,27 +49,27 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { CHECK_CUDA_KERNEL_LAUNCH(); } -template +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream) { // We want kBlockM to be as small as possible to maximize parallelism. // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). - static_assert(kHeadDim % 32 == 0, "kHeadDim must be a multiple of 32"); - static constexpr int kBlockM = kHeadDim % 128 == 0 ? 8 : (kHeadDim % 64 == 0 ? 16 : 32); + static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32"); + static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32); BOOL_SWITCH(params.seqused_q != nullptr, Varlen, [&] { if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32. if (params.num_splits <= 16) { - run_flash_fwd_combine(params, stream); + run_flash_fwd_combine(params, stream); return; } } if (params.num_splits <= 32) { - run_flash_fwd_combine(params, stream); + run_flash_fwd_combine(params, stream); } else if (params.num_splits <= 64) { - run_flash_fwd_combine(params, stream); + run_flash_fwd_combine(params, stream); } else if (params.num_splits <= 128) { - run_flash_fwd_combine(params, stream); + run_flash_fwd_combine(params, stream); } else { - run_flash_fwd_combine(params, stream); + run_flash_fwd_combine(params, stream); } }); } From db8ca796092463a38db8faf1089bde4f29def745 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 13:02:28 -0500 Subject: [PATCH 23/32] Use bf16 instead of fp16 in benchmark_gemm.py --- benchmarks/benchmark_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmark_gemm.py b/benchmarks/benchmark_gemm.py index df0d56b8f23..3f3639e0b53 100644 --- a/benchmarks/benchmark_gemm.py +++ b/benchmarks/benchmark_gemm.py @@ -26,7 +26,7 @@ def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs torch.manual_seed(0) repeats = 30 -dtype = torch.float16 +dtype = torch.bfloat16 device = 'cuda' verbose = False m, n = 8192, 8192 From 982c480c57c1b9a8e8ec3f70358957c69355f47a Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 15:52:15 -0500 Subject: [PATCH 24/32] Update Cutlass to 3.7 --- csrc/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cutlass b/csrc/cutlass index c506e16788c..b78588d1630 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit c506e16788cb08416a4a57e11a9067beeee29420 +Subproject commit b78588d1630aa6643bf021613717bafb705df4ef From 58ebfa5865516c7fb4ad83783501c802484260bb Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 16:02:41 -0500 Subject: [PATCH 25/32] Use nvcc 12.6 but ptxas 12.8 --- hopper/benchmark_attn.py | 8 ++++---- hopper/setup.py | 23 ++++++++++++++++------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index e61cea9e67e..6dc253e00f6 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -261,9 +261,9 @@ def run(*args, **kwargs): # for headdim in [64, 96, 128, 192]: # for headdim in [64, 96, 128, 192, 256]: # for headdim in [64, 96, 128]: -# for headdim in [64, 128, 256]: +for headdim in [64, 128, 256]: # for headdim in [64, 96, 128, 192, 256]: -for headdim in [192]: +# for headdim in [128]: nheads = dim // headdim # headdim = 64 # batch_size = 64 @@ -276,7 +276,7 @@ def run(*args, **kwargs): # headdim_v = 128 for batch_size, seqlen in bs_seqlen_vals: - num_splits = 1 + num_splits = 0 window_size = (-1, -1) # window_size = (seqlen // 2 - 1, 0) sink_token_length = 0 @@ -320,7 +320,7 @@ def run(*args, **kwargs): page_table = None for causal in [False, True]: - # for causal in [False]: + # for causal in [True]: print(f"\n### {headdim = }, {causal = }, {seqlen = } ###") nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, headdim_v, causal=causal, window_size=window_size) if cudnn is not None: diff --git a/hopper/setup.py b/hopper/setup.py index db89902550b..1fb22acae43 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -366,7 +366,7 @@ def nvcc_threads_args(): # NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"} -NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.8.61"} +NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.6.85", "ptxas": "12.8.61", "cicc": "12.8.61"} exe_extension = sysconfig.get_config_var("EXE") @@ -387,10 +387,12 @@ def nvcc_threads_args(): if bare_metal_version < Version("12.3"): raise RuntimeError("FlashAttention-3 is only supported on CUDA 12.3 and above") - if bare_metal_version != Version("12.8"): # nvcc 12.8 gives the best perf currently + # ptxas 12.8 gives the best perf currently + # We want to use the nvcc front end from 12.6 however, since if we use nvcc 12.8 + # Cutlass 3.8 will expect the new data types in cuda.h from CTK 12.8, which we don't have. + if bare_metal_version != Version("12.8"): download_and_copy( name="nvcc", - # src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas{exe_extension}", src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin", dst_path="bin", version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], @@ -398,11 +400,18 @@ def nvcc_threads_args(): f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", ) download_and_copy( - name="nvcc", - # src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas{exe_extension}", + name="ptxas", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas", + dst_path="bin", + version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", + ) + download_and_copy( + name="cicc", src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/nvvm/bin", dst_path="nvvm/bin", - version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], + version=NVIDIA_TOOLCHAIN_VERSION["cicc"], url_func=lambda system, arch, version: f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", ) @@ -411,7 +420,7 @@ def nvcc_threads_args(): nvcc_path_new = os.path.join(ctk_path_new, f"nvcc{exe_extension}") # Need to append to path otherwise nvcc can't find cicc in nvvm/bin/cicc # nvcc 12.8 seems to hard-code looking for cicc in ../nvvm/bin/cicc - # os.environ["PATH"] = ctk_path_new + os.pathsep + os.environ["PATH"] + os.environ["PATH"] = ctk_path_new + os.pathsep + os.environ["PATH"] os.environ["PYTORCH_NVCC"] = nvcc_path_new # Make nvcc executable, sometimes after the copy it loses its permissions os.chmod(nvcc_path_new, os.stat(nvcc_path_new).st_mode | stat.S_IEXEC) From ed435c6b364288b3a98a1ec26975adfa9f645f6b Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 16:11:22 -0500 Subject: [PATCH 26/32] cicc uses the same version as ptxas --- hopper/setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hopper/setup.py b/hopper/setup.py index 1fb22acae43..30063dd9350 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -366,7 +366,7 @@ def nvcc_threads_args(): # NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"} -NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.6.85", "ptxas": "12.8.61", "cicc": "12.8.61"} +NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.6.85", "ptxas": "12.8.61"} exe_extension = sysconfig.get_config_var("EXE") @@ -408,10 +408,10 @@ def nvcc_threads_args(): f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", ) download_and_copy( - name="cicc", + name="ptxas", src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/nvvm/bin", dst_path="nvvm/bin", - version=NVIDIA_TOOLCHAIN_VERSION["cicc"], + version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], url_func=lambda system, arch, version: f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", ) From 86688236356ad19f560a698525c47f99b06531f2 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 16:32:59 -0500 Subject: [PATCH 27/32] Split hdimdiff into a separate translation unit --- hopper/generate_kernels.py | 11 ++++++++++- .../flash_fwd_hdim64_512_bf16_packgqa_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_bf16_paged_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_bf16_paged_split_sm90.cu | 9 +++++++++ ...sh_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu | 9 +++++++++ .../instantiations/flash_fwd_hdim64_512_bf16_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_bf16_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_bf16_split_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_fp16_packgqa_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_fp16_paged_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_fp16_paged_split_sm90.cu | 9 +++++++++ ...sh_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu | 9 +++++++++ .../instantiations/flash_fwd_hdim64_512_fp16_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_fp16_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_fp16_split_sm90.cu | 9 +++++++++ .../flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu | 9 +++++++++ .../flash_fwd_hdimall_bf16_packgqa_sm90.cu | 1 - .../flash_fwd_hdimall_bf16_paged_sm90.cu | 1 - .../flash_fwd_hdimall_bf16_paged_softcap_sm90.cu | 1 - .../flash_fwd_hdimall_bf16_paged_split_sm90.cu | 1 - ...flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu | 1 - hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu | 1 - .../flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu | 1 - .../flash_fwd_hdimall_bf16_softcap_sm90.cu | 1 - .../flash_fwd_hdimall_bf16_split_sm90.cu | 1 - .../flash_fwd_hdimall_bf16_split_softcap_sm90.cu | 1 - .../flash_fwd_hdimall_e4m3_packgqa_sm90.cu | 1 - .../flash_fwd_hdimall_e4m3_paged_sm90.cu | 1 - .../flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu | 1 - .../flash_fwd_hdimall_e4m3_paged_split_sm90.cu | 1 - ...flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu | 1 - hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu | 1 - .../flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu | 1 - .../flash_fwd_hdimall_e4m3_softcap_sm90.cu | 1 - .../flash_fwd_hdimall_e4m3_split_sm90.cu | 1 - .../flash_fwd_hdimall_e4m3_split_softcap_sm90.cu | 1 - .../flash_fwd_hdimall_fp16_packgqa_sm90.cu | 1 - .../flash_fwd_hdimall_fp16_paged_sm90.cu | 1 - .../flash_fwd_hdimall_fp16_paged_softcap_sm90.cu | 1 - .../flash_fwd_hdimall_fp16_paged_split_sm90.cu | 1 - ...flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu | 1 - hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu | 1 - .../flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu | 1 - .../flash_fwd_hdimall_fp16_softcap_sm90.cu | 1 - .../flash_fwd_hdimall_fp16_split_sm90.cu | 1 - .../flash_fwd_hdimall_fp16_split_softcap_sm90.cu | 1 - .../flash_fwd_hdimdiff_bf16_packgqa_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_bf16_paged_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_bf16_paged_split_sm90.cu | 6 ++++++ ...lash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu | 6 ++++++ hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_bf16_softcap_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_bf16_split_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu | 5 +++++ .../flash_fwd_hdimdiff_e4m3_paged_sm90.cu | 5 +++++ .../flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu | 5 +++++ .../flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu | 5 +++++ ...lash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu | 5 +++++ hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu | 5 +++++ .../flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu | 5 +++++ .../flash_fwd_hdimdiff_e4m3_softcap_sm90.cu | 5 +++++ .../flash_fwd_hdimdiff_e4m3_split_sm90.cu | 5 +++++ .../flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu | 5 +++++ .../flash_fwd_hdimdiff_fp16_packgqa_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_fp16_paged_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_fp16_paged_split_sm90.cu | 6 ++++++ ...lash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu | 6 ++++++ hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_fp16_softcap_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_fp16_split_sm90.cu | 6 ++++++ .../flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu | 6 ++++++ hopper/setup.py | 2 +- 82 files changed, 361 insertions(+), 32 deletions(-) create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu create mode 100644 hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu diff --git a/hopper/generate_kernels.py b/hopper/generate_kernels.py index 7a5eb47d08b..19a6e90d345 100644 --- a/hopper/generate_kernels.py +++ b/hopper/generate_kernels.py @@ -138,6 +138,8 @@ def get_all_kernels() -> List[Kernel]: yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") if sm == 90 and head_dim == 192: yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=128, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") + if sm == 90 and head_dim == 64 and dtype in ["bf16", "fp16"]: + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=512, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") for dtype, head_dim, softcap, sm in itertools.product(DTYPE_MAP_BWD.keys(), HEAD_DIMENSIONS, SOFTCAP, SM): yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction="bwd") @@ -146,11 +148,18 @@ def batch_hdim(kernels_all) -> List[KERNEL_BATCH]: for dtype, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM): if sm < 90: continue - kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm] + # Same hdim and hdimv + kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm and k.head_dim == k.head_dim_v] if len(kernels) > 0: filename = f"flash_fwd_hdimall_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}{'_softcap' if softcap else ''}{'_packgqa' if packgqa else ''}_sm{sm}.cu" template = "\n".join([f"#include \"{k.filename}\"" for k in kernels]) yield KERNEL_BATCH(template, filename) + # Different hdim and hdimv + kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm and k.head_dim != k.head_dim_v] + if len(kernels) > 0: + filename = f"flash_fwd_hdimdiff_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}{'_softcap' if softcap else ''}{'_packgqa' if packgqa else ''}_sm{sm}.cu" + template = "\n".join([f"#include \"{k.filename}\"" for k in kernels]) + yield KERNEL_BATCH(template, filename) def batch_softcap(kernels_all) -> List[KERNEL_BATCH]: diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu new file mode 100644 index 00000000000..2f4ceaaed53 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu new file mode 100644 index 00000000000..5fd59af3486 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu new file mode 100644 index 00000000000..e0f885b0f72 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu new file mode 100644 index 00000000000..6dcda019627 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..5d20be6d2a7 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu new file mode 100644 index 00000000000..47463a7151c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..622b5533ce8 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu new file mode 100644 index 00000000000..c83f44722cd --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu new file mode 100644 index 00000000000..5c9130f8648 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu new file mode 100644 index 00000000000..a152022cb65 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu new file mode 100644 index 00000000000..ef05aa2038d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu new file mode 100644 index 00000000000..19fe6d94f7d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu new file mode 100644 index 00000000000..6eb2d3d134b --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu new file mode 100644 index 00000000000..ffbc9982122 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..3d35075b48d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu new file mode 100644 index 00000000000..c2af33cf533 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..e07547c92d0 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu new file mode 100644 index 00000000000..1a04eb01f5e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu new file mode 100644 index 00000000000..da9afc11571 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu new file mode 100644 index 00000000000..5e63a15515f --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu index e8ed21cda49..8b659e8321b 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim128_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim192_bf16_packgqa_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_packgqa_sm90.cu" #include "flash_fwd_hdim256_bf16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu index f7de8fa2019..c84d02b6d04 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_paged_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_paged_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu index 64e5ce4a33f..6aaf7d12f56 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu index 44619cce59b..11712141419 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_split_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_paged_split_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu index a059735824d..6175723086c 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu index daea288fe3a..2aac1970b1b 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_sm90.cu" #include "flash_fwd_hdim128_bf16_sm90.cu" #include "flash_fwd_hdim192_bf16_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_sm90.cu" #include "flash_fwd_hdim256_bf16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu index 62640192c68..be0c5af080b 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu index 79b0d52fa55..fd5893c59f4 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu index 333406cb439..bcde9c94582 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_split_sm90.cu" #include "flash_fwd_hdim128_bf16_split_sm90.cu" #include "flash_fwd_hdim192_bf16_split_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_split_sm90.cu" #include "flash_fwd_hdim256_bf16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu index b6c1fb54c4a..160eb3a18e4 100644 --- a/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim128_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_bf16_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu" #include "flash_fwd_hdim256_bf16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu index abf0b10e46e..28819a690a3 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_packgqa_sm90.cu" #include "flash_fwd_hdim128_e4m3_packgqa_sm90.cu" #include "flash_fwd_hdim192_e4m3_packgqa_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu" #include "flash_fwd_hdim256_e4m3_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu index 22b310e5aba..933ad982719 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_paged_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_paged_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu index f9eed0732d7..a934f7d9924 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu index b91c7f85ad7..8475e878ae2 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_paged_split_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_split_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_split_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu index a6b215bfdfd..dd1405b17f0 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu index ddec44c68ca..7e7d806c6d5 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_sm90.cu" #include "flash_fwd_hdim128_e4m3_sm90.cu" #include "flash_fwd_hdim192_e4m3_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_sm90.cu" #include "flash_fwd_hdim256_e4m3_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu index 81601b9ec21..f973a4e411d 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu index ae9a362c109..30390838d39 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu index 163ee761be1..0b629bd2b32 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_split_sm90.cu" #include "flash_fwd_hdim128_e4m3_split_sm90.cu" #include "flash_fwd_hdim192_e4m3_split_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_split_sm90.cu" #include "flash_fwd_hdim256_e4m3_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu index ba2d427ddd4..818c7fafb7a 100644 --- a/hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_e4m3_split_softcap_sm90.cu" #include "flash_fwd_hdim128_e4m3_split_softcap_sm90.cu" #include "flash_fwd_hdim192_e4m3_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu" #include "flash_fwd_hdim256_e4m3_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu index 34d1763483a..6652824d075 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim128_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim192_fp16_packgqa_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_packgqa_sm90.cu" #include "flash_fwd_hdim256_fp16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu index 326a2ea901a..05d11e2e258 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_paged_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_paged_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu index a9e032a071c..b638138eb26 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu index d7cc300b89b..3619a2175f0 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_split_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_paged_split_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu index fa4de4e298f..3a408ceacbd 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu index cb345586694..eec11be9162 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_sm90.cu" #include "flash_fwd_hdim128_fp16_sm90.cu" #include "flash_fwd_hdim192_fp16_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_sm90.cu" #include "flash_fwd_hdim256_fp16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu index 5dbd70ec5d3..ca2a1e1b843 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu" #include "flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu index 9a97b96041b..8cf31a8a85f 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu index 5aacbf02664..5ee7ace63ac 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_split_sm90.cu" #include "flash_fwd_hdim128_fp16_split_sm90.cu" #include "flash_fwd_hdim192_fp16_split_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_split_sm90.cu" #include "flash_fwd_hdim256_fp16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu index cfaabd990ab..4da0ee704eb 100644 --- a/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu @@ -6,5 +6,4 @@ #include "flash_fwd_hdim96_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim128_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim192_fp16_split_softcap_sm90.cu" -#include "flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu" #include "flash_fwd_hdim256_fp16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu new file mode 100644 index 00000000000..cc3a8a7c913 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu new file mode 100644 index 00000000000..d6d6df0d4ee --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_paged_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu new file mode 100644 index 00000000000..bd85f7608f6 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu new file mode 100644 index 00000000000..733511adb43 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_paged_split_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..c62ccf28d3c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu new file mode 100644 index 00000000000..b7e51551a04 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..0dbd0045425 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu new file mode 100644 index 00000000000..51a14371284 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu new file mode 100644 index 00000000000..24a64e8e49e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_split_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu new file mode 100644 index 00000000000..50c78f3d5d4 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu new file mode 100644 index 00000000000..526a51fb71e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu new file mode 100644 index 00000000000..4e5d9cc4fe2 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu new file mode 100644 index 00000000000..f553af139f2 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu new file mode 100644 index 00000000000..aa2a8260d25 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..bbc4449ba21 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu new file mode 100644 index 00000000000..02ca85ad672 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..d090fde972b --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu new file mode 100644 index 00000000000..d48f60ad7e2 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu new file mode 100644 index 00000000000..9dda19d1cea --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu new file mode 100644 index 00000000000..f3e51fc9ebd --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu new file mode 100644 index 00000000000..453282a4f29 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu new file mode 100644 index 00000000000..72736d8ef7a --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_paged_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu new file mode 100644 index 00000000000..97895aa708c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu new file mode 100644 index 00000000000..423c42221e0 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_paged_split_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..98c89572117 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu new file mode 100644 index 00000000000..69108d025fa --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..da39ba2731a --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu new file mode 100644 index 00000000000..be6496d1956 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu new file mode 100644 index 00000000000..a5a80909072 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_split_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu new file mode 100644 index 00000000000..62fe142562d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/setup.py b/hopper/setup.py index 30063dd9350..560ddcc1cc3 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -470,7 +470,7 @@ def nvcc_threads_args(): + ([192] if not DISABLE_HDIM192 else []) + ([256] if not DISABLE_HDIM256 else []) ) - HEAD_DIMENSIONS_FWD = ["all"] + HEAD_DIMENSIONS_FWD = ["all", "diff"] HEAD_DIMENSIONS_FWD_SM80 = HEAD_DIMENSIONS_BWD SPLIT = [""] + (["_split"] if not DISABLE_SPLIT else []) PAGEDKV = [""] + (["_paged"] if not DISABLE_PAGEDKV else []) From b2fc79d17526ab56d7561091441a62f241056a4b Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 16:39:17 -0500 Subject: [PATCH 28/32] Update benchmark script --- hopper/benchmark_attn.py | 28 ++-------------------------- 1 file changed, 2 insertions(+), 26 deletions(-) diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index 6dc253e00f6..5d1f5369214 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -242,28 +242,13 @@ def run(*args, **kwargs): time_f = {} time_b = {} -# tflops_matmul = {} -# m, n = 8192, 8192 -# for k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]: -# a = torch.randn(m, k, device=device, dtype=dtype) -# b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2) -# nFLOPS_matmul = 2 * m * n * k -# m5 = time_fwd(torch.matmul, a, b, desc='cuBLAS') -# print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS') -# tflops_matmul[k] = nFLOPS_matmul / m5.mean * 1e-12 -# # import pickle -# # # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp: -# # with open(f'flash3_matmul_tflops_h100.plk', 'wb') as fp: -# # pickle.dump(tflops_matmul, fp, protocol=pickle.HIGHEST_PROTOCOL) -# exit(0) - # for headdim in [64, 128, 256]: # for headdim in [64, 96, 128, 192]: # for headdim in [64, 96, 128, 192, 256]: # for headdim in [64, 96, 128]: -for headdim in [64, 128, 256]: +# for headdim in [64, 128, 256]: # for headdim in [64, 96, 128, 192, 256]: -# for headdim in [128]: +for headdim in [128]: nheads = dim // headdim # headdim = 64 # batch_size = 64 @@ -297,10 +282,6 @@ def run(*args, **kwargs): g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) o = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32) - a = torch.randn(batch_size, seqlen, seqlen, device=device, dtype=dtype_gen) - b = torch.randn(batch_size, dim * 2, seqlen, device=device, dtype=dtype_gen).transpose(-1, -2) - # x = torch.randn(batch_size * seqlen, 4096, device=device, dtype=dtype) - # w = torch.randn(4096 * 2, 4096, device=device, dtype=dtype).transpose(-1, -2) if varlen: q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_() for x in [q, k, v]] cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q @@ -377,11 +358,6 @@ def run(*args, **kwargs): m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits) time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean - # time.sleep(1) - # m5 = time_fwd(torch.bmm, a, b, desc='cuBLAS', repeats=repeats, verbose=False) - # nFLOPS_matmul = nFLOPS - # nFLOPS_matmul = 2 * x.shape[0] * x.shape[1] * w.shape[1] - # m5 = time_fwd(torch.matmul, x, w, desc='cuBLAS') if dtype != torch.float8_e4m3fn and headdim == headdim_v: time.sleep(1) if not varlen: From c09154572015e803123a5c875e7548cef423cd90 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 18:32:23 -0500 Subject: [PATCH 29/32] Update Cutlass to 3.8 --- csrc/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cutlass b/csrc/cutlass index b78588d1630..833f6990e03 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit b78588d1630aa6643bf021613717bafb705df4ef +Subproject commit 833f6990e031b48b4cd2fcf55e0849c51ef6bac2 From 5e39b100b421e104c3dca3011353e9889e8839ea Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 18:47:23 -0500 Subject: [PATCH 30/32] Adjust tile size for hdim 64 --- hopper/tile_size.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 997664bcbc5..5d0bd6e2634 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -13,7 +13,11 @@ constexpr std::tuple tile_size_fwd_sm90( if (element_size == 2) { if (headdim <= 64) { bool same_hdim = (headdim == headdim_v); // if not same hdim, we're targeting hdimv=512 - return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, true}; + // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, true}; + // With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why + // https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131 + // Switch to tile size 192 x 192 for now + return {same_hdim ? 192 : 64, same_hdim ? 192 : 64, false, true}; // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; } else if (headdim <= 96) { From 1a7f4dfa9e51f6a90177a3244a5bc9c687894cdd Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 9 Feb 2025 19:01:26 -0500 Subject: [PATCH 31/32] Adjust ninja build file --- hopper/setup.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/hopper/setup.py b/hopper/setup.py index 560ddcc1cc3..f638558a0a9 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -150,6 +150,8 @@ def sanitize_flags(flags): flags.append(f'cuda_post_cflags_sm80 = {" ".join(cuda_post_cflags_sm80)}') cuda_post_cflags_sm80_sm90 = cuda_post_cflags + ['-gencode', 'arch=compute_80,code=sm_80'] flags.append(f'cuda_post_cflags_sm80_sm90 = {" ".join(cuda_post_cflags_sm80_sm90)}') + cuda_post_cflags_sm100 = [s if s != 'arch=compute_90a,code=sm_90a' else 'arch=compute_100a,code=sm_100a' for s in cuda_post_cflags] + flags.append(f'cuda_post_cflags_sm100 = {" ".join(cuda_post_cflags_sm100)}') flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}') flags.append(f'ldflags = {" ".join(ldflags)}') @@ -187,6 +189,9 @@ def sanitize_flags(flags): cuda_compile_rule_sm80_sm90 = ['rule cuda_compile_sm80_sm90'] + cuda_compile_rule[1:] + [ f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80_sm90' ] + cuda_compile_rule_sm100 = ['rule cuda_compile_sm100'] + cuda_compile_rule[1:] + [ + f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm100' + ] cuda_compile_rule.append( f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags') @@ -199,6 +204,8 @@ def sanitize_flags(flags): rule = 'cuda_compile' elif source_file.endswith('_sm80.cu'): rule = 'cuda_compile_sm80' + elif source_file.endswith('_sm100.cu'): + rule = 'cuda_compile_sm100' else: rule = 'cuda_compile_sm80_sm90' else: @@ -244,6 +251,7 @@ def sanitize_flags(flags): blocks.append(cuda_compile_rule) # type: ignore[possibly-undefined] blocks.append(cuda_compile_rule_sm80) # type: ignore[possibly-undefined] blocks.append(cuda_compile_rule_sm80_sm90) # type: ignore[possibly-undefined] + blocks.append(cuda_compile_rule_sm100) # type: ignore[possibly-undefined] blocks += [devlink_rule, link_rule, build, devlink, link, default] content = "\n\n".join("\n".join(b) for b in blocks) # Ninja requires a new lines at the end of the .ninja file From 21fcadeabc0d65ec41e0cea07a565fe90bbef8db Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 11 Feb 2025 15:32:37 +0000 Subject: [PATCH 32/32] build head diff + fix build errors Signed-off-by: Lucas Wilkinson --- CMakeLists.txt | 6 ++++ hopper/flash_api_torch_lib.cpp | 42 +++++++++++++------------ vllm_flash_attn/flash_attn_interface.py | 4 ++- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 759c87f2e9d..e229b6f92ef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -176,12 +176,18 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0) # BF16 source files file(GLOB FA3_BF16_GEN_SRCS "hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu") + file(GLOB FA3_BF16_GEN_SRCS_ + "hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu") + list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_}) file(GLOB FA3_BF16_GEN_SRCS_ "hopper/instantiations/flash_fwd_*_bf16_*_sm80.cu") list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_}) # FP16 source files file(GLOB FA3_FP16_GEN_SRCS "hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu") + file(GLOB FA3_FP16_GEN_SRCS_ + "hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu") + list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_}) file(GLOB FA3_FP16_GEN_SRCS_ "hopper/instantiations/flash_fwd_*_fp16_*_sm80.cu") list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_}) diff --git a/hopper/flash_api_torch_lib.cpp b/hopper/flash_api_torch_lib.cpp index 81b522b17fa..2406d1a5076 100644 --- a/hopper/flash_api_torch_lib.cpp +++ b/hopper/flash_api_torch_lib.cpp @@ -12,26 +12,27 @@ std::vector mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. - const at::Tensor &v, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. - c10::optional &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - c10::optional &v_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - c10::optional &out_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - c10::optional &cu_seqlens_q_, // b+1 - c10::optional &cu_seqlens_k_, // b+1 - c10::optional &cu_seqlens_k_new_, // b+1 - c10::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - c10::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. - c10::optional max_seqlen_q_, + const at::Tensor &v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. + std::optional &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional &q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional &cu_seqlens_q_, // b+1 + std::optional &cu_seqlens_k_, // b+1 + std::optional &cu_seqlens_k_new_, // b+1 + std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, // TODO: check if we need max_seqlen_k - c10::optional max_seqlen_k_, - c10::optional &page_table_, // (b_k, max_num_pages_per_seq) - c10::optional &kv_batch_idx_, // b. indices to index into the KV cache - c10::optional &leftpad_k_, // b - c10::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) - c10::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) - c10::optional &q_descale_, // (b, h_k), not (b, h) - c10::optional &k_descale_, // (b, h_k) - c10::optional &v_descale_, // (b, h_k) + std::optional max_seqlen_k_, + std::optional &page_table_, // (b_k, max_num_pages_per_seq) + std::optional &kv_batch_idx_, // b. indices to index into the KV cache + std::optional &leftpad_k_, // b + std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional &q_descale_, // (b, h_k), not (b, h) + std::optional &k_descale_, // (b, h_k) + std::optional &v_descale_, // (b, h_k) float const softmax_scale, bool is_causal, int window_size_left, @@ -40,7 +41,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq float const softcap, bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits, - c10::optional pack_gqa_, + std::optional pack_gqa_, int const sm_margin); /** @@ -52,6 +53,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor v," " Tensor? k_new," " Tensor? v_new," + " Tensor? q_v," " Tensor!? out," " Tensor? cu_seqlens_q," " Tensor? cu_seqlens_k," diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 81e2c22e57f..9bf430b3672 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -82,6 +82,7 @@ def flash_attn_varlen_func( max_seqlen_k, cu_seqlens_k=None, # only used for non-paged prefill seqused_k=None, + q_v=None, dropout_p=0.0, softmax_scale=None, causal=False, @@ -91,7 +92,6 @@ def flash_attn_varlen_func( deterministic=False, return_attn_probs=False, block_table=None, - *, return_softmax_lse=False, out=None, fa_version: int = DEFAULT_FA_VERSION, @@ -196,6 +196,7 @@ def flash_attn_varlen_func( out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd( q, k, v, None, None, # k_new, v_new + q_v, # out, cu_seqlens_q, cu_seqlens_k, # cu_seqlens_k @@ -369,6 +370,7 @@ def flash_attn_with_kvcache( out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd( q, k_cache, v_cache, # q, k, v k, v, # k_new, v_new + None, # q_v out, None, None, # cu_seqlens_q, cu_seqlens_k None, # cu_seqlens_k_new