diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 4746c714930..6f227d1abe1 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -44,8 +44,8 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] - torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001'] - cuda-version: ['11.8.0', '12.3.2'] + torch-version: ['2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0'] + cuda-version: ['12.4.1'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) @@ -53,12 +53,7 @@ jobs: cxx11_abi: ['FALSE', 'TRUE'] exclude: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix - # Pytorch < 2.2 does not support Python 3.12 - - torch-version: '2.1.2' - python-version: '3.12' # Pytorch < 2.5 does not support Python 3.13 - - torch-version: '2.1.2' - python-version: '3.13' - torch-version: '2.2.2' python-version: '3.13' - torch-version: '2.3.1' @@ -113,7 +108,7 @@ jobs: run: | pip install --upgrade pip # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools - pip install setuptools==68.0.0 + pip install setuptools==75.8.0 # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable pip install typing-extensions==4.12.2 @@ -122,8 +117,8 @@ jobs: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix # This code is ugly, maybe there's a better way to do this. export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \ - maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \ + minv = {'2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \ print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ ) if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then @@ -149,7 +144,7 @@ jobs: # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 # However this still fails so I'm using a newer version of setuptools - pip install setuptools==68.0.0 + pip install setuptools==75.8.0 pip install ninja packaging wheel export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH @@ -203,7 +198,9 @@ jobs: - name: Install dependencies run: | - pip install ninja packaging setuptools wheel twine + pip install ninja packaging wheel twine + # Install latest setuptools with support for pypi metadata 2.2 (improved compat w/ uv) + pip install setuptools==75.8.0 # We don't want to download anything CUDA-related here pip install torch --index-url https://download.pytorch.org/whl/cpu diff --git a/CMakeLists.txt b/CMakeLists.txt index 759c87f2e9d..e229b6f92ef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -176,12 +176,18 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0) # BF16 source files file(GLOB FA3_BF16_GEN_SRCS "hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu") + file(GLOB FA3_BF16_GEN_SRCS_ + "hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu") + list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_}) file(GLOB FA3_BF16_GEN_SRCS_ "hopper/instantiations/flash_fwd_*_bf16_*_sm80.cu") list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_}) # FP16 source files file(GLOB FA3_FP16_GEN_SRCS "hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu") + file(GLOB FA3_FP16_GEN_SRCS_ + "hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu") + list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_}) file(GLOB FA3_FP16_GEN_SRCS_ "hopper/instantiations/flash_fwd_*_fp16_*_sm80.cu") list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_}) diff --git a/MANIFEST.in b/MANIFEST.in index 65e31646ed3..90378828ebf 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,6 +3,7 @@ recursive-include csrc *.h recursive-include csrc *.cuh recursive-include csrc *.cpp recursive-include csrc *.hpp +recursive-include csrc *.py recursive-include vllm_flash_attn *.cu recursive-include vllm_flash_attn *.h diff --git a/README.md b/README.md index 033dba41006..c5d68536d4b 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ Currently released: Requirements: H100 / H800 GPU, CUDA >= 12.3. -For now, we highly recommend CUDA 12.3 for best performance. +We highly recommend CUDA 12.8 for best performance. To install: ```sh @@ -65,7 +65,7 @@ flash_attn_interface.flash_attn_func() ## Installation and features **Requirements:** - CUDA toolkit or ROCm toolkit -- PyTorch 1.12 and above. +- PyTorch 2.2 and above. - `packaging` Python package (`pip install packaging`) - `ninja` Python package (`pip install ninja`) * - Linux. Might work for Windows starting v2.3.2 (we've seen a few positive [reports](https://github.com/Dao-AILab/flash-attention/issues/595)) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue. @@ -98,7 +98,7 @@ MAX_JOBS=4 pip install flash-attn --no-build-isolation ### NVIDIA CUDA Support **Requirements:** -- CUDA 11.7 and above. +- CUDA 12.0 and above. We recommend the [Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) diff --git a/benchmarks/benchmark_gemm.py b/benchmarks/benchmark_gemm.py index df0d56b8f23..3f3639e0b53 100644 --- a/benchmarks/benchmark_gemm.py +++ b/benchmarks/benchmark_gemm.py @@ -26,7 +26,7 @@ def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs torch.manual_seed(0) repeats = 30 -dtype = torch.float16 +dtype = torch.bfloat16 device = 'cuda' verbose = False m, n = 8192, 8192 diff --git a/csrc/cutlass b/csrc/cutlass index c506e16788c..833f6990e03 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit c506e16788cb08416a4a57e11a9067beeee29420 +Subproject commit 833f6990e031b48b4cd2fcf55e0849c51ef6bac2 diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 07d16cd0f48..db131242dd4 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.7.3" +__version__ = "2.7.4.post1" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/flash_attn/ops/triton/cross_entropy.py b/flash_attn/ops/triton/cross_entropy.py index 7b0315b9793..1b5a415b73f 100644 --- a/flash_attn/ops/triton/cross_entropy.py +++ b/flash_attn/ops/triton/cross_entropy.py @@ -166,6 +166,7 @@ def forward( if labels.dtype == torch.long and labels.data_ptr() % 16 != 0: labels = F.pad(labels, (0, 1))[..., :-1] assert labels.data_ptr() % 16 == 0 + assert logit_scale > 0.0 n_rows, n_cols = logits.shape assert labels.shape == (n_rows,) world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) diff --git a/hopper/benchmark_attn.py b/hopper/benchmark_attn.py index 5f7522a8ac3..5d1f5369214 100644 --- a/hopper/benchmark_attn.py +++ b/hopper/benchmark_attn.py @@ -56,7 +56,7 @@ def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs): return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3) -def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, window_size=(-1, -1)): +def flops(batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(-1, -1)): if causal: avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 else: @@ -67,7 +67,7 @@ def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, window_size= col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0)) col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1)) avg_seqlen = (col_right - col_left + 1).float().mean().item() - return batch * nheads * 2 * seqlen_q * avg_seqlen * headdim * 2 + return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) def convert_to_cudnn_type(torch_type): @@ -242,21 +242,6 @@ def run(*args, **kwargs): time_f = {} time_b = {} -# tflops_matmul = {} -# m, n = 8192, 8192 -# for k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]: -# a = torch.randn(m, k, device=device, dtype=dtype) -# b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2) -# nFLOPS_matmul = 2 * m * n * k -# m5 = time_fwd(torch.matmul, a, b, desc='cuBLAS') -# print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS') -# tflops_matmul[k] = nFLOPS_matmul / m5.mean * 1e-12 -# # import pickle -# # # with open(f'flash3_attn_time_h100_hdim{headdim}_causal.plk', 'wb') as fp: -# # with open(f'flash3_matmul_tflops_h100.plk', 'wb') as fp: -# # pickle.dump(tflops_matmul, fp, protocol=pickle.HIGHEST_PROTOCOL) -# exit(0) - # for headdim in [64, 128, 256]: # for headdim in [64, 96, 128, 192]: # for headdim in [64, 96, 128, 192, 256]: @@ -272,9 +257,11 @@ def run(*args, **kwargs): # headdim = 128 nheads_kv = nheads # nheads_kv = nheads // 4 + headdim_v = headdim + # headdim_v = 128 for batch_size, seqlen in bs_seqlen_vals: - num_splits = 1 + num_splits = 0 window_size = (-1, -1) # window_size = (seqlen // 2 - 1, 0) sink_token_length = 0 @@ -285,20 +272,16 @@ def run(*args, **kwargs): # leftpad_k = torch.full((batch_size,), 0, device=device, dtype=torch.int32) q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True) - v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True) + v = torch.randn(batch_size, seqlen, nheads_kv, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]] v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_() v_fa3 = v if not V_colmajor else v_colmajor # q = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) # k = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) - # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim), device=device, dtype=torch.int32).to(dtype) - g = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) - o = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True) + # v = torch.randint(-2, 3, (batch_size, seqlen, nheads, headdim_v), device=device, dtype=torch.int32).to(dtype) + g = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) + o = torch.randn(batch_size, seqlen_q, nheads, headdim_v, device=device, dtype=dtype_gen, requires_grad=True) stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32) - a = torch.randn(batch_size, seqlen, seqlen, device=device, dtype=dtype_gen) - b = torch.randn(batch_size, dim * 2, seqlen, device=device, dtype=dtype_gen).transpose(-1, -2) - # x = torch.randn(batch_size * seqlen, 4096, device=device, dtype=dtype) - # w = torch.randn(4096 * 2, 4096, device=device, dtype=dtype).transpose(-1, -2) if varlen: q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_() for x in [q, k, v]] cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q @@ -318,16 +301,16 @@ def run(*args, **kwargs): page_table = None for causal in [False, True]: - # for causal in [False]: + # for causal in [True]: print(f"\n### {headdim = }, {causal = }, {seqlen = } ###") - nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, causal=causal, window_size=window_size) + nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, headdim_v, causal=causal, window_size=window_size) if cudnn is not None: # if False: - if headdim <= 256 and dtype != torch.float8_e4m3fn: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0]) cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0]) # _, m0 = benchmark_forward(flash_attn_func, q, k, v, dropout_p, causal=causal, repeats=repeats, verbose=verbose, desc='Fav2') - if dtype != torch.float8_e4m3fn: + if dtype != torch.float8_e4m3fn and headdim == headdim_v: # if False: if not varlen: m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2') @@ -343,7 +326,7 @@ def run(*args, **kwargs): repeats=repeats, verbose=False, desc='Fav2') time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = m0b.mean # pytorch_profiler(flash_attn_func, q, k, v, dropout_p, causal=causal, backward=True) - if headdim <= 256 and dtype != torch.float8_e4m3fn: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: if triton_attention is not None: qt, kt, vt = [x.detach().transpose(1, 2).contiguous().requires_grad_() for x in [q, k, v]] time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark @@ -356,7 +339,7 @@ def run(*args, **kwargs): # # pytorch_profiler(triton_attention, q.transpose(1, 2).contiguous(), k.transpose(1, 2).contiguous(), v.transpose(1, 2).contiguous(), causal, 1 / math.sqrt(headdim), backward=True) if cudnn is not None: # if False: - if headdim <= 256 and dtype != torch.float8_e4m3fn: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: time.sleep(1) # Sleep to avoid residual power throttling from the previous benchmark m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN') time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean @@ -375,12 +358,7 @@ def run(*args, **kwargs): m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3') # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits) time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean - # time.sleep(1) - # m5 = time_fwd(torch.bmm, a, b, desc='cuBLAS', repeats=repeats, verbose=False) - # nFLOPS_matmul = nFLOPS - # nFLOPS_matmul = 2 * x.shape[0] * x.shape[1] * w.shape[1] - # m5 = time_fwd(torch.matmul, x, w, desc='cuBLAS') - if dtype != torch.float8_e4m3fn: + if dtype != torch.float8_e4m3fn and headdim == headdim_v: time.sleep(1) if not varlen: _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic, @@ -396,11 +374,11 @@ def run(*args, **kwargs): # pytorch_profiler(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, causal=causal, deterministic=deterministic, backward=True) # benchmark_forward(torch.clone, k, repeats=repeats, verbose=verbose, desc='Memcpy') - if dtype != torch.float8_e4m3fn: + if dtype != torch.float8_e4m3fn and headdim == headdim_v: # if False: print(f'Fav2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS') print(f'Fav2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS') - if headdim <= 256 and dtype != torch.float8_e4m3fn: + if headdim <= 256 and dtype != torch.float8_e4m3fn and headdim == headdim_v: if triton_attention is not None: print(f'Triton fwd: {m3.mean * 1e3:.3f}ms, {(nFLOPS / m3.mean * 1e-12):.1f} TFLOPS') # if causal: @@ -409,7 +387,7 @@ def run(*args, **kwargs): print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS') print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS') print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS') - if dtype != torch.float8_e4m3fn: + if dtype != torch.float8_e4m3fn and headdim == headdim_v: print(f'Fav3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS') # benchmark_forward(torch.square, k) # print(f'cuBLAS: {m5.mean * 1e3:.3f}ms, {(nFLOPS_matmul / m5.mean * 1e-12):.1f} TFLOPS') diff --git a/hopper/combine.h b/hopper/combine.h deleted file mode 100644 index c26f7ea5623..00000000000 --- a/hopper/combine.h +++ /dev/null @@ -1,248 +0,0 @@ - -#pragma once - -#include - -#include -#include "cutlass/layout/layout.h" -#include -#include - -#include "kernel_traits.h" -#include "utils.h" - -namespace flash { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SharedStorageLSE { - cute::array_aligned> smem_lse; - cute::array_aligned> smem_valid_splits; -}; - -// DONT use Kernel_traits here to avoid redundant compilation. -// template -template -__global__ void combine_attn_seqk_parallel(Params const params) { - // using Element = typename Kernel_traits::OutputType; - // using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = int64_t; // Kernel_traits::index_t - constexpr int kMaxSplits = 1 << Log_max_splits; - // constexpr int kHeadDim = Kernel_traits::kHeadDim; - constexpr int kNThreads = 128; //Kernel_traits::kNThreads; - - static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); - static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); - static_assert(kNThreads == 128, "We assume that each block has 128 threads"); - - // Shared memory. - // kBlockM + 1 instead of kBlockM to reduce bank conflicts. - //__shared__ __align__(16) ElementAccum sLSE[kMaxSplits][kBlockM+1]; - extern __shared__ char smem_[]; - using SharedStorage = SharedStorageLSE, Int>, Shape>>; - SharedStorage &shared_storage = - *reinterpret_cast(smem_); - Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape, Int>{}); - Tensor sValidSplits = make_tensor(make_smem_ptr(shared_storage.smem_valid_splits.data()), Shape>{}); - - // The thread and block index. - const int tidx = threadIdx.x; - const int bidx = blockIdx.x; - - const index_t lse_size = params.b * params.h * params.seqlen_q; - //if (cute::thread0()) print ("final %d %d %d %d\n", params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q); - - const index_t row_offset_lse = bidx * kBlockM; - Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), - Shape, Int>{}, - make_stride(lse_size, _1{})); - - // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile. - // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}. - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape>{}, Stride<_1>{}); - - // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}. - Layout flat_layout = make_layout(lse_size); - Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b)); - auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q); - Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride); - Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout)); - - Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), final_layout); - - constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; - - // Read the LSE values from gmem and store them in shared memory, then transpose them. - constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; - #pragma unroll - for (int l = 0; l < kNLsePerThread; ++l) { - const int row = l * kRowsPerLoadLSE + tidx / kBlockM; - const int col = tidx % kBlockM; - ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; - if (row < kMaxSplits) { sLSE(row,col) = lse; } - // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } - } - __syncthreads(); - - // Reduce along the kBlockM dimension to determine valid splits (store in SMEM) - // One thread per split. Know NumThreads = 128 >= NumMaxSplits - if (tidx < kMaxSplits) { - bool is_valid_split = false; - #pragma unroll - for (int col = 0; col < kBlockM; ++col) { - if(sLSE(tidx,col) != -INFINITY) { - is_valid_split = true; - } - } - sValidSplits(tidx) = is_valid_split; - } - __syncthreads(); - // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } - - Tensor lse_accum = make_tensor(Shape>{}); - constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); - // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits - // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, - // kBlockM rows, so each time we load we can load 128 / kBlockM rows). - // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; - // static_assert(kThreadsPerSplit <= 32); - static_assert(kRowsPerLoadTranspose <= 32); - static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); - #pragma unroll - for (int l = 0; l < kNLsePerThread; ++l) { - const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; - const int col = tidx / kRowsPerLoadTranspose; - //if (bidx == 0 && tidx < 128) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } - lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE(row,col) : -INFINITY; - - } - //return; - - // Compute the logsumexp of the LSE along the split dimension. - ElementAccum lse_max = lse_accum(0); - #pragma unroll - for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); } - MaxOp max_op; - lse_max = Allreduce::run(lse_max, max_op); - lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf - float lse_sum = expf(lse_accum(0) - lse_max); - #pragma unroll - for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } - SumOp sum_op; - lse_sum = Allreduce::run(lse_sum, sum_op); - // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise - // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. - ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; - // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } - if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { - if (params.unpadded_lse) { - const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; - if (lse_offset < lse_size) { - gLSE_unpadded(lse_offset) = lse_logsum; - } - } else { - gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; - } - } - //if (cute::thread0()) printf ("lse_logsum = %f\n", lse_logsum); - - // Store the scales exp(lse - lse_logsum) in shared memory. - #pragma unroll - for (int l = 0; l < kNLsePerThread; ++l) { - const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; - const int col = tidx / kRowsPerLoadTranspose; - if (row < params.num_splits && col < kBlockM) { sLSE(row,col) = expf(lse_accum(l) - lse_logsum); } - } - __syncthreads(); - - const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; - Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), - Shape, Int>{}, - Stride, _1>{}); - constexpr int kBlockN = kNThreads / kBlockM; - using GmemLayoutAtomOaccum = Layout, Int>, Stride, _1>>; - using GmemTiledCopyOaccum = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtomOaccum{}, - Layout>{})); // Val layout, 4 vals per store - GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; - auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); - Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); - Tensor tOrO = make_tensor(shape(tOgOaccum)); - Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); - clear(tOrO); - - // Predicates - Tensor cOaccum = make_identity_tensor(Shape, Int>{}); - //if (cute::thread0()) print_tensor (cOaccum); - // Repeat the partitioning with identity layouts - Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); - Tensor tOpOaccum = make_tensor(make_shape(size<2>(tOgOaccum))); - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } - } - // Load Oaccum in then scale and accumulate to O - for (int split = 0; split < params.num_splits; ++split) { - // DONT copy in Oaccum if lse(split) = -inf for all kBlockM. - if(sValidSplits(split)) { - flash::copy( - gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM - ); - #pragma unroll - for (int m = 0; m < size<1>(tOrOaccum); ++m) { - int row = get<0>(tOcOaccum(0, m, 0)); - ElementAccum lse_scale = sLSE(split,row); - if (lse_scale != 0.f) { - #pragma unroll - for (int k = 0; k < size<2>(tOrOaccum); ++k) { - #pragma unroll - for (int i = 0; i < size<0>(tOrOaccum); ++i) { - tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); - //tOrO(i, m, k) += tOrOaccum(i, m, k); - } - } - } - //if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE(split, 0), sLSE(split, 1)); print_tensor(tOrOaccum); } - } - } - tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; - } - //if (cute::thread0()) { print_tensor(tOrO); } - - Tensor rO = flash::convert_type(tOrO); - // Write to gO - #pragma unroll - for (int m = 0; m < size<1>(rO); ++m) { - const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); - //if (cute::thread0()) print ("final %d %d %d %d %d\n", idx, params.b, params.h, params.seqlen_q, params.b * params.h * params.seqlen_q); - if (idx < params.b * params.h * params.seqlen_q) { - //print ("final2\n"); - const int batch_idx = idx / (params.h * params.seqlen_q); - const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; - // The index to the rows of Q - const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; - auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride - + head_idx * params.o_head_stride + row * params.o_row_stride; - #pragma unroll - for (int k = 0; k < size<2>(rO); ++k) { - if (Is_even_K || tOpOaccum(k)) { - const int col = get<1>(tOcOaccum(0, m, k)); - Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), - Shape(rO))::value>>{}, Stride<_1>{}); - // TODO: Should check if this is using vectorized store, but it seems pretty fast - copy(rO(_, m, k), gO); - //if (cute::thread0()) { print ("final\n"); print_tensor(gO); } - // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } - // reinterpret_cast(o_ptr)[col / 4] = recast(rO)(0, m, k); - } - } - } - } -} - -} // namespace flash diff --git a/hopper/epilogue_fwd.hpp b/hopper/epilogue_fwd.hpp index 0f916060260..1c13988ebd7 100644 --- a/hopper/epilogue_fwd.hpp +++ b/hopper/epilogue_fwd.hpp @@ -20,11 +20,11 @@ namespace flash { using namespace cute; -template struct CollectiveEpilogueFwd { - using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_PV = TileShape_MNK_PV_; using ClusterShape = ClusterShape_; using Element = Element_; using ArchTag = ArchTag_; @@ -37,21 +37,23 @@ struct CollectiveEpilogueFwd { static_assert(ArchTag::kMinComputeCapability >= 80); static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1); - static constexpr int kBlockM = get<0>(TileShape_MNK{}); - static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); + static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{}); + + static constexpr bool LargeHeadDimV = kHeadDimV > 256; using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; // These are for storing the output tensor without TMA (e.g., for setting output to zero) static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore"); + static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times // we need to call divmod. - static constexpr int kBytePerRow = kHeadDim * sizeof(Element); + static constexpr int kBytePerRow = kHeadDimV * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); - // static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); - // static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerStore, NumEpilogueThreads); + // static constexpr int kBlockKGmem = kHeadDimV % 128 == 0 ? 128 : (kHeadDimV % 64 == 0 ? 64 : 32); + // static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDimV / kGmemElemsPerStore, NumEpilogueThreads); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore; // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0); @@ -65,15 +67,15 @@ struct CollectiveEpilogueFwd { Layout>>{})); // Val layout, 8 or 16 vals per store using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 2>(TileShape_MNK{}))); + decltype(cute::get<0>(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>()); + using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{}))); static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1)); static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4); using SmemLayoutAtomO = decltype( composition(Swizzle{}, Layout>, Stride, _1>>{})); - using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); + using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{}))); using SmemLayoutO = std::conditional_t= 90, SmemLayoutOTMA, SmemLayoutOSTS>; using ShapeO = cute::Shape; // (seqlen_q, d, head, batch, num_splits) @@ -109,7 +111,7 @@ struct CollectiveEpilogueFwd { GmemTiledCopyOTMA{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeO{}, StrideO{}), SmemLayoutOTMA{}, - select<0, 2>(TileShape_MNK{}), + select<0, 1>(TileShape_MNK_PV{}), _1{})), // no mcast for O std::nullptr_t >; @@ -148,7 +150,7 @@ struct CollectiveEpilogueFwd { Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O); TMA_O tma_store_O = [&]{ if constexpr (Use_TMA_O) { - return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 2>(TileShape_MNK{}), _1{}); // no mcast + return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast } else { return nullptr; } @@ -239,35 +241,38 @@ struct CollectiveEpilogueFwd { bool is_varlen = Varlen && params.cu_seqlens; int offset_o = seqlen_info.offset; int seqlen_o = seqlen_info.seqlen; + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); // Step 2: Write LSE from rmem -> gmem auto thread_mma = tiled_mma.get_thread_slice(thread_idx); // (MMA,MMA_M,MMA_K) - Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); + Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); static_assert(decltype(size<0, 0>(taccOcO))::value == 2); static_assert(decltype(size<0, 1>(taccOcO))::value == 2); Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M - using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>; + using PackGQAt = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset_o * get<0>(params.stride_LSE)), params.shape_LSE_packed, params.stride_LSE_packed)(_, bidh, !is_varlen ? bidb : 0, split_idx); // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); } - if constexpr (!PackGQA) { - #pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { - int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); - if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); } + if (!LargeHeadDimV || warp_group_idx == 0) { + if constexpr (!PackGQA) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + int const row = m_block * kBlockM + get<0>(taccOcO_row(mi)); + if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); } + } + } else { + PackGQAt::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } - } else { - PackGQAt::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); } // Step 3: Write O from smem -> gmem if constexpr (Use_TMA_O) { Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx); - Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) auto block_tma_O = params.tma_store_O.get_slice(_0{}); Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K) @@ -287,7 +292,7 @@ struct CollectiveEpilogueFwd { } } else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx); - Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast(&mO(0)) - reinterpret_cast(params.ptr_O)); } if constexpr (Use_smem) { GmemTiledCopyO gmem_tiled_copy_O; @@ -305,7 +310,7 @@ struct CollectiveEpilogueFwd { } if constexpr (!PackGQA) { // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); Tensor tOpO = make_tensor(make_shape(size<2>(tOsO))); #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } @@ -361,7 +366,7 @@ struct CollectiveEpilogueFwd { int thread_idx, cute::tuple const& block_coord ) { - static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); auto [m_block, bidh, bidb, split_idx] = block_coord; flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; @@ -391,12 +396,12 @@ struct CollectiveEpilogueFwd { GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 2>(TileShape_MNK{}))); + Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); if constexpr (!PackGQA) { Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); #pragma unroll for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); } - Tensor gO = local_tile(mO, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) Tensor tOgO = gmem_thr_copy_O.partition_D(gO); Tensor tOrO = make_fragment_like(tOgO); cute::clear(tOrO); @@ -406,7 +411,7 @@ struct CollectiveEpilogueFwd { ); } else { // If PackGQA, we split the work of compute O_ptr among threads in the same row - using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumEpilogueThreads, Element>; + using PackGQAt = flash::PackGQAManager(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>; Tensor tOrO = make_tensor(make_shape(Shape<_1, Int>{}, size<1>(tOcO), size<2>(tOcO))); cute::clear(tOrO); PackGQAt::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block); diff --git a/hopper/flash.h b/hopper/flash.h index 4559a1352e4..8e95f5ff75c 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -65,6 +65,7 @@ struct Flash_fwd_params : public Qkv_params { int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; int total_q, total_k, total_knew; int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q + int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim // The scaling factors for the kernel. float scale_softmax; @@ -103,6 +104,11 @@ struct Flash_fwd_params : public Qkv_params { index_t knew_head_stride; index_t vnew_head_stride; + void *__restrict__ qv_ptr; + index_t qv_batch_stride; + index_t qv_row_stride; + index_t qv_head_stride; + // The cos and sin matrices for rotary embedding. void * __restrict__ rotary_cos_ptr; void * __restrict__ rotary_sin_ptr; @@ -197,9 +203,9 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -template +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); -template +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 9dd55d0d6d3..d990f986fbd 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -270,36 +270,48 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { if (!params.is_e4m3) { if (params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_(params, stream); } + if (params.d <= 64) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_(params, stream); } + if (params.d <= 96) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_(params, stream); } + if (params.d <= 128) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_fwd_(params, stream); } + if (params.d <= 192) { + if (params.dv <= 128 && Arch == 90) { + return run_mha_fwd_(params, stream); + } else { + return run_mha_fwd_(params, stream); + } + } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_(params, stream); } + if (params.d <= 256) { return run_mha_fwd_(params, stream); } #endif } else { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_(params, stream); } + if (params.d <= 64) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_(params, stream); } + if (params.d <= 96) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_(params, stream); } + if (params.d <= 128) { return run_mha_fwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_fwd_(params, stream); } + if (params.d <= 192) { + if (params.dv <= 128 && Arch == 90) { + return run_mha_fwd_(params, stream); + } else { + return run_mha_fwd_(params, stream); + } + } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_(params, stream); } + if (params.d <= 256) { return run_mha_fwd_(params, stream); } #endif #else TORCH_CHECK(false, "This flash attention build does not support FP16."); @@ -308,19 +320,25 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } else { #ifndef FLASHATTENTION_DISABLE_FP8 #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 192) { + if (params.dv <= 128 && Arch == 90) { + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); + } else { + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream); + } + } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } + if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); } #endif #else TORCH_CHECK(false, "This flash attention build does not support FP8."); @@ -338,28 +356,22 @@ void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively // so that kBlockM is smaller and we have more parallelism. if (params.is_fp32) { - if (params.d <= 64) { + if (params.dv <= 64) { run_mha_fwd_combine_(params, stream); - } else if (params.d <= 128) { - run_mha_fwd_combine_(params, stream); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream); } } else if (params.is_bf16) { - if (params.d <= 64) { + if (params.dv <= 64) { run_mha_fwd_combine_(params, stream); - } else if (params.d <= 128) { - run_mha_fwd_combine_(params, stream); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream); } } else { - if (params.d <= 64) { + if (params.dv <= 64) { run_mha_fwd_combine_(params, stream); - } else if (params.d <= 128) { - run_mha_fwd_combine_(params, stream); } else { - run_mha_fwd_combine_(params, stream); + run_mha_fwd_combine_(params, stream); } } #else @@ -377,7 +389,7 @@ inline bool get_pack_gqa(Flash_fwd_params const& params) { // params.page_table must already be set if (params.h == params.h_k) { return false; } // This needs to match the kernel configs - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); #endif @@ -391,10 +403,10 @@ inline int get_num_splits(Flash_fwd_params const& params) { // params.page_table must already be set // This needs to match the kernel configs bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k; - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f); // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits // has not been set here. It's OK though because we might just underestimate kBlockN a bit - auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); + auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr); int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x); int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k); @@ -404,7 +416,8 @@ inline int get_num_splits(Flash_fwd_params const& params) { : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM)); int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN; int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM; - return num_splits_heuristic(params.b * (!params.pack_gqa ? params.h : params.h_k) * num_m_blocks, params.num_sm, num_n_blocks, 128); + // Always enable PackGQA for Split + return num_splits_heuristic(params.b * params.h_k * num_m_blocks, params.num_sm, num_n_blocks, 128); // return num_splits_heuristic(params.b * params.h_k * num_m_blocks, params.b * params.h_k, // params.num_sm, num_n_blocks, 128, params.d_rounded); #endif @@ -459,10 +472,11 @@ inline int round_up_headdim(int head_size) { std::vector mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. - const at::Tensor &v, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. + const at::Tensor &v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. std::optional &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional &v_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional &out_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + std::optional &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional &q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q std::optional &cu_seqlens_q_, // b+1 std::optional &cu_seqlens_k_, // b+1 std::optional &cu_seqlens_k_new_, // b+1 @@ -550,6 +564,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0]; int num_heads = q.size(-2); int const head_size = q.size(-1); + int const head_size_v = v.size(-1); int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1); int const num_pages = !paged_KV ? 0 : k.size(0); int const page_size = !paged_KV ? 1 : k.size(1); @@ -563,6 +578,14 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int const max_headdim = get_max_headdim(); TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (head_size_v != head_size) { + TORCH_CHECK(head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128, "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128]"); + TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim"); + if (head_size_v > 256) { + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "HeaddimV > 256 requires fp16 and bf16 data type"); + } + } // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM // TODO: check this @@ -582,15 +605,15 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq if (!paged_KV) { if (!is_varlen_k) { CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size); - CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v); } else { CHECK_SHAPE(k, total_k, num_heads_k, head_size); - CHECK_SHAPE(v, total_k, num_heads_k, head_size); + CHECK_SHAPE(v, total_k, num_heads_k, head_size_v); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); } } else { CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); - CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); } @@ -609,6 +632,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8; TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); + TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); auto opts = q.options(); auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type; @@ -619,16 +643,19 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); if (!is_varlen_q) { - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v); } else { - CHECK_SHAPE(out, total_q, num_heads, head_size); + CHECK_SHAPE(out, total_q, num_heads, head_size_v); } } else { - out = torch::empty_like(q, opts.dtype(out_type)); + out = !is_varlen_q + ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)) + : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type)); } auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; int const head_size_rounded = round_up_headdim(head_size); + int const head_size_v_rounded = round_up_headdim(head_size_v); int const seqlen_q_rounded = round_multiple(seqlen_q, 128); int const seqlen_k_rounded = round_multiple(seqlen_k, 128); @@ -666,6 +693,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.total_k = total_k; params.sink_token_length = sink_token_length; params.b_k = batch_size_k; + params.dv = head_size_v; + params.dv_rounded = head_size_v_rounded; if (paged_KV) { params.page_table = page_table.data_ptr(); @@ -675,6 +704,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq params.num_pages = num_pages; params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); if (k_new_.has_value()) { @@ -701,10 +731,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0); if (!is_varlen_k_new) { CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size); - CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v); } else { CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size); - CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size); + CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v); CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1); } params.seqlen_knew = seqlen_k_new; @@ -725,6 +755,30 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } } + if (q_v_.has_value()) { + TORCH_CHECK(false, "q_v should be None for now"); + TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "q_v is only supported for fp16 and bf16 data type"); + TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); + at::Tensor q_v = q_v_.value(); + TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); + CHECK_DEVICE(q_v); + TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); + if (!is_varlen_q) { + CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v); + } else { + CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); + } + params.qv_ptr = q_v.data_ptr(); + // All stride are in elements, not bytes. + params.qv_row_stride = q_v.stride(-3); + params.qv_head_stride = q_v.stride(-2); + if (!is_varlen_q) { + params.qv_batch_stride = q_v.stride(0); + } + } + if (leftpad_k_.has_value()) { auto leftpad_k = leftpad_k_.value(); TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); @@ -771,12 +825,12 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq if (params.num_splits > 1) { TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported"); if (!is_varlen_q) { - out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size}, opts.dtype(outaccum_type)); + out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type)); softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); params.oaccum_batch_stride = out_accum.stride(1); params.lseaccum_batch_stride = softmax_lse_accum.stride(1); } else { - out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size}, opts.dtype(outaccum_type)); + out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type)); softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat)); } params.is_fp32 = false; @@ -1257,7 +1311,7 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x const int seqlen = sizes[2]; const int num_heads = sizes[3]; const int head_size_og = sizes[4]; - TORCH_CHECK(head_size_og <= 256, "FlashAttention combine only supports head dimension at most 256"); + TORCH_CHECK(head_size_og <= 512, "FlashAttention combine only supports head dimension at most 512"); TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256"); CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og); @@ -1306,7 +1360,7 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x params.b = batch_size; params.h = num_heads; params.seqlen_q = seqlen; - params.d = head_size; + params.dv = head_size; params.num_splits = num_splits; params.oaccum_split_stride = out_partial_padded.stride(0); params.oaccum_row_stride = out_partial_padded.stride(2); diff --git a/hopper/flash_api_torch_lib.cpp b/hopper/flash_api_torch_lib.cpp index 81b522b17fa..2406d1a5076 100644 --- a/hopper/flash_api_torch_lib.cpp +++ b/hopper/flash_api_torch_lib.cpp @@ -12,26 +12,27 @@ std::vector mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. - const at::Tensor &v, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table. - c10::optional &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - c10::optional &v_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - c10::optional &out_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - c10::optional &cu_seqlens_q_, // b+1 - c10::optional &cu_seqlens_k_, // b+1 - c10::optional &cu_seqlens_k_new_, // b+1 - c10::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - c10::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. - c10::optional max_seqlen_q_, + const at::Tensor &v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table. + std::optional &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional &q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional &cu_seqlens_q_, // b+1 + std::optional &cu_seqlens_k_, // b+1 + std::optional &cu_seqlens_k_new_, // b+1 + std::optional &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. + std::optional &seqused_k_, // b. If given, only this many elements of each batch element's keys are used. + std::optional max_seqlen_q_, // TODO: check if we need max_seqlen_k - c10::optional max_seqlen_k_, - c10::optional &page_table_, // (b_k, max_num_pages_per_seq) - c10::optional &kv_batch_idx_, // b. indices to index into the KV cache - c10::optional &leftpad_k_, // b - c10::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) - c10::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) - c10::optional &q_descale_, // (b, h_k), not (b, h) - c10::optional &k_descale_, // (b, h_k) - c10::optional &v_descale_, // (b, h_k) + std::optional max_seqlen_k_, + std::optional &page_table_, // (b_k, max_num_pages_per_seq) + std::optional &kv_batch_idx_, // b. indices to index into the KV cache + std::optional &leftpad_k_, // b + std::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional &q_descale_, // (b, h_k), not (b, h) + std::optional &k_descale_, // (b, h_k) + std::optional &v_descale_, // (b, h_k) float const softmax_scale, bool is_causal, int window_size_left, @@ -40,7 +41,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq float const softcap, bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits, - c10::optional pack_gqa_, + std::optional pack_gqa_, int const sm_margin); /** @@ -52,6 +53,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor v," " Tensor? k_new," " Tensor? v_new," + " Tensor? q_v," " Tensor!? out," " Tensor? cu_seqlens_q," " Tensor? cu_seqlens_k," diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 5f1e4899c92..adee1a0ff26 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -22,6 +22,7 @@ def _flash_attn_forward( v, k_new, v_new, + qv, out, cu_seqlens_q, cu_seqlens_k, @@ -64,6 +65,7 @@ def _flash_attn_forward( v, k_new, v_new, + qv, out, cu_seqlens_q, cu_seqlens_k, @@ -239,6 +241,7 @@ def forward( v, softmax_scale, causal, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), sink_token_length=0, @@ -249,13 +252,14 @@ def forward( sm_margin=0, ): if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) # out, q, k, v, out_padded, softmax_lse = _flash_attn_forward( out, softmax_lse, *rest = _flash_attn_forward( q, k, v, None, None, # k_new, v_new + qv, # qv None, # out None, None, None, # cu_seqlens_q/k/k_new None, None, # seqused_q/k @@ -311,7 +315,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @@ -330,6 +334,7 @@ def forward( max_seqlen_k, softmax_scale, causal, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), sink_token_length=0, @@ -340,13 +345,14 @@ def forward( sm_margin=0, ): if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) # out, q, k, v, out_padded, softmax_lse = _flash_attn_varlen_forward( out, softmax_lse, *rest = _flash_attn_forward( q, k, v, None, None, # k_new, v_new + qv, # qv None, # out cu_seqlens_q, cu_seqlens_k, @@ -411,7 +417,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( @@ -478,6 +484,7 @@ def flash_attn_func( v, softmax_scale=None, causal=False, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), sink_token_length=0, @@ -538,6 +545,7 @@ def flash_attn_func( v, softmax_scale, causal, + qv, q_descale, k_descale, v_descale, window_size, sink_token_length, @@ -561,6 +569,7 @@ def flash_attn_varlen_func( max_seqlen_k, softmax_scale=None, causal=False, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), sink_token_length=0, @@ -582,6 +591,7 @@ def flash_attn_varlen_func( max_seqlen_k, softmax_scale, causal, + qv, q_descale, k_descale, v_descale, window_size, sink_token_length, @@ -603,6 +613,7 @@ def flash_attn_with_kvcache( v_cache, k=None, v=None, + qv=None, rotary_cos=None, rotary_sin=None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, @@ -673,11 +684,12 @@ def flash_attn_with_kvcache( k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) page_block_size must be a multiple of 256. - v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no _table, - or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate k with k_cache, starting at the indices specified by cache_seqlens. - v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k. + qv [optional]: (batch_size, seqlen, nheads, headdim_v) rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. @@ -714,7 +726,7 @@ def flash_attn_with_kvcache( assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) if cache_seqlens is not None and isinstance(cache_seqlens, int): cache_seqlens = torch.full( (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device @@ -726,6 +738,7 @@ def flash_attn_with_kvcache( v_cache, k, v, + qv, None, # out cu_seqlens_q, None, # cu_seqlens_k diff --git a/hopper/flash_fwd_combine.cu b/hopper/flash_fwd_combine.cu index 5b7d9eed655..a1725cf2a82 100644 --- a/hopper/flash_fwd_combine.cu +++ b/hopper/flash_fwd_combine.cu @@ -5,12 +5,9 @@ template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index aaec31e5807..20685a15656 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -40,11 +40,11 @@ class FlashAttnFwdCombine { static constexpr uint32_t MinBlocksPerMultiprocessor = 2; static constexpr int kBlockM = get<0>(TileShape_MK{}); - static constexpr int kHeadDim = get<1>(TileShape_MK{}); + static constexpr int kBlockK = get<1>(TileShape_MK{}); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); - static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static_assert(kBlockK % kGmemElemsPerLoad == 0, "kBlockK must be a multiple of kGmemElemsPerLoad"); + static constexpr int kBlockKGmem = kBlockK % 128 == 0 ? 128 : (kBlockK % 64 == 0 ? 64 : 32); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow"); using GmemCopyAtom = std::conditional_t< @@ -98,8 +98,8 @@ class FlashAttnFwdCombine { Stride, _1>>{})); using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape, Int>{})); - using SmemLayoutO = Layout, Int, Int>, - Stride, _1, Int>>; + using SmemLayoutO = Layout, Int, Int>, + Stride, _1, Int>>; // We want each column (kMaxSplits) to be processed by threads in the same warp. // To reduce the number of shuffles, we want as few threads on the same column as possible. @@ -194,7 +194,8 @@ class FlashAttnFwdCombine { Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{}); int const thread_idx = threadIdx.x; - int const m_block = blockIdx.x; + int const k_block = blockIdx.x; + int const m_block = blockIdx.y; int const batch = !Varlen ? 0 : blockIdx.y; int const num_splits = get<1>(params.shape_LSE_partial); flash::SeqlenInfo seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused}; @@ -254,7 +255,8 @@ class FlashAttnFwdCombine { Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO); - Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)), params.shape_O_partial, params.stride_O_partial); // (seqlen, d, num_splits, head, batch) + Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)), + params.shape_O_partial, params.stride_O_partial); // (seqlen, d, num_splits, head, batch) // Precompute these values to avoid recomputing them in the loop Tensor tOmidx = make_tensor(make_shape(size<1>(tOcO))); @@ -271,7 +273,7 @@ class FlashAttnFwdCombine { tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx); tObidb[m] = 0; } - tOrOptr[m] = &mOpartial(tOmidx(m), _0{}, _0{}, tObidh(m), tObidb(m)); + tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m), tObidb(m)); if (idx >= max_idx) { tObidb[m] = -1; } @@ -280,7 +282,7 @@ class FlashAttnFwdCombine { Tensor tOpO = make_tensor(make_shape(size<2>(tOcO))); if constexpr (!(Is_even_K)) { #pragma unroll - for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial); } + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial) - k_block * kBlockK; } } Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO); @@ -358,26 +360,36 @@ class FlashAttnFwdCombine { // Store the scales exp(lse - lse_logsum) back to smem cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE); - // Step 5: store final LSE back to gmem - auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial); - Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE); + // Store max_valid_split to smem #pragma unroll for (int m = 0; m < size<2>(ts2rrLSE); ++m) { - if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem + if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to smem int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); - int idx = m_block * kBlockM + mi; - if (idx < max_idx) { - int m_idx, bidh, bidb; - if constexpr (!Varlen) { - bidb = params.head_divmod.divmod(bidh, params.seqlen_divmod.divmod(m_idx, idx)); - } else { - bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); - bidb = 0; + if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; } + } + } + + // Step 5: store final LSE back to gmem + if (k_block == 0) { + auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial); + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE); + #pragma unroll + for (int m = 0; m < size<2>(ts2rrLSE); ++m) { + if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem + int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m))); + int idx = m_block * kBlockM + mi; + if (idx < max_idx) { + int m_idx, bidh, bidb; + if constexpr (!Varlen) { + bidb = params.head_divmod.divmod(bidh, params.seqlen_divmod.divmod(m_idx, idx)); + } else { + bidh = seqlen_divmod_dynamic.divmod(m_idx, idx); + bidb = 0; + } + // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m)); + mLSE(m_idx, bidh, bidb) = lse_sum(m); } - // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m)); - mLSE(m_idx, bidh, bidb) = lse_sum(m); } - if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; } } } @@ -427,8 +439,9 @@ class FlashAttnFwdCombine { // Step 7: Write the final O to gmem Tensor rO = make_tensor_like(tOrO); flash::convert_type_out(tOrO, rO); - auto shape_O = select<0, 1, 3, 4>(params.shape_O_partial); - Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O)), shape_O, params.stride_O); + auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial)); + Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)), + shape_O, params.stride_O); Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int>{}); GmemTiledCopy gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index 33e66c21f82..101f894b2d6 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -16,15 +16,15 @@ using namespace cute; -template +template void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { - using TileShape_MK = cute::Shape, Int>; + using TileShape_MK = cute::Shape, Int>; using CombineKernel = flash::FlashAttnFwdCombine; typename CombineKernel::Arguments args { static_cast(params.oaccum_ptr), - {!Varlen ? params.seqlen_q : params.total_q, params.d, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial + {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial static_cast(params.softmax_lseaccum_ptr), {!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_LSE_partial @@ -37,8 +37,9 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { }; typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); + int num_blocks_k = cute::ceil_div(params.dv, kBlockK); int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h * (!Varlen ? params.b : 1), kBlockM); - dim3 grid_m(num_blocks_m, !Varlen ? 1 : params.b); + dim3 grid_m(num_blocks_k, num_blocks_m, !Varlen ? 1 : params.b); auto kernel = cutlass::device_kernel; int smem_size = CombineKernel::SharedStorageSize; if (smem_size >= 48 * 1024) { @@ -48,27 +49,27 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream) { CHECK_CUDA_KERNEL_LAUNCH(); } -template +template void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream) { // We want kBlockM to be as small as possible to maximize parallelism. // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats). - static_assert(kHeadDim % 32 == 0, "kHeadDim must be a multiple of 32"); - static constexpr int kBlockM = kHeadDim % 128 == 0 ? 8 : (kHeadDim % 64 == 0 ? 16 : 32); + static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32"); + static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32); BOOL_SWITCH(params.seqused_q != nullptr, Varlen, [&] { if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32. if (params.num_splits <= 16) { - run_flash_fwd_combine(params, stream); + run_flash_fwd_combine(params, stream); return; } } if (params.num_splits <= 32) { - run_flash_fwd_combine(params, stream); + run_flash_fwd_combine(params, stream); } else if (params.num_splits <= 64) { - run_flash_fwd_combine(params, stream); + run_flash_fwd_combine(params, stream); } else if (params.num_splits <= 128) { - run_flash_fwd_combine(params, stream); + run_flash_fwd_combine(params, stream); } else { - run_flash_fwd_combine(params, stream); + run_flash_fwd_combine(params, stream); } }); } diff --git a/hopper/flash_fwd_kernel_sm90.h b/hopper/flash_fwd_kernel_sm90.h index e5411042dc9..c7fec6df559 100644 --- a/hopper/flash_fwd_kernel_sm90.h +++ b/hopper/flash_fwd_kernel_sm90.h @@ -40,17 +40,20 @@ class FlashAttnFwdSm90 { static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8; static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V; static constexpr bool AppendKV = CollectiveMainloop::AppendKV; + static constexpr bool HasQv = CollectiveMainloop::HasQv; static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q; static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV; static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O; static constexpr bool PackGQA = CollectiveMainloop::PackGQA; static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads; + static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim; + static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV; + static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV); using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t; // Mainloop derived types - using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK; - using TiledMma0 = typename CollectiveMainloop::TiledMma0; - using TiledMma1 = typename CollectiveMainloop::TiledMma1; + using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV; + using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV; using ArchTag = typename CollectiveMainloop::ArchTag; using ClusterShape = typename CollectiveMainloop::ClusterShape; using MainloopArguments = typename CollectiveMainloop::Arguments; @@ -68,8 +71,8 @@ class FlashAttnFwdSm90 { using TileSchedulerParams = typename TileScheduler::Params; static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma0{})) / cutlass::NumThreadsPerWarpGroup; - static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma0{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); @@ -88,7 +91,7 @@ class FlashAttnFwdSm90 { static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v))); static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_; struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { + struct TensorStorage : cute::aligned_struct<128, _1> { union { struct { cute::array padding_; @@ -98,9 +101,9 @@ class FlashAttnFwdSm90 { typename CollectiveEpilogue::TensorStorage epilogue; }; } tensors; - - struct PipelineStorage : cute::aligned_struct<16> { + struct PipelineStorage : cute::aligned_struct<16, _1> { alignas(16) BarrierQ barrier_Q; + alignas(16) BarrierQ barrier_Qv; alignas(16) cutlass::arch::ClusterBarrier barrier_O; alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k; alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v; @@ -176,7 +179,7 @@ class FlashAttnFwdSm90 { static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup; - static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK; using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV; @@ -205,6 +208,9 @@ class FlashAttnFwdSm90 { if (warp_idx == 0 && lane_predicate) { shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/); + } shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/); } @@ -216,12 +222,21 @@ class FlashAttnFwdSm90 { if constexpr (Use_TMA_KV) { pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; pipeline_params_k.is_leader = warp_group_thread_idx == 0; - pipeline_params_k.num_consumers = NumMmaThreads; + pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup; } else { - pipeline_params_k.consumer_arv_count = NumMmaThreads; + pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup; pipeline_params_k.producer_arv_count = NumProducerThreads; } + static_assert(is_same_v); + PipelineParamsVt pipeline_params_vt = pipeline_params_k; + if constexpr (Use_TMA_KV && !SameHeadDim) { + pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; } + } else { + if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; } + } + MainloopPipelineK pipeline_k = [&] { if constexpr (Use_TMA_KV) { return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{}); @@ -234,9 +249,9 @@ class FlashAttnFwdSm90 { if constexpr (!Transpose_V) { static_assert(is_same_v); if constexpr (Use_TMA_KV) { - return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_k, ClusterShape{}); + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{}); } else { - return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_k); + return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt); } } else { PipelineParamsV pipeline_params_v; @@ -248,7 +263,6 @@ class FlashAttnFwdSm90 { return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v); } }(); - static_assert(is_same_v); // If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then // the producer WG will read from pipeline_vt and write to pipeline_v. // If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used. @@ -256,11 +270,11 @@ class FlashAttnFwdSm90 { // However, the thread role isn't used in the pipeline implementation. MainloopPipelineVt pipeline_vt = [&] { if constexpr (Use_TMA_KV) { - pipeline_params_k.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG - return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_k, ClusterShape{}); + pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{}); } else { - pipeline_params_k.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG - return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_k); + pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG + return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt); } }(); @@ -272,6 +286,9 @@ class FlashAttnFwdSm90 { pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0; pipeline_params_kv_new.num_consumers = NumMmaThreads; auto pipeline_k_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr); + if constexpr (!SameHeadDim) { + pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + } auto pipeline_v_new = cute::conditional_return(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr); CollectiveMainloop collective_mainloop; @@ -341,7 +358,7 @@ class FlashAttnFwdSm90 { TileScheduler scheduler(reinterpret_cast(&shared_storage.pipelines.smem_scheduler)); // Initialize matmul objects. - TiledMma1 tiled_mma1; + TiledMmaPV tiled_mma_pv; PipelineState smem_pipe_read; PipelineState smem_pipe_read_new; @@ -357,7 +374,7 @@ class FlashAttnFwdSm90 { work_tile_info.is_valid(params.scheduler); work_tile_info = scheduler.template get_next_work(params.scheduler, work_tile_info)) { // Attention output (GEMM-II) accumulator. - Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{})); + Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{})); float softmax_scale_log2 = params.mainloop.softmax_scale_log2; // If there's tanh softcap, the scaling will be done before tanh. auto block_coord = work_tile_info.get_block_coord(params.scheduler); @@ -369,7 +386,7 @@ class FlashAttnFwdSm90 { float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)]; softmax_scale_log2 *= q_descale * k_descale; } - flash::Softmax<2 * (2 * kBlockM / NumMmaThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2); + flash::Softmax softmax(softmax_scale_log2); SeqlenInfo_t seqlen_info{ bidb, @@ -395,12 +412,25 @@ class FlashAttnFwdSm90 { // if (threadIdx.x == 128) { printf("Consumer: After sync\n"); } } } - bool tile_valid = collective_mainloop.mma( - params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, - tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); + bool tile_valid; + if constexpr (!LargeHeadDimV) { + tile_valid = collective_mainloop.mma( + params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); + } else { // mma_pv might not compile if !LargeHeadDimV + if (warp_group_idx == 1) { + tile_valid = collective_mainloop.mma( + params.mainloop, pipeline_k, pipeline_v, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage); + } else { + tile_valid = collective_mainloop.mma_pv( + params.mainloop, pipeline_v, smem_pipe_read, + tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage); + } + } if (tile_valid) { // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); } - collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma1, + collective_epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv, threadIdx.x - MmaThreadOffset, block_coord); } else { // Write 0 to gO and -inf to gLSE. diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 16701f160d2..b4f80a04e7c 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -23,8 +23,8 @@ using namespace cute; -template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); @@ -35,24 +35,25 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ArchTag = std::conditional_t= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>; // Can't use structured binding since it's not compatible with constexpr - static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKV, Has_softcap); - static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKV, Varlen && Split, Has_softcap, AppendKV); + static constexpr std::tuple kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKV, Has_softcap); + static constexpr std::tuple kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKV, Varlen && Split, Has_softcap, AppendKV); static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS); static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS); - static constexpr bool Mma1_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); + static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap); static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS); static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS); static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS); using TileShape_MNK = cute::Shape, Int, Int>; + using TileShape_MNK_PV = cute::Shape, Int, Int>; using ClusterShape = cute::Shape, _1, _1>; using CollectiveMainloop = std::conditional_t< Arch >= 90, - flash::CollectiveMainloopFwdSm90, - flash::CollectiveMainloopFwdSm80 + flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm80 >; - using CollectiveEpilogue = flash::CollectiveEpilogueFwd; + using CollectiveEpilogue = flash::CollectiveEpilogueFwd; static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; using SchedulerPersistent = std::conditional_t(params.v_ptr), + params.dv, // headdim_v v_strides, // stride_V static_cast(params.knew_ptr), {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new static_cast(params.vnew_ptr), {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new + static_cast(params.qv_ptr), + {params.qv_row_stride, _1{}, params.qv_head_stride, !is_varlen_q ? params.qv_batch_stride : 0}, // stride_Qv static_cast(params.rotary_cos_ptr), {params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter {params.rotary_dim / 2, _1{}}, // stride_rotary_cos @@ -124,7 +128,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }; typename CollectiveEpilogue::Arguments epilogue_args { static_cast(!Split ? params.o_ptr : params.oaccum_ptr), - {seqlen_q, params.d, params.h, batch_q, params.num_splits}, // shape_O + {seqlen_q, params.dv, params.h, batch_q, params.num_splits}, // shape_O {!Split ? params.o_row_stride : params.oaccum_row_stride, _1{}, !Split ? params.o_head_stride : params.oaccum_head_stride, @@ -179,7 +183,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { CHECK_CUDA_KERNEL_LAUNCH(); } -template +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported"); static constexpr bool Is_FP8 = cute::is_same_v || cute::is_same_v; @@ -189,14 +193,18 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1; VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { // Only needed here to decide if we should use cluster - static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128; + static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKV, Has_softcap)) : 128; - static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; - APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { - // Only use Cluster if number of tiles along seqlen_q is even and not varlen - CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { - static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - run_flash_fwd(params, stream); + // On nvcc 12.8, hdim 128, without cluster is faster (730 vs 700 TFLOPS) + static constexpr bool Enable_cluster = Arch >= 90 && (sizeof(T) == 2 ? (kHeadDim >= 192) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKV && !Varlen; + BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { + static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 and false; + APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { + // Only use Cluster if number of tiles along seqlen_q is even and not varlen + CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { + static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; + run_flash_fwd(params, stream); + }); }); }); }); diff --git a/hopper/generate_kernels.py b/hopper/generate_kernels.py index e741c13826f..19a6e90d345 100644 --- a/hopper/generate_kernels.py +++ b/hopper/generate_kernels.py @@ -38,7 +38,7 @@ KERNEL_IMPL_TEMPLATE_FWD_SM90 = """#include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM} -template void run_mha_fwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<{ARCH}, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif """ @@ -46,8 +46,8 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM{HEAD_DIM} -template void run_mha_fwd_<80, {DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, {DTYPE}, {HEAD_DIM}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, {DTYPE}, {HEAD_DIM}, {HEAD_DIM_V}, {SPLIT}, {PAGEDKV}, {SOFTCAP}, {PACKGQA}>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif """ @@ -85,6 +85,7 @@ class Kernel: sm: int dtype: str head_dim: int + head_dim_v: int split: bool paged_kv: bool softcap: bool @@ -98,14 +99,15 @@ def template(self) -> str: # Always enable PackGQA for PagedKV or Split to reduce compilation packgqa = self.packgqa or self.paged_kv or self.split return KERNEL_IMPL_TEMPLATE_FWD_SM90.format( - ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, + ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], + HEAD_DIM=self.head_dim, HEAD_DIM_V=self.head_dim_v, SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(), SOFTCAP=str(self.softcap).lower(), PACKGQA=str(packgqa).lower() ) else: # Always enable PackGQA for Sm8x to reduce compilation return KERNEL_IMPL_TEMPLATE_FWD_SM8x.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, HEAD_DIM_V=self.head_dim_v, SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(), SOFTCAP=str(self.softcap).lower(), PACKGQA=str(True).lower() ) @@ -117,13 +119,13 @@ def template(self) -> str: ) else: return KERNEL_IMPL_TEMPLATE_BWD_SM8x.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, SOFTCAP=str(self.softcap).lower() ) @property def filename(self) -> str: - return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}{'_paged' if self.paged_kv else ''}{'_split' if self.split else ''}{'_softcap' if self.softcap else ''}{'_packgqa' if self.packgqa else ''}_sm{self.sm}.cu" + return f"flash_{self.direction}_hdim{self.head_dim}{f'_{self.head_dim_v}' if self.head_dim_v != self.head_dim else ''}_{self.dtype}{'_paged' if self.paged_kv else ''}{'_split' if self.split else ''}{'_softcap' if self.softcap else ''}{'_packgqa' if self.packgqa else ''}_sm{self.sm}.cu" def get_all_kernels() -> List[Kernel]: @@ -133,20 +135,31 @@ def get_all_kernels() -> List[Kernel]: if packgqa and (sm < 90 or (sm >= 90 and (paged_kv or split))): continue if sm >= 90 or dtype in DTYPE_MAP_FWD_SM8x: - yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") + if sm == 90 and head_dim == 192: + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=128, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") + if sm == 90 and head_dim == 64 and dtype in ["bf16", "fp16"]: + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=512, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd") for dtype, head_dim, softcap, sm in itertools.product(DTYPE_MAP_BWD.keys(), HEAD_DIMENSIONS, SOFTCAP, SM): - yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction="bwd") + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, head_dim_v=head_dim, split=False, paged_kv=False, softcap=softcap, packgqa=False, direction="bwd") def batch_hdim(kernels_all) -> List[KERNEL_BATCH]: for dtype, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM): if sm < 90: continue - kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm] + # Same hdim and hdimv + kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm and k.head_dim == k.head_dim_v] if len(kernels) > 0: filename = f"flash_fwd_hdimall_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}{'_softcap' if softcap else ''}{'_packgqa' if packgqa else ''}_sm{sm}.cu" template = "\n".join([f"#include \"{k.filename}\"" for k in kernels]) yield KERNEL_BATCH(template, filename) + # Different hdim and hdimv + kernels = [k for k in kernels_all if k.direction == "fwd" and k.dtype == dtype and k.split == split and k.paged_kv == paged_kv and k.softcap == softcap and k.packgqa == packgqa and k.sm == sm and k.head_dim != k.head_dim_v] + if len(kernels) > 0: + filename = f"flash_fwd_hdimdiff_{dtype}{'_paged' if paged_kv else ''}{'_split' if split else ''}{'_softcap' if softcap else ''}{'_packgqa' if packgqa else ''}_sm{sm}.cu" + template = "\n".join([f"#include \"{k.filename}\"" for k in kernels]) + yield KERNEL_BATCH(template, filename) def batch_softcap(kernels_all) -> List[KERNEL_BATCH]: diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu index 18879eff6ee..affc7a4dd96 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu index 35c0ad78fd1..7e13614bfea 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu index 7a39869a001..670041341bc 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu index fb7ba5caed5..f315fbb4545 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu index 296ec9e91fc..bde3024a4a6 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu index 8cffb6de830..2724463e621 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu index 12d564ce364..a38a1d5cf33 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu index 845b1fa5d06..284eeba1823 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu index 25fbfda38d2..0c40ddba8fe 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu index 1130ca747d1..cc89c4d5d25 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu index 502bc1d1771..3a236b712c4 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu index 537e42ba56b..8449104c5aa 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu index 2255e7949e2..b152b90bab7 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu index 086f55b3588..8cc4fed1739 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu index 54590eebbcc..1db3f1e6d80 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu index af322d1d15f..9b3e294f1b3 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu index 3e83398e7f6..07bd687fc34 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu index 3f917d26abc..5f44833b10d 100644 --- a/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu index 87c78f28929..9f95ca29f6b 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu index e56b64c3d9a..ad97737d4f3 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu index 8202bfadde5..d77d37ec041 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu index ee7439b277c..ae05c7ce5f0 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu index 812239ef50e..bc52a9f356f 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu index 74e52315bd4..480d485d069 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu index fe0bff6a1d3..d3da5f4e665 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu index 55df1a66635..1c1c2d8207f 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu index 03a9c61e409..371d933e3e1 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu index 67ba153c605..7491148dcde 100644 --- a/hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu index 9f7bcec9ed9..d04159a62a0 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu index 7116702f3fd..28ad6c14963 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu index 04f18ac0fac..7afb267e3eb 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu index c7c7c9e69f5..69758584cb6 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu index b4ea8bc3301..3be45956bb4 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu index ec99965c928..698095dad6a 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu index d1dd9645233..16d443a9ad1 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu index 83274ca3fdb..1e8f6af71bd 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu index 80e9eb0e2c0..4ec68886112 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu index fbbc273b7e2..670b5952d9d 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu index f4f4829f331..b9778dc92e1 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu index c768a89fdfb..446e917c795 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu index 89c2db39e61..fd62a2c5435 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu index 5b87286aef4..0a397f4acf2 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu index 7506097821d..4d3c553e296 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu index d3b7b0f87b2..77621846ffe 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu index 4d8625cd6dc..7d217ac2733 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<80, cutlass::half_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu index f6f129c550b..0b6430abc2f 100644 --- a/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM128 -template void run_mha_fwd_<90, cutlass::half_t, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 128, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu new file mode 100644 index 00000000000..ea1e266f8d4 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu new file mode 100644 index 00000000000..2d7488fefe2 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu new file mode 100644 index 00000000000..8718571e30c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu new file mode 100644 index 00000000000..f7dfc18fc1e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..935f5a0fe60 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu new file mode 100644 index 00000000000..3f4d858ff57 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..54d720efeb3 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu new file mode 100644 index 00000000000..b9b93af4fc5 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu new file mode 100644 index 00000000000..39d9167b9f1 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu new file mode 100644 index 00000000000..0f86458012a --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu new file mode 100644 index 00000000000..bd6f4df8f69 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu new file mode 100644 index 00000000000..1824b86c64c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu new file mode 100644 index 00000000000..87dd01725a5 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu new file mode 100644 index 00000000000..6594d560123 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..d7dc84ebc1c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu new file mode 100644 index 00000000000..b9d6e54cbed --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..a8c47652ec1 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu new file mode 100644 index 00000000000..32d17c7665d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu new file mode 100644 index 00000000000..365017c256d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu new file mode 100644 index 00000000000..82cfdf040b0 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu new file mode 100644 index 00000000000..f3254936a47 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu new file mode 100644 index 00000000000..931a6dbf869 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu new file mode 100644 index 00000000000..5c8877a756d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu new file mode 100644 index 00000000000..1e230ab084b --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..03716c86237 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu new file mode 100644 index 00000000000..54c66c9552e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..e5e0ec47db1 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu new file mode 100644 index 00000000000..e4411b5db32 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu new file mode 100644 index 00000000000..157ed06dddf --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu new file mode 100644 index 00000000000..7ef5adc9e85 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM192 +template void run_mha_fwd_<90, cutlass::half_t, 192, 128, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu index 96243edf0ae..bf8386b8297 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu index a51a8945888..cbc6f988424 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu index 515d88a11aa..d5aa15b5c8c 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu index e5a154c18db..b8593612df3 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu index 2bd860c7758..a03514d919b 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu index 6e1d8037819..df547749e93 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu index 942685e148f..1ddb1916209 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu index d6050520e02..cefffcd2169 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu index 7ee500a80ee..3d4333b9e1f 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu index 1f9d8bfd56c..35a2abef8c9 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu index 0313ad1b2c8..99e34ac0bfb 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu index 8d87eb21f2d..ed1cf22d5c4 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu index 081bb31b12e..4527d9a2793 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu index a9b5aa0de49..41fcf800170 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu index d465545ef8f..704cbcb337e 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu index 68c57145532..e0ea082156b 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu index e1d656e5ae3..a9c00408a8b 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu index 57d1c73d85e..1497e7aa843 100644 --- a/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu index 5104d439810..c66ea9baca1 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu index cbc61f27e76..a7e472b478b 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu index f08ba1459e1..9f090aeeda8 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu index e413758de8f..2205168a67f 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu index c8205c1605f..2a01898b560 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu index f0db959e0f3..888e241a9f1 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu index 249cae97ff7..2a6bde7a39f 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu index 14b073deb31..3d315187b2d 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu index 8152dbaa6b4..3c3d0938034 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu index d0b0df02798..4ca103566d6 100644 --- a/hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu index 24f3e128dca..16debf27799 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu index 6eabe0ee269..43c2615718e 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu index 5c780da81f7..d9d483838f1 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu index 5a943660174..70543998d94 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu index 9815dd13551..c30c7e3b8b9 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu index 66fc2cb8a6b..7ae26e69c96 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu index 2ceddd8cae6..155b5a539fd 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu index 4c64bc61c57..3e6173c31c2 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu index 6ad1a1529c1..e1e3191a202 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu index f0ee8c0159f..8272ecb76cb 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu index 4a9583196bb..74606c39373 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu index 2b65a88f06c..89a58502b37 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu index e324a932671..b13373806a1 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu index a8be65709d1..1335fad7f2b 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu index 1ad82d7edfb..18c31bdfc0e 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu index 75f53ee4f6e..18a5603cf1f 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu index 09f76526338..4e99c7db027 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<80, cutlass::half_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu index e5299154c3d..82f8204aa66 100644 --- a/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM192 -template void run_mha_fwd_<90, cutlass::half_t, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 192, 192, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu index 364579e1b32..cb851a77110 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu index a5f821becd7..ae2871c1655 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu index 364bd2b3aee..ed24fbffef9 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu index 3d2e337e164..ffca9c7f8fe 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu index 310c4a5c38f..57a06bd6e66 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu index 96f5bbf3ada..ccdcf21e492 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu index 7d3131bd5bb..c2bc7787765 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu index 7715a52531d..6bba953fc69 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu index 686bdfa5c7f..25c96174c79 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu index 97fdc0094c0..f172239e5b9 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu index 25a90d3be3a..9dde6adb04b 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu index 4c91ee5bc06..2317adef8c5 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu index ef12a584c65..b9b3b74867e 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu index e4e746f9d3d..c57a5a30abb 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu index 99924af52c7..4f59a6aea92 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu index 705582b9f22..2c2de1574ac 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu index 7e969012051..0dbd062c79f 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu index 058eca375d5..bee54c702de 100644 --- a/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu index 679066d5443..c02e6833494 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu index e4ce6f9aa15..02b50b98b8d 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu index 03eff4c6f7b..6599de63bbd 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu index 26df5e592eb..a1cdc775cbb 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu index 57de7421dfd..6d01be60f58 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu index 53974f3e61f..968bbf36f83 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu index 24e1f635638..d564a622111 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu index a2fc325dad1..cb5bccc176c 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu index 2c1f5f56f1f..146a7bc3430 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu index 7cbdff3e8a5..a195e0931c0 100644 --- a/hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu index b81bf0b99b0..045fc71bedb 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu index 88a00e91212..a31da2eddf4 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu index c28edfd8f95..7382b58a231 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu index dbcd163308e..87ca31ce902 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu index 63620ec90a4..60f4d6ebbf6 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu index d8c11ee6a0e..e0d5d318bcd 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu index 4af31d0bf9c..dec7db046bd 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu index c7a04dc47b8..7b71f435226 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu index 9bca3a1c5ee..08fc989af8b 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu index acd0fa660fa..2cc8b5b86d4 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu index a38430fb3c7..644e268469f 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu index 03bb0516f77..1ebcec8b3fe 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu index 8ea90bd417a..780ade7f6b8 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu index f9144326426..bfcffe2a39a 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu index e7e1cecd1f8..ba4ba78ad49 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu index 18b79da92c9..f04260ba4f1 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu index 1c1c9470d6d..33c78e53059 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<80, cutlass::half_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu index 6cadc2641d5..8388420921e 100644 --- a/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM256 -template void run_mha_fwd_<90, cutlass::half_t, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 256, 256, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu new file mode 100644 index 00000000000..2f4ceaaed53 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu new file mode 100644 index 00000000000..5fd59af3486 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu new file mode 100644 index 00000000000..e0f885b0f72 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu new file mode 100644 index 00000000000..6dcda019627 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..5d20be6d2a7 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu new file mode 100644 index 00000000000..47463a7151c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..622b5533ce8 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu new file mode 100644 index 00000000000..c83f44722cd --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu new file mode 100644 index 00000000000..5c9130f8648 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu new file mode 100644 index 00000000000..a152022cb65 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 512, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu new file mode 100644 index 00000000000..ef05aa2038d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu new file mode 100644 index 00000000000..19fe6d94f7d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu new file mode 100644 index 00000000000..6eb2d3d134b --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu new file mode 100644 index 00000000000..ffbc9982122 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..3d35075b48d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu new file mode 100644 index 00000000000..c2af33cf533 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..e07547c92d0 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu new file mode 100644 index 00000000000..1a04eb01f5e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu new file mode 100644 index 00000000000..da9afc11571 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu new file mode 100644 index 00000000000..5e63a15515f --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +#ifndef FLASHATTENTION_DISABLE_HDIM64 +template void run_mha_fwd_<90, cutlass::half_t, 64, 512, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +#endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu index 4b650f53cfc..4134d7d80bb 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu index 29cb3fe18be..11e3503b0d9 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu index 2612bc9c98b..67e39bd7371 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu index 4c5fae060a8..c37844daa56 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu index c0b58521bc5..f0c40e2f89f 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu index 0a058847247..3ed9694908c 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu index b421199714b..4a16aae66c0 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu index 7f337595b56..b5b5fc26b28 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu index c4c35a18c7c..3b29be627ed 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu index 9ea549e1173..5f1c298c4c4 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu index 8ffc852e3e9..64895643d20 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu index 7143da2f79a..dd508590d66 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu index 4f7cd4f8e4c..8411b6fccbd 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu index 5a9bb142056..b5b4f40770e 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu index dc9b71a5b3d..e608da04b02 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu index 4c5440436a2..c69b78ac3b6 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu index d988a48f990..170cdb5cb8c 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu index c6ae246e7f2..ef0d1e921c1 100644 --- a/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu index 761a625564e..6a7fc29ddda 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu index a74d7c2c3b3..faeb6c487fc 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu index 6d48fb099b1..655258d5194 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu index 0e49f26aaaa..4bd8ad8f267 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu index f780a8eb73a..657820f2854 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu index 948c8b17c85..cb0955d1a53 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu index 519783851fe..357b64e83b6 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu index d5392ef3b07..c1207925864 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu index 06086d40840..21687f8932b 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu index a15ab4c60da..4df8ed64d7b 100644 --- a/hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu index 7038c0ad726..b601195d7e0 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu index 9a805fd3e5b..ced47531898 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu index b23cb43e770..03090f73cb2 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu index c18f470fcfe..d6fe1559ca1 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu index d61b04a07e1..7b5ae4a56aa 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu index 1d33fe12e05..6c603b4dcaa 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu index 03ac4d2f84e..26d25fc1909 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu index 7b031a49031..05a0baf18b6 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu index 77dbc58123b..3a45776537f 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu index 6bae5faa535..9b80bae51f6 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu index 30f666a73fc..f6810efafb8 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu index 358e813eca8..98c018893f1 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu index f5df3f502b8..a10dfaca722 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu index f16185c3ac2..b912a81443e 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu index 796e4d63a3e..8603c396e1f 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu index 6eeb977415f..dc55dbc66aa 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu index aa1d2cd05f0..ef48844972a 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<80, cutlass::half_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu index 5a92ebdddfb..b1c0ead6e5c 100644 --- a/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM64 -template void run_mha_fwd_<90, cutlass::half_t, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 64, 64, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu index 78c390e5ef0..5d76d0fff04 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu index 2b5aaff0d87..44ea823d272 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu index f0fa3ac63d1..30fe623508b 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu index 0d9407b2ce4..6eb12dc80a6 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu index 223b6783e02..b806fc9d501 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu index 2f49d5f5aae..8f0a26da03a 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu index 9661156d889..6de2819a172 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu index b5f6d7f8757..16927295b82 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu index 82b827e180a..08413072092 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu index 042dd0cc71b..7d4dcdc293b 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu index 4712aed6c3b..b4dfbf7f8b2 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu index 8295033deeb..1fa048752dc 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu index 21c43e6dbd1..e0b6a75e635 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu index d3317ad6280..e257b42f79c 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu index 86218988c2c..f97ab4733a0 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu index 7a6450373c8..cee43ef94cd 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu index 34c1a3d3f04..0442e1f94b5 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu index 96affd254c9..bc71fa9e71f 100644 --- a/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::bfloat16_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu index 489717ff2cd..b61dd71885d 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu index 69917aa1ea4..f47e1f5cdac 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu index 3e3cc66f667..215752f1b06 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu index e5f53e49c51..207afc79242 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu index 0899aa8987b..6c38c083384 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu index 22f4cf6b14c..dc2eb35dc29 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu index d601d694d4f..f04e8bca6f3 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu index 1c5ba9b0066..2697f6910e7 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu index 8073b677a1d..e7a98b2e6ee 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu index 857be35920c..98fb39c86ee 100644 --- a/hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu index 6931ffa2792..cb938ad93b0 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu index 84facb47e70..e2dc45c79c6 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu index 878d160ff2b..64f99c05a32 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu index e5561f7d63f..3fdbbf23bac 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu index 30474d3543d..ffe202ee394 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu index 074f7232f23..42740f0228b 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu index 734abb7b0e9..829929980d0 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, true, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu index 285e7ef520d..d6a330432a4 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu index d552e45db1a..39c774e6f77 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, true, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu index 64ca02345eb..bc54be11e6c 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu index 3d8bb7c2775..a68790500d8 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, false, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu index 6fab8802c5a..3bca3065c7f 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu index 1fb30696ddb..985692b9fa1 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, false, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu index af9b88d9a3d..3c99cb6b5a0 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, false, false, true, false>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu index 5f9794a9873..cf77a1ae819 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu index c906649acc6..f9a46a44dd5 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, false, false, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu index 2d7ac26e250..9b4dbbba58a 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu @@ -6,7 +6,7 @@ #ifndef FLASHATTENTION_DISABLE_SM8x #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<80, cutlass::half_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_<86, cutlass::half_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<80, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<86, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif #endif diff --git a/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu index 171f28e9ce2..da5373fd13e 100644 --- a/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu +++ b/hopper/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu @@ -5,5 +5,5 @@ #include "flash_fwd_launch_template.h" #ifndef FLASHATTENTION_DISABLE_HDIM96 -template void run_mha_fwd_<90, cutlass::half_t, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_<90, cutlass::half_t, 96, 96, true, false, true, true>(Flash_fwd_params ¶ms, cudaStream_t stream); #endif diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu new file mode 100644 index 00000000000..cc3a8a7c913 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu new file mode 100644 index 00000000000..d6d6df0d4ee --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_paged_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu new file mode 100644 index 00000000000..bd85f7608f6 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu new file mode 100644 index 00000000000..733511adb43 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_paged_split_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..c62ccf28d3c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu new file mode 100644 index 00000000000..b7e51551a04 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..0dbd0045425 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu new file mode 100644 index 00000000000..51a14371284 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu new file mode 100644 index 00000000000..24a64e8e49e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_split_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu new file mode 100644 index 00000000000..50c78f3d5d4 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu new file mode 100644 index 00000000000..526a51fb71e --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu new file mode 100644 index 00000000000..4e5d9cc4fe2 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu new file mode 100644 index 00000000000..f553af139f2 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu new file mode 100644 index 00000000000..aa2a8260d25 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..bbc4449ba21 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu new file mode 100644 index 00000000000..02ca85ad672 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..d090fde972b --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu new file mode 100644 index 00000000000..d48f60ad7e2 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu new file mode 100644 index 00000000000..9dda19d1cea --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu new file mode 100644 index 00000000000..f3e51fc9ebd --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu @@ -0,0 +1,5 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu new file mode 100644 index 00000000000..453282a4f29 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu new file mode 100644 index 00000000000..72736d8ef7a --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_paged_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu new file mode 100644 index 00000000000..97895aa708c --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu new file mode 100644 index 00000000000..423c42221e0 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_paged_split_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu new file mode 100644 index 00000000000..98c89572117 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu new file mode 100644 index 00000000000..69108d025fa --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu new file mode 100644 index 00000000000..da39ba2731a --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu new file mode 100644 index 00000000000..be6496d1956 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu new file mode 100644 index 00000000000..a5a80909072 --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_split_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_split_sm90.cu" \ No newline at end of file diff --git a/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu new file mode 100644 index 00000000000..62fe142562d --- /dev/null +++ b/hopper/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu @@ -0,0 +1,6 @@ +// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +// Splitting the different template instantiations to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu" +#include "flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu" \ No newline at end of file diff --git a/hopper/mainloop_fwd_sm80.hpp b/hopper/mainloop_fwd_sm80.hpp index e43904518cf..0fb32c7a900 100644 --- a/hopper/mainloop_fwd_sm80.hpp +++ b/hopper/mainloop_fwd_sm80.hpp @@ -22,7 +22,7 @@ namespace flash { using namespace cute; -template struct CollectiveMainloopFwdSm80 { @@ -30,6 +30,7 @@ struct CollectiveMainloopFwdSm80 { static constexpr int kStages = Stages; static_assert(kStages > 0, "kStages must be greater than 0"); using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; @@ -177,12 +178,15 @@ struct CollectiveMainloopFwdSm80 { ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; StrideQK const stride_K_new; Element const* const ptr_V_new; StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideQK const stride_Qv; Element const* const ptr_rotary_cos; ShapeRotary const shape_rotary; StrideRotary const stride_rotary_cos; @@ -218,6 +222,7 @@ struct CollectiveMainloopFwdSm80 { ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; @@ -272,7 +277,7 @@ struct CollectiveMainloopFwdSm80 { // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) // (assigning it to params.softmax_scale_log2). return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, - args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.stride_V, + args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, @@ -430,11 +435,11 @@ struct CollectiveMainloopFwdSm80 { } cute::cp_async_fence(); - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element, true /*KV_Same_Iter*/>; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.stride_V, + params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k ); @@ -730,11 +735,11 @@ struct CollectiveMainloopFwdSm80 { params.ptr_rotary_sin, params.stride_rotary_sin, params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary); - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.stride_V, + params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position ); diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index dbbf2f8f821..1834f200c57 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -27,14 +27,16 @@ namespace flash { using namespace cute; -template +template struct CollectiveMainloopFwdSm90 { static constexpr int kStages = Stages; using ClusterShape = ClusterShape_; using TileShape_MNK = TileShape_MNK_; + using TileShape_MNK_PV = Shape(TileShape_MNK{})), Int, decltype(get<1>(TileShape_MNK{}))>; + using TileShape_MNK_QV = Shape(TileShape_MNK{})), decltype(get<1>(TileShape_MNK{})), Int>; using Element = Element_; using ElementAccum = ElementAccum_; using ArchTag = ArchTag_; @@ -45,6 +47,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr bool Varlen = Varlen_; static constexpr bool PagedKV = PagedKV_; static constexpr bool AppendKV = AppendKV_; + static constexpr bool HasQv = HasQv_; static constexpr bool PackGQA = PackGQA_; static constexpr bool Split = Split_; static constexpr bool V_colmajor = V_colmajor_; @@ -53,6 +56,8 @@ struct CollectiveMainloopFwdSm90 { static constexpr bool Use_TMA_KV = !PagedKV; static_assert(Use_TMA_KV || CUTE_STATIC_V(size(ClusterShape{})) == 1, "If not using TMA for KV, ClusterShape must be 1"); static_assert(Use_TMA_KV || !V_colmajor, "If not using TMA for KV, V_colmajor is not supported"); + static constexpr bool SameHeadDim = get<2>(TileShape_MNK{}) == kHeadDimV; + static constexpr bool LargeHeadDimV = kHeadDimV > 256; using SeqlenInfo_t = flash::SeqlenInfoQKNewK; static_assert(ArchTag::kMinComputeCapability >= 90); @@ -64,34 +69,49 @@ struct CollectiveMainloopFwdSm90 { static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static_assert(!LargeHeadDimV || kHeadDimV % 256 == 0); + static_assert(!LargeHeadDimV || kBlockM <= 64, "kBlockM must be 64 or less for large Headdim_V"); + static_assert(!LargeHeadDimV || !MmaPV_is_RS, "MmaPV must be SS for large Headdim_V"); + static_assert(!(HasQv && !IntraWGOverlap), "HasQv requires IntraWGOverlap"); + // Register bandwidth is actually a bottleneck so we don't want Q to be in registers. // Leaving this option here for reference. - static constexpr bool Mma0_is_RS = false; - // We can have Mma1 (P @ V) with P in smem in rmem to reduce register pressure at the cost of more smem. - static_assert(!(!Mma1_is_RS && !IntraWGOverlap), "Mma1 must be RS if IntraWGOverlap is enabled"); - static_assert(!(!Mma1_is_RS && Is_FP8), "Mma1 must be RS if FP8"); - static_assert(!(!Mma1_is_RS && Transpose_V), "Mma1 must be RS if Transpose_V"); - - using AtomLayoutMNK = Layout, _1, _1>>; - using TiledMma0 = decltype(cute::make_tiled_mma( + static constexpr bool MmaQK_is_RS = false; + // We can have MmaPV with P in smem in rmem to reduce register pressure at the cost of more smem. + static_assert(!(!MmaPV_is_RS && !IntraWGOverlap), "MmaPV must be RS if IntraWGOverlap is disabled"); + static_assert(!(!MmaPV_is_RS && Is_FP8), "MmaPV must be RS if FP8"); + static_assert(!(!MmaPV_is_RS && Transpose_V), "MmaPV must be RS if Transpose_V"); + + using AtomLayoutQK = Layout, _1, _1>>; + using TiledMmaQK = decltype(cute::make_tiled_mma( std::conditional_t< - !Mma0_is_RS, + !MmaQK_is_RS, decltype(cute::GMMA::ss_op_selector()), decltype(cute::GMMA::rs_op_selector()) >{}, - AtomLayoutMNK{})); - using TiledMma1 = decltype(cute::make_tiled_mma( + AtomLayoutQK{})); + using AtomLayoutPV = std::conditional_t< + !LargeHeadDimV, + AtomLayoutQK, + Layout, _1>> + >; + using TiledMmaPV = decltype(cute::make_tiled_mma( std::conditional_t< - !Mma1_is_RS, + !MmaPV_is_RS, decltype(cute::GMMA::ss_op_selector(TileShape_MNK{})), GMMA::Major::K, MmaMajorV>()), + TileShape_MNK_PV, GMMA::Major::K, MmaMajorV>()), decltype(cute::GMMA::rs_op_selector(TileShape_MNK{})), GMMA::Major::K, MmaMajorV>()) + TileShape_MNK_PV, GMMA::Major::K, MmaMajorV>()) >{}, - AtomLayoutMNK{})); + AtomLayoutPV{})); + using TiledMmaQV = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutQK{})); - static constexpr int NumMmaThreads = size(TiledMma0{}); + static constexpr int NumMmaThreadsQK = size(TiledMmaQK{}); + static constexpr int NumMmaThreads = size(TiledMmaPV{}); static constexpr int NumProducerThreads = !Transpose_V && Use_TMA_KV && Use_TMA_Q ? cutlass::NumThreadsPerWarp : cutlass::NumThreadsPerWarpGroup; + static_assert(NumMmaThreadsQK % cutlass::NumThreadsPerWarpGroup == 0); static_assert(NumMmaThreads % cutlass::NumThreadsPerWarpGroup == 0); static constexpr int NumMmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); @@ -107,54 +127,67 @@ struct CollectiveMainloopFwdSm90 { make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + decltype(cute::get<1>(TileShape_MNK_PV{})), decltype(cute::get<2>(TileShape_MNK_PV{}))>()); using SmemLayoutVt = decltype(tile_to_shape( SmemLayoutAtomVt{}, - make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}), + make_shape(shape<1>(TileShape_MNK_PV{}), shape<2>(TileShape_MNK_PV{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); using SmemLayoutAtomVtMma = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + decltype(cute::get<1>(TileShape_MNK_PV{})), decltype(cute::get<2>(TileShape_MNK_PV{}))>()); using SmemLayoutVtMma = decltype(tile_to_shape( SmemLayoutAtomVtMma{}, - make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}), + make_shape(shape<1>(TileShape_MNK_PV{}), shape<2>(TileShape_MNK_PV{}), Int{}), std::conditional_t, cute::Step<_2, _1, _3>>{})); + using SmemLayoutAtomQv = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); + using SmemLayoutQv = decltype(tile_to_shape(SmemLayoutAtomQv{}, select<0, 2>(TileShape_MNK_QV{}))); + using SmemLayoutAtomVMmaQV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK_QV{})), decltype(cute::get<2>(TileShape_MNK_QV{}))>()); + using SmemLayoutVMmaQV = decltype(tile_to_shape( + SmemLayoutAtomVMmaQV{}, + make_shape(shape<1>(TileShape_MNK_QV{}), shape<2>(TileShape_MNK_QV{}), Int{}))); + static_assert(CUTE_STATIC_V(size(SmemLayoutVMmaQV{})) == size(SmemLayoutVtMma{})); + // Only used if we're using cp.async to load V using SmemLayoutAtomVCpAsync = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + decltype(cute::get<1>(TileShape_MNK{})), Int>()); using SmemLayoutVCpAsync = decltype(tile_to_shape( SmemLayoutAtomVCpAsync{}, - make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + make_shape(shape<1>(TileShape_MNK{}), Int{}, Int{}))); using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{}))); + // Only for LargeHeadDimV where WG0 sends WG1 the scales + using SmemLayoutScale = cute::Layout, Int>>; + using SmemCopyAtomP = Copy_Atom; // Use LDSM.T and STSM to transpose V in the case of FP8 and V being row-major. // For FP16/BF16 we don't do any transposing. - static_assert(!Transpose_V || (kHeadDim % 32 == 0 && kBlockN % 32 == 0)); - static constexpr bool kHeadDim_multiple_64 = kHeadDim % 64 == 0; - // Either kHeadDim is a multiple of 64 (in which case we use a block size of 64 x 32 for the transpose), + static_assert(!Transpose_V || (kHeadDimV % 32 == 0 && kBlockN % 32 == 0)); + static constexpr bool kHeadDimV_multiple_64 = kHeadDimV % 64 == 0; + // Either kHeadDimV is a multiple of 64 (in which case we use a block size of 64 x 32 for the transpose), // or we need kBlockN to be a multiple of 64 (in which case we use a block size of 32 x 64 for the transpose). - static_assert(!Transpose_V || (kHeadDim_multiple_64 || kBlockN % 64 == 0)); - using LDSM_thread_shape = std::conditional_t, Shape<_16, _4, _1, _2>>; - using LDSM_thread_stride = std::conditional_t, Stride<_4, _1, _0, _64>>; + static_assert(!Transpose_V || (kHeadDimV_multiple_64 || kBlockN % 64 == 0)); + using LDSM_thread_shape = std::conditional_t, Shape<_16, _4, _1, _2>>; + using LDSM_thread_stride = std::conditional_t, Stride<_4, _1, _0, _64>>; using LDSM_value_shape = Shape<_2, _2, _1, _4>; using LDSM_value_stride = Stride<_1, _2, _16, _4>; - using LDSM_divide_shape = std::conditional_t, Shape<_32, _8>>; + using LDSM_divide_shape = std::conditional_t, Shape<_32, _8>>; using S2RTiledCopyVt = decltype(make_tiled_copy( Copy_Atom{}, Layout{}, Layout{})); - using STSM_thread_shape = std::conditional_t, Shape<_8, _4, _2, _2>>; - using STSM_thread_stride = std::conditional_t, Stride<_4, _1, _32, _64>>; + using STSM_thread_shape = std::conditional_t, Shape<_8, _4, _2, _2>>; + using STSM_thread_stride = std::conditional_t, Stride<_4, _1, _32, _64>>; using STSM_value_shape = Shape<_1, _4, _2, _2>; using STSM_value_stride = Stride<_0, _1, _4, _8>; using STSM_divide_shape = Shape<_8, _16>; - // These will not permute the columns of V (the kHeadDim dimension) but incur bank conflicts + // These will not permute the columns of V (the kHeadDimV dimension) but incur bank conflicts // so a little slower (e.g. 1150 TFLOPS for hdim 256 instead of 1200 TFLOPS). // Instead we will permute the cols of V, and un-permute the cols of O in the epilogue. // using STSM_value_shape = Shape<_2, _4, _1, _2>; @@ -168,14 +201,15 @@ struct CollectiveMainloopFwdSm90 { using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); // We use CpAsync for K and V if PagedKV and AppendKV, since TMA doesn't work there + static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. // We want each thread to have at least 2 loads in the K direction since in the case of non-interleaved // rotary (combining elements at indices 0 and rotary_dim/2, 1 and rotary_dim/2+1, etc), each thread will // load twice from the same row. - static constexpr int kBytePerHalfRow = kHeadDim / 2 * sizeof(Element); + static constexpr int kBytePerHalfRow = kHeadDimGCD / 2 * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerHalfRow % 128 == 0 ? 128 : (kBytePerHalfRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); @@ -221,14 +255,22 @@ struct CollectiveMainloopFwdSm90 { GmemTiledCopyKV{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, select<1, 0, 2, 3>(StrideV{})), take<0, 2>(SmemLayoutVt{}), - select<2, 1>(TileShape_MNK{}), + select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + using TMA_Qv_ = decltype(make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + SmemLayoutQv{}, + TileShape_MNK_QV{}, + ClusterShape{})); + using TMA_Qv = std::conditional_t; + // Set the bytes transferred in this TMA transaction (may involve multiple issues) static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesK = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesV = static_cast(size(take<0, 2>(SmemLayoutVt{})) * cutlass::sizeof_bits_v / 8); - static_assert(TmaTransactionBytesK == TmaTransactionBytesV); + static constexpr uint32_t TmaTransactionBytesQv = static_cast(size(SmemLayoutQv{}) * cutlass::sizeof_bits_v / 8); using PipelineTmaAsync = std::conditional_t, typename cutlass::PipelineTmaAsync>; using MainloopPipelineK = std::conditional_t>; @@ -241,48 +283,68 @@ struct CollectiveMainloopFwdSm90 { // If PackGQA, we use cp.async (instead of TMA) to load Q, so we want smem_q to be aligned // and have sQ being position_independent_swizzle_tensor. // If !Use_TMA_KV, we use cp.async (instead of TMA) to load K & V, so we want smem_k and smem_v to be aligned. - static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !AppendKV && !Mma0_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); + static constexpr size_t SmemAlignmentQ = Use_TMA_Q && !MmaQK_is_RS ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQ{}); static constexpr size_t SmemAlignmentK = Use_TMA_KV && !AppendKV ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutK{}); static constexpr size_t SmemAlignmentVtNoTranspose = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); + static constexpr size_t SmemAlignmentQv = Use_TMA_Q ? 128 : cutlass::detail::alignment_for_swizzle(SmemLayoutQv{}); static_assert(SmemAlignmentQ >= 128 and SmemAlignmentK >= 128 && SmemAlignmentVtNoTranspose >= 128, "Require at least 128B alignment"); static constexpr size_t SmemAlignmentP = cutlass::detail::alignment_for_swizzle(SmemLayoutP{}); static_assert(SmemAlignmentP >= 128, "Require at least 128B alignment"); - using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; + using SmemP_t = std::conditional_t, cute::array_aligned, SmemAlignmentP>>; + using SmemScale_t = std::conditional_t, cute::array_aligned, 128>>; + using SmemQv_t = std::conditional_t, cute::array_aligned, SmemAlignmentQv>>; // Sometimes even with SmemP_t = cute::array, putting it in the TensorStorage struct causes // smem size to go from 227KB to 228KB and we get "invalid argument". - struct TensorStorageWithoutPNoTranspose : cute::aligned_struct { + struct TensorStorageWithoutPNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; }; - struct TensorStorageWithPNoTranspose : cute::aligned_struct { + struct TensorStorageWithPNoTranspose : cute::aligned_struct { cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; SmemP_t smem_p; }; + struct TensorStorageWithPScaleNoTranspose : cute::aligned_struct { + cute::array_aligned, SmemAlignmentVtNoTranspose> smem_v; + cute::array_aligned, SmemAlignmentQ> smem_q; + cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; + SmemP_t smem_p; + SmemScale_t smem_scale; + }; - using TensorStorageNoTranspose = std::conditional_t; + using TensorStorageNoTranspose = std::conditional_t< + MmaPV_is_RS, + TensorStorageWithoutPNoTranspose, + std::conditional_t + >; static constexpr size_t SmemAlignmentVt = cutlass::detail::alignment_for_swizzle(SmemLayoutVt{}); static constexpr size_t SmemAlignmentV = cutlass::detail::alignment_for_swizzle(SmemLayoutVtMma{}); static_assert(SmemAlignmentVt >= 128 and SmemAlignmentV >= 128, "Require at least 128B alignment"); - struct TensorStorageTransposeV : cute::aligned_struct { + struct TensorStorageTransposeV : cute::aligned_struct { cute::array_aligned, SmemAlignmentV> smem_v; cute::array_aligned, SmemAlignmentVt> smem_vt; cute::array_aligned, SmemAlignmentQ> smem_q; cute::array_aligned, SmemAlignmentK> smem_k; + SmemQv_t smem_qv; + SmemScale_t smem_scale; }; using TensorStorage = std::conditional_t; // These are tuned for speed. They don't affect correctness. - static constexpr bool UseSchedulerBarrier = IntraWGOverlap + static constexpr bool UseSchedulerBarrier = (IntraWGOverlap ? (NumMmaWarpGroups >= 2) && (!Is_FP8 ? kHeadDim <= 128 : kHeadDim >= 128) - : NumMmaWarpGroups == 2; + : NumMmaWarpGroups == 2) + && !LargeHeadDimV; static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor); // Host side kernel arguments @@ -294,12 +356,15 @@ struct CollectiveMainloopFwdSm90 { ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; StrideQK const stride_K_new; Element const* const ptr_V_new; StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideQK const stride_Qv; Element const* const ptr_rotary_cos; ShapeRotary const shape_rotary; StrideRotary const stride_rotary_cos; @@ -335,12 +400,17 @@ struct CollectiveMainloopFwdSm90 { ShapeQKV const shape_K; StrideQK const stride_K; Element* const ptr_V; + int32_t const headdim_v; StrideV const stride_V; Element const* const ptr_K_new; ShapeQKV const shape_K_new; StrideQK const stride_K_new; Element const* const ptr_V_new; StrideV const stride_V_new; + Element const* const ptr_Qv; + StrideV const stride_Qv; + ShapeQPacked const shape_Qv_packed; + StrideQPacked const stride_Qv_packed; Element const* const ptr_rotary_cos; ShapeRotary const shape_rotary; StrideRotary const stride_rotary_cos; @@ -357,6 +427,7 @@ struct CollectiveMainloopFwdSm90 { TMA_V tma_load_V; TMA_K tma_load_K_new; TMA_V tma_load_V_new; + TMA_Qv tma_load_Qv; float const softmax_scale_log2; float const* ptr_q_descale, *ptr_k_descale, *ptr_v_descale; StrideDescale const stride_q_descale, stride_k_descale, stride_v_descale; @@ -388,12 +459,14 @@ struct CollectiveMainloopFwdSm90 { take<0, 2>(SmemLayoutK{}), TileShape_MNK{}, ClusterShape{}); // mcast along M mode for this N load, if any - Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), select<1, 0, 2, 3>(args.shape_K), select<1, 0, 2, 3>(args.stride_V)); + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), + make_shape(args.headdim_v, get<0>(args.shape_K), get<2>(args.shape_K), get<3>(args.shape_K)), + select<1, 0, 2, 3>(args.stride_V)); TMA_V tma_load_V = make_tma_copy( GmemTiledCopyKV{}, mV, take<0, 2>(SmemLayoutVt{}), - select<2, 1>(TileShape_MNK{}), + select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any Tensor mKnew = make_tensor(make_gmem_ptr(args.ptr_K_new), args.shape_K_new, args.stride_K_new); TMA_K tma_load_K_new = make_tma_copy_B_sm90( @@ -402,13 +475,29 @@ struct CollectiveMainloopFwdSm90 { take<0, 2>(SmemLayoutK{}), TileShape_MNK{}, ClusterShape{}); // mcast along M mode for this N load, if any - Tensor mVnew = make_tensor(make_gmem_ptr(args.ptr_V_new), select<1, 0, 2, 3>(args.shape_K_new), select<1, 0, 2, 3>(args.stride_V_new)); + Tensor mVnew = make_tensor(make_gmem_ptr(args.ptr_V_new), + make_shape(args.headdim_v, get<0>(args.shape_K_new), get<2>(args.shape_K_new), get<3>(args.shape_K_new)), + select<1, 0, 2, 3>(args.stride_V_new)); TMA_V tma_load_V_new = make_tma_copy( GmemTiledCopyKV{}, cute::conditional_return(mVnew, mV), take<0, 2>(SmemLayoutVt{}), - select<2, 1>(TileShape_MNK{}), + select<1, 2>(TileShape_MNK_PV{}), size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + auto shape_Qv = make_shape(get<0>(args.shape_Q), args.headdim_v, get<2>(args.shape_Q), get<3>(args.shape_Q)); + Tensor mQv = make_tensor(make_gmem_ptr(args.ptr_Qv), shape_Qv, args.stride_Qv); + TMA_Qv tma_load_Qv = [&] { + if constexpr (HasQv) { + return make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + mQv, + SmemLayoutQv{}, + TileShape_MNK_QV{}, + ClusterShape{}); // no mcast for Qv + } else { + return nullptr; + } + }(); // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); auto const shape_Q_packed = cute::conditional_return( @@ -419,6 +508,14 @@ struct CollectiveMainloopFwdSm90 { args.stride_Q, make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q)) ); + auto const shape_Qv_packed = cute::conditional_return( + shape_Qv, + make_shape(make_shape(qhead_per_khead, get<0>(shape_Qv)), get<1>(shape_Qv), get<2>(args.shape_K), get<3>(shape_Qv)) + ); + auto const stride_Qv_packed = cute::conditional_return( + args.stride_Qv, + make_stride(make_stride(get<2>(args.stride_Qv), get<0>(args.stride_Qv)), get<1>(args.stride_Qv), get<2>(args.stride_Qv) * qhead_per_khead, get<3>(args.stride_Qv)) + ); if (get<1>(args.shape_rotary) > 0) { assert(args.ptr_rotary_cos != nullptr && args.ptr_rotary_sin != nullptr); } @@ -429,14 +526,15 @@ struct CollectiveMainloopFwdSm90 { // (assigning it to params.softcap_val) and pre-multiply softcap_val * log2(e) // (assigning it to params.softmax_scale_log2). return {args.ptr_Q, args.shape_Q, args.stride_Q, shape_Q_packed, stride_Q_packed, - args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.stride_V, + args.ptr_K, args.shape_K, args.stride_K, args.ptr_V, args.headdim_v, args.stride_V, args.ptr_K_new, args.shape_K_new, args.stride_K_new, args.ptr_V_new, args.stride_V_new, + args.ptr_Qv, args.stride_Qv, shape_Qv_packed, stride_Qv_packed, args.ptr_rotary_cos, args.shape_rotary, args.stride_rotary_cos, args.ptr_rotary_sin, args.stride_rotary_sin, args.is_rotary_interleaved, args.ptr_pagetable, args.shape_pagetable, args.stride_pagetable, cutlass::FastDivmod(int(get<0>(args.shape_K))), cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), - tma_load_Q, tma_load_K, tma_load_V, tma_load_K_new, tma_load_V_new, + tma_load_Q, tma_load_K, tma_load_V, tma_load_K_new, tma_load_V_new, tma_load_Qv, !Has_softcap ? float(args.softmax_scale * M_LOG2E) : float(args.softcap_val * M_LOG2E), args.ptr_q_descale, args.ptr_k_descale, args.ptr_v_descale, args.stride_q_descale, args.stride_k_descale, args.stride_v_descale, @@ -453,6 +551,9 @@ struct CollectiveMainloopFwdSm90 { static void prefetch_tma_descriptors(Params const& params) { if constexpr (Use_TMA_Q) { cute::prefetch_tma_descriptor(params.tma_load_Q.get_tma_descriptor()); + if constexpr (HasQv) { + cute::prefetch_tma_descriptor(params.tma_load_Qv.get_tma_descriptor()); + } } if constexpr (Use_TMA_KV) { cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); @@ -509,7 +610,11 @@ struct CollectiveMainloopFwdSm90 { int &work_idx ) { - auto [m_block, bidh, bidb, split_idx] = block_coord; + // some of these are captured in lambda so can't use structured binding + int const m_block = get<0>(block_coord); + int const bidh = get<1>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access. if constexpr (Is_causal || Is_local || Varlen || Split) { @@ -541,6 +646,7 @@ struct CollectiveMainloopFwdSm90 { return cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVCpAsync{})); } }(); + Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{}); int const thread_idx = threadIdx.x % NumProducerThreads; int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; @@ -555,12 +661,13 @@ struct CollectiveMainloopFwdSm90 { bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); - Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(select<1, 0, 2, 3>(params.shape_K))(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); + Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(shape_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) // if (cute::thread0()) { printf("Varlen = %d, params.leftpad_k = %p, leftpad_k = %d\n", Varlen, params.leftpad_k, leftpad_k); } Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k), mVt_TMA), select<2, 1>(TileShape_MNK{}), make_coord(_0{}, _)); // (K, N, _) + Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k), mVt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) @@ -572,12 +679,25 @@ struct CollectiveMainloopFwdSm90 { auto block_tma_V = params.tma_load_V.get_slice(cluster_local_block_id.x); Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA)); // (TMA, k) Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V.partition_D(sVt)); // (TMA, PIPE) + auto [tQvgQv, tQvsQv] = [&] { + if constexpr (HasQv) { + auto shape_Qv = make_shape(get<0>(params.shape_Q), params.headdim_v, get<2>(params.shape_Q), get<3>(params.shape_Q)); + Tensor mQv = params.tma_load_Qv.get_tma_tensor(shape_Qv)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor gQv = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQv), select<0, 2>(TileShape_MNK_QV{}), make_coord(m_block, _0{})); // (M, Kv) + auto block_tma_Qv = params.tma_load_Qv.get_slice(_0{}); + Tensor tQvgQv = group_modes<0, 3>(block_tma_Qv.partition_S(gQv)); // (TMA) + Tensor tQvsQv = group_modes<0, 3>(block_tma_Qv.partition_D(sQv)); // (TMA) + return cute::make_tuple(tQvgQv, tQvsQv); + } else { + return cute::make_tuple(nullptr, nullptr); + } + }(); - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumProducerThreads, Element, Transpose_V || !IntraWGOverlap /*KV_Same_Iter*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.stride_V, + params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k ); @@ -690,16 +810,21 @@ struct CollectiveMainloopFwdSm90 { if constexpr (Use_TMA_Q) { // Wait for the MMA warpgroups to signal that smem_q is ready if (SingleProducerWarp || warp_idx_in_warpgroup == 0) { - cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); } if ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()) { shared_storage.pipelines.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); copy(params.tma_load_Q.with(reinterpret_cast(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), tQgQ, tQsQ); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.arrive_and_expect_tx(TmaTransactionBytesQv); + copy(params.tma_load_Qv.with(reinterpret_cast(shared_storage.pipelines.barrier_Qv), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), + tQvgQv, tQvsQv); + } } } else { // Load Q with cp.async - cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element>; @@ -707,6 +832,15 @@ struct CollectiveMainloopFwdSm90 { auto &barrier_Q = shared_storage.pipelines.barrier_Q; cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Q)); barrier_Q.arrive(); + if constexpr (HasQv) { + Tensor mQv = make_tensor(make_gmem_ptr(params.ptr_Qv + seqlen_info.offset_q * get<0>(params.stride_Qv)), params.shape_Qv_packed, params.stride_Qv_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor sQv_pi = cute::as_position_independent_swizzle_tensor(sQv); + using PackGQAt = flash::PackGQAManager(TileShape_MNK_QV{}), get<2>(TileShape_MNK_QV{}), NumProducerThreads, Element>; + PackGQAt::load_Q(mQv, sQv_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block); + auto &barrier_Qv = shared_storage.pipelines.barrier_Qv; + cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Qv)); + barrier_Qv.arrive(); + } } // Wait for the MMA WGs to signal that smem_v are ready and V can be copied from gmem @@ -821,13 +955,19 @@ struct CollectiveMainloopFwdSm90 { CUTLASS_DEVICE void mma_init() { + int warp_group_idx = flash::canonical_warp_group_idx_nosync(); // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + if (!LargeHeadDimV || warp_group_idx == 1) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + } + if (LargeHeadDimV && warp_group_idx > 1) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } if constexpr (UseSchedulerBarrier) { // We have NamedBarrier for up to 3 WGs static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3); // WG1 needs the very first signal to start - if (flash::canonical_warp_group_idx_nosync() == 1) { + if (warp_group_idx == 1) { cutlass::arch::NamedBarrier::arrive(2 * cutlass::NumThreadsPerWarpGroup, static_cast(FwdNamedBarriers::WarpSchedulerWG1) /*id*/); } } @@ -867,41 +1007,69 @@ struct CollectiveMainloopFwdSm90 { Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); Tensor sP = [&] { - if constexpr (Mma1_is_RS) { - // We might not have smem_p if !Mma1_is_RS1, just use smem_q as a placeholder since we don't use it + if constexpr (MmaPV_is_RS) { + // We might not have smem_p if !MmaPV_is_RS, just use smem_q as a placeholder since we don't use it return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutP{}); } else { return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); } }(); - - if constexpr (!Mma0_is_RS) { - static_assert(stride<0>(typename TiledMma0::ALayout{}) == 0 and - stride<0>(typename TiledMma0::BLayout{}) == 0 and - size<0>(typename TiledMma0::ALayout{}) == cutlass::NumThreadsPerWarpGroup and - size<0>(typename TiledMma0::BLayout{}) == cutlass::NumThreadsPerWarpGroup, + Tensor sScale = [&] { + if constexpr (LargeHeadDimV) { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); + } else { // won't be used, just a placeholder + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_q.data()), SmemLayoutScale{}); + } + }(); + Tensor sQv = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_qv.data()), SmemLayoutQv{}); + Tensor sVMmaQV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVMmaQV{}); + + if constexpr (!MmaQK_is_RS) { + static_assert(stride<0>(typename TiledMmaQK::ALayout{}) == 0 and + stride<0>(typename TiledMmaQK::BLayout{}) == 0 and + size<0>(typename TiledMmaQK::ALayout{}) == cutlass::NumThreadsPerWarpGroup and + size<0>(typename TiledMmaQK::BLayout{}) == cutlass::NumThreadsPerWarpGroup, "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); } - constexpr int MmaWarpGroups = size(TiledMma0{}) / cutlass::NumThreadsPerWarpGroup; + static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup; Layout warp_group_thread_layout = make_layout(make_shape(Int{}), make_stride(Int{})); int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); - TiledMma0 tiled_mma0; - TiledMma1 tiled_mma1; - auto wg_mma0 = tiled_mma0.get_slice(warp_group_thread_layout(warp_group_idx)); - auto wg_mma1 = tiled_mma1.get_slice(warp_group_thread_layout(warp_group_idx)); - - auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtomP{}, tiled_mma0); + TiledMmaQK tiled_mma_qk; + TiledMmaPV tiled_mma_pv; + TiledMmaQV tiled_mma_qv; + auto wg_mma_qk = tiled_mma_qk.get_slice(warp_group_thread_layout(warp_group_idx)); + auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); + auto wg_mma_qv = tiled_mma_qv.get_slice(warp_group_thread_layout(warp_group_idx)); + + auto smem_tiled_copy_P = make_tiled_copy_C(SmemCopyAtomP{}, tiled_mma_qk); auto smem_thr_copy_P = smem_tiled_copy_P.get_thread_slice(thread_idx); // Allocate "fragments/descriptors" - Tensor tSrQ = wg_mma0.partition_fragment_A(sQ); - Tensor tSrK = wg_mma0.partition_fragment_B(sK); - Tensor tOrV = wg_mma1.partition_fragment_B(sV); - Tensor tOsP = wg_mma1.partition_fragment_A(sP); + Tensor tSrQ = wg_mma_qk.partition_fragment_A(sQ); + Tensor tSrK = wg_mma_qk.partition_fragment_B(sK); + Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); + Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); + Tensor tSrQv = wg_mma_qv.partition_fragment_A(sQv); + Tensor tSrV = wg_mma_qv.partition_fragment_B(sVMmaQV); Tensor tPsP = smem_thr_copy_P.partition_D(cute::as_position_independent_swizzle_tensor(sP)); + // For storing scales to smem, only used when LargeHeadDimV + auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + auto store_scales = [&](auto& scales, int stage) { + static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row))); + #pragma unroll + for (int mi = 0; mi < size(taccOcO_row); ++mi) { + if (get<1>(taccOcO_row(_0{})) == 0) { + sScale(get<0>(taccOcO_row(mi)), stage) = scales(mi); + } + } + }; + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); @@ -909,13 +1077,13 @@ struct CollectiveMainloopFwdSm90 { // Need to initialize tOrO in the case of RescaleOBeforeGemm where we will scale tOrO even in the 1st iter clear(tOrO); - // tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; + // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; int const seqlen_q = seqlen_info.seqlen_q; int const seqlen_k = seqlen_info.seqlen_k; int n_block = n_block_max - 1; - flash::Mask mask( + flash::Mask mask( thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, params.sink_token_length, params.qhead_per_khead_divmod ); @@ -938,7 +1106,7 @@ struct CollectiveMainloopFwdSm90 { } else { if (get<1>(params.shape_rotary) > 0) { // Apply rotary to Q int const offset_rotary = seqlen_info.seqlen_k_og + seqlen_info.leftpad_k; - using Rotary_t = Rotary; + using Rotary_t = Rotary; Rotary_t rotary(params.ptr_rotary_cos, params.shape_rotary, params.stride_rotary_cos, params.ptr_rotary_sin, params.stride_rotary_sin, params.is_rotary_interleaved, thread_idx, seqlen_q, offset_rotary); @@ -961,15 +1129,15 @@ struct CollectiveMainloopFwdSm90 { } // SMEM fence to make sure the rotated Q is visible to GMMA cutlass::arch::fence_view_async_shared(); - cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::QueryRotated) /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK, static_cast(FwdNamedBarriers::QueryRotated) /*id*/); } else { barrier_Q.wait(work_idx % 2); } } - if constexpr (Mma0_is_RS) { + if constexpr (MmaQK_is_RS) { using SmemCopyAtomQ = Copy_Atom; - auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtomQ{}, tiled_mma0); + auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtomQ{}, tiled_mma_qk); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); Tensor tSsQ_copy_view = smem_thr_copy_Q.partition_S(cute::as_position_independent_swizzle_tensor(sQ)); @@ -978,25 +1146,38 @@ struct CollectiveMainloopFwdSm90 { // TODO: check the case where n_block_max <= n_block_min but there are sink tokens if constexpr (IntraWGOverlap) { - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); warpgroup_wait<0>(); pipeline_k.consumer_release(smem_pipe_read); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.wait(work_idx % 2); + consumer_wait(pipeline_v, smem_pipe_read); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + } scoremod_premask_fn(tSrS); mask.template apply(tSrS, m_block, n_block); Tensor scores_scale = softmax.template max_get_scale(tSrS); + // Don't need to store scales to send to WG1 (in the case of LargeHeadDimV) since it's 1.f + softmax.template online_softmax(tSrS); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } - Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); Tensor tOrP = make_tensor_like(tOrP_acc); convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } - if constexpr (!Mma1_is_RS) { + if constexpr (!MmaPV_is_RS) { + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); cutlass::arch::fence_view_async_shared(); - __syncwarp(); // Only need syncwarp since each warp is using its own P values for Mma1 + __syncwarp(); // Only need syncwarp since each warp is using its own P values for MmaPV + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } } --n_block; @@ -1005,30 +1186,47 @@ struct CollectiveMainloopFwdSm90 { static constexpr bool Check_inf = decltype(check_inf_type)::value; PipelineState smem_pipe_read_v(smem_pipe_read.index(), smem_pipe_read.phase(), smem_pipe_read.count()); ++smem_pipe_read; - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_k, smem_pipe_read); } warp_scheduler_barrier_sync(); - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } - flash::gemm(tiled_mma1, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + if constexpr(!HasQv) { + if (!UseSchedulerBarrier || warp_group_idx == 0) { consumer_wait(pipeline_v, smem_pipe_read_v); } + } + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); warp_scheduler_barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read); // release K + if constexpr (HasQv) { + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + consumer_wait(pipeline_v, smem_pipe_read); + flash::gemm(tiled_mma_qv, tSrQv, tSrV(_, _, _, smem_pipe_read.index()), tSrS); + } scoremod_premask_fn(tSrS); mask_fn(tSrS, n_block); cute::copy(softmax.template max_get_scale(tSrS), scores_scale); + if constexpr (LargeHeadDimV) { store_scales(scores_scale, smem_pipe_read_v.index()); } softmax.template online_softmax(tSrS); - warpgroup_wait<0>(); - pipeline_v.consumer_release(smem_pipe_read_v); // release V + if constexpr (!HasQv) { + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + } if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } convert_type_out(make_tensor(tSrS.data(), tOrP.layout()), tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } - if constexpr (!Mma1_is_RS) { cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); } + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + } + if constexpr (!MmaPV_is_RS) { cute::copy(smem_tiled_copy_P, smem_thr_copy_P.retile_S(tOrP), tPsP); } if constexpr (!RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - if constexpr (!Mma1_is_RS) { + if constexpr (!MmaPV_is_RS) { cutlass::arch::fence_view_async_shared(); __syncwarp(); + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } } }; @@ -1068,12 +1266,17 @@ struct CollectiveMainloopFwdSm90 { // } } // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); if constexpr (RescaleOBeforeGemm) { softmax.rescale_o(tOrO, scores_scale); } - consumer_wait(pipeline_v, smem_pipe_read); - flash::gemm(tiled_mma1, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); + if constexpr (!HasQv) { consumer_wait(pipeline_v, smem_pipe_read); } + flash::gemm(tiled_mma_pv, cute::conditional_return(tOrP, tOsP), tOrV(_, _, _, smem_pipe_read.index()), tOrO); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; cute::copy(softmax.finalize(v_descale), scores_scale); + if constexpr (LargeHeadDimV) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + store_scales(scores_scale, smem_pipe_read.index()); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + } warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read); // release V, otherwise producers will hang softmax.rescale_o(tOrO, scores_scale); @@ -1087,9 +1290,9 @@ struct CollectiveMainloopFwdSm90 { auto fwd_step = [&](int const n_block, auto mask_fn, auto is_first_iter_type, auto check_inf_type) { static constexpr bool Is_first_iter = decltype(is_first_iter_type)::value; static constexpr bool Check_inf = decltype(check_inf_type)::value; - Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); consumer_wait(pipeline_k, smem_pipe_read); - flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); + flash::gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS); warp_scheduler_barrier_arrive(); warpgroup_wait<0>(); pipeline_k.consumer_release(smem_pipe_read); // release K @@ -1098,14 +1301,14 @@ struct CollectiveMainloopFwdSm90 { Tensor scores_scale = softmax.template max_get_scale(tSrS); softmax.template online_softmax(tSrS); if constexpr (Is_FP8 && !V_colmajor) { flash::permute_Cregs_fp8(tSrS); } - Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); Tensor tOrP = make_tensor_like(tOrP_acc); convert_type_out(tOrP_acc, tOrP); if constexpr (Is_FP8 && V_colmajor) { flash::permute_Aregs_fp8(tOrP); } if constexpr (!Is_first_iter) { softmax.rescale_o(tOrO, scores_scale); } consumer_wait(pipeline_v, smem_pipe_read); warp_scheduler_barrier_sync(); - flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + flash::gemm(tiled_mma_pv, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); pipeline_v.consumer_release(smem_pipe_read); // release V ++smem_pipe_read; }; @@ -1149,7 +1352,7 @@ struct CollectiveMainloopFwdSm90 { } warp_scheduler_barrier_arrive(); // Tell producers that smem_q is ready - cutlass::arch::NamedBarrier::arrive(NumMmaThreads + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreadsQK + (Use_TMA_Q ? cutlass::NumThreadsPerWarp : NumProducerThreads), static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); float const v_descale = !Is_FP8 || params.ptr_v_descale == nullptr ? 1.0f : params.ptr_v_descale[bidb * get<0>(params.stride_v_descale) + bidh_kv * get<1>(params.stride_v_descale)]; Tensor scores_scale = softmax.finalize(v_descale); softmax.rescale_o(tOrO, scores_scale); @@ -1159,6 +1362,90 @@ struct CollectiveMainloopFwdSm90 { return true; } + template + CUTLASS_DEVICE bool + mma_pv(Params const& params, + MainloopPipelineV pipeline_v, + PipelineState& smem_pipe_read, + FrgTensorO& tOrO, + Softmax& softmax, + int const thread_idx, + SeqlenInfo_t const& seqlen_info, + cute::tuple block_coord, + SharedStorage& shared_storage + ) { + static_assert(is_rmem::value, "O tensor must be rmem resident."); + // can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda + int const m_block = get<0>(block_coord); + int const bidb = get<2>(block_coord); + int const split_idx = get<3>(block_coord); + auto [n_block_min, n_block_max] = get_n_block_min_max(params, seqlen_info, m_block, bidb, split_idx, params.num_splits); + // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier + if constexpr (Is_causal || Is_local || Varlen || Split) { + if (n_block_max <= n_block_min) { return false; } + } + + Tensor sV = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVtMma{}); + Tensor sP = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_p.data()), SmemLayoutP{}); + Tensor sScale = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_scale.data()), SmemLayoutScale{}); + static constexpr int MmaWarpGroups = size(TiledMmaPV{}) / cutlass::NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(make_shape(Int{}), + make_stride(Int{})); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0); + TiledMmaPV tiled_mma_pv; + auto wg_mma_pv = tiled_mma_pv.get_slice(warp_group_thread_layout(warp_group_idx)); + + // Allocate "fragments/descriptors" + Tensor tOrV = wg_mma_pv.partition_fragment_B(sV); + Tensor tOsP = wg_mma_pv.partition_fragment_A(sP); + + // For load scales to smem, pretend thread_idx is thread_idx % 128 + auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx % cutlass::NumThreadsPerWarpGroup); + Tensor taccOcO = thread_mma_pv.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{}))); + Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout())); + Tensor taccOcO_row = taccOcO_rowcol(_, _0{}); + auto load_scales = [&](auto& scales, int stage) { + static_assert(CUTE_STATIC_V(size(scales)) == CUTE_STATIC_V(size(taccOcO_row))); + #pragma unroll + for (int mi = 0; mi < size(taccOcO_row); ++mi) { + scales(mi) = sScale(get<0>(taccOcO_row(mi)), stage); + } + }; + + clear(tOrO); + // tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; + + typename Softmax::TensorT scores_scale; + + int n_block = n_block_max - 1; + pipeline_v.consumer_wait(smem_pipe_read); + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + pipeline_v.consumer_release(smem_pipe_read); // release V + --n_block; + + for (; n_block >= n_block_min; --n_block) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + load_scales(scores_scale, smem_pipe_read.index()); + softmax.rescale_o(tOrO, scores_scale); + ++smem_pipe_read; + auto barrier_token = pipeline_v.consumer_try_wait(smem_pipe_read); + pipeline_v.consumer_wait(smem_pipe_read, barrier_token); + flash::gemm(tiled_mma_pv, tOsP, tOrV(_, _, _, smem_pipe_read.index()), tOrO); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + pipeline_v.consumer_release(smem_pipe_read); // release V + }; + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::PFull) /*id*/); + load_scales(scores_scale, smem_pipe_read.index()); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::PEmpty) /*id*/); + softmax.rescale_o(tOrO, scores_scale); + if constexpr (Is_FP8 && !V_colmajor) { flash::permute_output_fp8(tOrO); } + ++smem_pipe_read; + return true; + } + CUTLASS_DEVICE cute::tuple get_n_block_k_new_min_max(Params const& params, SeqlenInfo_t const& seqlen_info, int m_block, int bidb, int split_idx=0, int num_splits=1) { @@ -1207,10 +1494,11 @@ struct CollectiveMainloopFwdSm90 { bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; Tensor mKnew_TMA = params.tma_load_K_new.get_tma_tensor(params.shape_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); - Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(select<1, 0, 2, 3>(params.shape_K_new))(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); + auto shape_Vnew = make_shape(params.headdim_v, get<0>(params.shape_K_new), get<2>(params.shape_K_new), get<3>(params.shape_K_new)); + Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(shape_Vnew)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); Tensor gKnew_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<2, 1>(TileShape_MNK{}), make_coord(_0{}, _)); // (K, N, _) + Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) auto block_tma_K_new = params.tma_load_K_new.get_slice(cluster_local_block_id.x); Tensor tKgKnew_TMA = group_modes<0, 3>(block_tma_K_new.partition_S(gKnew_TMA)); // (TMA, k) @@ -1302,11 +1590,12 @@ struct CollectiveMainloopFwdSm90 { bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); - Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); + Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), shape_V, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); int const offset_k = seqlen_info.offset_k + seqlen_info.seqlen_k_og; Tensor gK = local_tile(domain_offset(make_coord(offset_k, _0{}), mK), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gV = local_tile(domain_offset(make_coord(offset_k, _0{}), mV), select<2, 1>(TileShape_MNK_PV{}), make_coord(_, _0{})); // (N, K_v, _) static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); @@ -1317,11 +1606,11 @@ struct CollectiveMainloopFwdSm90 { params.ptr_rotary_sin, params.stride_rotary_sin, params.is_rotary_interleaved, thread_idx, seqlen_k_new, offset_rotary); - using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; + using PagedKVManager_t = PagedKVManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), get<1>(TileShape_MNK_PV{}), NumMmaThreads, Element, true /*KV_Same_Iter*/, 2 /*LoadsPerRow_LB*/>; PagedKVManager_t paged_kv_manager( params.ptr_pagetable, params.shape_pagetable, params.stride_pagetable, params.ptr_K, params.shape_K, params.stride_K, - params.ptr_V, params.stride_V, + params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, bidb_kv, bidh_kv, thread_idx, seqlen_k_new, offset_k // passing offset_k instead of leftpad_k will move the PageTable pointer to the right position ); @@ -1347,6 +1636,12 @@ struct CollectiveMainloopFwdSm90 { Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); #pragma unroll for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(params.shape_K); } + Tensor cV = cute::make_identity_tensor(select<2, 1>(TileShape_MNK_PV{})); // (BLK_N,BLK_K_V) -> (blk_n,blk_k_v) + Tensor tVcV = cute::conditional_return(tKcK, gmem_thr_copy_kv.partition_D(cV)); + Tensor tVpV_ = make_tensor(make_shape(size<2>(tVsV))); + #pragma unroll + for (int k = 0; k < size(tVpV_); ++k) { tVpV_(k) = get<1>(tVcV(_0{}, _0{}, k)) < params.headdim_v; } + Tensor tVpV = cute::conditional_return(tKpK, tVpV_); auto store_K = [&] (int const n_block, auto const& smem_pipe_read) { int const n_limit = std::min(seqlen_k_new - n_block * kBlockN, kBlockN); @@ -1392,7 +1687,7 @@ struct CollectiveMainloopFwdSm90 { Tensor tVgV_cur = tVgV(_, _, _, n_block); // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_kv, tVsV_cur, tVgV_cur, tKcK, tKpK, n_limit); + gmem_tiled_copy_kv, tVsV_cur, tVgV_cur, tVcV, tVpV, n_limit); } else { paged_kv_manager.store_V(n_block, tVsV_cur); } diff --git a/hopper/named_barrier.hpp b/hopper/named_barrier.hpp index f77ea778298..8d07f6aa2fc 100644 --- a/hopper/named_barrier.hpp +++ b/hopper/named_barrier.hpp @@ -57,6 +57,8 @@ enum class FwdNamedBarriers { WarpSchedulerWG3 = 6, AppendKV = 7, QueryRotated = 8, + PFull = 9, + PEmpty = 6, // HACK: PEmpty is only used when we don't have 3 WGs }; enum class BwdNamedBarriers { diff --git a/hopper/paged_kv.h b/hopper/paged_kv.h index 0f710e54935..80ee61b9a41 100644 --- a/hopper/paged_kv.h +++ b/hopper/paged_kv.h @@ -14,7 +14,7 @@ namespace flash { using namespace cute; -template +template struct PagedKVManager { // If KV_Same_Iter=false, then we do load_page_table(0), load_K(0), load_page_table(1), load_K(1), load_V(0), // load_page_table(2), load_K(2), load_V(1), etc. @@ -23,14 +23,17 @@ struct PagedKVManager { // LoadsPerRow_LB is the lower bound on number of loads per row in the K direction. This is useful for // rotary where we want each thread to have at least 2 loads per row. + static constexpr bool SameHeadDim = (kHeadDim == kHeadDimV); + static constexpr int kHeadDimGCD = cute::gcd(kHeadDim, kHeadDimV); + // We use CpAsync for K and V if PagedKV, since TMA doesn't work there static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad"); + static_assert(kHeadDimGCD % kGmemElemsPerLoad == 0, "Headdim and HeaddimV must be a multiple of kGmemElemsPerLoad"); // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). E.g. if hdim=128, we want each // thread to have 4 loads in the M direction and 2 vectorized load in the K direction. // In the case of PackGQA, this reduces the number of times we need to call divmod. - static_assert(kHeadDim % LoadsPerRow_LB == 0, "Headdim must be a multiple of LoadsPerRow_LB"); - static constexpr int kBytePerRow = kHeadDim / LoadsPerRow_LB * sizeof(Element); + static_assert(kHeadDimGCD % LoadsPerRow_LB == 0, "Headdim and HeaddimV must be a multiple of LoadsPerRow_LB"); + static constexpr int kBytePerRow = kHeadDimGCD / LoadsPerRow_LB * sizeof(Element); static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element); static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad; static_assert(NumThreads % kGmemThreadsPerRow == 0, "NumThreads must be a multiple of kGmemThreadsPerRow"); @@ -59,6 +62,8 @@ struct PagedKVManager { using GmemThrCopyKVCpAsync = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0))); using TensortKcK = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); using TensortKpK = decltype(make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{})); + using TensortVcV = decltype(GmemTiledCopyKVCpAsync{}.get_thread_slice(int(0)).partition_D(cute::make_identity_tensor(Shape, Int>{}))); + using TensortVpV = decltype(make_tensor(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{})); // For PagedKV, it's expensive the calculate the pointers to K and V for each page table entry, // since those require int64_t arithmetic. We optimize by having threads split this work. @@ -66,6 +71,7 @@ struct PagedKVManager { // that each thread needs to load for the case of hdim 128 and kBlockN = 176. // So each of those 8 threads will calculate the K_ptr and V_ptr for 11 / 8 = 2 rows. // We then use __shfl_sync to broadcast the pointers to the other threads in the warp. + static_assert(CUTE_STATIC_V(size<1>(TensortKcK{})) == CUTE_STATIC_V(size<1>(TensortVcV{}))); static constexpr int kPageEntryPerThread = cute::ceil_div(size<1>(TensortKcK{}), kGmemThreadsPerRow); using TensorPageOffset = decltype(make_tensor>(Shape>{})); using TensorKVPtr = decltype(make_tensor(Shape>{})); @@ -79,15 +85,15 @@ struct PagedKVManager { TensorPageTable mPageTable; TensorKV mK_paged, mV_paged; TensortKpK tKpK; + TensortVpV tVpV; TensorPageOffset tPrPageOffset; TensorKVPtr tPrVPtr; - CUTLASS_DEVICE PagedKVManager(int const* const ptr_page_table, ShapePageTable const &shape_pagetable, StridePageTable const &stride_pagetable, Element* const ptr_K, ShapeKV const &shape_K, StrideKV const &stride_K, - Element* const ptr_V, StrideKV const &stride_V, + Element* const ptr_V, int const headdim_v, StrideKV const &stride_V, cutlass::FastDivmod const &page_size_divmod, int const bidb, int const bidh, int const thread_idx, int const seqlen_k, int const leftpad_k ) @@ -100,13 +106,19 @@ struct PagedKVManager { { mPageTable = make_tensor(make_gmem_ptr(ptr_page_table), shape_pagetable, stride_pagetable)(bidb, _); mK_paged = make_tensor(make_gmem_ptr(ptr_K), shape_K, stride_K)(_, _, bidh, _); - mV_paged = make_tensor(make_gmem_ptr(ptr_V), shape_K, stride_V)(_, _, bidh, _); + auto shape_V = make_shape(get<0>(shape_K), headdim_v, get<2>(shape_K), get<3>(shape_K)); + mV_paged = make_tensor(make_gmem_ptr(ptr_V), shape_V, stride_V)(_, _, bidh, _); tKpK = make_tensor(make_shape(size<1>(TensortKcK{}), size<2>(TensortKcK{})), Stride<_0, _1>{}); - Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); #pragma unroll for (int k = 0; k < size<1>(tKpK); ++k) { tKpK(_0{}, k) = get<1>(tKcK(_0{}, _0{}, k)) < get<1>(shape_K); } + Tensor tVpV_ = make_tensor(make_shape(size<1>(TensortVcV{}), size<2>(TensortVcV{})), Stride<_0, _1>{}); + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + #pragma unroll + for (int k = 0; k < size<1>(tVpV_); ++k) { tVpV_(_0{}, k) = get<1>(tVcV(_0{}, _0{}, k)) < get<1>(shape_V); } + tVpV = cute::conditional_return(tKpK, tVpV_); }; template @@ -200,27 +212,27 @@ struct PagedKVManager { // Only for index calculation, since all the indices of thread 0 are known at compile time auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); Tensor tVsV = gmem_thr_copy_kv.partition_D(sV); - Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts - Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); - Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV); - int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKcK(_0{}, _0{}, _0{})); + int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tVcV(_0{}, _0{}, _0{})); #pragma unroll for (int m = 0; m < size<1>(tVsV); ++m) { // Faster to rely on the cp.async to clear smem that are out of bound, // rather than calling cute::clear directly. // We have to be careful not to write to smem past `kBlockN` if !EvenN. // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to checked - if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tKcK(_0{}, m, _0{})) < kBlockN) { - bool const should_load = !Seqlenk_mask || get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; + if (EvenN || m < size<1>(tVsV) - 1 || get<0>(tVcV(_0{}, m, _0{})) < kBlockN) { + bool const should_load = !Seqlenk_mask || get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit; Element const* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); - Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); + Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); #pragma unroll for (int k = 0; k < size<2>(tVsV); ++k) { - int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; - cute::copy(gmem_tiled_copy_kv.with(tKpK(_0{}, k) && should_load), mV_paged_cur_copy(_, ki), tVsV(_, m, k)); + int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad; + cute::copy(gmem_tiled_copy_kv.with(tVpV(_0{}, k) && should_load), mV_paged_cur_copy(_, ki), tVsV(_, m, k)); } } } @@ -269,24 +281,24 @@ struct PagedKVManager { if constexpr (KV_Same_Iter) { compute_V_ptr(); } // Only for index calculation, since all the indices of thread 0 are known at compile time auto gmem_thr0_copy_kv = gmem_tiled_copy_kv.get_thread_slice(_0{}); - Tensor cK = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor cV = cute::make_identity_tensor(Shape, Int>{}); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts - Tensor tKcK = gmem_thr_copy_kv.partition_S(cK); - Tensor t0KcK = gmem_thr0_copy_kv.partition_S(cK); + Tensor tVcV = gmem_thr_copy_kv.partition_S(cV); + Tensor t0VcV = gmem_thr0_copy_kv.partition_S(cV); GmemTiledCopyKVStore gmem_tiled_copy_kv_store; - int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tKcK(_0{}, _0{}, _0{})); + int const seqlenk_row_limit = std::min(seqlen_k - n_block * kBlockN, kBlockN) - get<0>(tVcV(_0{}, _0{}, _0{})); #pragma unroll for (int m = 0; m < size<1>(tVrV); ++m) { - bool const should_load = get<0>(t0KcK(_0{}, m, _0{})) < seqlenk_row_limit; + bool const should_load = get<0>(t0VcV(_0{}, m, _0{})) < seqlenk_row_limit; Element* v_ptr = reinterpret_cast(__shfl_sync(0xffffffff, reinterpret_cast(tPrVPtr(m / kGmemThreadsPerRow)), m % kGmemThreadsPerRow, kGmemThreadsPerRow)); - Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); + Tensor mV_paged_cur = make_tensor(make_gmem_ptr(v_ptr), Shape>{}); Tensor mV_paged_cur_copy = cute::tiled_divide(mV_paged_cur, Shape>{}); if (should_load) { #pragma unroll for (int k = 0; k < size<2>(tVrV); ++k) { - int const ki = get<1>(tKcK(_0{}, _0{}, k)) / kGmemElemsPerLoad; - if (tKpK(_0{}, k)) { + int const ki = get<1>(tVcV(_0{}, _0{}, k)) / kGmemElemsPerLoad; + if (tVpV(_0{}, k)) { cute::copy(gmem_tiled_copy_kv_store, tVrV(_, m, k), mV_paged_cur_copy(_, ki)); } } diff --git a/hopper/setup.py b/hopper/setup.py index d95be9ad409..f638558a0a9 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -150,6 +150,8 @@ def sanitize_flags(flags): flags.append(f'cuda_post_cflags_sm80 = {" ".join(cuda_post_cflags_sm80)}') cuda_post_cflags_sm80_sm90 = cuda_post_cflags + ['-gencode', 'arch=compute_80,code=sm_80'] flags.append(f'cuda_post_cflags_sm80_sm90 = {" ".join(cuda_post_cflags_sm80_sm90)}') + cuda_post_cflags_sm100 = [s if s != 'arch=compute_90a,code=sm_90a' else 'arch=compute_100a,code=sm_100a' for s in cuda_post_cflags] + flags.append(f'cuda_post_cflags_sm100 = {" ".join(cuda_post_cflags_sm100)}') flags.append(f'cuda_dlink_post_cflags = {" ".join(cuda_dlink_post_cflags)}') flags.append(f'ldflags = {" ".join(ldflags)}') @@ -182,10 +184,13 @@ def sanitize_flags(flags): # to make this work on Windows too. nvcc_gendeps = '--generate-dependencies-with-compile --dependency-output $out.d' cuda_compile_rule_sm80 = ['rule cuda_compile_sm80'] + cuda_compile_rule[1:] + [ - f' command = $nvcc {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80' + f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80' ] cuda_compile_rule_sm80_sm90 = ['rule cuda_compile_sm80_sm90'] + cuda_compile_rule[1:] + [ - f' command = $nvcc {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80_sm90' + f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm80_sm90' + ] + cuda_compile_rule_sm100 = ['rule cuda_compile_sm100'] + cuda_compile_rule[1:] + [ + f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags_sm100' ] cuda_compile_rule.append( f' command = $nvcc_from_env {nvcc_gendeps} $cuda_cflags -c $in -o $out $cuda_post_cflags') @@ -199,6 +204,8 @@ def sanitize_flags(flags): rule = 'cuda_compile' elif source_file.endswith('_sm80.cu'): rule = 'cuda_compile_sm80' + elif source_file.endswith('_sm100.cu'): + rule = 'cuda_compile_sm100' else: rule = 'cuda_compile_sm80_sm90' else: @@ -244,6 +251,7 @@ def sanitize_flags(flags): blocks.append(cuda_compile_rule) # type: ignore[possibly-undefined] blocks.append(cuda_compile_rule_sm80) # type: ignore[possibly-undefined] blocks.append(cuda_compile_rule_sm80_sm90) # type: ignore[possibly-undefined] + blocks.append(cuda_compile_rule_sm100) # type: ignore[possibly-undefined] blocks += [devlink_rule, link_rule, build, devlink, link, default] content = "\n\n".join("\n".join(b) for b in blocks) # Ninja requires a new lines at the end of the .ninja file @@ -333,22 +341,19 @@ def open_url(url): return urllib.request.urlopen(request, timeout=300) -def download_and_copy(name, src_path, dst_path, version, url_func): +def download_and_copy(name, src_func, dst_path, version, url_func): if is_offline_build(): return flashattn_cache_path = get_flashattn_cache_path() base_dir = os.path.dirname(__file__) system = platform.system() - try: - arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()] - except KeyError: - arch = platform.machine() + arch = platform.machine() + arch = {"arm64": "aarch64"}.get(arch, arch) supported = {"Linux": "linux", "Darwin": "linux"} url = url_func(supported[system], arch, version) + src_path = src_func(supported[system], arch, version) tmp_path = os.path.join(flashattn_cache_path, "nvidia", name) # path to cache the download dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path - platform_name = "sbsa-linux" if arch == "aarch64" else "x86_64-linux" - src_path = src_path(platform_name, version) if callable(src_path) else src_path src_path = os.path.join(tmp_path, src_path) download = not os.path.exists(src_path) if download: @@ -364,11 +369,12 @@ def download_and_copy(name, src_path, dst_path, version, url_func): def nvcc_threads_args(): - nvcc_threads = os.getenv("NVCC_THREADS") or "4" + nvcc_threads = os.getenv("NVCC_THREADS") or "2" return ["--threads", nvcc_threads] -NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"} +# NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.3.107"} +NVIDIA_TOOLCHAIN_VERSION = {"nvcc": "12.6.85", "ptxas": "12.8.61"} exe_extension = sysconfig.get_config_var("EXE") @@ -389,23 +395,39 @@ def nvcc_threads_args(): if bare_metal_version < Version("12.3"): raise RuntimeError("FlashAttention-3 is only supported on CUDA 12.3 and above") - if bare_metal_version != Version("12.3"): # nvcc 12.3 gives the best perf currently + # ptxas 12.8 gives the best perf currently + # We want to use the nvcc front end from 12.6 however, since if we use nvcc 12.8 + # Cutlass 3.8 will expect the new data types in cuda.h from CTK 12.8, which we don't have. + if bare_metal_version != Version("12.8"): download_and_copy( - name="nvcc", src_path=f"bin", dst_path="bin", - version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], url_func=lambda system, arch, version: - ((lambda version_major, version_minor1, version_minor2: - f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2") - (*version.split('.')))) + name="nvcc", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin", + dst_path="bin", + version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", + ) download_and_copy( - name="nvcc", src_path=f"nvvm/bin", dst_path="bin", - version=NVIDIA_TOOLCHAIN_VERSION["nvcc"], url_func=lambda system, arch, version: - ((lambda version_major, version_minor1, version_minor2: - f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2") - (*version.split('.')))) + name="ptxas", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/bin/ptxas", + dst_path="bin", + version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", + ) + download_and_copy( + name="ptxas", + src_func=lambda system, arch, version: f"cuda_nvcc-{system}-{arch}-{version}-archive/nvvm/bin", + dst_path="nvvm/bin", + version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], + url_func=lambda system, arch, version: + f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/{system}-{arch}/cuda_nvcc-{system}-{arch}-{version}-archive.tar.xz", + ) base_dir = os.path.dirname(__file__) ctk_path_new = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", "bin") nvcc_path_new = os.path.join(ctk_path_new, f"nvcc{exe_extension}") # Need to append to path otherwise nvcc can't find cicc in nvvm/bin/cicc + # nvcc 12.8 seems to hard-code looking for cicc in ../nvvm/bin/cicc os.environ["PATH"] = ctk_path_new + os.pathsep + os.environ["PATH"] os.environ["PYTORCH_NVCC"] = nvcc_path_new # Make nvcc executable, sometimes after the copy it loses its permissions @@ -456,7 +478,7 @@ def nvcc_threads_args(): + ([192] if not DISABLE_HDIM192 else []) + ([256] if not DISABLE_HDIM256 else []) ) - HEAD_DIMENSIONS_FWD = ["all"] + HEAD_DIMENSIONS_FWD = ["all", "diff"] HEAD_DIMENSIONS_FWD_SM80 = HEAD_DIMENSIONS_BWD SPLIT = [""] + (["_split"] if not DISABLE_SPLIT else []) PAGEDKV = [""] + (["_paged"] if not DISABLE_PAGEDKV else []) diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 1fe43e21fa2..e9cd8c9d6cb 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -50,6 +50,8 @@ # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -96,7 +98,7 @@ ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, mha_type, dtype + seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype ): # sink_token_length = 0 if not local else 4 sink_token_length = 0 if not local else 0 @@ -113,90 +115,101 @@ def test_flash_attn_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) - if softcap > 0.0: - # Ensure the values of qk are at least within softcap range. - q_ref = (q_ref * softcap / 4) - q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() - k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - # window_size = (-1, -1) if not local else (16, 0) - if dtype == torch.float8_e4m3fn: - q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] - else: - q_descale, k_descale, v_descale = None, None, None - q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] - if V_colmajor: - v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() - out_ref, attn_ref = attention_ref( - q_ref, - k_ref, - v_ref, - None, - None, - causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - sink_token_length=sink_token_length, - softcap=softcap - ) - out_pt, attn_pt = attention_ref( - q_ref, - k_ref, - v_ref, - None, - None, - causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - sink_token_length=sink_token_length, - softcap=softcap, - upcast=False, - reorder_ops=True, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, - ) - - # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() - # m = qk.amax(-1, keepdim=True) - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) - # exp_sum = s_tmp.sum(-1) - # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) - # lse_ref = torch.logsumexp(qk, dim=-1) - - # Numerical error if we just do any arithmetic on out_ref - fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() - rtol = 2 if softcap == 0.0 else 3 - - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] - for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - out, lse = flash_attn_func( - q, - k, - v, + for dv in [128, d] if d > 128 and d <= 192 else [d]: + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4) + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + # window_size = (-1, -1) if not local else (16, 0) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + if V_colmajor: + v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + sink_token_length=sink_token_length, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, causal=causal, + qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, - pack_gqa=pack_gqa, - num_splits=num_splits + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - # if not causal: - # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") - # breakpoint() - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() + # if qv is not None: + # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor: + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out, lse = flash_attn_func( + q, + k, + v, + causal=causal, + qv=qv, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + sink_token_length=sink_token_length, + softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv: g = torch.randn_like(out) do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) # import flash_attn_3_cuda @@ -248,7 +261,7 @@ def test_flash_attn_output( # breakpoint() - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor: + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv: dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) @@ -263,6 +276,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -307,7 +322,7 @@ def test_flash_attn_output( ], ) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, mha_type, dtype + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype ): device = "cuda" # set seed @@ -320,135 +335,146 @@ def test_flash_attn_varlen_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) - if softcap > 0.0: - # Ensure the values of qk are at least within softcap range. - q_ref = (q_ref * softcap / 4).detach().requires_grad_() - q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() - k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() - # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - if dtype == torch.float8_e4m3fn: - q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] - else: - q_descale, k_descale, v_descale = None, None, None - q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] - query_padding_mask = generate_random_padding_mask( - seqlen_q, batch_size, device, mode="random", zero_lengths=False - ) - key_padding_mask = generate_random_padding_mask( - seqlen_k, batch_size, device, mode="random", zero_lengths=True - ) - - def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): - if add_unused: - another_mask = generate_random_padding_mask(max_seq_len, bs, device) - attn_mask = torch.logical_and(padding_mask, another_mask) - unused_mask = torch.logical_xor( - torch.logical_or(padding_mask, another_mask), attn_mask - ) + for dv in [128, d] if d > 128 and d <= 192 else [d]: + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) else: - attn_mask = padding_mask - unused_mask = None - return attn_mask, unused_mask - - query_padding_mask, query_unused_mask = _gen_unused_masks( - query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device - ) - key_padding_mask, key_unused_mask = _gen_unused_masks( - key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device - ) - - ( - q_unpad, - k_unpad, - v_unpad, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - q, - k, - v, - output_pad_fn, - dq_pad_fn, - dk_pad_fn, - ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False, - query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) - q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] - out_ref, attn_ref = attention_ref( - q_ref, - k_ref, - v_ref, - query_padding_mask, - key_padding_mask, - causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - softcap=softcap - ) - out_pt, attn_pt = attention_ref( - q_ref, - k_ref, - v_ref, - query_padding_mask, - key_padding_mask, - causal=causal, - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - window_size=window_size, - softcap=softcap, - upcast=False, - reorder_ops=True, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, - ) - - - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="random", zero_lengths=True + ) - if query_unused_mask is not None: - q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask - # Numerical error if we just do any arithmetic on out_ref - fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() - rtol = 2 if softcap == 0.0 else 3 + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) - pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] - for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - out_unpad, lse = flash_attn_varlen_func( + ( q_unpad, k_unpad, v_unpad, + qv_unpad, cu_seqlens_q, cu_seqlens_k, - seqused_q, seqused_k, + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_k, + q, + k, + v, + qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, + query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) + q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, causal=causal, - q_descale=q_descale, - k_descale=k_descale, v_descale=v_descale, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, ) - out = output_pad_fn(out_unpad) + + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + if query_unused_mask is not None: - out.masked_fill_(q_zero_masking, 0.0) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - # if not causal: - # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") - # breakpoint() + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") - # Check that FlashAttention's numerical error is at most 3x the numerical error - # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out_unpad, lse = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, seqused_k, + max_seqlen_q, + max_seqlen_k, + causal=causal, + qv=qv_unpad, + q_descale=q_descale, + k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + softcap=softcap, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn: + + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: g_unpad = torch.randn_like(out_unpad) do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) # import flash_attn_3_cuda @@ -516,7 +542,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # breakpoint() - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn: + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) @@ -557,7 +583,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) -@pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("d", [192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -569,8 +596,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): (3, 799), (64, 2048), (16, 20000), - (1, 128 * 1024), - (16, 128 * 1024), + # (1, 128 * 1024), + # (16, 128 * 1024), (128, 128), (256, 512), # To test appending KV with more than 1 block (2048, 3577), # Enough tile to test persistent scheduler @@ -614,261 +641,275 @@ def test_flash_attn_kvcache( nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - if varlen_q: - query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") - q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) - output_pad_fn = lambda output_unpad: pad_input( - output_unpad, indices_q, batch_size, seqlen_q - ) - else: - query_padding_mask = None - q_unpad = q - cu_seqlens_q, max_seqlen_q = None, None - # Put window_size after QKV randn so that window_size changes from test to test - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + dv_vals = [128, d] if d > 128 and d <= 192 else [d] + has_qv_vals = [False] + for dv, has_qv in itertools.product(dv_vals, has_qv_vals): + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if has_qv: + qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv = None + if varlen_q: + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + else: + query_padding_mask = None + q_unpad = q + qv_unpad = qv + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() - cu_seqlens_k_new = None - key_new_padding_mask = None - if new_kv: - k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - if varlen_q: # k & v are also varlen - key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") - k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) - v_unpad, *rest = unpad_input(v, key_new_padding_mask) + seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v else: - k_unpad, v_unpad = k, v - else: - k, v, k_unpad, v_unpad = None, None, None, None - if page_size is None: - k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) - page_table = None - else: - ( - k_cache, - v_cache, - page_table, - k_cache_paged, - v_cache_paged, - num_blocks, - ) = _generate_block_kvcache( - seqlen_k, page_size, batch_size_cache, nheads_k, d, device, dtype_ref - ) - cache_seqlens = torch.randint( - 0 if new_kv else 1, - # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough - ( - (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) - if new_kv - else (seqlen_k + 1) - ), - (batch_size,), - dtype=torch.int32, - device=device, - ) - if has_leftpad: - cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) - if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) - for i in range(batch_size)]) - else: - cache_leftpad = None - if has_batch_idx: - cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ - :batch_size - ] - else: - cache_batch_idx = None - arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") - cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") - if not new_kv: - key_padding_mask = arange < cache_seqlens_expanded - else: - k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new - key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens - if has_leftpad: - key_padding_mask = torch.logical_and( - key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) - ) - # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) - if rotary_dim > 0: - angle = ( - torch.rand( - seqlen_k if page_size is None else num_blocks * page_size, - rotary_dim // 2, - device=device, + k, v, k_unpad, v_unpad = None, None, None, None + if page_size is None: + k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + page_table = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype_ref ) - * 2 - * math.pi + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, ) - cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) - sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) - if causal or local: - q_ro = apply_rotary_emb( - q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + if has_leftpad: + cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) + if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size)]) + else: + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ + :batch_size + ] + else: + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded + else: + k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=cache_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved ) else: - q_ro = rearrange( - apply_rotary_emb( - rearrange(q, "b s h d -> b 1 (s h) d"), - cos, - sin, - seqlen_offsets=cache_seqlens, - interleaved=rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=seqlen_q, + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() + v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens ) - # q_ro = q - k_ro = apply_rotary_emb( - k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + key_leftpad=cache_leftpad, ) - else: - cos, sin = None, None - q_ro, k_ro = q, k - # k_cache[:, 64:] = -1 - k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() - v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() - if new_kv: - update_mask = torch.logical_and( - cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None + ) + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + out, lse, *rest = flash_attn_with_kvcache( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + k if not new_kv or not varlen_q else k_unpad, + v if not new_kv or not varlen_q else v_unpad, + qv=qv if not varlen_q else qv_unpad, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + causal=causal, + window_size=window_size, + rotary_interleaved=rotary_interleaved, + num_splits=num_splits, + return_softmax_lse=True ) - k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") - v_to_update = rearrange(v, "b s ... -> (b s) ...") if varlen_q: - k_to_update = k_to_update[indices_k] - v_to_update = v_to_update[indices_k] - k_cache_ref[update_mask] = k_to_update - v_cache_ref[update_mask] = v_to_update - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) - out_ref, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - query_padding_mask, - key_padding_mask, - causal=causal, - window_size=window_size, - key_leftpad=cache_leftpad, - ) - out_pt, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - query_padding_mask, - key_padding_mask, - causal=causal, - window_size=window_size, - upcast=False, - reorder_ops=True, - key_leftpad=cache_leftpad, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None - ) - q = q.to(dtype) - q_unpad = q_unpad.to(dtype) if varlen_q else None - k_cache = k_cache.to(dtype) - v_cache = v_cache.to(dtype) - k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None - v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None - k = k.to(dtype) if k is not None else None - v = v.to(dtype) if v is not None else None - k_unpad = k_unpad.to(dtype) if k_unpad is not None else None - v_unpad = v_unpad.to(dtype) if v_unpad is not None else None - cos = cos.to(dtype) if cos is not None else None - sin = sin.to(dtype) if sin is not None else None - out, lse, *rest = flash_attn_with_kvcache( - q if not varlen_q else q_unpad, - k_cache if page_size is None else k_cache_paged, - v_cache if page_size is None else v_cache_paged, - k if not new_kv or not varlen_q else k_unpad, - v if not new_kv or not varlen_q else v_unpad, - rotary_cos=cos, - rotary_sin=sin, - cache_seqlens=cache_seqlens, - cache_batch_idx=cache_batch_idx, - cache_leftpad=cache_leftpad, - page_table=page_table, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k_new=cu_seqlens_k_new, - max_seqlen_q=max_seqlen_q, - causal=causal, - window_size=window_size, - rotary_interleaved=rotary_interleaved, - num_splits=num_splits, - return_softmax_lse=True - ) - if varlen_q: - out = output_pad_fn(out) - # out = flash_attn_with_kvcache( - # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size - # ) - # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) - # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) - # m = qk.amax(-1, keepdim=True) - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) - # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) - # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) - # probs = torch.softmax(qk, dim=-1) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") - # breakpoint() - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. - if new_kv: - if page_size is None: - k_cache_select = ( - k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] - ) - v_cache_select = ( - v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] - ) - else: - k_cache_select = rearrange( - k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - v_cache_select = rearrange( - v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], - "(b nblocks) block_size ... -> b (nblocks block_size) ...", - b=batch_size, - )[:, :seqlen_k].to(dtype_ref) - k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) - v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) - if dtype is not torch.float8_e4m3fn: - assert torch.equal(v_cache_select, v_cache_ref) - else: - assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # breakpoint() - # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: - if rotary_dim == 0: - assert torch.equal(k_cache_select, k_cache_ref) - else: - # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): - # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) if dtype is not torch.float8_e4m3fn: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + assert torch.equal(v_cache_select, v_cache_ref) else: - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) - mult = 4 if dtype == torch.float8_e4m3fn else 2 - assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 - mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 - assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() + assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + else: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() -def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, device, dtype): +def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype): num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 k_cache_paged = torch.randn( num_blocks, page_size, nheads_k, d, device=device, dtype=dtype ) v_cache_paged = torch.randn( - num_blocks, page_size, nheads_k, d, device=device, dtype=dtype + num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype ) page_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), @@ -990,12 +1031,12 @@ def attention_combine_ref(out_partial, lse_partial): @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float32]) # @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) -@pytest.mark.parametrize("d", [64, 96, 128, 192, 256]) +@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) # @pytest.mark.parametrize("d", [128]) -@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024, 2048]) +@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) # @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) # @pytest.mark.parametrize("seqlen", [15]) -@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 155]) +@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) # @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) # @pytest.mark.parametrize("num_splits", [128]) def test_flash_attn_combine(num_splits, seqlen, d, dtype): diff --git a/hopper/test_util.py b/hopper/test_util.py index 54eb195eb36..b7ea3d3b752 100644 --- a/hopper/test_util.py +++ b/hopper/test_util.py @@ -30,22 +30,23 @@ def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", def generate_qkv( - q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False, + q, k, v, query_padding_mask=None, key_padding_mask=None, qv=None, kvpacked=False, qkvpacked=False, query_unused_mask=None, key_unused_mask=None, ): """ Arguments: q: (batch_size, seqlen_q, nheads, d) k: (batch_size, seqlen_k, nheads_k, d) - v: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d_v) query_padding_mask: (batch_size, seqlen), bool key_padding_mask: (batch_size, seqlen), bool """ assert not (kvpacked and qkvpacked) batch_size, seqlen_q, nheads, d = q.shape + d_v = v.shape[-1] _, seqlen_k, nheads_k, _ = k.shape assert k.shape == (batch_size, seqlen_k, nheads_k, d) - assert v.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d_v) if query_unused_mask is not None or key_unused_mask is not None: assert not kvpacked assert not qkvpacked @@ -57,6 +58,7 @@ def generate_qkv( output_pad_fn = lambda output_unpad: pad_input( output_unpad, indices_q, batch_size, seqlen_q ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if qv is not None else None else: q_unpad = rearrange(q, "b s h d -> (b s) h d") cu_seqlens_q = torch.arange( @@ -67,6 +69,7 @@ def generate_qkv( output_pad_fn = lambda output_unpad: rearrange( output_unpad, "(b s) h d -> b s h d", b=batch_size ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...") if qv is not None else None if key_padding_mask is not None: k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input( @@ -134,6 +137,7 @@ def generate_qkv( q_unpad.detach().requires_grad_(), k_unpad.detach().requires_grad_(), v_unpad.detach().requires_grad_(), + qv_unpad.detach() if qv is not None else None, cu_seqlens_q, cu_seqlens_k, seqused_q, @@ -143,6 +147,7 @@ def generate_qkv( q.detach().requires_grad_(), k.detach().requires_grad_(), v.detach().requires_grad_(), + qv.detach() if qv is not None else None, output_pad_fn, dq_pad_fn, dk_pad_fn, @@ -196,6 +201,7 @@ def attention_ref( dropout_p=0.0, dropout_mask=None, causal=False, + qv=None, q_descale=None, k_descale=None, v_descale=None, window_size=(-1, -1), # -1 means infinite window size sink_token_length=0, @@ -208,7 +214,8 @@ def attention_ref( Arguments: q: (batch_size, seqlen_q, nheads, head_dim) k: (batch_size, seqlen_k, nheads, head_dim) - v: (batch_size, seqlen_k, nheads, head_dim) + v: (batch_size, seqlen_k, nheads, head_dim_v) + qv: (batch_size, seqlen_q, nheads, head_dim_v) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) @@ -221,7 +228,7 @@ def attention_ref( without changing the math. This is to estimate the numerical error from operation reordering. Output: - output: (batch_size, seqlen_q, nheads, head_dim) + output: (batch_size, seqlen_q, nheads, head_dim_v) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ if causal: @@ -229,9 +236,11 @@ def attention_ref( dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() + qv = qv.float() if qv is not None else None if q_descale is not None: - q_descale = repeat(q_descale, "b h -> b (h g)", g = q.shape[2] // k.shape[2]) - q = (q.float() * rearrange(q_descale, "b h -> b 1 h 1")).to(dtype=q.dtype) + q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g = q.shape[2] // k.shape[2]).to(dtype=q.dtype) + q = q.float() * q_descale + qv = qv.float() * q_descale if qv is not None else None if k_descale is not None: k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) if v_descale is not None: @@ -240,10 +249,14 @@ def attention_ref( k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) d = q.shape[-1] + dv = v.shape[-1] + softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) if not reorder_ops: - scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + scores = torch.einsum("bthd,bshd->bhts", q * softmax_scale, k) else: - scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if qv is not None: + scores = scores + torch.einsum("bthd,bshd->bhts", qv * softmax_scale, v) if softcap > 0: scores = torch.tanh(scores / softcap) * softcap if key_padding_mask is not None: diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 127f518bbb6..5d0bd6e2634 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -6,13 +6,18 @@ #include -// Return {kBlockM, kBlockN, Mma1_is_RS, IntraWGOverlap} +// Return {kBlockM, kBlockN, MmaPV_is_RS, IntraWGOverlap} constexpr std::tuple tile_size_fwd_sm90( - int headdim, bool is_causal, bool is_local, int element_size=2, + int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, bool v_colmajor=false, bool paged_kv=false, bool softcap=false) { if (element_size == 2) { if (headdim <= 64) { - return {192, 128, true, true}; + bool same_hdim = (headdim == headdim_v); // if not same hdim, we're targeting hdimv=512 + // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, true}; + // With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why + // https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131 + // Switch to tile size 192 x 192 for now + return {same_hdim ? 192 : 64, same_hdim ? 192 : 64, false, true}; // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; } else if (headdim <= 96) { @@ -20,9 +25,9 @@ constexpr std::tuple tile_size_fwd_sm90( } else if (headdim <= 128) { return {128, is_causal || is_local || paged_kv ? 128 : 176, true, true}; // {128, 192, false, false} and {192, 128, false, true} are quite good too - // 128 x 192 hits the limit of smem if Mma1_is_RS, 128 x 144 hits the limit if !Mma1_is_RS + // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS } else if (headdim <= 192) { - return {128, paged_kv || is_local ? 96 : 112, true, true}; // 128 x 112 hits the limit of smem + return {128, paged_kv || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem } else { return {128, is_local ? 64 : 80, true, true}; // 128 x 80 hits the limit of smem } @@ -43,7 +48,7 @@ constexpr std::tuple tile_size_fwd_sm90( // Return {kBlockM, kBlockN, kNWarps, kStages, Q_in_regs} constexpr std::tuple tile_size_fwd_sm8x( - bool sm86_or_89, int headdim, bool is_causal, bool is_local, int element_size=2, + bool sm86_or_89, int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, bool paged_kv=false, bool varlen_and_split=false, bool softcap=false, bool append_kv=false) { if (element_size == 2) { diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 81e2c22e57f..9bf430b3672 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -82,6 +82,7 @@ def flash_attn_varlen_func( max_seqlen_k, cu_seqlens_k=None, # only used for non-paged prefill seqused_k=None, + q_v=None, dropout_p=0.0, softmax_scale=None, causal=False, @@ -91,7 +92,6 @@ def flash_attn_varlen_func( deterministic=False, return_attn_probs=False, block_table=None, - *, return_softmax_lse=False, out=None, fa_version: int = DEFAULT_FA_VERSION, @@ -196,6 +196,7 @@ def flash_attn_varlen_func( out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd( q, k, v, None, None, # k_new, v_new + q_v, # out, cu_seqlens_q, cu_seqlens_k, # cu_seqlens_k @@ -369,6 +370,7 @@ def flash_attn_with_kvcache( out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd( q, k_cache, v_cache, # q, k, v k, v, # k_new, v_new + None, # q_v out, None, None, # cu_seqlens_q, cu_seqlens_k None, # cu_seqlens_k_new