From 490836427dfb97bb6bb763d43bd218ea6e4e6b10 Mon Sep 17 00:00:00 2001 From: Michael Date: Sat, 30 Aug 2025 21:45:50 -0500 Subject: [PATCH 1/7] Compress PA work narrow pa test ref works on most cases inplace ref with new_kv inplace paged attention add pa ref save pa basic paged works save fix swa + causal in pa. Also new_kv only on pa path passing build fa v3 import interface from fa v3 copy fa tests use v3 api clean up rename to match old test support different head sizes remove fp8 basisc passing v3 cases test_flash_attn_varlen_output v3 working isolate bad case for kvcache case passing save use decode is seqused/ cacheseql is given use decode if not varlen basci kvcache v3 working kvcache enable more cases detect kvcache case if seqused_q is non and sequese_k is not None skip failing test find fp8 failing case mha fp8 works fix fp8 MQA/GQA bug clean up more clean up clean up more don't need fp8 dead code remove train code with fp8 stuff fp8 working in kvcache paged + fp8 seems to be working new_kv allowed --- .github/workflows/amd_tests.yml | 65 - flash_attn/flash_attn_triton_amd/.gitignore | 2 - flash_attn/flash_attn_triton_amd/README.md | 113 -- .../bwd_prefill_fused_atomics.py | 8 +- .../bwd_prefill_fused_no_atomics.py | 299 +++-- .../bwd_prefill_split.py | 10 +- flash_attn/flash_attn_triton_amd/fp8.py | 716 ---------- .../flash_attn_triton_amd/fwd_decode.py | 669 +++++++--- .../flash_attn_triton_amd/fwd_prefill.py | 350 +++-- flash_attn/flash_attn_triton_amd/fwd_ref.py | 336 ++++- .../flash_attn_triton_amd/interface_fa.py | 31 +- .../flash_attn_triton_amd/interface_fa_v3.py | 660 ++++++++++ flash_attn/flash_attn_triton_amd/test.py | 777 ----------- flash_attn/flash_attn_triton_amd/train.py | 404 ------ flash_attn/flash_attn_triton_amd/utils.py | 91 +- hopper/flash_attn_interface.py | 26 +- hopper/setup.py | 8 +- hopper/test_flash_attn_triton_amd.py | 1173 +++++++++++++++++ tests/test_flash_attn_triton_amd.py | 56 +- 19 files changed, 3146 insertions(+), 2648 deletions(-) delete mode 100644 .github/workflows/amd_tests.yml delete mode 100644 flash_attn/flash_attn_triton_amd/.gitignore delete mode 100644 flash_attn/flash_attn_triton_amd/README.md mode change 100644 => 100755 flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py mode change 100644 => 100755 flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py mode change 100644 => 100755 flash_attn/flash_attn_triton_amd/bwd_prefill_split.py delete mode 100644 flash_attn/flash_attn_triton_amd/fp8.py mode change 100644 => 100755 flash_attn/flash_attn_triton_amd/fwd_decode.py mode change 100644 => 100755 flash_attn/flash_attn_triton_amd/fwd_prefill.py mode change 100644 => 100755 flash_attn/flash_attn_triton_amd/fwd_ref.py create mode 100755 flash_attn/flash_attn_triton_amd/interface_fa_v3.py delete mode 100644 flash_attn/flash_attn_triton_amd/test.py delete mode 100644 flash_attn/flash_attn_triton_amd/train.py mode change 100644 => 100755 hopper/flash_attn_interface.py mode change 100644 => 100755 hopper/setup.py create mode 100755 hopper/test_flash_attn_triton_amd.py diff --git a/.github/workflows/amd_tests.yml b/.github/workflows/amd_tests.yml deleted file mode 100644 index 2e3f061c78d..00000000000 --- a/.github/workflows/amd_tests.yml +++ /dev/null @@ -1,65 +0,0 @@ -name: AMD Perf Kernel Tests - -on: - workflow_dispatch: - pull_request: - branches: [main_perf] - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - Integration-Tests-AMD: - runs-on: ${{ matrix.runner }} - strategy: - matrix: - runner: [linux-mi300-gpu-1] - fail-fast: false # disables failing the entire job when one matrix entry fails - timeout-minutes: 720 # self hosted runners can run jobs for longer than the default of 360 minutes - container: - image: rocm/pytorch:latest - options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --shm-size 16G --group-add video --user root - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Show Device Info - run: | - rocminfo | grep gfx - - - name: Uninstall Triton - run: | - pip uninstall -y triton - rm -rf ~/.triton - rm -rf ./triton/python/build - - - name: Install Triton - run: | - pip install triton==3.3.0 - - - name: Show Triton version - run: | - pip show triton - - - name: Build - run: | - FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install - - - name: Install dependencies for bench and misc - run: | - pip install matplotlib pandas tabulate - - - name: AMD Internal Tests - run: | - FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest flash_attn/flash_attn_triton_amd/test.py - - - name: Flash Attention Tests - run: | - FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest -n 8 tests/test_flash_attn_triton_amd.py - - - name: AMD Bench - run: | - python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_func - python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_varlen_func - python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_with_kvcache diff --git a/flash_attn/flash_attn_triton_amd/.gitignore b/flash_attn/flash_attn_triton_amd/.gitignore deleted file mode 100644 index 21538fc4e4a..00000000000 --- a/flash_attn/flash_attn_triton_amd/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -bwd_prefill_fused.py -bwd_prefill_onekernel.py \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/README.md b/flash_attn/flash_attn_triton_amd/README.md deleted file mode 100644 index 87213c1883c..00000000000 --- a/flash_attn/flash_attn_triton_amd/README.md +++ /dev/null @@ -1,113 +0,0 @@ -Flash Attention Triton Kernel -=============== - -#### Introduction -The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress. - -It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes. - -These features are supported in Fwd and Bwd -1) Fwd and Bwd with causal masking -2) Variable sequence lengths -3) Arbitrary Q and KV sequence lengths -4) Arbitrary head sizes -5) Multi and grouped query attention -6) Dropout -7) Rotary embeddings -8) ALiBi - -We are working on the following things -1) Paged Attention -2) Sliding Window -3) FP8 -4) Performance Improvements - -##### Getting Started -To get started with the triton backend for AMD, follow the steps below. - -First install the recommended Triton version - -``` -pip install triton==3.3.0 -``` -Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. - -``` -cd flash-attention -git checkout main_perf -FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install -``` - -To test that things are working, you can run our tests. These tests take hours so you don't need to run the full thing. -``` -FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py -``` - -You can use autotune for better performance by using this flag `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"` -``` -FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" python $PATH_TO_CODE -``` - -###### Docker -You can also use the Dockerfile below which does the above steps on top of the latest rocm/pytorch image. -``` -FROM rocm/pytorch:latest - -WORKDIR /workspace - -# install triton -RUN pip install triton==3.3.0 - -# install flash attention -ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" - -RUN git clone https://github.com/ROCm/flash-attention.git &&\ - cd flash-attention &&\ - git checkout main_perf &&\ - python setup.py install - -# set working dir -WORKDIR /workspace/flash-attention -``` - -To build the docker file -``` -docker build -t fa_triton . -``` - -To run the docker image -``` -docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton -``` - -###### FP8 -In our fork We have created the following api functions that use fp8 internally to compute their values. These functions are `flash_attn_fp8_func`, `flash_attn_varlen_fp8_func`, `flash_attn_qkvpacked_fp8_func` and `flash_attn_varlen_qkvpacked_fp8_func`. Here is a usage example - -``` -from flash_attn.flash_attn_triton_amd.fp8 import flash_attn_qkvpacked_fp8_func - -# forward pass -out, lse, S_dmask = flash_attn_qkvpacked_fp8_func( - qkv, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - -# backward pass -do = torch.randn_like(out) -dqkv = torch.autograd.grad(out, (qkv), do) -``` - -You can use the other api functions in a similar way. - - - -##### Credits -AMD Triton kernels team - -OpenAI kernel team diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py old mode 100644 new mode 100755 index e969a3770b8..51e53daedc2 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py @@ -1,7 +1,7 @@ import torch import triton import triton.language as tl -from flash_attn.flash_attn_triton_amd.utils import compute_fp8_scaling_factors +from flash_attn.flash_attn_triton_amd.utils import compute_fp8_scaling_factors, DEBUG, is_fp8 from typing import Optional, Tuple @@ -1503,11 +1503,17 @@ def attention_prefill_backward_triton_fused_atomics_impl( descale_v: Optional[torch.Tensor] = None, descale_do: Optional[torch.Tensor] = None, fused: bool = False, + # seqused for FA v3 (currently ignored in this implementation) + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, ): IS_FP8 = is_fp8(q) if IS_FP8: FP8_MAX = torch.finfo(q.dtype).max descale_strides = (descale_q.stride(0),descale_k.stride(0),descale_v.stride(0),descale_do.stride(0) ) + + if DEBUG: + print(f"FP8 path triggered in bwd_prefill_fused_atomics.py") else: FP8_MAX = None stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_do_z = None diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py old mode 100644 new mode 100755 index 5b2f8858d11..0d3b3a6fdf4 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py @@ -3,10 +3,9 @@ import triton # type: ignore import triton.language as tl # type: ignore from typing import Literal, Optional -from .utils import AUTOTUNE, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, \ +from .utils import AUTOTUNE, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, DEBUG, compute_fp8_scaling_factors, \ create_dropout_mask, create_dropout_mask_varlen, is_cdna, is_fp8, is_rdna, round_multiple -DEBUG= False # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) @@ -29,7 +28,7 @@ def get_autotune_configs(): ] preprocess_autotune_keys = [ "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", ] causal_autotune_configs = [ triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config @@ -39,7 +38,7 @@ def get_autotune_configs(): ] causal_autotune_keys = [ "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", ] noncausal_autotune_configs = [ triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config @@ -49,7 +48,7 @@ def get_autotune_configs(): ] noncausal_autotune_keys = [ "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", ] return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) @@ -72,21 +71,21 @@ def get_autotune_configs(): ] preprocess_autotune_keys = [ "max_seqlen_q", - "ACTUAL_HEAD_DIM", "IS_VARLEN", + "ACTUAL_HEAD_DIM_V", "IS_VARLEN", ] causal_autotune_configs = [ triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), ] causal_autotune_keys = [ "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", ] noncausal_autotune_configs = [ triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), ] noncausal_autotune_keys = [ "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", ] return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) @@ -117,8 +116,8 @@ def _bwd_preprocess( cu_seqlens_q, max_seqlen_q, Descale_do, PRE_BLOCK: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, IS_VARLEN: tl.constexpr, IS_FP8: tl.constexpr ): @@ -136,7 +135,7 @@ def _bwd_preprocess( # Compute offsets offs_m = pid_m * PRE_BLOCK + tl.arange(0, PRE_BLOCK) - offs_d = tl.arange(0, HEAD_DIM) + offs_d = tl.arange(0, HEAD_DIM_V) # pointer offsets for O & DO off_o = ( bid * stride_ob + hid * stride_oh @@ -152,9 +151,9 @@ def _bwd_preprocess( # create masks mask_m = offs_m < seqlen_q mask_md = mask_m[:, None] - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - if PADDED_HEAD: - mask_md &= offs_d[None, :] < ACTUAL_HEAD_DIM + PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) + if PADDED_HEAD_V: + mask_md &= offs_d[None, :] < ACTUAL_HEAD_DIM_V # load o = tl.load(O + off_o, mask=mask_md, other=0.0) do = tl.load(DO + off_do, mask=mask_md, other=0.0) @@ -185,8 +184,10 @@ def _bwd_dkdv_inner( stride_lse_m, stride_delta_m, BLOCK_M: tl.constexpr, # 16 BLOCK_N: tl.constexpr, # 128 - HEAD_DIM: tl.constexpr, # - ACTUAL_HEAD_DIM: tl.constexpr, # + HEAD_DIM_QK: tl.constexpr, # + HEAD_DIM_V: tl.constexpr, # + ACTUAL_HEAD_DIM_QK: tl.constexpr, # + ACTUAL_HEAD_DIM_V: tl.constexpr, # dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k @@ -203,18 +204,20 @@ def _bwd_dkdv_inner( DEBUG_TRITON_DETAIL: tl.constexpr, ): # if HEAD_DIM is padded - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + PADDED_HEAD_QK: tl.constexpr = (ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK) + PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) delta_qk = seqlen_q - seqlen_k offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) - offs_k = tl.arange(0, HEAD_DIM) + offs_k_qk = tl.arange(0, HEAD_DIM_QK) + offs_k_v = tl.arange(0, HEAD_DIM_V) # mask to make sure not OOB of seqlen_q mask_n = offs_n < seqlen_k # Q and DO are (seqlen_q, head_dim) - # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM, 1), transpose of q - qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k[:, None] * stride_qk - # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM), NOT transposed - do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM_QK, 1), transpose of q + qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k_qk[:, None] * stride_qk + # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM_V), NOT transposed + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k_v[None, :] * stride_dok # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. tl.static_assert(BLOCK_N % BLOCK_M == 0) curr_m = start_m @@ -231,9 +234,10 @@ def _bwd_dkdv_inner( mask_qT = mask_m[None, :] mask_do = mask_m[:, None] mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) - if PADDED_HEAD: - mask_qT &= offs_k[:, None] < ACTUAL_HEAD_DIM - mask_do &= offs_k[None, :] < ACTUAL_HEAD_DIM + if PADDED_HEAD_QK: + mask_qT &= offs_k_qk[:, None] < ACTUAL_HEAD_DIM_QK + if PADDED_HEAD_V: + mask_do &= offs_k_v[None, :] < ACTUAL_HEAD_DIM_V qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) # generate dropout mask if ENABLE_DROPOUT: @@ -341,8 +345,10 @@ def _bwd_dq_inner( seqlen_q, seqlen_k, # BLOCK_M2: tl.constexpr, # BLOCK_N2: tl.constexpr, # - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, # + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, # dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, # Filled in by the wrapper. @@ -358,17 +364,19 @@ def _bwd_dq_inner( DEBUG_TRITON_DETAIL: tl.constexpr, ): # if HEAD_DIM is padded - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + PADDED_HEAD_QK: tl.constexpr = (ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK) + PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) delta_qk = seqlen_q - seqlen_k offs_m = start_m + tl.arange(0, BLOCK_M2) offs_n = start_n + tl.arange(0, BLOCK_N2) - offs_k = tl.arange(0, HEAD_DIM) + offs_k_qk = tl.arange(0, HEAD_DIM_QK) + offs_k_v = tl.arange(0, HEAD_DIM_V) # mask to make sure not OOB of seqlen_q mask_m = offs_m < seqlen_q - kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk - vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k_qk[:, None] * stride_kk + vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k_v[:, None] * stride_vk # D (= delta) is pre-divided by ds_scale. Di = tl.load(Delta + offs_m * stride_delta_m, mask=mask_m, other=0.0) # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. @@ -387,12 +395,15 @@ def _bwd_dq_inner( if DEBUG_TRITON_DETAIL: print(f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}") # noqa: E701 if DEBUG_TRITON_DETAIL: print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 mask_kT = mask_n[None, :] + mask_vT = mask_n[None, :] mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) - if PADDED_HEAD: - mask_kT &= offs_k[:, None] < ACTUAL_HEAD_DIM + if PADDED_HEAD_QK: + mask_kT &= offs_k_qk[:, None] < ACTUAL_HEAD_DIM_QK + if PADDED_HEAD_V: + mask_vT &= offs_k_v[:, None] < ACTUAL_HEAD_DIM_V kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) - vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) + vT = tl.load(vT_ptrs, mask=mask_vT, other=0.0) if ENABLE_DROPOUT: # NOTE: dropout is transposed because it is used to mask pT @@ -477,6 +488,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_az, stride_ah, HQ, HK, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Add seqused parameters max_seqlen_q, max_seqlen_k, Dropout_mask, dropout_p, philox_seed, philox_offset_base, Alibi_slopes, @@ -486,8 +498,10 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, IS_VARLEN: tl.constexpr, USE_ALIBI: tl.constexpr, @@ -495,6 +509,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr, + USE_SEQUSED: tl.constexpr, # Add flag for seqused DEBUG_TRITON: tl.constexpr, DEBUG_TRITON_DETAIL: tl.constexpr, ): @@ -514,21 +529,31 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba q_end = tl.load(cu_seqlens_q + bid + 1) k_start = tl.load(cu_seqlens_k + bid) k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start + + # If seqused is provided, use it to limit the actual sequence length + if USE_SEQUSED: + actual_seqlen_q = tl.load(seqused_q + bid) if seqused_q is not None else q_end - q_start + seqlen_q = tl.minimum(actual_seqlen_q, q_end - q_start) + actual_seqlen_k = tl.load(seqused_k + bid) if seqused_k is not None else k_end - k_start + seqlen_k = tl.minimum(actual_seqlen_k, k_end - k_start) + else: + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start delta_qk = seqlen_q - seqlen_k if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") # noqa: E701 - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - offs_d = tl.arange(0, HEAD_DIM) + PADDED_HEAD_QK: tl.constexpr = (ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK) + PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) + offs_d_qk = tl.arange(0, HEAD_DIM_QK) + offs_d_v = tl.arange(0, HEAD_DIM_V) GROUP_SIZE: tl.constexpr = HQ // HK # align the delta_qk start_n = pid * BLOCK_N1 if start_n < seqlen_k: # This section does dk and dv - dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, HEAD_DIM_QK], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM_V], dtype=tl.float32) # q > k: diretcly skip all the way until the start of causal block start_delta_q_gt_k = delta_qk @@ -548,17 +573,21 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba offs_n = start_n + tl.arange(0, BLOCK_N1) # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - if PADDED_HEAD: - mask_d = offs_d < ACTUAL_HEAD_DIM - mask_kv &= mask_d[None, :] + mask_k = offs_n[:, None] < seqlen_k + mask_v = offs_n[:, None] < seqlen_k + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_k &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_v &= mask_d_v[None, :] # K/V tensors not changed for the group - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d_qk[None, :] * stride_kd + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d_v[None, :] * stride_vd # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k, mask=mask_kv, other=0.0) - v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + k = tl.load(K + adj_k, mask=mask_k, other=0.0) + v = tl.load(V + adj_v, mask=mask_v, other=0.0) # If MQA / GQA, set the K and V head offsets appropriately. # hqid = hkid for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): @@ -628,7 +657,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_dropoutm, stride_dropoutn, # strides for dropout stride_lse_m, stride_delta_m, MASK_BLOCK_M1, BLOCK_N1, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, # head dim dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k @@ -658,7 +687,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_dropoutm, stride_dropoutn, # strides for dropout stride_lse_m, stride_delta_m, BLOCK_M1, BLOCK_N1, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, # head dim dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k @@ -676,13 +705,13 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba # end of GQA/MQA of dkdv # Write back dV adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn - offs_dv = offs_n[:, None] * stride_dvn + offs_d[None, :] * stride_dvd - tl.store(DV + adj_dv + offs_dv, dv, mask=mask_kv) + offs_dv = offs_n[:, None] * stride_dvn + offs_d_v[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_v) # write back dk adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn - offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd + offs_dk = offs_n[:, None] * stride_dkn + offs_d_qk[None, :] * stride_dkd dk *= sm_scale - tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_k) # This part does dq start_m = pid * BLOCK_M2 @@ -696,11 +725,15 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba offs_m = start_m + tl.arange(0, BLOCK_M2) # Mask for loading K and V mask_q = offs_m[:, None] < seqlen_q - if PADDED_HEAD: - mask_d = offs_d < ACTUAL_HEAD_DIM - mask_q &= mask_d[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd - offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + mask_do = offs_m[:, None] < seqlen_q + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_q &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_do &= mask_d_v[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d_v[None, :] * stride_dod # NOTE: don't assume that the strides for k and v are the same! K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn @@ -739,7 +772,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba dropout_offset = \ Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_do, other=0.0) m = tl.load(M + adj_m + offs_m * stride_lse_m, mask=offs_m < seqlen_q) m = m[:, None] @@ -757,7 +790,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba else: descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, HEAD_DIM_QK], dtype=tl.float32) dq = _bwd_dq_inner( dq, q, K, V, do, m, Delta_ptr, sm_scale, @@ -767,7 +800,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_delta_m, seqlen_q, seqlen_k, BLOCK_M2, MASK_BLOCK_N2, - HEAD_DIM, ACTUAL_HEAD_DIM, + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, start_m, start_n, end_n, num_steps, @@ -794,7 +827,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_delta_m, seqlen_q, seqlen_k, BLOCK_M2, BLOCK_N2, - HEAD_DIM, ACTUAL_HEAD_DIM, + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, start_m, start_n, end_n, num_steps, @@ -810,7 +843,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba ) # Write back dQ. adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd + offs_dq = offs_m[:, None] * stride_dqm + offs_d_qk[None, :] * stride_dqd dq *= sm_scale tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) # end of GQA/MQA of dq @@ -838,6 +871,7 @@ def bwd_kernel_noncausal( stride_az, stride_ah, HQ, HK, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Add seqused parameters max_seqlen_q, max_seqlen_k, Dropout_mask, dropout_p, philox_seed, philox_offset_base, Alibi_slopes, @@ -847,8 +881,10 @@ def bwd_kernel_noncausal( BLOCK_M2: tl.constexpr, # 128 BLOCK_N2: tl.constexpr, # 32 BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, IS_VARLEN: tl.constexpr, USE_ALIBI: tl.constexpr, @@ -856,6 +892,7 @@ def bwd_kernel_noncausal( IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr, + USE_SEQUSED: tl.constexpr, # Add flag for seqused DEBUG_TRITON: tl.constexpr, DEBUG_TRITON_DETAIL: tl.constexpr, ): @@ -875,31 +912,45 @@ def bwd_kernel_noncausal( q_end = tl.load(cu_seqlens_q + bid + 1) k_start = tl.load(cu_seqlens_k + bid) k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start + + # If seqused is provided, use it to limit the actual sequence length + if USE_SEQUSED: + actual_seqlen_q = tl.load(seqused_q + bid) if seqused_q is not None else q_end - q_start + seqlen_q = tl.minimum(actual_seqlen_q, q_end - q_start) + actual_seqlen_k = tl.load(seqused_k + bid) if seqused_k is not None else k_end - k_start + seqlen_k = tl.minimum(actual_seqlen_k, k_end - k_start) + else: + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - offs_d = tl.arange(0, HEAD_DIM) + PADDED_HEAD_QK: tl.constexpr = (ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK) + PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) + offs_d_qk = tl.arange(0, HEAD_DIM_QK) + offs_d_v = tl.arange(0, HEAD_DIM_V) GROUP_SIZE: tl.constexpr = HQ // HK start_n = pid * BLOCK_N1 if start_n < seqlen_k: - dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, HEAD_DIM_QK], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM_V], dtype=tl.float32) offs_n = start_n + tl.arange(0, BLOCK_N1) # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - if PADDED_HEAD: - mask_d = offs_d < ACTUAL_HEAD_DIM - mask_kv &= mask_d[None, :] + mask_k = offs_n[:, None] < seqlen_k + mask_v = offs_n[:, None] < seqlen_k + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_k &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_v &= mask_d_v[None, :] # NOTE: don't assume that the strides for k and v are the same! # K/V tensors not changed for the group - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d_qk[None, :] * stride_kd + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d_v[None, :] * stride_vd # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k, mask=mask_kv, other=0.0) - v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + k = tl.load(K + adj_k, mask=mask_k, other=0.0) + v = tl.load(V + adj_v, mask=mask_v, other=0.0) # If MQA / GQA, set the K and V head offsets appropriately. for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): # offset input and output tensor by batch and Q/K heads @@ -948,7 +999,7 @@ def bwd_kernel_noncausal( stride_lse_m, stride_delta_m, BLOCK_M1, BLOCK_N1, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, # head dim dropout_p, philox_seed, batch_philox_offset, dropout_offset, # alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k @@ -966,13 +1017,13 @@ def bwd_kernel_noncausal( # Write back dV adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn - offs_dv = offs_n[:, None] * stride_dvn + offs_d[None, :] * stride_dvd - tl.store(DV + adj_dv + offs_dv, dv, mask=mask_kv) + offs_dv = offs_n[:, None] * stride_dvn + offs_d_v[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_v) # write back dk adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn - offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd + offs_dk = offs_n[:, None] * stride_dkn + offs_d_qk[None, :] * stride_dkd dk *= sm_scale - tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_k) # THIS PART DOES DQ start_m = pid * BLOCK_M2 @@ -980,11 +1031,15 @@ def bwd_kernel_noncausal( offs_m = start_m + tl.arange(0, BLOCK_M2) # Mask for loading K and V mask_q = offs_m[:, None] < seqlen_q - if PADDED_HEAD: - mask_d = offs_d < ACTUAL_HEAD_DIM - mask_q &= mask_d[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd - offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + mask_do = offs_m[:, None] < seqlen_q + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_q &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_do &= mask_d_v[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d_v[None, :] * stride_dod K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn # If MQA / GQA, set the K and V head offsets appropriately. @@ -1016,7 +1071,7 @@ def bwd_kernel_noncausal( Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_do, other=0.0) m = tl.load(M + adj_m + offs_m * stride_lse_m, mask=offs_m < seqlen_q) m = m[:, None] @@ -1034,7 +1089,7 @@ def bwd_kernel_noncausal( end_n = seqlen_k num_steps = tl.cdiv(seqlen_k, BLOCK_N2) - dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, HEAD_DIM_QK], dtype=tl.float32) dq = _bwd_dq_inner( dq, q, K, V, do, m, Delta_ptr, sm_scale, @@ -1044,7 +1099,7 @@ def bwd_kernel_noncausal( stride_delta_m, seqlen_q, seqlen_k, BLOCK_M2, BLOCK_N2, - HEAD_DIM, ACTUAL_HEAD_DIM, + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, start_m, start_n, end_n, num_steps, @@ -1060,7 +1115,7 @@ def bwd_kernel_noncausal( ) # Write back dQ. adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd + offs_dq = offs_m[:, None] * stride_dqm + offs_d_qk[None, :] * stride_dqd dq *= sm_scale tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) @@ -1106,6 +1161,9 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( descale_dq: Optional[torch.Tensor], descale_dk: Optional[torch.Tensor], descale_dv: Optional[torch.Tensor], + # seqused for FA v3 + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, ): # get params, strides and shape IS_VARLEN = layout == "thd" @@ -1135,13 +1193,13 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( assert max_seqlen_k is not None, "max_seqlen_k must be provided for varlen layout" # assert head dimensions - assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert head_size_q == head_size_k, f"head sizes must match: q={head_size_q}, k={head_size_k}" assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" assert nheads_lse == nheads_q, f"softmax_lse heads {nheads_lse} != q heads {nheads_q}" # assert output shapes - assert o.shape == (total_seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_q)}" + assert o.shape == (total_seqlen_q, nheads_q, head_size_v), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_v)}" assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" @@ -1157,7 +1215,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( # set vars batch = len(cu_seqlens_q) - 1 - head_size = head_size_q + head_size_qk = head_size_q # strides stride_qb, stride_qm, stride_qh, stride_qd = 0, q.stride(0), q.stride(1), q.stride(2) @@ -1180,7 +1238,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( assert batch_q == batch_k == batch_v, f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" # assert head dimensions - assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert head_size_q == head_size_k, f"head sizes must match: q={head_size_q}, k={head_size_k}" assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" @@ -1188,7 +1246,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( assert seqlen_k == seqlen_v, f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" # assert output shapes - assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected" + assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_v), f"o shape {o.shape} != expected" assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" @@ -1199,7 +1257,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( # set vars batch = batch_q - head_size = head_size_q + head_size_qk = head_size_q max_seqlen_q = seqlen_q max_seqlen_k = seqlen_k @@ -1233,6 +1291,9 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None + + if DEBUG: + print(f"FP8 path triggered in bwd_prefill_fused_no_atomics.py (FP8_OUTPUT={FP8_OUTPUT})") else: FP8_MAX = None FP8_OUTPUT = False @@ -1242,10 +1303,14 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) # get closest power of 2 over or equal to 32. - padded_d_model = 1 << (head_size - 1).bit_length() - padded_d_model = max(padded_d_model, 32) - HEAD_DIM = padded_d_model - ACTUAL_HEAD_DIM = head_size + padded_d_model_qk = 1 << (head_size_qk - 1).bit_length() + padded_d_model_qk = max(padded_d_model_qk, 32) + padded_d_model_v = 1 << (head_size_v - 1).bit_length() + padded_d_model_v = max(padded_d_model_v, 32) + HEAD_DIM_QK = padded_d_model_qk + HEAD_DIM_V = padded_d_model_v + ACTUAL_HEAD_DIM_QK = head_size_qk + ACTUAL_HEAD_DIM_V = head_size_v # init delta if OLD_LSE: @@ -1280,13 +1345,13 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_descale_do_z, cu_seqlens_q, max_seqlen_q, descale_do, - HEAD_DIM=HEAD_DIM, - ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, IS_VARLEN=IS_VARLEN, IS_FP8=IS_FP8 ) - if DEBUG: + if False: print("delta:", delta, delta.shape) # dropout mask tensor for debugging. We dump the dropout mask created in @@ -1337,12 +1402,15 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_az, stride_ah, nheads_q, nheads_k, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Pass seqused tensors max_seqlen_q, max_seqlen_k, dropout_mask, dropout_p, philox_seed, philox_offset, alibi_slopes, descale_q, descale_k, descale_v, descale_do, - HEAD_DIM=HEAD_DIM, - ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + HEAD_DIM_QK=HEAD_DIM_QK, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, USE_ALIBI=use_alibi, @@ -1350,6 +1418,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT, + USE_SEQUSED=(seqused_q is not None or seqused_k is not None), # Add flag for seqused DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) @@ -1371,12 +1440,15 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_az, stride_ah, nheads_q, nheads_k, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Pass seqused tensors max_seqlen_q, max_seqlen_k, dropout_mask, dropout_p, philox_seed, philox_offset, alibi_slopes, descale_q, descale_k, descale_v, descale_do, - HEAD_DIM=HEAD_DIM, - ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + HEAD_DIM_QK=HEAD_DIM_QK, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, USE_ALIBI=use_alibi, @@ -1384,6 +1456,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT, + USE_SEQUSED=(seqused_q is not None or seqused_k is not None), # Add flag for seqused DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py old mode 100644 new mode 100755 index 56187ea71f0..9ffdc9dea1a --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py @@ -2,11 +2,9 @@ import triton # type: ignore import triton.language as tl # type: ignore from typing import Literal, Optional -from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, get_shapes_from_layout, \ +from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, DEBUG, compute_fp8_scaling_factors, get_shapes_from_layout, \ get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen, is_fp8 -DEBUG = False - # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) @@ -1092,6 +1090,9 @@ def attention_prefill_backward_triton_split_impl( descale_dq: Optional[torch.Tensor], descale_dk: Optional[torch.Tensor], descale_dv: Optional[torch.Tensor], + # seqused for FA v3 (currently ignored in this implementation) + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, ): # debug DEBUG_TRITON: bool = False @@ -1117,6 +1118,9 @@ def attention_prefill_backward_triton_split_impl( stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None + + if DEBUG: + print(f"FP8 path triggered in bwd_prefill_split.py (FP8_OUTPUT={FP8_OUTPUT})") else: FP8_MAX = None FP8_OUTPUT = False diff --git a/flash_attn/flash_attn_triton_amd/fp8.py b/flash_attn/flash_attn_triton_amd/fp8.py deleted file mode 100644 index df79c7926b2..00000000000 --- a/flash_attn/flash_attn_triton_amd/fp8.py +++ /dev/null @@ -1,716 +0,0 @@ -from typing import Optional, Sequence, Tuple, Union -import torch -import torch.nn as nn -from .utils import cast_to_fp8, is_fp8 -from . import interface_fa as flash_attn_gpu - - -def maybe_contiguous(x): - return x.contiguous() if x is not None and x.stride(-1) != 1 else x - -class FlashAttnFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # figure out fwd parameters - if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" - q_fp8 = q - k_fp8 = k - v_fp8 = v - out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "bshd") - k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "bshd") - v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "bshd") - out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None - - q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] - _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - alibi_slopes, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_o=descale_o - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_fp8[..., :head_size_og] # NOTE: this used to be out_padded. It might cause issue doing an empty - - # check output type - assert out.dtype == q.dtype, "Input and output type must match otherwise there will be implicit casting by autograd" - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - # figure out bwd parameters - if is_fp8(dout): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_do is not None), f"You need to pass descale factors for do" - dout_padded_fp8 = dout_padded - dq, descale_dq = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - dk, descale_dk = torch.zeros_like(k_fp8), torch.zeros_like(descale_k) - dv, descale_dv = torch.zeros_like(v_fp8), torch.zeros_like(descale_v) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "bshd") - dq, descale_dq = torch.zeros_like(q_fp8, dtype=torch.float32), None - dk, descale_dk = torch.zeros_like(k_fp8, dtype=torch.float32), None - dv, descale_dv = torch.zeros_like(v_fp8, dtype=torch.float32), None - - # dq, dk, dv are allocated by us so they should already be contiguous - dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] - flash_attn_gpu.bwd( - dout_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - softmax_lse, - dq, - dk, - dv, - ctx.alibi_slopes, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.deterministic, - None, # gen_ - rng_state, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - descale_dq, - descale_dk, - descale_dv, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None - -def flash_attn_fp8_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None -): - return FlashAttnFP8Func.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - descale_q, - descale_k, - descale_v, - descale_do - ) - -class FlashAttnVarlenFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - block_table, - is_grad_enabled, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # figure out fwd parameters - if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" - q_fp8 = q - k_fp8 = k - v_fp8 = v - out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_q, max_seqlen=max_seqlen_q) - k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_k, max_seqlen=max_seqlen_k) - v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_k, max_seqlen=max_seqlen_k) - out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None - - q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] - _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - cu_seqlens_q, - cu_seqlens_k, - None, - None, - block_table, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - False, - causal, - window_size[0], - window_size[1], - softcap, - return_softmax, - None, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_o=descale_o - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_fp8[..., :head_size_og] # NOTE: this used to be out_padded. It might cause issue doing an empty - - # check output type - assert out.dtype == q.dtype, "Input and output type must match otherwise there will be implicit casting by autograd" - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - # figure out bwd parameters - if is_fp8(dout_padded): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_do is not None), f"You need to pass descale factors for do" - dout_padded_fp8 = dout_padded - dq, descale_dq = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - dk, descale_dk = torch.zeros_like(k_fp8), torch.zeros_like(descale_k) - dv, descale_dv = torch.zeros_like(v_fp8), torch.zeros_like(descale_v) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_q, max_seqlen=ctx.max_seqlen_q) - dq, descale_dq = torch.zeros_like(q_fp8, dtype=torch.float32), None - dk, descale_dk = torch.zeros_like(k_fp8, dtype=torch.float32), None - dv, descale_dv = torch.zeros_like(v_fp8, dtype=torch.float32), None - - # dq, dk, dv are allocated by us so they should already be contiguous - dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] - flash_attn_gpu.varlen_bwd( - dout_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - ctx.alibi_slopes, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.softmax_scale, - False, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.deterministic, - None, - rng_state, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - descale_dq, - descale_dk, - descale_dv, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None - - -def flash_attn_varlen_fp8_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - block_table=None -): - return FlashAttnVarlenFP8Func.apply( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - block_table, - torch.is_grad_enabled() - ) - -class FlashAttnQKVPackedFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None - ): - is_grad = is_grad_enabled and qkv.requires_grad - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach() - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # figure out fwd parameters - if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" - q_fp8 = q - k_fp8 = k - v_fp8 = v - out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "bshd") - k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "bshd") - v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "bshd") - out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None - - q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] - _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - alibi_slopes, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_o=descale_o, - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_fp8[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors - qkv_shape = q_fp8.shape[:-2] + (3, *q_fp8.shape[-2:]) - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - # figure out bwd parameters - if is_fp8(dout): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_do is not None), f"You need to pass descale factors for do" - dout_padded_fp8 = dout_padded - dqkv, descale_dqkv = torch.zeros(qkv_shape, device=q_fp8.device), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "bshd") - dqkv, descale_dqkv = torch.zeros(qkv_shape, dtype=torch.float32, device=q_fp8.device), None - - - # dq, dk, dv are allocated by us so they should already be contiguous - dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] - flash_attn_gpu.bwd( - dout_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - softmax_lse, - dqkv[:, :, 0], - dqkv[:, :, 1], - dqkv[:, :, 2], - ctx.alibi_slopes, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.deterministic, - None, # gen_ - rng_state, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - None, - None, - None, - ) - dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None - - -def flash_attn_qkvpacked_fp8_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # <=0.0 means deactivate - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - return FlashAttnQKVPackedFP8Func.apply( - qkv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -class FlashAttnVarlenQKVPackedFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None - ): - is_grad = is_grad_enabled and qkv.requires_grad - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach() - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # figure out fwd parameters - if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" - q_fp8 = q - k_fp8 = k - v_fp8 = v - out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None - - q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] - _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - cu_seqlens, - cu_seqlens, - None, - None, - None, - alibi_slopes, - max_seqlen, - max_seqlen, - dropout_p, - softmax_scale, - False, - causal, - window_size[0], - window_size[1], - softcap, - return_softmax, - None, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_o=descale_o - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) - ctx.dropout_p = dropout_p - ctx.max_seqlen = max_seqlen - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_fp8[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors - qkv_shape = q_fp8.shape[:-2] + (3, *q_fp8.shape[-2:]) - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - # figure out bwd parameters - if is_fp8(dout_padded): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_do is not None), f"You need to pass descale factors for do" - dout_padded_fp8 = dout_padded - dqkv, descale_dqkv = torch.zeros(qkv_shape, device=q_fp8.device), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=ctx.max_seqlen) - dqkv, descale_dqkv = torch.zeros(qkv_shape, dtype=torch.float32, device=q_fp8.device), None - - # dq, dk, dv are allocated by us so they should already be contiguous - dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] - flash_attn_gpu.varlen_bwd( - dout_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - softmax_lse, - dqkv[:, 0], - dqkv[:, 1], - dqkv[:, 2], - cu_seqlens, - cu_seqlens, - ctx.alibi_slopes, - ctx.max_seqlen, - ctx.max_seqlen, - ctx.dropout_p, - ctx.softmax_scale, - False, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.deterministic, - None, - rng_state, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - None, - None, - None, - ) - dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None, None, None - - -def flash_attn_varlen_qkvpacked_fp8_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - return FlashAttnVarlenQKVPackedFP8Func.apply( - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py old mode 100644 new mode 100755 index d3f7f9c32b9..c14e02b9b56 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -2,7 +2,7 @@ import triton import triton.language as tl from typing import Literal, Optional, Union -from .utils import AUTOTUNE, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna +from .utils import AUTOTUNE, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna, is_fp8 DEBUG = False @@ -59,6 +59,220 @@ def get_autotune_configs(): (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) = get_autotune_configs() +@triton.jit +def _attn_fwd_inner( + q, kT, v, start_n, + m_i, l_i, acc, + pid_m, hi, + q_descale, k_descale, v_descale, # FP8 scaling factors + IS_FP8: tl.constexpr, # FP8 flag + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + N_CTX_Q: tl.constexpr, + N_CTX_K_FINAL: tl.constexpr, + USE_ALIBI: tl.constexpr, + alibi_slope, + USE_SLIDING_WINDOW: tl.constexpr, + IS_CAUSAL: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, +): + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_FP8: + qk += (tl.dot(q, kT) * q_descale * k_descale) # Apply FP8 scaling + else: + qk += tl.dot(q, kT) # noqa: F821 + + if USE_ALIBI: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = start_n + tl.arange(0, BLOCK_N) + + # Compute relative positions + relative_pos = row_idx[:, None] + N_CTX_K_FINAL - (N_CTX_Q + col_idx[None, :]) + relative_pos = tl.abs(relative_pos) + + # Compute ALiBi bias + alibi_bias = -1 * alibi_slope * relative_pos + qk += (alibi_bias * 1.44269504) + + # ------------------------------------------------------------------ + # masking + # ------------------------------------------------------------------ + if USE_SLIDING_WINDOW: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # q positions + col_idx = start_n + tl.arange(0, BLOCK_N) # k positions + row = row_idx[:, None] # [M,1] + col = col_idx[None, :] # [1,N] + + if IS_CAUSAL: + # -------- causal + window -------- + diag = N_CTX_K_FINAL - N_CTX_Q # sk-sq + causal_ok = col <= row + diag + if WINDOW_SIZE_LEFT < 0: # only right window + win_ok = col <= row + diag + WINDOW_SIZE_RIGHT + else: # both sides + win_ok = ((col >= row + diag - WINDOW_SIZE_LEFT) & + (col <= row + diag + WINDOW_SIZE_RIGHT)) + mask = ~(causal_ok & win_ok) # True ⇒ -inf + else: + # -------- non-causal window -------- + sk, sq = N_CTX_K_FINAL, N_CTX_Q + if WINDOW_SIZE_LEFT < 0: + mask = col > row + (sk - sq) + WINDOW_SIZE_RIGHT + else: + right = tl.minimum(row + (sk - sq) + WINDOW_SIZE_RIGHT, sk) + left = row + (sk - sq) - WINDOW_SIZE_LEFT + mask = (col > right) | (col < left) + qk = tl.where(mask, float("-inf"), qk) + else: + if IS_CAUSAL: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = start_n + tl.arange(0, BLOCK_N) + + # create a N_CTX_Q x kv_len causal mask + col_offset = N_CTX_K_FINAL - N_CTX_Q + causal_mask = row_idx[:, None] >= (col_idx[None, :] - col_offset) + + # Apply the mask + qk = tl.where(causal_mask, qk, float("-inf")) + + # TODO: This is slow, and only needed at the last iteration. + # Maybe we can unroll the last iteration instead? + if BOUNDS_CHECKS_N: + qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) # per-row max so far + + # rows that are *all* -inf after masking + valid = m_i_new > float("-inf") + + # scale previous partial sums safely + alpha = tl.where(valid, tl.math.exp2(m_i - m_i_new), 0.0) + + # subtract the row max only on valid rows + qk = tl.where(valid[:, None], qk - m_i_new[:, None], float("-inf")) + p = tl.math.exp2(qk) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(q.dtype) + + # -- scale and update acc -- + acc *= alpha[:, None] + if IS_FP8: + acc += tl.dot(p.to(v.dtype), v) * v_descale # Apply FP8 scaling for V + else: + acc += tl.dot(p.to(v.dtype), v) + + return m_i, l_i, acc + +@triton.jit +def _attn_fwd_inner_paged( + q, kT, v, seq_pos, valid_mask, + m_i, l_i, acc, + pid_m, + q_descale, k_descale, v_descale, # FP8 scaling factors + IS_FP8: tl.constexpr, # FP8 flag + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + N_CTX_Q: tl.constexpr, + N_CTX_K_FINAL: tl.constexpr, + USE_ALIBI: tl.constexpr, + alibi_slope, + USE_SLIDING_WINDOW: tl.constexpr, + IS_CAUSAL: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, +): + """ + Specialized attention computation for paged KV cache. + + Key differences from _attn_fwd_inner: + - Takes a valid_mask parameter to handle block boundaries + - No BOUNDS_CHECKS_N needed as masking is handled via valid_mask + - seq_pos represents the absolute position in the sequence + """ + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_FP8: + qk += (tl.dot(q, kT) * q_descale * k_descale) # Apply FP8 scaling + else: + qk += tl.dot(q, kT) + + # Apply ALiBi if needed + if USE_ALIBI: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = seq_pos + tl.arange(0, BLOCK_N) + + # Compute relative positions + relative_pos = row_idx[:, None] + N_CTX_K_FINAL - (N_CTX_Q + col_idx[None, :]) + relative_pos = tl.abs(relative_pos) + + # Compute ALiBi bias + alibi_bias = -1 * alibi_slope * relative_pos + qk += (alibi_bias * 1.44269504) + + # Apply sliding window if needed + if USE_SLIDING_WINDOW: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = seq_pos + tl.arange(0, BLOCK_N) + row = row_idx[:, None] + col = col_idx[None, :] + + if IS_CAUSAL: + # -------- causal + window -------- + diag = N_CTX_K_FINAL - N_CTX_Q # sk-sq + causal_ok = col <= row + diag + if WINDOW_SIZE_LEFT < 0: # only right window + win_ok = col <= row + diag + WINDOW_SIZE_RIGHT + else: # both sides + win_ok = ((col >= row + diag - WINDOW_SIZE_LEFT) & + (col <= row + diag + WINDOW_SIZE_RIGHT)) + mask = ~(causal_ok & win_ok) # True ⇒ -inf + else: + # -------- non-causal window -------- + sk, sq = N_CTX_K_FINAL, N_CTX_Q + if WINDOW_SIZE_LEFT < 0: + mask = col > row + (sk - sq) + WINDOW_SIZE_RIGHT + else: + right = tl.minimum(row + (sk - sq) + WINDOW_SIZE_RIGHT, sk) + left = row + (sk - sq) - WINDOW_SIZE_LEFT + mask = (col > right) | (col < left) + qk = tl.where(mask, float("-inf"), qk) + elif IS_CAUSAL: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = seq_pos + tl.arange(0, BLOCK_N) + col_offset = N_CTX_K_FINAL - N_CTX_Q + causal_mask = row_idx[:, None] >= (col_idx[None, :] - col_offset) + qk = tl.where(causal_mask, qk, float("-inf")) + + # Mask out invalid positions (from block boundaries) + qk = tl.where(valid_mask[None, :], qk, float("-inf")) + + # Compute new m and do softmax + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + valid = m_i_new > float("-inf") + alpha = tl.where(valid, tl.math.exp2(m_i - m_i_new), 0.0) + qk = tl.where(valid[:, None], qk - m_i_new[:, None], float("-inf")) + p = tl.math.exp2(qk) + + # Update m_i and l_i + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(q.dtype) + + # Scale and update acc + acc *= alpha[:, None] + if IS_FP8: + acc += tl.dot(p.to(v.dtype), v) * v_descale # Apply FP8 scaling for V + else: + acc += tl.dot(p.to(v.dtype), v) + + return m_i, l_i, acc + # @triton.autotune( # configs=fwd_auto_tune_configs, # key=fwd_autotune_keys, @@ -69,6 +283,9 @@ def _fwd_kernel_splitK( Q, K, V, + Q_Descale, # FP8 descale factors for Q + K_Descale, # FP8 descale factors for K + V_Descale, # FP8 descale factors for V sm_scale, Out_splitK, # [B*H*G, split_k, Mq, K] Metadata, # [B*H*G, 2, split_k, M_ceil] contains [mi, li] @@ -76,6 +293,7 @@ def _fwd_kernel_splitK( V_new, Cache_seqlens, Cache_batch_idx, + Block_table, Alibi_slopes, stride_qz, stride_qm, @@ -110,13 +328,22 @@ def _fwd_kernel_splitK( stride_vn_g, stride_vn_h, stride_vn_d, + stride_bt_b, + stride_bt_s, stride_az, stride_ah, + stride_q_descale_z, # FP8 descale strides + stride_q_descale_h, + stride_k_descale_z, + stride_k_descale_h, + stride_v_descale_z, + stride_v_descale_h, Z, N_CTX_Q, N_CTX_K, N_CTX_NEW, BLOCK_N_PER_SPLIT, + BLOCK_SIZE_K: tl.constexpr, H_q: tl.constexpr, H_kv: tl.constexpr, G_q: tl.constexpr, @@ -136,6 +363,8 @@ def _fwd_kernel_splitK( USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, + USE_BLOCK_TABLE: tl.constexpr, + IS_FP8: tl.constexpr, # FP8 flag ): # get program ids pid_m = tl.program_id(0) @@ -155,14 +384,24 @@ def _fwd_kernel_splitK( hk_id = hq_id hv_id = hq_id + # Load FP8 descale factors if needed + if IS_FP8: + if IS_GQA: + # For MQA/GQA, q_descale uses the same indexing as k/v (hk_id) + q_descale = tl.load(Q_Descale + z_id * stride_q_descale_z + hk_id * stride_q_descale_h) + else: + # For MHA, q_descale uses hq_id + q_descale = tl.load(Q_Descale + z_id * stride_q_descale_z + hq_id * stride_q_descale_h) + k_descale = tl.load(K_Descale + z_id * stride_k_descale_z + hk_id * stride_k_descale_h) + v_descale = tl.load(V_Descale + z_id * stride_v_descale_z + hv_id * stride_v_descale_h) + else: + q_descale, k_descale, v_descale = 1.0, 1.0, 1.0 + # figure out seqlens lo = pid_splitk * BLOCK_N_PER_SPLIT if USE_CACHE_SEQLENs: cache_seqlen_last_idx = tl.load(Cache_seqlens + z_id) - if NEW_KV: - N_CTX_K_FINAL = cache_seqlen_last_idx + N_CTX_NEW - else: - N_CTX_K_FINAL = cache_seqlen_last_idx + N_CTX_K_FINAL = cache_seqlen_last_idx else: N_CTX_K_FINAL = N_CTX_K hi = tl.minimum((pid_splitk + 1) * BLOCK_N_PER_SPLIT, N_CTX_K_FINAL) @@ -181,8 +420,15 @@ def _fwd_kernel_splitK( # compute ptrs q_offset = Q + hq_id * stride_qh + z_id * stride_qz + g_id * stride_qg q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd - k_offset = K + hk_id * stride_kh + cache_batch_idx * stride_kz + g_id * stride_kg - v_offset = V + hv_id * stride_vh + cache_batch_idx * stride_vz + g_id * stride_vg + + # Handle block table for paged attention + if USE_BLOCK_TABLE: + # K and V now point to paged cache + # Each batch has its own block table row + block_table_ptr = Block_table + z_id * stride_bt_b + else: + k_offset = K + hk_id * stride_kh + cache_batch_idx * stride_kz + g_id * stride_kg + v_offset = V + hv_id * stride_vh + cache_batch_idx * stride_vz + g_id * stride_vg # compute masks if PADDED_HEAD: @@ -212,160 +458,103 @@ def _fwd_kernel_splitK( else: alibi_slope = None - # Copy new Keys and Values into Cache - if NEW_KV: - knew_base = K_new + hk_id * stride_kn_h + z_id * stride_kn_z + g_id * stride_kn_g - - # Determine the starting position for new data in the cache - if USE_CACHE_SEQLENs: - start_idx = tl.load(Cache_seqlens + z_id) - else: - start_idx = N_CTX_K - N_CTX_NEW - - # Copy new Keys - for i in range(0, N_CTX_NEW, BLOCK_N): - # Load from K_new - k_new_block = tl.load( - knew_base + - tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d + - (tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n, - mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & - (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), - other=0 - ) - - # Store to K - tl.store( - k_offset + - tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd + - (tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn, - k_new_block, - mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & - (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), - ) - - # Copy new Values - vnew_base = V_new + hv_id * stride_vn_h + z_id * stride_vn_z + g_id * stride_vn_g - for i in range(0, N_CTX_NEW, BLOCK_N): - # Load from V_new - v_new_block = tl.load( - vnew_base + - (tl.arange(0, BLOCK_N) + i)[:, None] * stride_vn_n + - tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vn_d, - mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & - (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), - other=0 - ) - - # Store to V - tl.store( - v_offset + - (tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn + - tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd, - v_new_block, - mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & - (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), - ) - - # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # noqa: F821 - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - kT_ptrs = k_offset + offs_d[:, None] * stride_kd + (start_n + offs_n)[None, :] * stride_kn - V_ptrs = v_offset + (start_n + offs_n)[:, None] * stride_vn + offs_d[None, :] * stride_vd - - # load k - kT = tl.load(kT_ptrs, mask=kT_mask, other=0.0) - v = tl.load(V_ptrs, mask=v_mask, other=0.0) - - # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, kT) # noqa: F821 - - if USE_ALIBI: - row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - col_idx = start_n + tl.arange(0, BLOCK_N) - - # Compute relative positions - relative_pos = row_idx[:, None] + N_CTX_K_FINAL - (N_CTX_Q + col_idx[None, :]) - relative_pos = tl.abs(relative_pos) + if USE_BLOCK_TABLE: + # Paged attention: process all KV blocks from cache + # Note: Cache should be updated externally before calling this kernel + num_kv_blocks = (N_CTX_K_FINAL + BLOCK_SIZE_K - 1) // BLOCK_SIZE_K + + for block_idx in range(num_kv_blocks): + # Calculate sequence range for this block + block_start = block_idx * BLOCK_SIZE_K + block_end = tl.minimum(block_start + BLOCK_SIZE_K, N_CTX_K_FINAL) - # Compute ALiBi bias - alibi_bias = -1 * alibi_slope * relative_pos - qk += (alibi_bias * 1.44269504) - - # ------------------------------------------------------------------ - # masking - # ------------------------------------------------------------------ - if USE_SLIDING_WINDOW: - row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # q positions - col_idx = start_n + tl.arange(0, BLOCK_N) # k positions - row = row_idx[:, None] # [M,1] - col = col_idx[None, :] # [1,N] - - if IS_CAUSAL: - # -------- causal + window -------- - diag = N_CTX_K_FINAL - N_CTX_Q # sk-sq - causal_ok = col <= row + diag - if WINDOW_SIZE_LEFT < 0: # only right window - win_ok = col <= row + diag + WINDOW_SIZE_RIGHT - else: # both sides - win_ok = ((col >= row + diag - WINDOW_SIZE_LEFT) & - (col <= row + diag + WINDOW_SIZE_RIGHT)) - mask = ~(causal_ok & win_ok) # True ⇒ -inf - else: - # -------- non-causal window -------- - sk, sq = N_CTX_K_FINAL, N_CTX_Q - if WINDOW_SIZE_LEFT < 0: - mask = col > row + (sk - sq) + WINDOW_SIZE_RIGHT - else: - right = tl.minimum(row + (sk - sq) + WINDOW_SIZE_RIGHT, sk) - left = row + (sk - sq) - WINDOW_SIZE_LEFT - mask = (col > right) | (col < left) - qk = tl.where(mask, float("-inf"), qk) - else: - if IS_CAUSAL: - row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - col_idx = start_n + tl.arange(0, BLOCK_N) - - # create a N_CTX_Q x kv_len causal mask - col_offset = N_CTX_Q - N_CTX_K_FINAL - causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :]) - - # Apply the mask - qk = tl.where(causal_mask, qk, float("-inf")) - - # TODO: This is slow, and only needed at the last iteration. - # Maybe we can unroll the last iteration instead? - if BOUNDS_CHECKS_N: - qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) - - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) # per-row max so far - - # rows that are *all* -inf after masking - valid = m_i_new > float("-inf") - - # scale previous partial sums safely - alpha = tl.where(valid, tl.math.exp2(m_i - m_i_new), 0.0) - - # subtract the row max only on valid rows - qk = tl.where(valid[:, None], qk - m_i_new[:, None], float("-inf")) - p = tl.math.exp2(qk) - - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - p = p.to(Q.dtype.element_ty) - - # -- scale and update acc -- - acc *= alpha[:, None] - acc += tl.dot(p.to(v.dtype), v) + # Check if block overlaps with our split-k range [lo, hi) + if block_end > lo and block_start < hi: + # Load physical block number + physical_block = tl.load(block_table_ptr + block_idx * stride_bt_s) + + # Calculate the range within this block that overlaps with [lo, hi) + process_start = tl.maximum(lo - block_start, 0) + process_end = tl.minimum(hi - block_start, BLOCK_SIZE_K) + process_end = tl.minimum(process_end, block_end - block_start) + + # Align to BLOCK_N boundaries + process_start = (process_start // BLOCK_N) * BLOCK_N + + for offset in range(process_start, process_end, BLOCK_N): + # Current position in the sequence + seq_pos = block_start + offset + + # Only process if in range + if seq_pos < hi and seq_pos >= lo: + # Calculate base addresses for K and V in this physical block + k_base = K + physical_block * BLOCK_SIZE_K * stride_kn + hk_id * stride_kh + g_id * stride_kg + v_base = V + physical_block * BLOCK_SIZE_K * stride_vn + hv_id * stride_vh + g_id * stride_vg + + # Offsets within the current block + block_offs = offset + offs_n + + # Masks for valid data + seq_mask = ((seq_pos + offs_n) < N_CTX_K_FINAL) + block_mask = (block_offs < BLOCK_SIZE_K) + valid_mask = seq_mask & block_mask + + # Apply masks + kT_mask_final = kT_mask & valid_mask[None, :] + v_mask_final = v_mask & valid_mask[:, None] + + # Load K and V + kT_ptrs = k_base + offs_d[:, None] * stride_kd + block_offs[None, :] * stride_kn + v_ptrs = v_base + block_offs[:, None] * stride_vn + offs_d[None, :] * stride_vd + + kT = tl.load(kT_ptrs, mask=kT_mask_final, other=0.0) + v = tl.load(v_ptrs, mask=v_mask_final, other=0.0) + + # Use the specialized paged attention inner function + m_i, l_i, acc = _attn_fwd_inner_paged( + q, kT, v, seq_pos, valid_mask, + m_i, l_i, acc, + pid_m, + q_descale, k_descale, v_descale, # FP8 scaling + IS_FP8, # FP8 flag + BLOCK_M, BLOCK_N, + N_CTX_Q, N_CTX_K_FINAL, + USE_ALIBI, alibi_slope, + USE_SLIDING_WINDOW, IS_CAUSAL, + WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, + ) + else: + # Non-paged attention: process KV from cache + # Note: Cache should be updated externally before calling this kernel + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + kT_ptrs = k_offset + offs_d[:, None] * stride_kd + (start_n + offs_n)[None, :] * stride_kn + V_ptrs = v_offset + (start_n + offs_n)[:, None] * stride_vn + offs_d[None, :] * stride_vd + + # load k + kT = tl.load(kT_ptrs, mask=kT_mask, other=0.0) + v = tl.load(V_ptrs, mask=v_mask, other=0.0) + + # Use the same inner loop logic + m_i, l_i, acc = _attn_fwd_inner( + q, kT, v, start_n, + m_i, l_i, acc, + pid_m, hi, + q_descale, k_descale, v_descale, # FP8 scaling + IS_FP8, # FP8 flag + BLOCK_M, BLOCK_N, + N_CTX_Q, N_CTX_K_FINAL, + USE_ALIBI, alibi_slope, + USE_SLIDING_WINDOW, IS_CAUSAL, + WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, + BOUNDS_CHECKS_N, + ) # write back O osk_offset = Out_splitK + pid_zhg * stride_osk_zhg + pid_splitk * stride_osk_s @@ -598,7 +787,79 @@ def attention_decode_forward_triton_impl( layout: Literal["bshd"], cache_seqlens: Optional[torch.Tensor], cache_batch_idx: Optional[torch.Tensor], + block_table: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, ): + # Check for unsupported configuration + seqlen_q = q.shape[1] + if block_table is not None and alibi_slopes is not None and seqlen_q > 1: + raise NotImplementedError( + "Paged attention with ALiBi and multiple queries (seqlen_q > 1) is not supported " + "due to numerical precision issues. Please use non-paged attention or single query decode." + ) + + # Handle cache updates externally before calling the kernel + if k_new is not None and v_new is not None: + # Update cache with new KV values + if block_table is None: + # Non-paged attention: update cache directly + batch_size = k_new.shape[0] + seqlen_new = k_new.shape[1] + + if cache_seqlens is not None: + # Use cache_seqlens to determine where to insert new KV + for b in range(batch_size): + start_idx = int(cache_seqlens[b].item()) + end_idx = start_idx + seqlen_new + k_cache[b, start_idx:end_idx] = k_new[b] + v_cache[b, start_idx:end_idx] = v_new[b] + cache_seqlens[b] = end_idx + else: + # Append at the end of existing cache + seqlen_cache = k_cache.shape[1] + k_cache[:, seqlen_cache - seqlen_new:] = k_new + v_cache[:, seqlen_cache - seqlen_new:] = v_new + else: + # Paged attention: update cache using block table + batch_size = k_new.shape[0] + seqlen_new = k_new.shape[1] + block_size = k_cache.shape[1] # k_cache shape: [num_blocks, block_size, nheads, head_dim] + + # Update cache for each batch element + for b in range(batch_size): + if cache_seqlens is not None: + start_idx = int(cache_seqlens[b].item()) + else: + # If no cache_seqlens, assume we're appending at the end + # Find the last used position from block table + start_idx = 0 + for block_idx in range(block_table.shape[1]): + if block_table[b, block_idx] >= 0: + start_idx = (block_idx + 1) * block_size + else: + start_idx = block_idx * block_size + break + + # Copy new KV values into the paged cache + for i in range(seqlen_new): + pos = start_idx + i + block_idx = pos // block_size + within_block_idx = pos % block_size + + # Get the physical block number from block table + if block_idx < block_table.shape[1]: + physical_block = int(block_table[b, block_idx].item()) + + # Update k_cache and v_cache at the physical block location + k_cache[physical_block, within_block_idx] = k_new[b, i] + v_cache[physical_block, within_block_idx] = v_new[b, i] + + # Update cache_seqlens if provided + if cache_seqlens is not None: + cache_seqlens[b] = start_idx + seqlen_new + # triton configs BLOCK_M = 16 BLOCK_N = 64 @@ -607,17 +868,45 @@ def attention_decode_forward_triton_impl( num_warps_reduce = 4 # kernel_configs - is_new_kv = True if k_new is not None and v_new is not None else False + is_new_kv = False # Cache has been updated, so no new KV in kernel use_alibi, (stride_az, stride_ah) = True if alibi_slopes is not None else False, alibi_slopes.stride() if alibi_slopes is not None else (None, None) use_cache_seqlens = cache_seqlens is not None use_sliding_window = window_size_left != -1 or window_size_right != -1 + use_block_table = block_table is not None SPLIT_K = None NUM_QUANT_GROUPS = 1 # get shapes and strides (batch_size, seqlen_q, nheads_q, dim_q), (stride_qz, stride_qh, stride_qm, stride_qd) = get_shape_and_strides_from_layout(q, layout) - (_, seqlen_kc, nheads_kc, dim_kc), (stride_kc_z, stride_kc_h, stride_kc_n, stride_kc_d) = get_shape_and_strides_from_layout(k_cache, layout) - (_, seqlen_vc, nheads_vc, dim_vc), (stride_vc_z, stride_vc_h, stride_vc_n, stride_vc_d) = get_shape_and_strides_from_layout(v_cache, layout) + + # Handle paged KV cache layout + if use_block_table: + # For paged attention, k_cache and v_cache have shape [num_blocks, block_size, nheads, head_dim] + num_blocks_kc, block_size_k, nheads_kc, dim_kc = k_cache.shape + num_blocks_vc, block_size_v, nheads_vc, dim_vc = v_cache.shape + # Get the actual sequence length from cache_seqlens or block_table + if cache_seqlens is not None: + seqlen_kc = int(cache_seqlens.max().item()) + else: + # Infer from block_table shape [batch_size, num_blocks_per_seq] + num_blocks_per_seq = block_table.shape[1] + seqlen_kc = num_blocks_per_seq * block_size_k + seqlen_vc = seqlen_kc + + # Strides for paged layout + stride_kc_z = 0 # No batch dimension in paged cache + stride_kc_n = k_cache.stride(1) # Sequence stride + stride_kc_h = k_cache.stride(2) # Head stride + stride_kc_d = k_cache.stride(3) # Dim stride + + stride_vc_z = 0 + stride_vc_n = v_cache.stride(1) + stride_vc_h = v_cache.stride(2) + stride_vc_d = v_cache.stride(3) + else: + (_, seqlen_kc, nheads_kc, dim_kc), (stride_kc_z, stride_kc_h, stride_kc_n, stride_kc_d) = get_shape_and_strides_from_layout(k_cache, layout) + (_, seqlen_vc, nheads_vc, dim_vc), (stride_vc_z, stride_vc_h, stride_vc_n, stride_vc_d) = get_shape_and_strides_from_layout(v_cache, layout) + block_size_k = 0 # Not used if is_new_kv: ( _, seqlen_kn, nheads_kn, dim_kn), (stride_kn_z, stride_kn_h, stride_kn_n, stride_kn_d) = get_shape_and_strides_from_layout(k_new, layout) (_, seqlen_vn, nheads_vn, dim_vn), (stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d) = get_shape_and_strides_from_layout(v_new, layout) @@ -657,7 +946,12 @@ def attention_decode_forward_triton_impl( split_k = SPLIT_K else: # Use heuristics - split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_kc) # NOTE: should the split think about seqlens? + if use_block_table: + # For paged attention, use the actual sequence length from cache_seqlens + max_seqlen = int(cache_seqlens.max().item()) if cache_seqlens is not None else block_size_k + split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, max_seqlen) + else: + split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_kc) split_size = (seqlen_kc + split_k - 1) // split_k # setup grid @@ -673,6 +967,39 @@ def attention_decode_forward_triton_impl( stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d = out_splitk.stride() stride_mzhg, stride_m2, stride_ms, stride_mm = metadata.stride() stride_lse_zhg, stride_lse_m = lse.stride() + + # Block table strides + if use_block_table: + stride_bt_b, stride_bt_s = block_table.stride() + else: + stride_bt_b, stride_bt_s = 0, 0 + + # FP8 support + IS_FP8 = is_fp8(q) + if IS_FP8: + if (q_descale is None) or (k_descale is None) or (v_descale is None): + import warnings + warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) + # Create default descale tensors if not provided + if q_descale is None: + q_descale = torch.ones(batch_size, nheads_q, dtype=torch.float32, device=q.device) + if k_descale is None: + k_descale = torch.ones(batch_size, nheads_kc, dtype=torch.float32, device=q.device) + if v_descale is None: + v_descale = torch.ones(batch_size, nheads_vc, dtype=torch.float32, device=q.device) + stride_q_descale_z, stride_q_descale_h = q_descale.stride() + stride_k_descale_z, stride_k_descale_h = k_descale.stride() + stride_v_descale_z, stride_v_descale_h = v_descale.stride() + else: + q_descale = None + k_descale = None + v_descale = None + stride_q_descale_z = 0 + stride_q_descale_h = 0 + stride_k_descale_z = 0 + stride_k_descale_h = 0 + stride_v_descale_z = 0 + stride_v_descale_h = 0 if DEBUG: print("batch_size, seqlen_q, nheads_q, dim_q", (batch_size, seqlen_q, nheads_q, dim_q)) @@ -689,18 +1016,21 @@ def attention_decode_forward_triton_impl( print("stride_mzhg, stride_m2, stride_ms, stride_mm", (stride_mzhg, stride_m2, stride_ms, stride_mm)) print("stride_lse_zhg, stride_lse_m", (stride_lse_zhg, stride_lse_m)) - # TODO: enable quantization _fwd_kernel_splitK[grid]( Q=q, K=k_cache, V=v_cache, + Q_Descale=q_descale, + K_Descale=k_descale, + V_Descale=v_descale, sm_scale=sm_scale, Out_splitK=out_splitk, Metadata=metadata, - K_new=k_new, - V_new=v_new, + K_new=None, + V_new=None, Cache_seqlens=cache_seqlens, Cache_batch_idx=cache_batch_idx, + Block_table=block_table, Alibi_slopes=alibi_slopes, # q strides stride_qz=stride_qz, @@ -742,17 +1072,28 @@ def attention_decode_forward_triton_impl( stride_vn_g=stride_vn_g, stride_vn_h=stride_vn_h, stride_vn_d=stride_vn_d, + # block table strides + stride_bt_b=stride_bt_b, + stride_bt_s=stride_bt_s, # alibi strides stride_az=stride_az, stride_ah=stride_ah, + # FP8 descale strides + stride_q_descale_z=stride_q_descale_z, + stride_q_descale_h=stride_q_descale_h, + stride_k_descale_z=stride_k_descale_z, + stride_k_descale_h=stride_k_descale_h, + stride_v_descale_z=stride_v_descale_z, + stride_v_descale_h=stride_v_descale_h, Z=batch_size, H_q=heads_per_group_q, H_kv=heads_per_group_k, G_q=n_group_q, N_CTX_Q=seqlen_q, N_CTX_K=seqlen_kc, - N_CTX_NEW=seqlen_kn, + N_CTX_NEW=0, # No new KV, cache already updated BLOCK_N_PER_SPLIT=split_size, + BLOCK_SIZE_K=block_size_k if use_block_table else 256, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=dim_padded, @@ -760,7 +1101,7 @@ def attention_decode_forward_triton_impl( BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens, USE_CACHE_SEQLENs=use_cache_seqlens, USE_CACHE_BATCH_IDX=cache_batch_idx is not None, - NEW_KV=is_new_kv, + NEW_KV=False, # Cache already updated IS_GQA=is_gqa, IS_CAUSAL=causal, USE_ALIBI=use_alibi, @@ -769,6 +1110,8 @@ def attention_decode_forward_triton_impl( USE_SLIDING_WINDOW=use_sliding_window, WINDOW_SIZE_LEFT=window_size_left, WINDOW_SIZE_RIGHT=window_size_right, + USE_BLOCK_TABLE=use_block_table, + IS_FP8=IS_FP8, num_warps=num_warps_fwd, num_stages=num_stages, ) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py old mode 100644 new mode 100755 index 3a2bd56fda4..bb0301c7700 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -3,14 +3,99 @@ import triton import triton.language as tl from typing import Literal, Optional, Union -from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_arch, is_cdna, is_fp8, is_rdna, create_dropout_mask, get_fwd_prefill_autotune_configs - -DEBUG = False +from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_arch, is_cdna, is_fp8, is_rdna, create_dropout_mask, DEBUG # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) +# ------------------------------- +# Autotune +# ------------------------------- +def get_fwd_prefill_cdna_autotune_configs(): + return [ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL_QK', 'ACTUAL_BLOCK_DMODEL_V', 'IS_VARLEN', 'HQ', 'HK'] + + +def get_fwd_prefill_rdna_autotune_configs(): + return [ + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL_QK', 'ACTUAL_BLOCK_DMODEL_V', 'IS_VARLEN', 'HQ', 'HK'] + + +def get_fwd_prefill_autotune_configs(): + if AUTOTUNE: + if is_rdna(): + return get_fwd_prefill_rdna_autotune_configs() + elif is_cdna(): + return get_fwd_prefill_cdna_autotune_configs() + else: + raise ValueError("Unknown Device Type") + else: + arch = get_arch() + if arch == "gfx950": + default_config = triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + elif arch == "gfx942" and False: # Disabled due shared mem oom in CI when using triton==3.3.0 when using top of tree everything seems fine. + default_config = triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + else: + default_config = triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + + return [ + default_config + ], [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_BLOCK_DMODEL_QK", + "ACTUAL_BLOCK_DMODEL_V", + "IS_VARLEN", + "HQ", + "HK", + ] + + fwd_prefill_autotune_configs, fwd_prefill_autotune_keys = get_fwd_prefill_autotune_configs() @triton.jit @@ -20,13 +105,14 @@ def _attn_fwd_no_mask(acc, l_i, m_i, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_base_ptrs, sd_mask_base_ptrs, dropout_mask_base_ptrs, - offs_m, offs_n, offs_d, + offs_m, offs_n, offs_d_qk, offs_d_v, block_min, block_max, alibi_slope, - descale_q, descale_k, descale_v, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + q_descale, k_descale, v_descale, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_P_DESCALE: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_DMODEL_QK: tl.constexpr, BLOCK_DMODEL_V: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD_QK: tl.constexpr, PADDED_HEAD_V: tl.constexpr, + ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, ACTUAL_BLOCK_DMODEL_V: tl.constexpr, + SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, RETURN_SCORES: tl.constexpr, ACCUMULATOR_TYPE): if USE_EXP2: RCP_LN2: tl.constexpr = 1.4426950408889634 @@ -38,21 +124,27 @@ def _attn_fwd_no_mask(acc, l_i, m_i, v_ptrs = v_base_ptrs + start_n * stride_vk kv_offs_n = start_n + tl.arange(0, BLOCK_N) - if PADDED_HEAD: - k_mask, k_mask_other = (offs_d[:, None] < ACTUAL_BLOCK_DMODEL), 0.0 - v_mask, v_mask_other = (offs_d[None, :] < ACTUAL_BLOCK_DMODEL), 0.0 + if PADDED_HEAD_QK: + k_mask, k_mask_other = (offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK), 0.0 + else: + k_mask, k_mask_other = None, None + + if PADDED_HEAD_V: + v_mask, v_mask_other = (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V), 0.0 + else: + v_mask, v_mask_other = None, None # load k and if preload_v then v - k = tl.load(k_ptrs, mask=k_mask, other=k_mask_other) if PADDED_HEAD else tl.load(k_ptrs) + k = tl.load(k_ptrs, mask=k_mask, other=k_mask_other) if PADDED_HEAD_QK else tl.load(k_ptrs) if PRE_LOAD_V: - v = tl.load(v_ptrs, mask=v_mask, other=v_mask_other) if PADDED_HEAD else tl.load(v_ptrs) + v = tl.load(v_ptrs, mask=v_mask, other=v_mask_other) if PADDED_HEAD_V else tl.load(v_ptrs) # setup qk accumlator qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE) # -- compute qk ---- if IS_FP8 : - qk += (tl.dot(q, k) * descale_q * descale_k) + qk += (tl.dot(q, k) * q_descale * k_descale) else: qk += tl.dot(q, k) qk_scaled = qk * SM_SCALE @@ -124,15 +216,18 @@ def _attn_fwd_no_mask(acc, l_i, m_i, alpha = tl.math.exp(m_diff) acc = acc * alpha[:, None] if not PRE_LOAD_V: - v = tl.load(v_ptrs, mask=v_mask, other=v_mask_other) if PADDED_HEAD else tl.load(v_ptrs) + v = tl.load(v_ptrs, mask=v_mask, other=v_mask_other) if PADDED_HEAD_V else tl.load(v_ptrs) # -- update m_i and l_i l_i = l_i * alpha + l_ij m_i = m_ij if IS_FP8: - scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) - acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v) + if FP8_P_DESCALE: + scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) + acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * v_descale) + else: + acc += tl.dot(p.to(v.type.element_ty), v) * v_descale else: acc += tl.dot(p.to(v.type.element_ty), v) @@ -145,13 +240,14 @@ def _attn_fwd_mask(acc, l_i, m_i, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_base_ptrs, sd_mask_base_ptrs, dropout_mask_base_ptrs, - offs_m, offs_n, offs_d, + offs_m, offs_n, offs_d_qk, offs_d_v, block_min, block_max, n_extra_tokens, alibi_slope, - descale_q, descale_k, descale_v, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, - IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + q_descale, k_descale, v_descale, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_P_DESCALE: tl.constexpr, + IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL_QK: tl.constexpr, BLOCK_DMODEL_V: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD_QK: tl.constexpr, PADDED_HEAD_V: tl.constexpr, + ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, ACTUAL_BLOCK_DMODEL_V: tl.constexpr, + SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, ACCUMULATOR_TYPE): @@ -172,9 +268,10 @@ def _attn_fwd_mask(acc, l_i, m_i, kv_offs_n = start_n + tl.arange(0, BLOCK_N) k_mask = (kv_offs_n[None, :] < seqlen_k) v_mask = (kv_offs_n[:, None] < seqlen_k) - if PADDED_HEAD: - k_mask = k_mask & (offs_d[:, None] < ACTUAL_BLOCK_DMODEL) - v_mask = v_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + if PADDED_HEAD_QK: + k_mask = k_mask & (offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK) + if PADDED_HEAD_V: + v_mask = v_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) # load k and if preload_v then v k = tl.load(k_ptrs, mask=k_mask, other = 0.0) @@ -200,7 +297,7 @@ def _attn_fwd_mask(acc, l_i, m_i, # -- compute qk ---- if IS_FP8 : - qk += (tl.dot(q, k) * descale_q * descale_k) + qk += (tl.dot(q, k) * q_descale * k_descale) else: qk += tl.dot(q, k) qk_scaled = qk * SM_SCALE @@ -373,8 +470,11 @@ def _attn_fwd_mask(acc, l_i, m_i, m_i = m_ij if IS_FP8: - scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) - acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v) + if FP8_P_DESCALE: + scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) + acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * v_descale) + else: + acc += tl.dot(p.to(v.type.element_ty), v) * v_descale else: acc += tl.dot(p.to(v.type.element_ty), v) @@ -626,18 +726,19 @@ def compute_block_masking(seqlen_k, seqlen_q, start_m, ) @triton.jit def attn_fwd(Q, K, V, bias, - Descale_Q, Descale_K, Descale_V, Descale_O, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, + Q_Descale, K_Descale, V_Descale, stride_q_descale_z, stride_k_descale_z, stride_v_descale_z, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Add seqused parameters dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr, - HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, + HK: tl.constexpr, ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, ACTUAL_BLOCK_DMODEL_V: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, IS_VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, + BLOCK_DMODEL_QK: tl.constexpr, BLOCK_DMODEL_V: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, NEEDS_SDMASK : tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr): + IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_P_DESCALE: tl.constexpr, USE_SEQUSED: tl.constexpr): # set params ACCUMULATOR_TYPE = tl.float32 @@ -652,24 +753,38 @@ def attn_fwd(Q, K, V, bias, else: off_h_k = off_h_q # Determine if we need to mask the heads - PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + PADDED_HEAD_QK: tl.constexpr = (ACTUAL_BLOCK_DMODEL_QK != BLOCK_DMODEL_QK) + PADDED_HEAD_V: tl.constexpr = (ACTUAL_BLOCK_DMODEL_V != BLOCK_DMODEL_V) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_d_qk = tl.arange(0, BLOCK_DMODEL_QK) + offs_d_v = tl.arange(0, BLOCK_DMODEL_V) # handle seqlen if IS_VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) - seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + + # If seqused is provided, use it to limit the actual sequence length + if USE_SEQUSED: + actual_seqlen_q = tl.load(seqused_q + off_z) if seqused_q is not None else cu_seqlens_q_end - cu_seqlens_q_start + seqlen_q = tl.minimum(actual_seqlen_q, cu_seqlens_q_end - cu_seqlens_q_start) + else: + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start # we have a one-size-fits-all grid in id(0). Some seqlens might be too small for all start_m so for those we return early. if start_m * BLOCK_M > seqlen_q: return cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) - seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + + # If seqused is provided, use it to limit the actual sequence length for keys + if USE_SEQUSED: + actual_seqlen_k = tl.load(seqused_k + off_z) if seqused_k is not None else cu_seqlens_k_end - cu_seqlens_k_start + seqlen_k = tl.minimum(actual_seqlen_k, cu_seqlens_k_end - cu_seqlens_k_start) + else: + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start else: cu_seqlens_q_start = 0 cu_seqlens_k_start = 0 @@ -678,11 +793,16 @@ def attn_fwd(Q, K, V, bias, # Load scale factors if IS_FP8. if IS_FP8: - descale_q = tl.load(Descale_Q + off_z * stride_descale_q_z + off_h_q) - descale_k = tl.load(Descale_K + off_z * stride_descale_k_z + off_h_k) - descale_v = tl.load(Descale_V + off_z * stride_descale_v_z + off_h_k) + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (off_h_k) + # For MHA (GROUP_SIZE == 1), q_descale uses off_h_q (same as off_h_k) + if GROUP_SIZE != 1: + q_descale = tl.load(Q_Descale + off_z * stride_q_descale_z + off_h_k) # MQA/GQA: broadcast using k/v head index + else: + q_descale = tl.load(Q_Descale + off_z * stride_q_descale_z + off_h_q) # MHA: use q head index + k_descale = tl.load(K_Descale + off_z * stride_k_descale_z + off_h_k) + v_descale = tl.load(V_Descale + off_z * stride_v_descale_z + off_h_k) else: - descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + q_descale, k_descale, v_descale = 1.0, 1.0, 1.0 # figure out masking pattern @@ -701,11 +821,11 @@ def attn_fwd(Q, K, V, bias, """ # Write zeros to output o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om - o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d_v[None, :] * stride_on o_mask = (offs_m[:, None] < seqlen_q) - if PADDED_HEAD: - o_mask = o_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) - tl.store(o_ptrs, tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty), mask=o_mask) + if PADDED_HEAD_V: + o_mask = o_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) + tl.store(o_ptrs, tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=Out.type.element_ty), mask=o_mask) # Write zeros to LSE l_ptrs = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m + offs_m * stride_lse_m @@ -723,11 +843,11 @@ def attn_fwd(Q, K, V, bias, # Initialize for processing # Compute pointers for all the tensors used in this kernel. q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm - q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qk k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn - k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn + k_ptrs = k_offset + offs_d_qk[:, None] * stride_kk + offs_n[None, :] * stride_kn v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk - v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn + v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d_v[None, :] * stride_vn if USE_BIAS: # Note: this might get large enough to overflow on some configs bias_offset = off_h_q * stride_bh @@ -759,12 +879,12 @@ def attn_fwd(Q, K, V, bias, # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=ACCUMULATOR_TYPE) l_i = tl.full([BLOCK_M], 1.0, dtype=ACCUMULATOR_TYPE) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=ACCUMULATOR_TYPE) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=ACCUMULATOR_TYPE) # Q is loaded once at the beginning and shared by all N blocks. q_ptrs_mask = offs_m[:, None] < seqlen_q - if PADDED_HEAD: - q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + if PADDED_HEAD_QK: + q_ptrs_mask = q_ptrs_mask & (offs_d_qk[None, :] < ACTUAL_BLOCK_DMODEL_QK) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) @@ -781,17 +901,17 @@ def attn_fwd(Q, K, V, bias, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, - offs_m, offs_n, offs_d, + offs_m, offs_n, offs_d_qk, offs_d_v, block_min, # Start of front masked blocks block_max, # End of front masked blocks 0, # n_extra_tokens (0 for front blocks, only relevant for last block) alibi_slope, - descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, + q_descale, k_descale, v_descale, IS_FP8, FP8_MAX, FP8_P_DESCALE, IS_CAUSAL, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, + BLOCK_M, BLOCK_DMODEL_QK, BLOCK_DMODEL_V, BLOCK_N, PRE_LOAD_V, - ENABLE_DROPOUT, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL, SM_SCALE, + ENABLE_DROPOUT, PADDED_HEAD_QK, PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK, ACTUAL_BLOCK_DMODEL_V, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, @@ -812,15 +932,15 @@ def attn_fwd(Q, K, V, bias, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, - offs_m, offs_n, offs_d, + offs_m, offs_n, offs_d_qk, offs_d_v, block_min, # Start of range: 0 block_max, # End of range: n_full_blocks * BLOCK_N alibi_slope, - descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, + q_descale, k_descale, v_descale, IS_FP8, FP8_MAX, FP8_P_DESCALE, + BLOCK_M, BLOCK_DMODEL_QK, BLOCK_DMODEL_V, BLOCK_N, PRE_LOAD_V, - ENABLE_DROPOUT, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL, SM_SCALE, + ENABLE_DROPOUT, PADDED_HEAD_QK, PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK, ACTUAL_BLOCK_DMODEL_V, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE ) @@ -837,17 +957,17 @@ def attn_fwd(Q, K, V, bias, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, - offs_m, offs_n, offs_d, + offs_m, offs_n, offs_d_qk, offs_d_v, block_min, # Start of range: n_full_blocks * BLOCK_N block_max, # End of range: n_visible_k_blocks * BLOCK_N n_extra_tokens, # Padding tokens in last block alibi_slope, - descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, + q_descale, k_descale, v_descale, IS_FP8, FP8_MAX, FP8_P_DESCALE, IS_CAUSAL, # Use actual causal flag - BLOCK_M, BLOCK_DMODEL, BLOCK_N, + BLOCK_M, BLOCK_DMODEL_QK, BLOCK_DMODEL_V, BLOCK_N, PRE_LOAD_V, - ENABLE_DROPOUT, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL, SM_SCALE, + ENABLE_DROPOUT, PADDED_HEAD_QK, PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK, ACTUAL_BLOCK_DMODEL_V, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, @@ -933,7 +1053,7 @@ def attn_fwd(Q, K, V, bias, end_m_idx = (start_m + 1) * BLOCK_M if causal_start_idx < end_m_idx: # This block contains the boundary - need to mask acc - out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) + out_mask_boundary = tl.full((BLOCK_DMODEL_V, ), causal_start_idx, dtype=tl.int32) out_ptrs_mask = row_indices[:, None] >= out_mask_boundary[None, :] z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) @@ -958,19 +1078,14 @@ def attn_fwd(Q, K, V, bias, # write back O o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om - o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on - o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d_v[None, :] * stride_on + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL_V], 1, dtype=tl.int1) if overflow_size > 0: o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) - if PADDED_HEAD: - o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + if PADDED_HEAD_V: + o_ptrs_mask = o_ptrs_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) - if FP8_OUTPUT: - scale_acc, descale_acc = compute_fp8_scaling_factors(acc, FP8_MAX) - tl.store(Descale_O + off_z * stride_descale_o_z + off_h_q, descale_acc) - tl.store(o_ptrs, (acc * scale_acc).to(Out.type.element_ty), mask=o_ptrs_mask) - else: - tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) def attention_prefill_forward_triton_impl( q: torch.Tensor, @@ -997,10 +1112,12 @@ def attention_prefill_forward_triton_impl( return_softmax: bool, use_exp2: bool, # fp8 - descale_q: Optional[torch.Tensor], - descale_k: Optional[torch.Tensor], - descale_v: Optional[torch.Tensor], - descale_o: Optional[torch.Tensor], + q_descale: Optional[torch.Tensor], + k_descale: Optional[torch.Tensor], + v_descale: Optional[torch.Tensor], + # seqused for FA v3 + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, ): # get params, strides and shape IS_VARLEN = layout == "thd" @@ -1027,12 +1144,12 @@ def attention_prefill_forward_triton_impl( assert max_seqlens_k is not None and max_seqlens_k > 0, "max_seqlens_k must be provided and positive for varlen layout" # assert head dimensions - assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert head_size_q == head_size_k, f"head sizes must match: q={head_size_q}, k={head_size_k}" assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" # assert output shapes - assert o.shape == (total_seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_q)}" + assert o.shape == (total_seqlen_q, nheads_q, head_size_v), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_v)}" # assert cu_seqlens assert cu_seqlens_q.dtype == torch.int32, f"cu_seqlens_q must be int32, got {cu_seqlens_q.dtype}" @@ -1044,7 +1161,7 @@ def attention_prefill_forward_triton_impl( # set vars batch = len(cu_seqlens_q) - 1 - head_size = head_size_q + head_size_qk = head_size_q # softmax_lse shape softmax_lse = torch.zeros((nheads_q, total_seqlen_q), device=q.device, dtype=torch.float32) @@ -1065,7 +1182,7 @@ def attention_prefill_forward_triton_impl( assert batch_q == batch_k == batch_v, f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" # assert head dimensions - assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert head_size_q == head_size_k, f"head sizes must match: q={head_size_q}, k={head_size_k}" assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" @@ -1073,11 +1190,11 @@ def attention_prefill_forward_triton_impl( assert seqlen_k == seqlen_v, f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" # assert output shapes - assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected {(batch_q, seqlen_q, nheads_q, head_size_q)}" + assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_v), f"o shape {o.shape} != expected {(batch_q, seqlen_q, nheads_q, head_size_v)}" # set vars batch = batch_q - head_size = head_size_q + head_size_qk = head_size_q max_seqlens_q = seqlen_q max_seqlens_k = seqlen_k @@ -1098,29 +1215,32 @@ def attention_prefill_forward_triton_impl( FP8_MAX = torch.finfo(q.dtype).max - # Check descale tensors - assert descale_q is not None, "descale_q must be provided when using fp8" - assert descale_k is not None, "descale_k must be provided when using fp8" - assert descale_v is not None, "descale_v must be provided when using fp8" + # Check and create default descale tensors if not provided + if (q_descale is None) or (k_descale is None) or (v_descale is None): + import warnings + warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) + # Create default descale tensors if not provided + if q_descale is None: + q_descale = torch.ones(batch, nheads_q, dtype=torch.float32, device=q.device) + if k_descale is None: + k_descale = torch.ones(batch, nheads_k, dtype=torch.float32, device=q.device) + if v_descale is None: + v_descale = torch.ones(batch, nheads_k, dtype=torch.float32, device=q.device) - if is_fp8(o): - FP8_OUTPUT = True - assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o." - else: - FP8_OUTPUT = False - # o should be fp32 or fp16/bf16 - assert o.dtype in [torch.float16, torch.bfloat16, torch.float32], \ - f"Output tensor o must be fp16, bf16, or fp32 when using fp8, got {o.dtype}" - - stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None - stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None - stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None - stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None + # o should be fp32 or fp16/bf16 + assert o.dtype in [torch.float16, torch.bfloat16, torch.float32], \ + f"Output tensor o must be fp16, bf16, or fp32 when using fp8, got {o.dtype}" + + stride_q_descale_z = q_descale.stride(0) if q_descale is not None else 0 + stride_k_descale_z = k_descale.stride(0) if k_descale is not None else 0 + stride_v_descale_z = v_descale.stride(0) if v_descale is not None else 0 + + if DEBUG: + print(f"FP8 path triggered in fwd_prefill.py") else: FP8_MAX = None - FP8_OUTPUT = False - descale_q = descale_k = descale_v = descale_o = None - stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = None + q_descale = k_descale = v_descale = None + stride_q_descale_z = stride_k_descale_z = stride_v_descale_z = None # check output dtype matches input dtype when not using fp8 assert o.dtype == q.dtype, f"Output dtype {o.dtype} must match input dtype {q.dtype} when not using fp8" @@ -1132,11 +1252,13 @@ def attention_prefill_forward_triton_impl( if (bias is not None): assert (bias.numel() < 2**31) - # Get closest power of 2 over or equal to 32. - padded_d_model = 1 << (head_size - 1).bit_length() + # Get closest power of 2 over or equal to 32 for both QK and V dimensions + padded_d_model_qk = 1 << (head_size_qk - 1).bit_length() + padded_d_model_v = 1 << (head_size_v - 1).bit_length() # Smallest head_dim supported is 16. If smaller, the tile in the # kernel is padded - there is no padding in memory for any dims. - padded_d_model = max(padded_d_model, 16) + padded_d_model_qk = max(padded_d_model_qk, 16) + padded_d_model_v = max(padded_d_model_v, 16) # sd_mask is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according @@ -1166,7 +1288,7 @@ def attention_prefill_forward_triton_impl( # launch kernel grid = lambda META: (batch, nheads_q, triton.cdiv(max_seqlens_q, META['BLOCK_M'])) attn_fwd[grid](q, k, v, bias, - descale_q, descale_k, descale_v, descale_o, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, + q_descale, k_descale, v_descale, stride_q_descale_z, stride_k_descale_z, stride_v_descale_z, sm_scale, softmax_lse, o, stride_qb, stride_qh, stride_qm, stride_qd, stride_kb, stride_kh, stride_kn, stride_kd, @@ -1177,13 +1299,15 @@ def attention_prefill_forward_triton_impl( stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Pass seqused tensors dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, - HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, + HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL_QK=head_size_qk, ACTUAL_BLOCK_DMODEL_V=head_size_v, MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, USE_SLIDING_WINDOW=use_sliding_window, WINDOW_SIZE_LEFT=window_size_left, WINDOW_SIZE_RIGHT=window_size_right, IS_VARLEN=IS_VARLEN, - BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, + BLOCK_DMODEL_QK=padded_d_model_qk, BLOCK_DMODEL_V=padded_d_model_v, USE_BIAS=False if bias is None else True, USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, NEEDS_SDMASK=NEEDS_SDMASK, - IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT) + IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_P_DESCALE=False, + USE_SEQUSED=(seqused_q is not None or seqused_k is not None)) # Add flag for seqused return softmax_lse, sd_mask if return_softmax else None \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/fwd_ref.py b/flash_attn/flash_attn_triton_amd/fwd_ref.py old mode 100644 new mode 100755 index 2265af12096..a8ca54a7ec3 --- a/flash_attn/flash_attn_triton_amd/fwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/fwd_ref.py @@ -9,7 +9,7 @@ def attention_forward_core_ref_impl( q, k, v, sm_scale, causal, window_size_left, window_size_right, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2, - cache_seqlens=None + cache_seqlens=None, block_table=None, paged_kv_block_size=None ): if DEBUG_CORE: print() @@ -26,17 +26,99 @@ def attention_forward_core_ref_impl( print("philox_offset:", philox_offset) print("use_exp2:", use_exp2) print("cache_seqlens:", cache_seqlens) + print("block_table:", block_table) + print("paged_kv_block_size:", paged_kv_block_size) # cast to float32 q = q.to(torch.float32) - k = k.to(torch.float32) - v = v.to(torch.float32) - - # get seqlens - L_q, L_k = q.shape[1], k.shape[1] - # Compute attention scores - attention_scores = torch.matmul(q, k.transpose(-2, -1)) + # Check if we're in paged KV mode + is_paged = block_table is not None and paged_kv_block_size is not None + + if False: # Debug paged attention (disabled for production) + print(f"\n=== attention_forward_core_ref_impl DEBUG ===") + print(f"is_paged: {is_paged}") + print(f"block_table: {block_table.shape if block_table is not None else None}") + print(f"paged_kv_block_size: {paged_kv_block_size}") + if is_paged: + print(f"k shape (paged): {k.shape}") + print(f"v shape (paged): {v.shape}") + print(f"cache_seqlens: {cache_seqlens}") + + if is_paged: + # In paged mode, k and v are [num_blocks, block_size, nheads_k, head_dim] + # We'll compute attention on-the-fly without reconstructing + nheads_q = q.shape[0] + L_q = q.shape[1] + head_dim = q.shape[2] + + # Get number of KV heads from the cache + nheads_k = k.shape[2] # k shape: [num_blocks, block_size, nheads_k, head_dim] + + # Handle MQA/GQA + assert nheads_q % nheads_k == 0, f"nheads_q ({nheads_q}) must be divisible by nheads_k ({nheads_k})" + group_size = nheads_q // nheads_k + + # Determine the actual KV sequence length from cache_seqlens + L_k = cache_seqlens if isinstance(cache_seqlens, int) else cache_seqlens.item() + + if False: # Debug disabled + print(f"L_q: {L_q}, L_k: {L_k}, nheads_q: {nheads_q}, nheads_k: {nheads_k}, group_size: {group_size}, head_dim: {head_dim}") + print(f"block_table contents: {block_table if block_table is not None else 'None'}") + + # Initialize attention scores + attention_scores = torch.zeros((nheads_q, L_q, L_k), dtype=torch.float32, device=q.device) + + # Compute attention scores on-the-fly by accessing blocks directly + for kv_pos in range(L_k): + # Calculate which block and position within block + block_idx = kv_pos // paged_kv_block_size + within_block_idx = kv_pos % paged_kv_block_size + + # Get the physical block number from block_table + # block_table is [1, num_blocks] for single batch in core function + if block_table.dim() == 2: + physical_block = block_table[0, block_idx].item() + else: + physical_block = block_table[block_idx].item() + + # Debug output disabled + # if kv_pos == 0: + # print(f"First KV access: block_idx={block_idx}, within_block={within_block_idx}, physical_block={physical_block}") + # print(f"k_vec shape will be: {k[physical_block, within_block_idx, :, :].shape}") + + # Access k values directly from paged cache + # k shape: [num_blocks, block_size, nheads_k, head_dim] + k_vec = k[physical_block, within_block_idx, :, :].to(torch.float32) # [nheads_k, head_dim] + + # For GQA/MQA, we need to repeat k_vec for each group + if group_size > 1: + # Expand k_vec to match query heads + # k_vec: [nheads_k, head_dim] -> [nheads_q, head_dim] + k_vec = k_vec.repeat_interleave(group_size, dim=0) + + # Compute dot product with all query positions + # q is [nheads_q, L_q, head_dim], k_vec is [nheads_q, head_dim] + # Result should be [nheads_q, L_q] for this kv_pos + attention_scores[:, :, kv_pos] = torch.sum(q * k_vec.unsqueeze(1), dim=-1) + + # Keep k and v in original format for later v computation + k_paged = k + v_paged = v + + # Debug output disabled + # print(f"attention_scores computed shape: {attention_scores.shape}") + # print(f"attention_scores sample values: {attention_scores[0, 0, :5]}") + else: + # Standard non-paged mode + k = k.to(torch.float32) + v = v.to(torch.float32) + + # get seqlens + L_q, L_k = q.shape[1], k.shape[1] + + # Compute attention scores + attention_scores = torch.matmul(q, k.transpose(-2, -1)) if DEBUG_CORE: print("attention_scores:", attention_scores, attention_scores.shape) @@ -256,7 +338,55 @@ def attention_forward_core_ref_impl( print("softmax_lse:", softmax_lse, softmax_lse.shape) # Compute output - o = torch.matmul(p, v) + if is_paged: + # Compute output on-the-fly using paged v cache + nheads_q = p.shape[0] + L_q = p.shape[1] + nheads_v = v_paged.shape[2] # [num_blocks, block_size, nheads_v, head_dim] + head_dim = v_paged.shape[3] + + # Handle MQA/GQA for v + assert nheads_q % nheads_v == 0, f"nheads_q ({nheads_q}) must be divisible by nheads_v ({nheads_v})" + v_group_size = nheads_q // nheads_v + + o = torch.zeros((nheads_q, L_q, head_dim), dtype=torch.float32, device=p.device) + + # Accumulate weighted v values + for kv_pos in range(L_k): + # Calculate which block and position within block + block_idx = kv_pos // paged_kv_block_size + within_block_idx = kv_pos % paged_kv_block_size + + # Get the physical block number from block_table + if block_table.dim() == 2: + physical_block = block_table[0, block_idx].item() + else: + physical_block = block_table[block_idx].item() + + # Access v values directly from paged cache + # v_paged shape: [num_blocks, block_size, nheads_v, head_dim] + v_vec = v_paged[physical_block, within_block_idx, :, :].to(torch.float32) # [nheads_v, head_dim] + + # For GQA/MQA, we need to repeat v_vec for each group + if v_group_size > 1: + # Expand v_vec to match query heads + # v_vec: [nheads_v, head_dim] -> [nheads_q, head_dim] + v_vec = v_vec.repeat_interleave(v_group_size, dim=0) + + # Weight by attention probabilities + # p is [nheads_q, L_q, L_k], we need p[:, :, kv_pos] which is [nheads_q, L_q] + # v_vec is [nheads_q, head_dim] + # We want to add p[:, :, kv_pos] * v_vec to each query position + weights = p[:, :, kv_pos].unsqueeze(-1) # [nheads_q, L_q, 1] + o += weights * v_vec.unsqueeze(1) # [nheads_q, L_q, head_dim] + else: + o = torch.matmul(p, v) + + # Debug output disabled + # if False: + # print(f"Output o shape: {o.shape}") + # print(f"Output o sample values: {o[0, 0, :5]}") + if DEBUG_CORE: print("o:", o, o.shape) @@ -439,7 +569,7 @@ def attention_varlen_forward_pytorch_ref_impl( return o, softmax_lse, sd_mask -def attention_forward_pytorch_ref_impl( +def attention_prefill_forward_ref_impl( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -516,12 +646,61 @@ def attention_decode_forward_ref_impl( layout: Literal["bshd"], cache_seqlens: Optional[torch.Tensor], cache_batch_idx: Optional[torch.Tensor], + block_table: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, ): """Compute reference output for decode attention using PyTorch's built-in functions""" + if False: # Permanently disabled old debug output + pass + # print("\n========== attention_decode_forward_ref_impl inputs ==========") + # print(f"q shape: {q.shape}, dtype: {q.dtype}, device: {q.device}") + print(f"q values:\n{q}") + print(f"\nk_cache shape: {k_cache.shape}, dtype: {k_cache.dtype}, device: {k_cache.device}") + print(f"k_cache values:\n{k_cache}") + print(f"\nv_cache shape: {v_cache.shape}, dtype: {v_cache.dtype}, device: {v_cache.device}") + print(f"v_cache values:\n{v_cache}") + print(f"\nk_new: {k_new.shape if k_new is not None else None}, dtype: {k_new.dtype if k_new is not None else None}") + if k_new is not None: + print(f"k_new values:\n{k_new}") + print(f"\nv_new: {v_new.shape if v_new is not None else None}, dtype: {v_new.dtype if v_new is not None else None}") + if v_new is not None: + print(f"v_new values:\n{v_new}") + print(f"\nout shape: {out.shape}, dtype: {out.dtype}, device: {out.device}") + print(f"out values:\n{out}") + print(f"\nsm_scale: {sm_scale}") + print(f"causal: {causal}") + print(f"window_size_left: {window_size_left}") + print(f"window_size_right: {window_size_right}") + print(f"\nalibi_slopes: {alibi_slopes.shape if alibi_slopes is not None else None}, dtype: {alibi_slopes.dtype if alibi_slopes is not None else None}") + if alibi_slopes is not None: + print(f"alibi_slopes values:\n{alibi_slopes}") + print(f"\nlayout: {layout}") + print(f"cache_seqlens: {cache_seqlens}") + if cache_seqlens is not None and torch.is_tensor(cache_seqlens): + print(f"cache_seqlens values: {cache_seqlens}") + print(f"cache_batch_idx: {cache_batch_idx}") + if cache_batch_idx is not None: + print(f"cache_batch_idx values: {cache_batch_idx}") + print(f"\nblock_table: {block_table.shape if block_table is not None else None}, dtype: {block_table.dtype if block_table is not None else None}") + if block_table is not None: + print(f"block_table values:\n{block_table}") + print("=" * 60) + # get batch size before any layout conversion batch_size = q.shape[0] + # Determine if we're in paged KV mode + is_paged = block_table is not None + if is_paged: + # Infer block size from cache shape + # k_cache shape for paged: [num_blocks, block_size, nheads, head_dim] + paged_kv_block_size = k_cache.shape[1] + else: + paged_kv_block_size = None + # handle cache_batch_idx if cache_batch_idx is not None: # remap batch indices for cache access @@ -531,39 +710,79 @@ def attention_decode_forward_ref_impl( # copy new keys and values into cache if provided (before any layout conversion) if k_new is not None and v_new is not None: - _, seq_len_new, _, _ = k_new.shape # shape is [batch, seq_len, nheads, head_dim] for bshd layout - - for b in range(batch_size): - cache_idx = batch_indices[b].item() if torch.is_tensor(batch_indices) else batch_indices + if is_paged: + # For paged KV cache, we need to update the blocks with new k/v values + _, seq_len_new, _, _ = k_new.shape # shape is [batch, seq_len, nheads, head_dim] for bshd layout - # determine where to place new k/v in cache - if cache_seqlens is not None: - if torch.is_tensor(cache_seqlens): - start_pos = cache_seqlens[b].item() + for b in range(batch_size): + # Determine where to place new k/v in cache + if cache_seqlens is not None: + if torch.is_tensor(cache_seqlens): + start_pos = cache_seqlens[b].item() + else: + start_pos = cache_seqlens else: - start_pos = cache_seqlens - else: - # if no cache_seqlens, assume we're filling from the beginning - start_pos = 0 - - end_pos = start_pos + seq_len_new + start_pos = 0 + + # For each new position, find the corresponding block and update it + for pos_offset in range(seq_len_new): + kv_pos = start_pos + pos_offset + + # Calculate which block and position within block + block_idx = kv_pos // paged_kv_block_size + within_block_idx = kv_pos % paged_kv_block_size + + # Get the physical block number from block_table + physical_block = block_table[b, block_idx].item() + + # Update the k and v values in the paged cache + # k_cache shape: [num_blocks, block_size, nheads, head_dim] + # k_new shape: [batch, seq_len, nheads, head_dim] + k_cache[physical_block, within_block_idx, :, :] = k_new[b, pos_offset, :, :] + v_cache[physical_block, within_block_idx, :, :] = v_new[b, pos_offset, :, :] + else: + _, seq_len_new, _, _ = k_new.shape # shape is [batch, seq_len, nheads, head_dim] for bshd layout - # copy new keys and values into cache (both are in bshd layout) - k_cache[cache_idx, start_pos:end_pos, :, :] = k_new[b, :, :, :] - v_cache[cache_idx, start_pos:end_pos, :, :] = v_new[b, :, :, :] + for b in range(batch_size): + cache_idx = batch_indices[b].item() if torch.is_tensor(batch_indices) else batch_indices + + # determine where to place new k/v in cache + if cache_seqlens is not None: + if torch.is_tensor(cache_seqlens): + start_pos = cache_seqlens[b].item() + else: + start_pos = cache_seqlens + else: + # if no cache_seqlens, assume we're filling from the beginning + start_pos = 0 + + end_pos = start_pos + seq_len_new + + # copy new keys and values into cache (both are in bshd layout) + k_cache[cache_idx, start_pos:end_pos, :, :] = k_new[b, :, :, :] + v_cache[cache_idx, start_pos:end_pos, :, :] = v_new[b, :, :, :] # ensure the layout is 'bhsd' if layout == "bshd": q = q.transpose(1, 2).contiguous() - k_cache = k_cache.transpose(1, 2).contiguous() - v_cache = v_cache.transpose(1, 2).contiguous() + if not is_paged: + k_cache = k_cache.transpose(1, 2).contiguous() + v_cache = v_cache.transpose(1, 2).contiguous() elif layout != "bhsd": raise ValueError(f"Unknown layout {layout}") # prepare tensors batch_size_q, nheads_q, seq_len_q, head_dim = q.shape - batch_size_cache, nheads_k, max_cache_len, head_dim_k = k_cache.shape - _, nheads_v, _, head_dim_v = v_cache.shape + + if is_paged: + # For paged cache: [num_blocks, block_size, nheads, head_dim] + num_blocks, block_size, nheads_k, head_dim_k = k_cache.shape + _, _, nheads_v, head_dim_v = v_cache.shape + max_cache_len = None # Not directly available in paged mode + batch_size_cache = None # Not applicable in paged mode + else: + batch_size_cache, nheads_k, max_cache_len, head_dim_k = k_cache.shape + _, nheads_v, _, head_dim_v = v_cache.shape # validate dimensions assert head_dim == head_dim_k == head_dim_v, f"Head dimensions must match: {head_dim}, {head_dim_k}, {head_dim_v}" @@ -586,7 +805,8 @@ def attention_decode_forward_ref_impl( # process each batch element for b in range(batch_size): - cache_idx = batch_indices[b].item() if torch.is_tensor(batch_indices) else batch_indices + if not is_paged: + cache_idx = batch_indices[b].item() if torch.is_tensor(batch_indices) else batch_indices # determine valid cache length for this batch element if cache_seqlens is not None: @@ -601,25 +821,41 @@ def attention_decode_forward_ref_impl( _, seq_len_new, _, _ = k_new.shape cache_len += seq_len_new else: - cache_len = max_cache_len - - # CHANGE: Extract the full cache, not just valid portion - # This matches what the test does - it uses full k_cache_rep/v_cache_rep - k_b = k_cache[cache_idx, :, :, :] # [nheads_k, max_cache_len, head_dim] - v_b = v_cache[cache_idx, :, :, :] # [nheads_v, max_cache_len, head_dim] - q_b = q[b:b+1, :, :, :] # [1, nheads_q, seq_len_q, head_dim] + if is_paged: + # For paged mode, we need cache_seqlens to know the valid length + raise ValueError("cache_seqlens must be provided for paged KV cache") + else: + cache_len = max_cache_len - # handle MQA/GQA by expanding k and v - if group_size != 1: - # expand k and v to match q's number of heads - k_b = k_b.unsqueeze(1).expand(-1, group_size, -1, -1) - k_b = k_b.reshape(nheads_q, max_cache_len, head_dim) + if is_paged: + # For paged KV cache, pass the cache and block table directly + # Extract block table for this batch element + block_table_b = block_table[b:b+1, :] # [1, num_blocks] + k_b = k_cache # Pass entire paged cache + v_b = v_cache # Pass entire paged cache + q_b = q[b:b+1, :, :, :] # [1, nheads_q, seq_len_q, head_dim] - v_b = v_b.unsqueeze(1).expand(-1, group_size, -1, -1) - v_b = v_b.reshape(nheads_q, max_cache_len, head_dim) - - # reshape for attention_forward_core_ref_impl - q_b = q_b.reshape(nheads_q, seq_len_q, head_dim) + # For paged mode with MQA/GQA, we handle expansion in the core function + # Just reshape q for now + q_b = q_b.reshape(nheads_q, seq_len_q, head_dim) + else: + # Standard non-paged mode + k_b = k_cache[cache_idx, :, :, :] # [nheads_k, max_cache_len, head_dim] + v_b = v_cache[cache_idx, :, :, :] # [nheads_v, max_cache_len, head_dim] + q_b = q[b:b+1, :, :, :] # [1, nheads_q, seq_len_q, head_dim] + block_table_b = None + + # handle MQA/GQA by expanding k and v + if group_size != 1: + # expand k and v to match q's number of heads + k_b = k_b.unsqueeze(1).expand(-1, group_size, -1, -1) + k_b = k_b.reshape(nheads_q, max_cache_len, head_dim) + + v_b = v_b.unsqueeze(1).expand(-1, group_size, -1, -1) + v_b = v_b.reshape(nheads_q, max_cache_len, head_dim) + + # reshape for attention_forward_core_ref_impl + q_b = q_b.reshape(nheads_q, seq_len_q, head_dim) # handle alibi slopes for this batch alibi_slopes_b = None @@ -635,6 +871,8 @@ def attention_decode_forward_ref_impl( dropout_p=0.0, philox_seed=None, philox_offset=None, alibi_slopes=alibi_slopes_b, use_exp2=True, cache_seqlens=cache_len, # Pass valid cache length + block_table=block_table_b, # Pass block table for paged mode + paged_kv_block_size=paged_kv_block_size, # Pass block size for paged mode ) # store outputs diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index a6b83f64365..3dc443abf67 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -5,7 +5,7 @@ from .bwd_prefill_fused_atomics import attention_prefill_backward_triton_fused_atomics_impl from .bwd_prefill_fused_no_atomics import attention_prefill_backward_triton_split_fused_no_atomics_impl from .fwd_decode import attention_decode_forward_triton_impl -from .fwd_ref import attention_forward_pytorch_ref_impl, attention_decode_forward_ref_impl +from .fwd_ref import attention_prefill_forward_ref_impl, attention_decode_forward_ref_impl from .bwd_ref import attention_backward_pytorch_ref_impl from .utils import DEBUG, USE_REF, MetaData, is_fp8 from einops import rearrange, repeat @@ -31,8 +31,7 @@ def fwd(q: torch.Tensor, gen_: Optional[torch.Tensor] = None, descale_q: Optional[torch.Tensor] = None, descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_o: Optional[torch.Tensor] = None + descale_v: Optional[torch.Tensor] = None ): if DEBUG: @@ -53,11 +52,12 @@ def fwd(q: torch.Tensor, print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) - print("descale_o:", descale_o, descale_o.shape if descale_o is not None else None) if is_fp8(q): assert out is not None, "fp8 output tensor should be passed in." - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"For fp8, you need to pass descale factors for q, k and v" + if (descale_q is None) or (descale_k is None) or (descale_v is None): + import warnings + warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) else: out = torch.zeros_like(q) if out is None else out.zero_() @@ -93,7 +93,7 @@ def fwd(q: torch.Tensor, if USE_REF: if DEBUG: print("Using reference implementation") - softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( + softmax_lse_ref, sd_mask_ref = attention_prefill_forward_ref_impl( q, k, v, @@ -140,16 +140,13 @@ def fwd(q: torch.Tensor, USE_EXP2, descale_q, descale_k, - descale_v, - descale_o) + descale_v) softmax_lse=softmax_lse_triton sd_mask=sd_mask_triton if DEBUG: print("flash_attn_triton_amd.py::fwd outputs") print("o:", out, out.shape) - if is_fp8(out): - print("descale_o:", descale_o, descale_o.shape if descale_o is not None else None) print("softmax_lse:", softmax_lse, softmax_lse.shape) print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None ) print("rng_state:", rng_state) @@ -405,8 +402,7 @@ def varlen_fwd( gen_: Optional[torch.Tensor] = None, descale_q: Optional[torch.Tensor] = None, descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_o: Optional[torch.Tensor] = None + descale_v: Optional[torch.Tensor] = None ): if DEBUG: @@ -432,7 +428,9 @@ def varlen_fwd( if is_fp8(q): assert out is not None, "fp8 output tensor should be passed in." - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"For fp8, you need to pass descale factors for q, k and v" + if (descale_q is None) or (descale_k is None) or (descale_v is None): + import warnings + warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) else: out = torch.zeros_like(q) if out is None else out.zero_() @@ -468,7 +466,7 @@ def varlen_fwd( if USE_REF: if DEBUG: print("Using reference implementation") - softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( + softmax_lse_ref, sd_mask_ref = attention_prefill_forward_ref_impl( q, k, v, @@ -515,8 +513,7 @@ def varlen_fwd( USE_EXP2, descale_q, descale_k, - descale_v, - descale_o) + descale_v) softmax_lse=softmax_lse_triton sd_mask=sd_mask_triton @@ -899,6 +896,7 @@ def fwd_kvcache( metadata.layout, metadata.cache_seqlens, metadata.cache_batch_idx, + block_table, ) softmax_lse=softmax_lse_ref else: @@ -919,6 +917,7 @@ def fwd_kvcache( metadata.layout, metadata.cache_seqlens, metadata.cache_batch_idx, + block_table, ) softmax_lse = softmax_lse_triton diff --git a/flash_attn/flash_attn_triton_amd/interface_fa_v3.py b/flash_attn/flash_attn_triton_amd/interface_fa_v3.py new file mode 100755 index 00000000000..be8e2d3cbeb --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/interface_fa_v3.py @@ -0,0 +1,660 @@ +import torch +import os +from .fwd_prefill import attention_prefill_forward_triton_impl +from .bwd_prefill_split import attention_prefill_backward_triton_split_impl +from .bwd_prefill_fused_atomics import attention_prefill_backward_triton_fused_atomics_impl +from .bwd_prefill_fused_no_atomics import attention_prefill_backward_triton_split_fused_no_atomics_impl +from .fwd_decode import attention_decode_forward_triton_impl +from .fwd_ref import attention_prefill_forward_ref_impl, attention_decode_forward_ref_impl +from .bwd_ref import attention_backward_pytorch_ref_impl +from .utils import DEBUG, USE_REF, MetaData, is_fp8 +from einops import rearrange, repeat +from flash_attn.layers.rotary import apply_rotary_emb +from typing import Optional, Union, Tuple + +USE_EXP2 = True +BWD_MODE = os.environ.get('BWD_MODE', 'fused_no_atomics').lower() +USE_DECODE_PATH = os.environ.get('FLASH_ATTENTION_V3_USE_DECODE', '0') == '1' + +def fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_new: Optional[torch.Tensor], + v_new: Optional[torch.Tensor], + qv: Optional[torch.Tensor], + out: Optional[torch.Tensor], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + cu_seqlens_k_new: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + page_table: Optional[torch.Tensor], + kv_batch_idx: Optional[torch.Tensor], + leftpad_k: Optional[torch.Tensor], + rotary_cos: Optional[torch.Tensor], + rotary_sin: Optional[torch.Tensor], + seqlens_rotary: Optional[torch.Tensor], + q_descale: Optional[torch.Tensor], + k_descale: Optional[torch.Tensor], + v_descale: Optional[torch.Tensor], + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + attention_chunk: int, + softcap: float, + rotary_interleaved: bool, + scheduler_metadata=None, + num_splits: int = 1, + pack_gqa=None, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Flash Attention v3 forward pass compatible interface for AMD Triton implementation. + + This function maps v3 parameters to the existing AMD Triton implementation. + """ + + if DEBUG: + print() + print("interface_fa_v3.py::fwd inputs") + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("k_new:", k_new, k_new.shape if k_new is not None else None) + print("v_new:", v_new, v_new.shape if v_new is not None else None) + print("qv:", qv, qv.shape if qv is not None else None) + print("out:", out, out.shape if out is not None else None) + print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape if cu_seqlens_q is not None else None) + print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape if cu_seqlens_k is not None else None) + print("cu_seqlens_k_new:", cu_seqlens_k_new, cu_seqlens_k_new.shape if cu_seqlens_k_new is not None else None) + print("seqused_q:", seqused_q, seqused_q.shape if seqused_q is not None else None) + print("seqused_k:", seqused_k, seqused_k.shape if seqused_k is not None else None) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("page_table:", page_table, page_table.shape if page_table is not None else None) + print("kv_batch_idx:", kv_batch_idx, kv_batch_idx.shape if kv_batch_idx is not None else None) + print("leftpad_k:", leftpad_k, leftpad_k.shape if leftpad_k is not None else None) + print("rotary_cos:", rotary_cos, rotary_cos.shape if rotary_cos is not None else None) + print("rotary_sin:", rotary_sin, rotary_sin.shape if rotary_sin is not None else None) + print("seqlens_rotary:", seqlens_rotary, seqlens_rotary.shape if seqlens_rotary is not None else None) + print("q_descale:", q_descale, q_descale.shape if q_descale is not None else None) + print("k_descale:", k_descale, k_descale.shape if k_descale is not None else None) + print("v_descale:", v_descale, v_descale.shape if v_descale is not None else None) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("attention_chunk:", attention_chunk) + print("softcap:", softcap) + print("rotary_interleaved:", rotary_interleaved) + print("scheduler_metadata:", scheduler_metadata) + print("num_splits:", num_splits) + print("pack_gqa:", pack_gqa) + print("sm_margin:", sm_margin) + + # Handle qv packed input + if qv is not None: + raise NotImplementedError("QV packed input is not yet supported in the AMD Triton backend") + + + # Handle softcap + if softcap != 0.0: + raise NotImplementedError(f"Softcap is not yet supported in the AMD Triton backend (got softcap={softcap}, expected 0.0)") + + # Handle attention_chunk + if attention_chunk != 0 and attention_chunk != 1: + raise NotImplementedError(f"attention_chunk is not yet supported in the AMD Triton backend (got attention_chunk={attention_chunk})") + + + # Handle scheduler metadata + if scheduler_metadata is not None: + raise NotImplementedError("Scheduler metadata is not yet supported in the AMD Triton backend") + + # Handle pack_gqa + if pack_gqa is not None and pack_gqa is not False: + raise NotImplementedError(f"pack_gqa is not yet supported in the AMD Triton backend (got pack_gqa={pack_gqa})") + + # Handle num_splits + if num_splits != 1: + raise NotImplementedError(f"Split attention (num_splits > 1) is not yet supported in the AMD Triton backend (got num_splits={num_splits})") + + # Handle sm_margin + if sm_margin != 0: + raise NotImplementedError(f"sm_margin is not yet supported in the AMD Triton backend (got sm_margin={sm_margin}, expected 0)") + + # Handle leftpad_k + if leftpad_k is not None: + raise NotImplementedError("Left padding (leftpad_k) is not yet supported in the AMD Triton backend") + + # Handle cu_seqlens_k_new + if cu_seqlens_k_new is not None: + raise NotImplementedError("cu_seqlens_k_new is not yet supported in the AMD Triton backend") + + # if seqlens_rotary is not None: + # raise NotImplementedError("seqlens_rotary is not yet supported in the AMD Triton backend") + + # Setup metadata + metadata = MetaData(sm_scale=softmax_scale) + + + # Handle variable length sequences first to determine layout + # Determine layout based on tensor dimensions and cu_seqlens presence + if cu_seqlens_q is not None: + # Q has variable length - check tensor dimensions to confirm + if len(q.shape) == 3: # [total_seqlen, nheads, head_dim] + metadata.layout = "thd" + metadata.varlen = True + metadata.cu_seqlens_q = cu_seqlens_q + metadata.max_seqlens_q = max_seqlen_q + + # K might be varlen or batch mode + if cu_seqlens_k is not None: + metadata.cu_seqlens_k = cu_seqlens_k + metadata.max_seqlens_k = max_seqlen_k + else: + # K is in batch mode while Q is varlen (KV cache scenario) + metadata.cu_seqlens_k = None + metadata.max_seqlens_k = k.shape[1] if len(k.shape) == 4 else max_seqlen_k + else: + raise ValueError(f"cu_seqlens_q provided but q has shape {q.shape}, expected 3D tensor for varlen") + else: + # Regular batch mode + metadata.layout = "bshd" + metadata.varlen = False + metadata.cu_seqlens_q = None + metadata.cu_seqlens_k = None + metadata.max_seqlens_q = q.shape[1] if max_seqlen_q is None else max_seqlen_q + metadata.max_seqlens_k = k.shape[1] if max_seqlen_k is None else max_seqlen_k + + # Now determine if we should use decode or prefill kernel + # Decode kernel should be used for KV cache scenarios where: + # 1. k_new/v_new are provided - incremental KV cache update (primary KV cache indicator) + # 2. kv_batch_idx is provided - KV cache batch indexing (primary KV cache indicator) + # 3. seqused_k without seqused_q - indicates KV cache fill levels (not varlen masking) + # Note: In varlen, both seqused_q and seqused_k are used for sequence masking + # In KV cache, only seqused_k is used to track cache fill levels + if USE_DECODE_PATH: + # Force decode path + use_decode = True + else: + # Detect KV cache scenarios: + # - Clear KV cache indicators (k_new, v_new, kv_batch_idx) + # - OR seqused_k without seqused_q (KV cache fill tracking, not varlen masking) + use_decode = ( + k_new is not None or # Have new KV to append (KV cache indicator) + v_new is not None or # Have new KV to append (KV cache indicator) + kv_batch_idx is not None or # Have KV cache batch indexing (KV cache indicator) + (seqused_k is not None and seqused_q is None) # KV cache fill levels (not varlen) + ) + + # Check for unsupported features with decode kernel + if use_decode: + if metadata.layout == "thd": + raise NotImplementedError("Varlen is not yet supported with the decode kernel in the AMD Triton backend") + if kv_batch_idx is not None: + raise NotImplementedError("kv_batch_idx is not yet supported with the decode kernel in the AMD Triton backend") + + + if out is None: + out_dtype = torch.float32 if is_fp8(q) else q.dtype + if metadata.layout == "bshd": + out = torch.zeros(q.shape[0], q.shape[1], q.shape[2], v.shape[-1], dtype=out_dtype, device=q.device) + elif metadata.layout == "thd": + out = torch.zeros(q.shape[0], q.shape[1], v.shape[-1], dtype=out_dtype, device=q.device) + else: + raise ValueError(f"Unsupported layout: {metadata.layout}. Only 'bshd' and 'thd' layouts are supported.") + else: + out = out.zero_() + + if is_fp8(q): + if (q_descale is None) or (k_descale is None) or (v_descale is None): + import warnings + warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) + + # Get shape + if metadata.layout == "bshd": + batch, _, nheads_q, _ = q.shape + else: # "thd" layout for varlen + _, nheads_q, _ = q.shape + batch = len(cu_seqlens_q) - 1 if cu_seqlens_q is not None else 1 + + # Handle causal mask + if causal: + metadata.need_causal(True) + + # Handle alibi slopes (not directly supported in v3 interface, but we'll keep the logic) + alibi_slopes = None # V3 doesn't have alibi_slopes in the signature + if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError(f"Alibi can be (nheads,) or (batch_size, nheads). Given tensor with shape {alibi_slopes.shape}") + metadata.need_alibi(alibi_slopes, batch, nheads_q) + + # Handle dropout (v3 doesn't have dropout in forward) + dropout_p = 0.0 + return_softmax = False + metadata.need_dropout(dropout_p, return_softmax) + + # Handle rotary embeddings + if rotary_cos is not None and rotary_sin is not None: + metadata.need_rotary(rotary_sin, rotary_cos, rotary_interleaved) + + # Apply rotary embeddings if provided + if metadata.causal or window_size_left != -1 or window_size_right != -1: + q_rot = apply_rotary_emb( + q, + rotary_cos, + rotary_sin, + seqlen_offsets=seqlens_rotary, + interleaved=rotary_interleaved, + ) + q = q_rot.to(q.dtype) + + if k_new is not None: + k_rot = apply_rotary_emb( + k_new, + rotary_cos, + rotary_sin, + seqlen_offsets=seqlens_rotary, + interleaved=rotary_interleaved, + ) + k_new = k_rot.to(k.dtype) + + # Store RNG state + rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) + + # Call implementation + if USE_REF: + if DEBUG: + print("Using reference implementation") + + if use_decode: + if DEBUG: + print(f"Using decode reference implementation ( layout={metadata.layout}, cache_seqlens={seqused_k is not None}, k_new={k_new is not None}, v_new={v_new is not None}, kv_batch_idx={kv_batch_idx is not None})") + # Use decode reference implementation + softmax_lse = attention_decode_forward_ref_impl( + q, + k, # k_cache + v, # v_cache + k_new, + v_new, + out, + metadata.sm_scale, + metadata.causal, + window_size_left, + window_size_right, + metadata.alibi_slopes, + metadata.layout, + seqused_k, # cache_seqlens + kv_batch_idx, # cache_batch_idx + page_table, # block_table + q_descale, + k_descale, + v_descale, + ) + else: + if DEBUG: + print("Using prefill reference implementation") + # Use prefill reference implementation + softmax_lse_ref, sd_mask_ref = attention_prefill_forward_ref_impl( + q, k, v, out, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + window_size_left, + window_size_right, + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + USE_EXP2 + ) + softmax_lse = softmax_lse_ref + else: + if DEBUG: + print("Using Triton implementation") + + if use_decode: + if DEBUG: + print(f"Using Decode Triton implementation (cache_seqlens={seqused_k is not None}, k_new={k_new is not None}, v_new={v_new is not None}, kv_batch_idx={kv_batch_idx is not None})") + + # Use decode kernel for KV cache scenarios + # Note: seqused_k can serve as cache_seqlens in v3 + softmax_lse = attention_decode_forward_triton_impl( + q, + k, # k_cache in v2 terminology + v, # v_cache in v2 terminology + k_new, # New KV values to append to cache + v_new, # New KV values to append to cache + out, + metadata.sm_scale, + metadata.causal, + window_size_left, + window_size_right, + metadata.alibi_slopes, + metadata.layout, + seqused_k, # cache_seqlens + kv_batch_idx, # cache_batch_idx + page_table, # block_table for paged attention + q_descale, + k_descale, + v_descale, + ) + # Decode kernel returns only softmax_lse, not sd_mask + sd_mask_triton = None + else: + if DEBUG: + print("Using prefill Triton implementation") + # Use prefill kernel + softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( + q, k, v, out, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + window_size_left, + window_size_right, + None, # block_table + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + metadata.return_softmax, + USE_EXP2, + q_descale, + k_descale, + v_descale, + seqused_q, + seqused_k, + ) + softmax_lse = softmax_lse_triton + + if DEBUG: + print("interface_fa_v3.py::fwd outputs") + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + + # Return format compatible with v3 + # V3 returns (out, softmax_lse, *rest) where rest can be empty or contain additional outputs + return out, softmax_lse + + +def bwd( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + deterministic: bool, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Flash Attention v3 backward pass compatible interface for AMD Triton implementation. + + This function maps v3 parameters to the existing AMD Triton implementation. + """ + + if DEBUG: + print() + print("interface_fa_v3.py::bwd inputs") + print("dout:", dout, dout.shape) + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("dq:", dq, dq.shape if dq is not None else None) + print("dk:", dk, dk.shape if dk is not None else None) + print("dv:", dv, dv.shape if dv is not None else None) + print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape if cu_seqlens_q is not None else None) + print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape if cu_seqlens_k is not None else None) + print("seqused_q:", seqused_q, seqused_q.shape if seqused_q is not None else None) + print("seqused_k:", seqused_k, seqused_k.shape if seqused_k is not None else None) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("softcap:", softcap) + print("deterministic:", deterministic) + print("sm_margin:", sm_margin) + + # Check for unsupported features in backward pass + + # Handle softcap + if softcap != 0.0: + raise NotImplementedError(f"Softcap is not yet supported in the AMD Triton backend backward pass (got softcap={softcap}, expected 0.0)") + + # Handle sm_margin + if sm_margin != 0: + raise NotImplementedError(f"sm_margin is not yet supported in the AMD Triton backend backward pass (got sm_margin={sm_margin}, expected 0)") + + # Initialize gradient tensors if not provided + dq = torch.zeros_like(q) if dq is None else dq.zero_() + dk = torch.zeros_like(k) if dk is None else dk.zero_() + dv = torch.zeros_like(v) if dv is None else dv.zero_() + + # Determine layout based on cu_seqlens + if cu_seqlens_q is not None and cu_seqlens_k is not None: + # Variable length sequence mode + layout = "thd" + batch = len(cu_seqlens_q) - 1 + _, nheads_q, _ = q.shape + else: + # Regular batch mode + layout = "bshd" + batch, _, nheads_q, _ = q.shape + max_seqlen_q = q.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlen_k = k.shape[1] if max_seqlen_k is None else max_seqlen_k + + # V3 backward doesn't have dropout or alibi slopes + dropout_p = 0.0 + philox_seed, philox_offset = None, None + alibi_slopes = None + + # For fp8, we would need descale factors, but v3 interface doesn't expose them + # So we'll pass None for now + descale_q = None + descale_k = None + descale_v = None + descale_o = None + descale_do = None + descale_dq = None + descale_dk = None + descale_dv = None + + # Call implementation + if USE_REF: + if DEBUG: + print("Using reference implementation") + delta_ref = attention_backward_pytorch_ref_impl( + dout, q, k, v, out, softmax_lse, + dq, dk, dv, + softmax_scale, + alibi_slopes, + causal, + window_size_left, + window_size_right, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + USE_EXP2, + ) + delta = delta_ref + else: + if DEBUG: + print("Using Triton implementation") + + if BWD_MODE == "split": + delta_triton = attention_prefill_backward_triton_split_impl( + dout, q, k, v, out, softmax_lse, + dq, dk, dv, + softmax_scale, + alibi_slopes, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + USE_EXP2, + descale_q, descale_k, descale_v, descale_o, + descale_do, descale_dq, descale_dk, descale_dv, + seqused_q, seqused_k, + ) + delta = delta_triton + elif BWD_MODE == "fused_atomics": + delta_triton = attention_prefill_backward_triton_fused_atomics_impl( + dout, q, k, v, out, softmax_lse, + dq, dk, dv, + softmax_scale, + alibi_slopes, + causal, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + descale_q, descale_k, descale_v, descale_o, + True, + ) + delta = delta_triton + elif BWD_MODE == "fused_no_atomics": + delta_triton = attention_prefill_backward_triton_split_fused_no_atomics_impl( + dout, q, k, v, out, softmax_lse, + dq, dk, dv, + softmax_scale, + alibi_slopes, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + USE_EXP2, + descale_q, descale_k, descale_v, descale_o, + descale_do, descale_dq, descale_dk, descale_dv, + seqused_q, seqused_k, + ) + delta = delta_triton + else: + raise ValueError(f"Unknown bwd mode {BWD_MODE}") + + if DEBUG: + print("interface_fa_v3.py::bwd outputs") + print("dq:", dq, dq.shape) + print("dk:", dk, dk.shape) + print("dv:", dv, dv.shape) + print("delta:", delta, delta.shape if delta is not None else None) + + # V3 expects (dq, dk, dv, softmax_d, *rest) + # delta is the softmax_d in this case + return dq, dk, dv, delta + + +def fwd_combine( + out_partial: torch.Tensor, + lse_partial: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """ + Combine partial outputs from split attention computation. + + This is used when num_splits > 1 to combine the partial results. + + Args: + out_partial: Partial output tensor from split computation + lse_partial: Partial log-sum-exp tensor + out: Optional output tensor to write to + out_dtype: Optional dtype for output + + Returns: + Combined output tensor + """ + raise NotImplementedError("fwd_combine is not yet implemented in the AMD Triton backend") + + +def get_scheduler_metadata( + batch_size: int, + max_seqlen_q: int, + max_seqlen_k: int, + num_heads_q: int, + num_heads_kv: int, + headdim: int, + headdim_v: int, + qkv_dtype: torch.dtype, + cache_seqlens: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + max_seqlen_k_new: int = 0, + causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + attention_chunk: int = 0, + has_softcap: bool = False, + num_splits: int = 0, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, +): + """ + Get scheduler metadata for optimized kernel selection. + + This function is used to precompute metadata for kernel scheduling in FA3. + The AMD Triton backend currently doesn't use scheduler metadata, so this + raises an error. + + Args: + Various attention parameters used for scheduling decisions + + Returns: + None - scheduler metadata is not used in AMD Triton backend + """ + raise NotImplementedError("get_scheduler_metadata is not supported in the AMD Triton backend yet.") \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py deleted file mode 100644 index 15fa8f2f07a..00000000000 --- a/flash_attn/flash_attn_triton_amd/test.py +++ /dev/null @@ -1,777 +0,0 @@ -import os -import glob -import shutil -import time -import torch -import pytest -import logging -import numpy as np -from pathlib import Path -import triton -import flash_attn -import flash_attn.flash_attn_triton_amd.fp8 - -from .utils import generate_bshd_kv_packed, generate_bshd_qkv_packed, generate_bshd_tensor, generate_varlen_kv_packed, generate_varlen_qkv_packed, input_helper, arch_supports_fp8, generate_varlen_tensor - -# debugging - -logging.basicConfig(level=logging.INFO, format='%(message)s', force=True) -logger = logging.getLogger(__name__) -DEBUG = False - -# defailt fp16 tolerance is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html -ATOL, RTOL = 1e-2, 1e-2 # old standard. maybe to lose. -# ATOL, RTOL = 1e-3, 1e-3 # catchs fa mismatch issues -# ATOL, RTOL = 1e-4, 1e-3 # to strict. there will be small diffs -# ATOL, RTOL = 1e-5, 1e-3 # # default fp16. there will be small diffs -# ATOL_fp8, RTOL_fp8 = 1e-1, 1e-1 # to strict for larger tensors in fp8 -ATOL_fp8, RTOL_fp8 = 2.5e-1, 2.5e-1 # fp8 -# ATOL_fp8, RTOL_fp8 = 2e-2, 2e-2 # fp8 -EQUAL_NAN = True - -def fp8_assert_close(tensor_a, tensor_b, atol=ATOL_fp8, rtol=RTOL_fp8, max_diff_percentage=0.5): - """Assert tensors are close with tolerance for small percentage of elements""" - # standard comparison - abs_diff = torch.abs(tensor_a - tensor_b) - rel_diff = abs_diff / torch.abs(tensor_b.clamp(min=1e-6)) - - # calculate elements that exceed tolerance - abs_check = abs_diff > atol - rel_check = rel_diff > rtol - failed_check = torch.logical_and(abs_check, rel_check) - - # calculate percentage of failed elements - failed_percentage = failed_check.sum().item() / failed_check.numel() * 100 - - # if percentage is small enough, test passes - if failed_percentage <= max_diff_percentage: - return True - - # Otherwise, provide diagnostic information - max_abs_idx = torch.argmax(abs_diff).item() - max_rel_idx = torch.argmax(rel_diff).item() - - flat_to_idx = lambda flat_idx, shape: np.unravel_index(flat_idx, shape) - - max_abs_pos = flat_to_idx(max_abs_idx, tensor_a.shape) - max_rel_pos = flat_to_idx(max_rel_idx, tensor_a.shape) - - max_abs_diff = abs_diff.flatten()[max_abs_idx].item() - max_rel_diff = rel_diff.flatten()[max_rel_idx].item() - - raise AssertionError( - f"Tensors not close enough! {failed_percentage:.6f}% elements exceed tolerance.\n" - f"Greatest absolute difference: {max_abs_diff} at index {max_abs_pos} (up to {atol} allowed)\n" - f"Greatest relative difference: {max_rel_diff} at index {max_rel_pos} (up to {rtol} allowed)" - ) - -@pytest.mark.parametrize( - "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", - [ - # seqlen q == k - (1, 1, 1, 1, 1, 1), - (1, 1, 1, 2, 2, 2), # small enough to debug - (1, 1, 1, 4, 4, 16), - (1, 2, 2, 4, 4, 16), - (2, 1, 1, 4, 4, 16), - (2, 2, 2, 4, 4, 16), - (1, 1, 1, 128, 128, 32), # only one block - (3, 3, 3, 128, 128, 64), - (1, 1, 1, 127, 127, 32), # only one block but with masking - # (1, 1, 1, 129, 129, 1), # two blocks with 2nd block small enough to debug # fails - (1, 2, 2, 129, 129, 32), # two blocks with 2nd block small enough to debug - (1, 1, 1, 350, 350, 32), # two blocks with 2nd block small enough to debug - (1, 1, 1, 350, 350, 68), # generic masking on q, k and head - (4, 1, 1, 512, 512, 128), # batch > 1 - (4, 2, 2, 512, 512, 128), - (4, 2, 2, 512, 512, 68), - (4, 2, 2, 500, 500, 68), - (2, 4, 4, 1024, 1024, 64), - (4, 8, 8, 2048, 2048, 128), - (2, 8, 8, 4096, 4096, 64), - (2, 4, 4, 8192, 8192, 32), - # seqlen q > k - (1, 1, 1, 4, 2, 16), - (1, 1, 1, 64, 32, 8), - (1, 1, 1, 128, 64, 16), - (1, 1, 1, 192, 128, 32), - (1, 2, 2, 1024, 512, 68), - (1, 4, 4, 729, 516, 68), - (2, 4, 4, 2753, 1528, 68), # a comprehensive seqlen_q > seqlen_k - # seqlen q < k - (1, 1, 1, 2, 4, 16), - (1, 2, 2, 2, 4, 16), - (1, 4, 1, 2, 4, 16), - (1, 4, 2, 2, 4, 16), - (2, 2, 2, 2, 128, 1), - (2, 3, 3, 2, 128, 16), - (1, 1, 1, 32, 64, 8), - (1, 1, 1, 128, 192, 32), - (4, 6, 6, 108, 256, 32), - (3, 2, 2, 256, 512, 16), - (2, 2, 2, 512, 1024, 68), - (1, 1, 1, 200, 413, 32), - (1, 1, 1, 782, 1546, 32), - # gqa/mqa # mismatch issue on varlen - (4, 8, 2, 500, 500, 68), - (4, 8, 2, 512, 512, 68), - (4, 8, 2, 512, 512, 128), - (4, 8, 2, 512, 1024, 68), - (4, 8, 2, 1024, 512, 64), - (4, 16, 4, 1528, 2753, 68), - # fa configs - (2, 4, 1, 113, 203, 64), - (2, 4, 2, 128, 217, 64), - (2, 6, 2, 113, 211, 128), - (2, 6, 2, 108, 256, 128), - (2, 6, 2, 256, 512, 64), - (2, 6, 2, 512, 256, 64), - (2, 6, 2, 1024, 1024, 32), - (2, 6, 2, 1023, 1024, 32), - (2, 6, 6, 1024, 1023, 32), - (2, 6, 6, 2048, 2048, 32), - ], -) -@pytest.mark.parametrize('causal', [False, True]) -@pytest.mark.parametrize('dropout_p', [0.0]) -@pytest.mark.parametrize('layout', ["bshd", "thd"]) -@pytest.mark.parametrize('packing', [None, "qkv"]) -@pytest.mark.parametrize('DEBUG_INPUT', [False]) -@pytest.mark.flaky(reruns=3, reason="Retry failures") -@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") -def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, packing, DEBUG_INPUT): - torch.manual_seed(20) - test_backward = True - device = "cuda" - window_size = (-1, -1) - softcap = 0.0 - alibi_slopes = None - deterministic = False - ref_dtype = torch.float32 - is_varlen = True if layout == "thd" else False - - # skip QKV packing tests for uneven sequence lengths and head sizes - if packing == 'qkv': - if N_CTX_Q != N_CTX_K: - pytest.skip("QKV packing requires N_CTX_Q == N_CTX_K") - if HQ != HK: - pytest.skip("QKV packing requires HQ == HK") - - # test apis - if packing == 'qkv': - # generate inputs - qkv, do, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout, packing=packing, device=device) - - # ---------------------------------------------------------------- - # --- FP8 --- - # ---------------------------------------------------------------- - qkv_fp8 = qkv.clone() - do_fp8= do.clone() - - if is_varlen: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_triton_amd.fp8.flash_attn_varlen_qkvpacked_fp8_func( - qkv_fp8, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_triton_amd.fp8.flash_attn_qkvpacked_fp8_func( - qkv_fp8, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # ---------------------------------------------------------------- - # --- Reference --- - # ---------------------------------------------------------------- - # reference forward pass - qkv_ref = qkv.clone() - do_ref= do.clone() - - if is_varlen: - out_ref, lse_ref, S_dmask_ref = flash_attn.flash_attn_varlen_qkvpacked_func( - qkv_ref, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out_ref, lse_ref, S_dmask_ref = flash_attn.flash_attn_qkvpacked_func( - qkv_ref, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # ---------------------------------------------------------------- - # --- Compare --- - # ---------------------------------------------------------------- - # compare forward - if DEBUG: - print() - print(f"Compare fp8 against ref with dtype {ref_dtype}") - - if DEBUG: - print("out_ref:", out_ref, out_ref.shape) - print("out_fp8:", out_fp8, out_fp8.shape) - fp8_assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - - - if DEBUG: - print("lse_ref:", lse_ref, lse_ref.shape) - print("lse_fp8:", lse_fp8, lse_fp8.shape) - fp8_assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - - if dropout_p > 0.0: - if DEBUG: - print("S_dmask_ref:", S_dmask_ref, S_dmask_ref.shape) - print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape) - fp8_assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - if not test_backward: - return - - # fp8 backward pass - dqkv_fp8, = torch.autograd.grad(out_fp8, (qkv_fp8), do_fp8) - - # ref backward pass - dqkv_ref, = torch.autograd.grad(out_ref, (qkv_ref), do_ref) - - # compare backward gradients - if DEBUG: - print("dqkv_ref:", dqkv_ref, dqkv_ref.shape) - print("dqkv_fp8:", dqkv_fp8, dqkv_fp8.shape) - fp8_assert_close(dqkv_ref, dqkv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - elif packing is None: - # generate inputs - q, k, v, do, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout, device=device) - - # ---------------------------------------------------------------- - # --- FP8 --- - # ---------------------------------------------------------------- - if DEBUG: - print() - print(f"Compute Fp8 Forward") - q_fp8 = q.clone() - k_fp8 = k.clone() - v_fp8 = v.clone() - do_fp8= do.clone() - - if is_varlen: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_triton_amd.fp8.flash_attn_varlen_fp8_func( - q_fp8, - k_fp8, - v_fp8, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_triton_amd.fp8.flash_attn_fp8_func( - q_fp8, - k_fp8, - v_fp8, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # ---------------------------------------------------------------- - # --- Reference --- - # ---------------------------------------------------------------- - if DEBUG: - print() - print(f"Compute Reference Forward") - # reference forward pass - q_ref = q.clone() - k_ref = k.clone() - v_ref = v.clone() - do_ref = do.clone() - - if is_varlen: - out_ref, lse_ref, S_dmask_ref = flash_attn.flash_attn_varlen_func( - q_ref, - k_ref, - v_ref, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out_ref, lse_ref, S_dmask_ref = flash_attn.flash_attn_func( - q_ref, - k_ref, - v_ref, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # ---------------------------------------------------------------- - # --- Compare --- - # ---------------------------------------------------------------- - # compare forward - if DEBUG: - print() - print(f"Compare fp8 against ref with dtype {ref_dtype}") - - if DEBUG: - print("out_ref:", out_ref, out_ref.shape) - print("out_fp8:", out_fp8, out_fp8.shape) - # torch.testing.assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - fp8_assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - - - if DEBUG: - print("lse_ref:", lse_ref, lse_ref.shape) - print("lse_fp8:", lse_fp8, lse_fp8.shape) - # torch.testing.assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - fp8_assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - - if dropout_p > 0.0: - if DEBUG: - print("S_dmask_ref:", S_dmask_ref, S_dmask_ref.shape) - print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape) - # torch.testing.assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - fp8_assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - if not test_backward: - return - - if DEBUG: - print() - print(f"Compute Fp8 Backward") - # fp8 backward pass - dq_fp8, dk_fp8, dv_fp8 = torch.autograd.grad(out_fp8, (q_fp8, k_fp8, v_fp8), do_fp8) - - if DEBUG: - print() - print(f"Compute Reference Backward") - # ref backward pass - dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), do_ref) - - # compare backward gradients - if DEBUG: - print("dv_ref:", dv_ref, dv_ref.shape) - print("dv_fp8:", dv_fp8, dv_fp8.shape) - # torch.testing.assert_close(dv_ref, dv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) - fp8_assert_close(dv_ref, dv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - - if DEBUG: - print("dk_ref:", dk_ref, dk_ref.shape) - print("dk_fp8:", dk_fp8, dk_fp8.shape) - # torch.testing.assert_close(dk_ref, dk_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) - fp8_assert_close(dk_ref, dk_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - - if DEBUG: - print("dq_ref:", dq_ref, dq_ref.shape) - print("dq_fp8:", dq_fp8, dq_fp8.shape) - # torch.testing.assert_close(dq_ref, dq_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) - fp8_assert_close(dq_ref, dq_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - -@pytest.mark.parametrize( - "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", - [ - (2, 4, 4, 512, 512, 128), - ], -) -@pytest.mark.parametrize('causal', [False, True]) -@pytest.mark.parametrize('dropout_p', [0.0, 0.1]) -@pytest.mark.parametrize('layout', ['bshd']) -@pytest.mark.parametrize('packing', [None]) -@pytest.mark.parametrize('test_backward', [False, True]) -@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") -@pytest.mark.skip("Breaks on CI but works locally") -def test_ir(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, packing, test_backward): # Don't run this test in parallel. It clears the cache so it doesnot work properly if run in parallel. - torch.manual_seed(20) - device = "cuda" - window_size = (-1, -1) - softcap = 0.0 - alibi_slopes = None - deterministic = False - ref_dtype = torch.float32 - is_varlen = True if layout == "thd" else False - - # remove cache - cache_path = Path(os.path.expanduser("~/.triton/cache")) - if cache_path.exists(): - shutil.rmtree(cache_path) - os.makedirs(cache_path) - - # inputs - q, k, v, do, metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout=layout, packing=packing, device=device) - - if packing == None: - # fp8 forward pass - if is_varlen: - out, lse, S_dmask = flash_attn.flash_attn_triton_amd.fp8.flash_attn_varlen_fp8_func( - q, - k, - v, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out, lse, S_dmask = flash_attn.flash_attn_triton_amd.fp8.flash_attn_fp8_func( - q, - k, - v, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # fp8 backward pass - if test_backward: - dq, dk, dv = torch.autograd.grad(out, (q, k, v), do) - elif packing == "qkv": - # qkv packing path - # pack input tensors (use dim=1 for varlen, else dim=2) - if is_varlen: - qkv = torch.stack([q, k, v], dim=1) - else: - qkv = torch.stack([q, k, v], dim=2) - - # fp8 forward pass for qkv-packed input - if is_varlen: - out, lse, S_dmask = flash_attn.flash_attn_triton_amd.fp8.flash_attn_varlen_qkvpacked_fp8_func( - qkv, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out, lse, S_dmask = flash_attn.flash_attn_triton_amd.fp8.flash_attn_qkvpacked_fp8_func( - qkv, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # fp8 backward pass for qkv-packed input - if test_backward: - dqkv, = torch.autograd.grad(out, (qkv,), do) - else: - raise ValueError(f"unknown packing type {packing}") - - # search for .ttir files - max_retries = 5 - retry_delay = 0.5 - ttir_files = [] - logging.info(f"Checking for .ttir files in {cache_path}...") - for attempt in range(max_retries): - # search for .ttir files recursively within the cache path - ttir_files = glob.glob(str(cache_path) + "/**/*.ttir", recursive=True) - - if ttir_files: - # Files found, log success and exit the loop - logging.info(f"Found {len(ttir_files)} .ttir files on attempt {attempt + 1}.") - break - else: - # Files not found yet - if attempt < max_retries - 1: - # If not the last attempt, wait and log before retrying - logging.warning( - f"No .ttir files found on attempt {attempt + 1}. " - f"Retrying in {retry_delay}s..." - ) - time.sleep(retry_delay) - else: - pytest.fail( - f"FATAL: No .ttir files found in cache {cache_path} " - f"after {max_retries} attempts." - ) - - # check if there is fp8 - ttir_files_fp8_found_status = {} - fp8_types = ['f8E4M3', 'f8E5M2'] - for ttir_file in ttir_files: - base_name = os.path.basename(ttir_file) - with open(ttir_file, 'r') as f: - content = f.read() - - # check content for fp8 - fp8_found = False - for f8_type in fp8_types: - if f8_type in content: - fp8_found = True - ttir_files_fp8_found_status[base_name] = fp8_found - - for file, fp8_found in ttir_files_fp8_found_status.items(): - assert fp8_found, f"{fp8_types} not found in {file}" - - -def clear_compile_cache(): - """Clear torch compile caches to prevent graph merging""" - if hasattr(torch._dynamo, 'reset'): - torch._dynamo.reset() - torch.cuda.synchronize() - - -@pytest.mark.parametrize( - "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", - [ - # (4, 8, 8, 128, 128, 64), # small test - (32, 32, 32, 531, 531, 128), # original test - # (16, 48, 16, 256, 512, 64), # MQA test (HQ > HK) - ], -) -@pytest.mark.skip() # this works or not depending on the torch and triton version compatibility. Keep the test but use it in docker containers where things are in sync -def test_torch_compile(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD): - print(f"\n\nTesting with BATCH={BATCH}, HQ={HQ}, HK={HK}, N_CTX_Q={N_CTX_Q}, N_CTX_K={N_CTX_K}, D_HEAD={D_HEAD}") - - try: - # Test 1: flash_attn_func - print("\n1. Testing flash_attn_func...") - clear_compile_cache() - - q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD) - k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD) - v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD) - - flash_attn_func_compiled = torch.compile(flash_attn.flash_attn_func) - o = flash_attn_func_compiled(q, k, v, causal=True) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_func SUCCESS") - - # cleanup - del q, k, v, o - torch.cuda.empty_cache() - - - # Test 2: flash_attn_varlen_func - print("\n2. Testing flash_attn_varlen_func...") - clear_compile_cache() - - q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(BATCH * N_CTX_Q, HQ, D_HEAD, BATCH) - k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(BATCH * N_CTX_K, HK, D_HEAD, BATCH) - v, _, _ = generate_varlen_tensor(BATCH * N_CTX_K, HK, D_HEAD, BATCH) - - flash_attn_varlen_func_compiled = torch.compile(flash_attn.flash_attn_varlen_func) - o = flash_attn_varlen_func_compiled( - q, k, v, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, causal=True - ) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_varlen_func SUCCESS") - - # cleanup - del q, k, v, o, cu_seqlens_q, cu_seqlens_k - torch.cuda.empty_cache() - - - # Test 3: flash_attn_qkvpacked_func - print("\n3. Testing flash_attn_qkvpacked_func...") - clear_compile_cache() - - qkv = generate_bshd_qkv_packed(BATCH, N_CTX_Q, HQ, D_HEAD) - - flash_attn_qkvpacked_func_compiled = torch.compile(flash_attn.flash_attn_qkvpacked_func) - o = flash_attn_qkvpacked_func_compiled(qkv, causal=True) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_qkvpacked_func SUCCESS") - - # cleanup - del qkv, o - torch.cuda.empty_cache() - - - # Test 4: flash_attn_varlen_qkvpacked_func - print("\n4. Testing flash_attn_varlen_qkvpacked_func...") - clear_compile_cache() - - total_q = BATCH * N_CTX_Q - qkv, cu_seqlens, max_seqlen = generate_varlen_qkv_packed(total_q, HQ, D_HEAD, BATCH) - - flash_attn_varlen_qkvpacked_func_compiled = torch.compile(flash_attn.flash_attn_varlen_qkvpacked_func) - o = flash_attn_varlen_qkvpacked_func_compiled( - qkv, cu_seqlens, max_seqlen, causal=True - ) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_varlen_qkvpacked_func SUCCESS") - - # cleanup - del qkv, o, cu_seqlens - torch.cuda.empty_cache() - - - # Test 5: flash_attn_kvpacked_func - print("\n5. Testing flash_attn_kvpacked_func...") - clear_compile_cache() - - q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD) - kv = generate_bshd_kv_packed(BATCH, N_CTX_K, HK, D_HEAD) - - flash_attn_kvpacked_func_compiled = torch.compile(flash_attn.flash_attn_kvpacked_func) - o = flash_attn_kvpacked_func_compiled(q, kv, causal=True) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_kvpacked_func SUCCESS") - - # cleanup - del q, kv, o - torch.cuda.empty_cache() - - - # Test 6: flash_attn_varlen_kvpacked_func - print("\n6. Testing flash_attn_varlen_kvpacked_func...") - clear_compile_cache() - - q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(BATCH * N_CTX_Q, HQ, D_HEAD, BATCH) - kv, cu_seqlens_k, max_seqlen_k = generate_varlen_kv_packed(BATCH * N_CTX_K, HK, D_HEAD, BATCH) - - flash_attn_varlen_kvpacked_func_compiled = torch.compile(flash_attn.flash_attn_varlen_kvpacked_func) - o = flash_attn_varlen_kvpacked_func_compiled( - q, kv, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, causal=True - ) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_varlen_kvpacked_func SUCCESS") - - # cleanup - del q, kv, o, cu_seqlens_q, cu_seqlens_k - torch.cuda.empty_cache() - - - # Test 7: flash_attn_with_kvcache - print("\n7. Testing flash_attn_with_kvcache...") - clear_compile_cache() - - # setup cache dimensions - CACHE_SEQLEN = 1024 # max cache size - NEW_SEQLEN = 1 # for incremental decoding, usually 1 token at a time - - # create query for new tokens - q = generate_bshd_tensor(BATCH, NEW_SEQLEN, HQ, D_HEAD, dtype=torch.float16) - - # create kv cache using generators - k_cache = generate_bshd_tensor(BATCH, CACHE_SEQLEN, HK, D_HEAD, dtype=torch.float16) - v_cache = generate_bshd_tensor(BATCH, CACHE_SEQLEN, HK, D_HEAD, dtype=torch.float16) - - # cache sequence lengths - cache_seqlens = torch.full((BATCH,), 100, dtype=torch.int32, device='cuda') - - # new k,v to append to cache (optional) - k_new = generate_bshd_tensor(BATCH, NEW_SEQLEN, HK, D_HEAD, dtype=torch.float16) - v_new = generate_bshd_tensor(BATCH, NEW_SEQLEN, HK, D_HEAD, dtype=torch.float16) - - # Note: flash_attn_with_kvcache doesn't support backward pass - flash_attn_with_kvcache_compiled = torch.compile(flash_attn.flash_attn_with_kvcache) - - # Test with new k,v (append to cache and do attention) - with torch.no_grad(): - o = flash_attn_with_kvcache_compiled( - q, k_cache, v_cache, - k=k_new, v=v_new, - cache_seqlens=cache_seqlens, - causal=True - ) - print(f"Output shape (with new kv): {o.shape}, dtype: {o.dtype}") - - print("✓ flash_attn_with_kvcache SUCCESS") - - print("\n\n✅ ALL TESTS PASSED! ✅") - - except Exception as e: - print(f"\n❌ ERROR: {str(e)}") - # ensure we sync even on error to get proper error message - torch.cuda.synchronize() - raise e - finally: - # final cleanup - torch.cuda.empty_cache() - clear_compile_cache() - -# log env -if os.environ.get('PYTEST_XDIST_WORKER') in (None, 'gw0'): - logger.info("\n" + "="*80) - logger.info("ENVIRONMENT INFORMATION") - logger.info("="*80) - - # triton - logger.info(f" Triton Version: {triton.__version__}") - - # torch - logger.info(f" PyTorch Version: {torch.__version__}") - - # flash attention - logger.info(f" Flash Attention Version: {flash_attn.__version__}") - - logger.info("="*80 + "\n") diff --git a/flash_attn/flash_attn_triton_amd/train.py b/flash_attn/flash_attn_triton_amd/train.py deleted file mode 100644 index 8f39f627bfb..00000000000 --- a/flash_attn/flash_attn_triton_amd/train.py +++ /dev/null @@ -1,404 +0,0 @@ -import torch -import torch.nn as nn -import torch.optim as optim -from torch.utils.data import DataLoader, Dataset, random_split -import numpy as np -import pandas as pd -from tqdm import tqdm -import matplotlib.pyplot as plt -from datasets import load_dataset -import flash_attn -import flash_attn.flash_attn_triton_amd.fp8 - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -print(f"using device: {device}") - -# ------------------------------- -# Model -# ------------------------------- -class FlashAttention(nn.Module): - def __init__(self, dim, num_heads=8, causal=True, dropout=0.1, qkv_bias=True, use_fp8=False): - super().__init__() - self.use_fp8 = use_fp8 - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - self.causal = causal - self.dropout_p = dropout - - # qkv and output projections - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim) - - def forward(self, x): - b, n, c = x.shape - # project to qkv - qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - - # reshape for flash attention function - qkv_packed = torch.stack([q, k, v], dim=2).reshape(b, n, 3, self.num_heads, self.head_dim) - - # use the appropriate flash attention function - if self.use_fp8: - context = flash_attn.flash_attn_triton_amd.fp8.flash_attn_qkvpacked_fp8_func( - qkv_packed, - dropout_p=self.dropout_p, - causal=self.causal - ) - else: - context = flash_attn.flash_attn_qkvpacked_func( - qkv_packed, - dropout_p=self.dropout_p, - causal=self.causal - ) - - # convert back to original shape and type - context = context.reshape(b, n, c) - - # output projection - x = self.proj(context) - - return x - -class TransformerBlock(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4.0, causal=True, dropout=0.1, use_fp8=False): - super().__init__() - self.norm1 = nn.LayerNorm(dim) - self.attn = FlashAttention(dim, num_heads=num_heads, causal=causal, dropout=dropout, use_fp8=use_fp8) - - self.norm2 = nn.LayerNorm(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = nn.Sequential( - nn.Linear(dim, mlp_hidden_dim), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(mlp_hidden_dim, dim), - nn.Dropout(dropout) - ) - - def forward(self, x): - x = x + self.attn(self.norm1(x)) - x = x + self.mlp(self.norm2(x)) - return x - -class FlashLM(nn.Module): - def __init__( - self, - vocab_size, - dim=256, - depth=6, - num_heads=8, - mlp_ratio=4.0, - causal=True, - dropout=0.1, - max_seq_len=256, - use_fp8=False - ): - super().__init__() - - # embedding layers - self.token_embedding = nn.Embedding(vocab_size, dim) - self.position_embedding = nn.Parameter(torch.zeros(1, max_seq_len, dim)) - self.dropout = nn.Dropout(dropout) - - # transformer blocks - self.blocks = nn.ModuleList([ - TransformerBlock(dim, num_heads, mlp_ratio, causal=causal, dropout=dropout, use_fp8=use_fp8) - for _ in range(depth) - ]) - - # lm head: project back to vocabulary dimension for each token - self.norm = nn.LayerNorm(dim) - self.lm_head = nn.Linear(dim, vocab_size) - - def forward(self, x): - b, n = x.shape - - # token + positional embedding - x = self.token_embedding(x) - x = x + self.position_embedding[:, :n, :] - x = self.dropout(x) - - # transformer blocks - for block in self.blocks: - x = block(x) - - # language modeling head - x = self.norm(x) - logits = self.lm_head(x) # shape: (b, n, vocab_size) - return logits - -# ------------------------------- -# Data -# ------------------------------- -class TextDataset(Dataset): - def __init__(self, sequences, max_len=None): - self.sequences = sequences - self.max_len = max_len - - def __len__(self): - return len(self.sequences) - - def __getitem__(self, idx): - seq = self.sequences[idx] - # input: all tokens except the last, target: all tokens except the first - return (torch.tensor(seq[:-1], dtype=torch.long), - torch.tensor(seq[1:], dtype=torch.long)) - -class VarLenTextDataset(Dataset): - def __init__(self, sequences, max_len=256): - self.sequences = sequences - self.max_len = max_len - - def __len__(self): - return len(self.sequences) - - def __getitem__(self, idx): - seq = self.sequences[idx] - # Ensure the sequence doesn't exceed max_len+1 - seq = seq[:self.max_len+1] - # input: all tokens except the last, target: all tokens except the first - return (torch.tensor(seq[:-1], dtype=torch.long), - torch.tensor(seq[1:], dtype=torch.long)) - -def prepare_dataset(batch_size, is_varlen=False, min_len=10, max_len=256, ratio_shorter=0.7): - # load the WikiText-2 - dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") - - # build vocabulary - corpus = " ".join([line for line in dataset["text"] if line.strip() != ""]) # join non-empty lines into a single corpus string - tokens = corpus.split() - vocab = sorted(set(tokens)) - word2idx = {word: idx for idx, word in enumerate(vocab)} - token_ids = [word2idx[word] for word in tokens] - - num_workers = 2 - if is_varlen: - # VARIABLE LENGTH: create sequences of different lengths - sequences = [] - for i in range(0, len(token_ids) - max_len, max_len // 2): # overlap to get more sequences - # Decide target length for this sequence - if np.random.random() < ratio_shorter: - # Shorter sequence - target_len = np.random.randint(min_len + 1, max_len + 1) - else: - # Full length sequence - target_len = max_len + 1 - - # Extract sequence up to target length or whatever's available - seq_end = min(i + target_len, len(token_ids)) - seq = token_ids[i:seq_end] - - # Only keep sequences that are long enough - if len(seq) > min_len + 1: # +1 because we need both input and target - sequences.append(seq) - - print(f"Created {len(sequences)} variable-length sequences") - - # Get some statistics - lens = [len(seq) for seq in sequences] - print(f"Sequence length stats: min={min(lens)}, max={max(lens)}, mean={np.mean(lens):.1f}") - - # split dataset - num_samples = len(sequences) - num_train = int(0.8 * num_samples) - num_val = num_samples - num_train - - # Use appropriate dataset class based on whether we need variable length - dataset_class = VarLenTextDataset - train_sequences = sequences[:num_train] - val_sequences = sequences[num_train:] - - train_dataset = dataset_class(train_sequences, max_len) - val_dataset = dataset_class(val_sequences, max_len) - - - # collate function - def collate_fn(batch): - """ - Collate function that creates a flat representation for variable length flash attention. - """ - # Separate inputs and targets - inputs, targets = zip(*batch) - - # Get sequence lengths - seq_lens = torch.tensor([len(x) for x in inputs], dtype=torch.int32) - - # Concatenate inputs and targets into single tensors - flat_inputs = torch.cat(inputs) - flat_targets = torch.cat(targets) - - # Create cumulative sequence lengths tensor - cu_seqlens = torch.zeros(len(seq_lens) + 1, dtype=torch.int32) - cu_seqlens[1:] = torch.cumsum(seq_lens, dim=0) - - # Calculate max sequence length for this batch - max_seqlen = seq_lens.max().item() - - return flat_inputs, flat_targets, seq_lens, cu_seqlens, max_seqlen - - # data loaders - train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn) - val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn) - else: - # FIXED LENGTH: create sequences of length max_len+1 - sequences = [] - for i in range(0, len(token_ids) - max_len, max_len): - seq = token_ids[i : i + max_len + 1] - if len(seq) == max_len + 1: - sequences.append(seq) - - # split dataset - num_samples = len(sequences) - num_train = int(0.8 * num_samples) - num_val = num_samples - num_train - train_dataset, val_dataset = random_split(TextDataset(sequences), [num_train, num_val]) - - # data loaders - train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) - val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) - - vocab_size = len(vocab) - print(f"vocab size: {vocab_size}, train samples: {len(train_dataset)}, validation samples: {len(val_dataset)}") - return train_dataloader, val_dataloader, vocab_size - -# ------------------------------- -# Training -# ------------------------------- -def train_lm(model, train_dataloader, val_dataloader, optimizer, criterion, num_epochs): - train_losses = [] - val_losses = [] - for epoch in range(num_epochs): - # Training phase - model.train() - epoch_train_loss = 0.0 - for inputs, targets in tqdm(train_dataloader, desc=f"epoch {epoch+1}/{num_epochs} [train]"): - inputs, targets = inputs.to(device), targets.to(device) - - optimizer.zero_grad() - logits = model(inputs) - loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) - loss.backward() - optimizer.step() - - epoch_train_loss += loss.item() - - epoch_train_loss /= len(train_dataloader) - train_losses.append(epoch_train_loss) - print(f"epoch {epoch+1}/{num_epochs} - train loss: {epoch_train_loss:.4f}") - - # Validation phase - model.eval() - epoch_val_loss = 0.0 - with torch.no_grad(): - for inputs, targets in tqdm(val_dataloader, desc=f"epoch {epoch+1}/{num_epochs} [validation]"): - inputs, targets = inputs.to(device), targets.to(device) - logits = model(inputs) - loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) - epoch_val_loss += loss.item() - epoch_val_loss /= len(val_dataloader) - val_losses.append(epoch_val_loss) - print(f"epoch {epoch+1}/{num_epochs} - validation loss: {epoch_val_loss:.4f}") - - return train_losses, val_losses - -# ------------------------------- -# Main -# ------------------------------- -def main(): - # hyperparameters - batch_size = 16 - num_epochs = 20 - learning_rate = 3e-4 - max_len = 128 # total length including both input and target tokens - is_varlen = False - causal=True - dropout=0.1 - - # prep data - print("Preparing Dataset") - train_dataloader, val_dataloader, vocab_size = prepare_dataset(batch_size, max_len=max_len, is_varlen=is_varlen) - - # create language models - print("Creating Models") - model_normal = FlashLM( - vocab_size=vocab_size, - dim=256, - depth=3, - num_heads=8, - causal=causal, - dropout=dropout, - max_seq_len=max_len, - ).to(device) - - model_fp8 = FlashLM( - vocab_size=vocab_size, - dim=256, - depth=3, - num_heads=8, - causal=causal, - dropout=dropout, - max_seq_len=max_len, - use_fp8=True - ).to(device) - - # Train Normal model - print("Starting training for Normal model...") - optimizer_normal = optim.AdamW(model_normal.parameters(), lr=learning_rate) - criterion = nn.CrossEntropyLoss() - normal_train_losses, normal_val_losses = train_lm( - model_normal, train_dataloader, val_dataloader, optimizer_normal, criterion, num_epochs - ) - torch.save(model_normal.state_dict(), 'flash_lm_normal.pth') - print("Normal model training complete and saved.") - - # Train FP8 model - print("Starting training for FP8 model...") - optimizer_fp8 = optim.AdamW(model_fp8.parameters(), lr=learning_rate) - fp8_train_losses, fp8_val_losses = train_lm( - model_fp8, train_dataloader, val_dataloader, optimizer_fp8, criterion, num_epochs - ) - torch.save(model_fp8.state_dict(), 'flash_lm_fp8.pth') - print("FP8 model training complete and saved.") - - # save losses to csv - epochs = range(1, num_epochs+1) - loss_data = { - "Epoch": epochs, - "Normal_Training_Loss": normal_train_losses, - "Normal_Validation_Loss": normal_val_losses, - "FP8_Training_Loss": fp8_train_losses, - "FP8_Validation_Loss": fp8_val_losses, - } - df_losses = pd.DataFrame(loss_data) - df_losses.to_csv("losses.csv", index=False) - print("Loss data saved to losses.csv") - - # plot Training Loss - plt.figure(figsize=(10, 6)) - plt.plot(epochs, normal_train_losses, label="Normal Training Loss", marker='o') - plt.plot(epochs, fp8_train_losses, label="FP8 Training Loss", marker='x') - plt.xlabel("Epoch") - plt.ylabel("Training Loss") - plt.title("Training Loss Comparison: Normal vs FP8 Flash Attention") - plt.legend() - plt.grid(True) - plt.savefig("training_loss.png") # Saves the training loss plot to disk - plt.show() - - # Plot Validation Loss - plt.figure(figsize=(10, 6)) - plt.plot(epochs, normal_val_losses, label="Normal Validation Loss", marker='o') - plt.plot(epochs, fp8_val_losses, label="FP8 Validation Loss", marker='x') - plt.xlabel("Epoch") - plt.ylabel("Validation Loss") - plt.title("Validation Loss Comparison: Normal vs FP8 Flash Attention") - plt.legend() - plt.grid(True) - plt.savefig("validation_loss.png") # Saves the validation loss plot to disk - plt.show() - - -if __name__ == "__main__": - main() diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 96d4f662567..44502785a35 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -139,11 +139,11 @@ def check_args(self, q, k, v, o): assert q.dim() == 4 assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 assert self.cu_seqlens_q is None and self.cu_seqlens_k is None - assert k.shape == v.shape - assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] # and q.shape[-1] == v.shape[-1] # TODO: Change assert if we support qkl f8 and v f16 assert q.dtype == k.dtype and q.dtype == v.dtype - assert o.shape == q.shape + assert o.shape[:-1] == q.shape[:-1] and o.shape[-1] == v.shape[-1] assert (nheads_q % nheads_k) == 0 assert self.layout is not None assert self.layout == 'thd' or not self.varlen @@ -1064,91 +1064,6 @@ def write_dropout_mask(x, tensor_name = "tensor"): else: writer.writerows(dropout_mask) -# ------------------------------- -# Autotune -# ------------------------------- -def get_fwd_prefill_cdna_autotune_configs(): - return [ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] - - -def get_fwd_prefill_rdna_autotune_configs(): - return [ - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] - - -def get_fwd_prefill_autotune_configs(): - if AUTOTUNE: - if is_rdna(): - return get_fwd_prefill_rdna_autotune_configs() - elif is_cdna(): - return get_fwd_prefill_cdna_autotune_configs() - else: - raise ValueError("Unknown Device Type") - else: - arch = get_arch() - if arch == "gfx950": - default_config = triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ) - elif arch == "gfx942" and False: # Disabled due shared mem oom in CI when using triton==3.3.0 when using top of tree everything seems fine. - default_config = triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ) - else: - default_config = triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ) - - return [ - default_config - ], [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_BLOCK_DMODEL", - "IS_VARLEN", - "HQ", - "HK", - ] - # ------------------------------- # Runtime info # ------------------------------- diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py old mode 100644 new mode 100755 index 0e93f234aa3..644e86e4b13 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -2,16 +2,24 @@ from typing import Optional, Union +import os import torch import torch.nn as nn -# isort: off -# We need to import the CUDA kernels after importing torch -import flash_attn_3._C # Registers operators with PyTorch -# isort: on +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" +if USE_TRITON_ROCM: + import sys + sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from flash_attn.flash_attn_triton_amd import interface_fa_v3 as flash_attn_3_gpu +else: + # isort: off + # We need to import the CUDA kernels after importing torch + import flash_attn_3._C # Registers operators with PyTorch -flash_attn_3_cuda = torch.ops.flash_attn_3 + # isort: on + + flash_attn_3_gpu = torch.ops.flash_attn_3 def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x @@ -62,7 +70,7 @@ def _flash_attn_forward( ] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] seqlens_rotary = maybe_contiguous(seqlens_rotary) - out, softmax_lse, *rest = flash_attn_3_cuda.fwd( + out, softmax_lse, *rest = flash_attn_3_gpu.fwd( q, k, v, @@ -126,7 +134,7 @@ def _flash_attn_backward( ): # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd( + dq, dk, dv, softmax_d, *rest = flash_attn_3_gpu.bwd( dout, q, k, @@ -625,7 +633,7 @@ def flash_attn_varlen_func( def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None): - return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype) + return flash_attn_3_gpu.fwd_combine(out_partial, lse_partial, out, out_dtype) def flash_attn_with_kvcache( @@ -812,7 +820,7 @@ def get_scheduler_metadata( cache_seqlens = maybe_contiguous(cache_seqlens) if headdim_v is None: headdim_v = headdim - scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata( + scheduler_metadata = flash_attn_3_gpu.get_scheduler_metadata( batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v, qkv_dtype, cache_seqlens, diff --git a/hopper/setup.py b/hopper/setup.py old mode 100644 new mode 100755 index 10894252db0..437ddc8c9aa --- a/hopper/setup.py +++ b/hopper/setup.py @@ -43,6 +43,10 @@ SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" +# ROCm specific settings +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" +if USE_TRITON_ROCM: + SKIP_CUDA_BUILD = True DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" @@ -383,10 +387,10 @@ def nvcc_threads_args(): cmdclass = {} ext_modules = [] - # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp # files included in the source distribution, in case the user compiles from source. -subprocess.run(["git", "submodule", "update", "--init", "../csrc/cutlass"]) +if not USE_TRITON_ROCM: + subprocess.run(["git", "submodule", "update", "--init", "../csrc/cutlass"]) if not SKIP_CUDA_BUILD: print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) diff --git a/hopper/test_flash_attn_triton_amd.py b/hopper/test_flash_attn_triton_amd.py new file mode 100755 index 00000000000..02141a72ed7 --- /dev/null +++ b/hopper/test_flash_attn_triton_amd.py @@ -0,0 +1,1173 @@ +import os +import math +import itertools + +import pytest +import torch +import torch.nn.functional as F +from torch._C import parse_schema + +from einops import rearrange, repeat +try: + from flash_attn.layers.rotary import apply_rotary_emb +except ImportError: + apply_rotary_emb = None + +from padding import pad_input, unpad_input +from test_util import ( + attention_ref, + generate_qkv, + generate_random_padding_mask, +) + +from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine +from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata + + +DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" +DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "TRUE") == "TRUE" +DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" +DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "TRUE") == "TRUE" +DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "TRUE") == "TRUE" +DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "TRUE") == "TRUE" +DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "TRUE") == "TRUE" +DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" +DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" or torch.cuda.get_device_capability("cuda")[0] < 9 +DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE" +DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE" +DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" +DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" +DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" + +COMPILED_HDIMS = ( + [] + + ([64] if not DISABLE_HDIM64 else []) + + ([96] if not DISABLE_HDIM96 else []) + + ([128] if not DISABLE_HDIM128 else []) + + ([192] if not DISABLE_HDIM192 else []) + + ([256] if not DISABLE_HDIM256 else []) +) + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("V_colmajor", [False, True]) +@pytest.mark.parametrize("V_colmajor", [False]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64, 128, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128, 192]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +def test_flash_attn_output( + seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype +): + if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): + pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") + device = "cuda" + # set seed + torch.random.manual_seed(0) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_k <= 2048 else 2 + # batch_size = 1 + nheads = 6 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4) + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + # window_size = (-1, -1) if not local else (16, 0) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + if V_colmajor: + v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() + # if qv is not None: + # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out = flash_attn_func( + q, + k, + v, + causal=causal, + qv=qv, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not V_colmajor + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g = torch.randn_like(out) + do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) + # import flash_attn_3_cuda + # dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flash_attn_3_cuda.bwd( + # g, + # q, + # k, + # v, + # out, + # lse, + # None, + # None, + # None, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("add_unused_qkv", [False, True]) +# @pytest.mark.parametrize("add_unused_qkv", [True]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (1, 3), + (2, 1), + (511, 1), + (3, 513), + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (307, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +def test_flash_attn_varlen_output( + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype +): + device = "cuda" + # set seed + torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_q <= 2048 else 2 + nheads = 6 + # batch_size = 2 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="random", zero_lengths=True + ) + + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + ( + q_unpad, + k_unpad, + v_unpad, + qv_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, + query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) + q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out_unpad = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + causal=causal, + qv=qv_unpad, + q_descale=q_descale, + k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g_unpad = torch.randn_like(out_unpad) + do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + # import flash_attn_3_cuda + # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( + # g_unpad, + # q_unpad, + # k_unpad, + # v_unpad, + # out_unpad, + # lse, + # None, + # None, + # None, + # cu_seqlens_q, + # cu_seqlens_k, + # None, None, + # max_seqlen_q, + # max_seqlen_k, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) + + # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() + # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) +# @pytest.mark.parametrize("new_kv", [True]) +@pytest.mark.parametrize("new_kv", [False, True]) +@pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) +# @pytest.mark.parametrize("causal,local", [(False, False)]) +@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False]) +@pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) +# @pytest.mark.parametrize("rotary_interleaved", [True]) +@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) +# @pytest.mark.parametrize("rotary_fraction", [0.0]) +@pytest.mark.parametrize("page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else [])) +# @pytest.mark.parametrize("page_size", [None]) +@pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_leftpad", [False]) +@pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("has_batch_idx", [False]) +@pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +@pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + # (1, 128 * 1024), + # (16, 128 * 1024), + (128, 128), + (256, 512), # To test appending KV with more than 1 block + (2048, 3577), # Enough tile to test persistent scheduler + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_kvcache( + seqlen_q, + seqlen_k, + d, + varlen_q, + has_batch_idx, + has_leftpad, + page_size, + rotary_fraction, + rotary_interleaved, + has_rotary_seqlens, + seqlen_new_eq_seqlen_q, + causal, + local, + new_kv, + mha_type, + dtype, +): + if page_size is not None and seqlen_k % page_size != 0: + pytest.skip() + if seqlen_q > seqlen_k and new_kv: + pytest.skip() + if not new_kv and rotary_fraction > 0.0: + pytest.skip() + if rotary_fraction == 0.0 and has_rotary_seqlens: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + # batch_size = 1 + batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 + nheads = 6 + # nheads = 1 + # rotary_dim must be a multiple of 16, and must be <= d + rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + has_qv = d == 64 and dv >= 256 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if has_qv: + qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv = None + if varlen_q: + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + else: + query_padding_mask = None + q_unpad = q + qv_unpad = qv + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + + seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v + else: + k, v, k_unpad, v_unpad = None, None, None, None + if page_size is None: + k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + page_table = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref + ) + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, + ) + if has_leftpad: + cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) + if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size)]) + else: + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ + :batch_size + ] + else: + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded + else: + k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() + v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens + ) + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + attention_chunk=attention_chunk, + key_leftpad=cache_leftpad, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + attention_chunk=attention_chunk, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None + ) + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() + v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() + num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] + precompute_metadata_vals = [False] + for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): + if precompute_metadata: + scheduler_metadata = get_scheduler_metadata( + batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, + cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, + max_seqlen_k_new=seqlen_new, page_size=page_size, + causal=causal, window_size=window_size, attention_chunk=attention_chunk, + num_splits=num_splits + ) + else: + scheduler_metadata = None + # Repeat to test metadata reuse + for _ in range(1 if not precompute_metadata else 2): + if page_size is None: + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) + else: + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + out, lse, *rest = flash_attn_with_kvcache( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + k if not new_kv or not varlen_q else k_unpad, + v if not new_kv or not varlen_q else v_unpad, + qv=qv if not varlen_q else qv_unpad, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + causal=causal, + window_size=window_size, + attention_chunk=attention_chunk, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + return_softmax_lse=True + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) + if dtype is not torch.float8_e4m3fn: + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + else: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() + + +def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): + num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 + k_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) + v_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) + page_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache = rearrange( + k_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache = rearrange( + v_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize('d', [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (64, 8192), + ], +) +def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype): + device = "cuda" + torch.random.manual_seed(0) + batch_size = 2 + nheads = 16 + nheads_kv = 4 + # There was a bug where this would cause "unspecified launch failure" due to Cluster + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) + for _ in range(100): + flash_attn_func(q, k, v, causal=causal) + + +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128]) +# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [80]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 239), + (239, 1), + (3, 799), + (799, 3), + (1024, 128), + (97, 97), + (128, 128), + (200, 200), + (256, 256), + (257, 257), + (384, 384), + (512, 512), + (768, 768), + (1024, 1024), + (2048, 2048), + ], +) +def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): + device = "cuda" + # set seed + torch.random.manual_seed(0) + # Simulate under memory load + dummy = torch.empty(70 * 1024 ** 3, dtype=torch.uint8, device=device) + batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger + nheads = 4 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + torch.random.manual_seed(42) + out0 = flash_attn_func(q, k, v, causal=causal) + g = torch.randn_like(out0) + dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g) + # Numerical error if we just do any arithmetic on dq + dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() + + for i in range(1000): + torch.random.manual_seed(42) + out = flash_attn_func(q, k, v, causal=causal) + assert torch.equal(out, out0) + # assert torch.equal(lse, lse0) + + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_equal = torch.allclose(dq, dq0, atol=dq_atol) + if not dq_equal: + print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") + # breakpoint() + assert torch.equal(dv, dv0) + assert torch.equal(dk, dk0) + assert dq_equal + + +def attention_combine_ref(out_partial, lse_partial): + """ + out_partial: (num_splits, batch_size, seqlen, nheads, d) + lse_partial: (num_splits, batch_size, nheads, seqlen) + """ + lse = torch.logsumexp(lse_partial, dim=0) + scale = torch.exp(lse_partial - lse) + scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale) + out = (scale.unsqueeze(-1) * out_partial).sum(0) + return out, lse + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float32]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) +# @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) +# @pytest.mark.parametrize("seqlen", [15]) +@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) +# @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) +# @pytest.mark.parametrize("num_splits", [128]) +def test_flash_attn_combine(num_splits, seqlen, d, dtype): + if DISABLE_SPLIT: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(1) + batch_size = 5 + nheads = 16 + # batch_size = 1 + # nheads = 1 + out_partial = torch.randn(num_splits * 2, batch_size, nheads, seqlen, d, device=device, dtype=torch.float32).transpose(2, 3)[:num_splits] # To test non-contiguous tensor + lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor + # To test short-circuiting based on num_splits + lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf") + out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype) + out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) + out_pt = out_ref.to(dtype) + + print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + print(f"LSE mean diff: {(lse - lse_ref).abs().mean().item()}") + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) + multiple = 2 + assert ((out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) + + # from flash_attn.utils.benchmark import pytorch_profiler + # # pytorch_profiler(torch.sum, lse_partial) + # pytorch_profiler(flash_attn_combine, out_partial, lse_partial) + # pytorch_profiler(torch.sum, out_partial) + +@pytest.mark.skip(reason="AMD Triton backend doesn't use torch ops registration") +def test_flash3_bw_compatibility() -> None: + # Let's try to always stay backward compatible! This will make life easier + # for downstream libaries, users, and exported models. + # 1/ Instead of removing arguments, error out if their value is no longer supported + # 2/ When adding arguments, add them at the end with a default value + assert torch.ops.flash_attn_3.fwd.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::fwd(Tensor q, Tensor k, Tensor v, Tensor(k_new!)? k_new=None, " + "Tensor(v_new!)? v_new=None, Tensor? q_v=None, Tensor(out!)? out=None, " + "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, " + "Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, Tensor? seqused_k=None, " + "int? max_seqlen_q=None, int? max_seqlen_k=None, Tensor? page_table=None, " + "Tensor? kv_batch_idx=None, Tensor? leftpad_k=None, Tensor? rotary_cos=None, Tensor? rotary_sin=None, " + "Tensor? seqlens_rotary=None, Tensor? q_descale=None, Tensor? k_descale=None, Tensor? v_descale=None, " + "float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " + "int attention_chunk=0, float softcap=0., bool is_rotary_interleaved=False, " + "Tensor? scheduler_metadata=None, int num_splits=0, bool? pack_gqa=None, int sm_margin=0) " + "-> (Tensor(out!), Tensor, Tensor, Tensor)" + )) + assert torch.ops.flash_attn_3.bwd.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, " + "Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, " + "Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, " + "int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, " + "int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) " + "-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)" + )) + assert torch.ops.flash_attn_3.fwd_combine.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::fwd_combine(Tensor out_partial, Tensor lse_partial, Tensor(out!)? out=None, " + "ScalarType? out_dtype=None) -> (Tensor(out!), Tensor)" + )) + assert torch.ops.flash_attn_3.get_scheduler_metadata.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::get_scheduler_metadata(int batch_size, int max_seqlen_q, int max_seqlen_k, " + "int num_heads, int num_heads_k, int headdim, int headdim_v, ScalarType qkv_dtype, Tensor seqused_k, " + "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? cu_seqlens_k_new=None, " + "Tensor? seqused_q=None, Tensor? leftpad_k=None, int? page_size=None, int max_seqlen_k_new=0, " + "bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " + "int attention_chunk=0, bool has_softcap=False, int num_splits=0, bool? pack_gqa=None, " + "int sm_margin=0) -> Tensor" + )) diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index adebe088a79..046edadb8db 100755 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -1921,7 +1921,7 @@ def test_flash_attn_splitkv( # @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) -@pytest.mark.parametrize("paged_kv_block_size", [None]) +@pytest.mark.parametrize("paged_kv_block_size", [None, 256]) # @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) # @pytest.mark.parametrize("paged_kv_block_size", [None]) @pytest.mark.parametrize("has_leftpad", [False]) @@ -1950,6 +1950,7 @@ def test_flash_attn_splitkv( ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +@pytest.mark.parametrize("use_generated_tensors", [False]) def test_flash_attn_kvcache( seqlen_q, seqlen_k, @@ -1967,6 +1968,7 @@ def test_flash_attn_kvcache( mha_type, num_splits, dtype, + use_generated_tensors, ): if seqlen_q > seqlen_k and new_kv: pytest.skip() @@ -1976,6 +1978,9 @@ def test_flash_attn_kvcache( pytest.skip() if has_leftpad and paged_kv_block_size is not None: pytest.skip() + # Skip problematic case: paged attention with alibi and multiple queries + if USE_TRITON_ROCM and paged_kv_block_size is not None and alibi and seqlen_q > 1: + pytest.skip("Paged attention with alibi and multiple queries has numerical issues") device = "cuda" # set seed torch.random.manual_seed(0) @@ -1987,16 +1992,31 @@ def test_flash_attn_kvcache( nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + + if use_generated_tensors: + # Use generate_bshd_tensor with incremental mode + q = generate_bshd_tensor(batch_size, seqlen_q, nheads, d, dtype=dtype, device=device, mode="incremental") + else: + # Original random tensor generation + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() if new_kv: - k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) - v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) + if use_generated_tensors: + k = generate_bshd_tensor(batch_size, seqlen_new, nheads_k, d, dtype=dtype, device=device, mode="incremental") + v = generate_bshd_tensor(batch_size, seqlen_new, nheads_k, d, dtype=dtype, device=device, mode="incremental") + else: + k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) else: k, v = None, None if paged_kv_block_size is None: - k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) - v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) + if use_generated_tensors: + k_cache = generate_bshd_tensor(batch_size_cache, seqlen_k, nheads_k, d, dtype=dtype, device=device, mode="incremental") + v_cache = generate_bshd_tensor(batch_size_cache, seqlen_k, nheads_k, d, dtype=dtype, device=device, mode="incremental") + else: + k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) block_table = None else: ( @@ -2007,7 +2027,7 @@ def test_flash_attn_kvcache( v_cache_paged, num_blocks, ) = _generate_block_kvcache( - seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype + seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype, use_generated_tensors ) cache_seqlens = torch.randint( 0 if new_kv else 1, @@ -2186,14 +2206,22 @@ def test_flash_attn_kvcache( assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 -def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype): +def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype, use_generated_tensors=False): num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3 - k_cache_paged = torch.randn( - num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype - ) - v_cache_paged = torch.randn( - num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype - ) + if use_generated_tensors: + k_cache_paged = generate_bshd_tensor( + num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype, mode="incremental" + ) + v_cache_paged = generate_bshd_tensor( + num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype, mode="incremental" + ) + else: + k_cache_paged = torch.randn( + num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype + ) + v_cache_paged = torch.randn( + num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype + ) block_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", From 9dba28dd37c9279bd1d840938b48a59acc459f7b Mon Sep 17 00:00:00 2001 From: Michael Date: Thu, 11 Sep 2025 14:10:56 -0500 Subject: [PATCH 2/7] clean up --- tests/test_flash_attn_triton_amd.py | 56 ++++++++--------------------- 1 file changed, 15 insertions(+), 41 deletions(-) diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index 046edadb8db..a8d71653850 100755 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -1950,7 +1950,6 @@ def test_flash_attn_splitkv( ], ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) -@pytest.mark.parametrize("use_generated_tensors", [False]) def test_flash_attn_kvcache( seqlen_q, seqlen_k, @@ -1968,7 +1967,6 @@ def test_flash_attn_kvcache( mha_type, num_splits, dtype, - use_generated_tensors, ): if seqlen_q > seqlen_k and new_kv: pytest.skip() @@ -1978,9 +1976,8 @@ def test_flash_attn_kvcache( pytest.skip() if has_leftpad and paged_kv_block_size is not None: pytest.skip() - # Skip problematic case: paged attention with alibi and multiple queries - if USE_TRITON_ROCM and paged_kv_block_size is not None and alibi and seqlen_q > 1: - pytest.skip("Paged attention with alibi and multiple queries has numerical issues") + if USE_TRITON_ROCM and paged_kv_block_size is not None and alibi: + pytest.skip("Paged attention with alibi has numerical issues") device = "cuda" # set seed torch.random.manual_seed(0) @@ -1992,31 +1989,16 @@ def test_flash_attn_kvcache( nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) - - if use_generated_tensors: - # Use generate_bshd_tensor with incremental mode - q = generate_bshd_tensor(batch_size, seqlen_q, nheads, d, dtype=dtype, device=device, mode="incremental") - else: - # Original random tensor generation - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) - + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() if new_kv: - if use_generated_tensors: - k = generate_bshd_tensor(batch_size, seqlen_new, nheads_k, d, dtype=dtype, device=device, mode="incremental") - v = generate_bshd_tensor(batch_size, seqlen_new, nheads_k, d, dtype=dtype, device=device, mode="incremental") - else: - k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) - v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) + k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) else: k, v = None, None if paged_kv_block_size is None: - if use_generated_tensors: - k_cache = generate_bshd_tensor(batch_size_cache, seqlen_k, nheads_k, d, dtype=dtype, device=device, mode="incremental") - v_cache = generate_bshd_tensor(batch_size_cache, seqlen_k, nheads_k, d, dtype=dtype, device=device, mode="incremental") - else: - k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) - v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) + k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) block_table = None else: ( @@ -2027,7 +2009,7 @@ def test_flash_attn_kvcache( v_cache_paged, num_blocks, ) = _generate_block_kvcache( - seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype, use_generated_tensors + seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype ) cache_seqlens = torch.randint( 0 if new_kv else 1, @@ -2206,22 +2188,14 @@ def test_flash_attn_kvcache( assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 -def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype, use_generated_tensors=False): +def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype): num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3 - if use_generated_tensors: - k_cache_paged = generate_bshd_tensor( - num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype, mode="incremental" - ) - v_cache_paged = generate_bshd_tensor( - num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype, mode="incremental" - ) - else: - k_cache_paged = torch.randn( - num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype - ) - v_cache_paged = torch.randn( - num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype - ) + k_cache_paged = torch.randn( + num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype + ) + v_cache_paged = torch.randn( + num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype + ) block_table = rearrange( torch.randperm(num_blocks, dtype=torch.int32, device=device), "(b nblocks) -> b nblocks", From 14cbdfc5c128ac37e7e24aadbdcdb8cfa8f2fe19 Mon Sep 17 00:00:00 2001 From: Michael Date: Thu, 11 Sep 2025 19:41:46 -0500 Subject: [PATCH 3/7] skip hopper race test --- hopper/test_flash_attn_triton_amd.py | 1 + 1 file changed, 1 insertion(+) diff --git a/hopper/test_flash_attn_triton_amd.py b/hopper/test_flash_attn_triton_amd.py index 02141a72ed7..738ec1d8c13 100755 --- a/hopper/test_flash_attn_triton_amd.py +++ b/hopper/test_flash_attn_triton_amd.py @@ -1039,6 +1039,7 @@ def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype): (2048, 2048), ], ) +@pytest.mark.skip(reason="Cannot be run in parallel with other tests due to memory usage") def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): device = "cuda" # set seed From a33c85ba0fb1092628f9e6ea4628fef48892acca Mon Sep 17 00:00:00 2001 From: Michael Date: Fri, 12 Sep 2025 10:06:57 -0500 Subject: [PATCH 4/7] clean up more --- README.md | 14 +++++++------- flash_attn/flash_attn_triton_amd/Dockerfile | 17 ----------------- 2 files changed, 7 insertions(+), 24 deletions(-) delete mode 100644 flash_attn/flash_attn_triton_amd/Dockerfile diff --git a/README.md b/README.md index 0c4929bfe84..ed4e05dbd29 100644 --- a/README.md +++ b/README.md @@ -160,7 +160,6 @@ Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` s ``` cd flash-attention -git checkout main_perf FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install ``` @@ -184,16 +183,17 @@ WORKDIR /workspace # install triton RUN pip install triton==3.3.0 -# install flash attention -ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" - -RUN git clone https://github.com/ROCm/flash-attention.git &&\ +# build flash attention with triton backend +RUN git clone https://github.com/Dao-AILab/flash-attention &&\ cd flash-attention &&\ - git checkout main_perf &&\ - python setup.py install + FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install # set working dir WORKDIR /workspace/flash-attention + +# set env variable to use triton backend +ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + ``` To build the docker file diff --git a/flash_attn/flash_attn_triton_amd/Dockerfile b/flash_attn/flash_attn_triton_amd/Dockerfile deleted file mode 100644 index 8df939a0886..00000000000 --- a/flash_attn/flash_attn_triton_amd/Dockerfile +++ /dev/null @@ -1,17 +0,0 @@ -FROM rocm/pytorch:latest - -WORKDIR /workspace - -# install triton -RUN pip install triton==3.3.0 - -# install flash attention -ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" - -RUN git clone https://github.com/ROCm/flash-attention.git &&\ - cd flash-attention &&\ - git checkout main_perf &&\ - python setup.py install - -# set working dir -WORKDIR /workspace/flash-attention \ No newline at end of file From d9b6fda32e62061a9777c501feefda530944fca6 Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 17 Sep 2025 12:03:09 -0500 Subject: [PATCH 5/7] fix paged + alibi --- .../flash_attn_triton_amd/fwd_decode.py | 106 +++++++++--------- tests/test_flash_attn_triton_amd.py | 2 - 2 files changed, 56 insertions(+), 52 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index c14e02b9b56..6dbec65e284 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -484,51 +484,65 @@ def _fwd_kernel_splitK( process_end = tl.minimum(hi - block_start, BLOCK_SIZE_K) process_end = tl.minimum(process_end, block_end - block_start) - # Align to BLOCK_N boundaries - process_start = (process_start // BLOCK_N) * BLOCK_N + # Instead of forcing a floor alignment to BLOCK_N (which can still skip + # part of the intended range if start falls mid-tile for small splits), + # start from the raw (possibly unaligned) process_start rounded *down* but + # allow the loop to begin earlier (at most BLOCK_N before) so that any + # partial tile overlapping [lo, hi) is covered. Masking below will remove + # columns < lo or >= hi ensuring numerically identical coverage without + # duplication. + aligned_start = (process_start // BLOCK_N) * BLOCK_N + if aligned_start > 0 and aligned_start + BLOCK_N > process_start: + # ensure we include the tile that contains process_start + process_start = aligned_start + else: + process_start = aligned_start for offset in range(process_start, process_end, BLOCK_N): - # Current position in the sequence + # Current position (may begin slightly before logical split range; masking fixes it) seq_pos = block_start + offset + # Proceed unconditionally; masking below enforces [lo, hi) + # Calculate base addresses for K and V in this physical block + k_base = K + physical_block * BLOCK_SIZE_K * stride_kn + hk_id * stride_kh + g_id * stride_kg + v_base = V + physical_block * BLOCK_SIZE_K * stride_vn + hv_id * stride_vh + g_id * stride_vg - # Only process if in range - if seq_pos < hi and seq_pos >= lo: - # Calculate base addresses for K and V in this physical block - k_base = K + physical_block * BLOCK_SIZE_K * stride_kn + hk_id * stride_kh + g_id * stride_kg - v_base = V + physical_block * BLOCK_SIZE_K * stride_vn + hv_id * stride_vh + g_id * stride_vg - - # Offsets within the current block - block_offs = offset + offs_n - - # Masks for valid data - seq_mask = ((seq_pos + offs_n) < N_CTX_K_FINAL) - block_mask = (block_offs < BLOCK_SIZE_K) - valid_mask = seq_mask & block_mask - - # Apply masks - kT_mask_final = kT_mask & valid_mask[None, :] - v_mask_final = v_mask & valid_mask[:, None] - - # Load K and V - kT_ptrs = k_base + offs_d[:, None] * stride_kd + block_offs[None, :] * stride_kn - v_ptrs = v_base + block_offs[:, None] * stride_vn + offs_d[None, :] * stride_vd - - kT = tl.load(kT_ptrs, mask=kT_mask_final, other=0.0) - v = tl.load(v_ptrs, mask=v_mask_final, other=0.0) - - # Use the specialized paged attention inner function - m_i, l_i, acc = _attn_fwd_inner_paged( - q, kT, v, seq_pos, valid_mask, - m_i, l_i, acc, - pid_m, - q_descale, k_descale, v_descale, # FP8 scaling - IS_FP8, # FP8 flag - BLOCK_M, BLOCK_N, - N_CTX_Q, N_CTX_K_FINAL, - USE_ALIBI, alibi_slope, - USE_SLIDING_WINDOW, IS_CAUSAL, - WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, - ) + # Offsets within the current block + block_offs = offset + offs_n + + # Masks for valid data respecting: + # (1) global key length (seq_mask) + # (2) block bounds (block_mask) + # (3) current split range [lo, hi) + seq_mask = ((seq_pos + offs_n) < N_CTX_K_FINAL) + block_mask = (block_offs < BLOCK_SIZE_K) + end_mask = (block_offs < process_end) + split_mask = ((seq_pos + offs_n) >= lo) & ((seq_pos + offs_n) < hi) + valid_mask = seq_mask & block_mask & end_mask & split_mask + + # Apply masks + kT_mask_final = kT_mask & valid_mask[None, :] + v_mask_final = v_mask & valid_mask[:, None] + + # Load K and V + kT_ptrs = k_base + offs_d[:, None] * stride_kd + block_offs[None, :] * stride_kn + v_ptrs = v_base + block_offs[:, None] * stride_vn + offs_d[None, :] * stride_vd + + kT = tl.load(kT_ptrs, mask=kT_mask_final, other=0.0) + v = tl.load(v_ptrs, mask=v_mask_final, other=0.0) + + # Use the specialized paged attention inner function + m_i, l_i, acc = _attn_fwd_inner_paged( + q, kT, v, seq_pos, valid_mask, + m_i, l_i, acc, + pid_m, + q_descale, k_descale, v_descale, # FP8 scaling + IS_FP8, # FP8 flag + BLOCK_M, BLOCK_N, + N_CTX_Q, N_CTX_K_FINAL, + USE_ALIBI, alibi_slope, + USE_SLIDING_WINDOW, IS_CAUSAL, + WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, + ) else: # Non-paged attention: process KV from cache # Note: Cache should be updated externally before calling this kernel @@ -792,15 +806,7 @@ def attention_decode_forward_triton_impl( k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, ): - # Check for unsupported configuration - seqlen_q = q.shape[1] - if block_table is not None and alibi_slopes is not None and seqlen_q > 1: - raise NotImplementedError( - "Paged attention with ALiBi and multiple queries (seqlen_q > 1) is not supported " - "due to numerical precision issues. Please use non-paged attention or single query decode." - ) - - # Handle cache updates externally before calling the kernel + # handle cache updates if k_new is not None and v_new is not None: # Update cache with new KV values if block_table is None: diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index a8d71653850..ba1932438e2 100755 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -1976,8 +1976,6 @@ def test_flash_attn_kvcache( pytest.skip() if has_leftpad and paged_kv_block_size is not None: pytest.skip() - if USE_TRITON_ROCM and paged_kv_block_size is not None and alibi: - pytest.skip("Paged attention with alibi has numerical issues") device = "cuda" # set seed torch.random.manual_seed(0) From 2dd76fe49e454d7348851ec77d5af734bc3929a4 Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 17 Sep 2025 12:53:59 -0500 Subject: [PATCH 6/7] similar inner paged api --- .../flash_attn_triton_amd/fwd_decode.py | 63 ++++++++++--------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index 6dbec65e284..c8210576216 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -61,9 +61,9 @@ def get_autotune_configs(): @triton.jit def _attn_fwd_inner( - q, kT, v, start_n, + q, kT, v, pos, col_mask, m_i, l_i, acc, - pid_m, hi, + pid_m, q_descale, k_descale, v_descale, # FP8 scaling factors IS_FP8: tl.constexpr, # FP8 flag BLOCK_M: tl.constexpr, @@ -76,7 +76,7 @@ def _attn_fwd_inner( IS_CAUSAL: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, - BOUNDS_CHECKS_N: tl.constexpr, + APPLY_COL_MASK: tl.constexpr, # apply provided col_mask when True ): # -- compute qk --- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) @@ -87,7 +87,7 @@ def _attn_fwd_inner( if USE_ALIBI: row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - col_idx = start_n + tl.arange(0, BLOCK_N) + col_idx = pos + tl.arange(0, BLOCK_N) # Compute relative positions relative_pos = row_idx[:, None] + N_CTX_K_FINAL - (N_CTX_Q + col_idx[None, :]) @@ -102,7 +102,7 @@ def _attn_fwd_inner( # ------------------------------------------------------------------ if USE_SLIDING_WINDOW: row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # q positions - col_idx = start_n + tl.arange(0, BLOCK_N) # k positions + col_idx = pos + tl.arange(0, BLOCK_N) # k positions row = row_idx[:, None] # [M,1] col = col_idx[None, :] # [1,N] @@ -129,7 +129,7 @@ def _attn_fwd_inner( else: if IS_CAUSAL: row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - col_idx = start_n + tl.arange(0, BLOCK_N) + col_idx = pos + tl.arange(0, BLOCK_N) # create a N_CTX_Q x kv_len causal mask col_offset = N_CTX_K_FINAL - N_CTX_Q @@ -138,10 +138,11 @@ def _attn_fwd_inner( # Apply the mask qk = tl.where(causal_mask, qk, float("-inf")) - # TODO: This is slow, and only needed at the last iteration. - # Maybe we can unroll the last iteration instead? - if BOUNDS_CHECKS_N: - qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + # Column mask (tail / variable-length). Instead of recomputing an arange each time, + # we accept a precomputed mask from the caller (col_valid_mask). + if APPLY_COL_MASK: + # Expect col_mask shape: [BLOCK_N]. True where column is within sequence. + qk = tl.where(col_mask[None, :], qk, float("-inf")) m_i_new = tl.maximum(m_i, tl.max(qk, 1)) # per-row max so far @@ -171,11 +172,11 @@ def _attn_fwd_inner( @triton.jit def _attn_fwd_inner_paged( - q, kT, v, seq_pos, valid_mask, + q, kT, v, pos, col_mask, m_i, l_i, acc, pid_m, - q_descale, k_descale, v_descale, # FP8 scaling factors - IS_FP8: tl.constexpr, # FP8 flag + q_descale, k_descale, v_descale, + IS_FP8: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, N_CTX_Q: tl.constexpr, @@ -186,14 +187,14 @@ def _attn_fwd_inner_paged( IS_CAUSAL: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, + APPLY_COL_MASK: tl.constexpr, ): """ Specialized attention computation for paged KV cache. - Key differences from _attn_fwd_inner: - - Takes a valid_mask parameter to handle block boundaries - - No BOUNDS_CHECKS_N needed as masking is handled via valid_mask - - seq_pos represents the absolute position in the sequence + Key differences from _attn_fwd_inner (pre-unification): + - Takes a col_mask parameter to handle block boundaries / padding + - pos represents the absolute starting column position in the sequence """ # -- compute qk --- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) @@ -205,7 +206,7 @@ def _attn_fwd_inner_paged( # Apply ALiBi if needed if USE_ALIBI: row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - col_idx = seq_pos + tl.arange(0, BLOCK_N) + col_idx = pos + tl.arange(0, BLOCK_N) # Compute relative positions relative_pos = row_idx[:, None] + N_CTX_K_FINAL - (N_CTX_Q + col_idx[None, :]) @@ -218,7 +219,7 @@ def _attn_fwd_inner_paged( # Apply sliding window if needed if USE_SLIDING_WINDOW: row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - col_idx = seq_pos + tl.arange(0, BLOCK_N) + col_idx = pos + tl.arange(0, BLOCK_N) row = row_idx[:, None] col = col_idx[None, :] @@ -244,13 +245,14 @@ def _attn_fwd_inner_paged( qk = tl.where(mask, float("-inf"), qk) elif IS_CAUSAL: row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - col_idx = seq_pos + tl.arange(0, BLOCK_N) + col_idx = pos + tl.arange(0, BLOCK_N) col_offset = N_CTX_K_FINAL - N_CTX_Q causal_mask = row_idx[:, None] >= (col_idx[None, :] - col_offset) qk = tl.where(causal_mask, qk, float("-inf")) - # Mask out invalid positions (from block boundaries) - qk = tl.where(valid_mask[None, :], qk, float("-inf")) + # Mask out invalid positions (from block / range boundaries) if enabled + if APPLY_COL_MASK: + qk = tl.where(col_mask[None, :], qk, float("-inf")) # Compute new m and do softmax m_i_new = tl.maximum(m_i, tl.max(qk, 1)) @@ -535,13 +537,14 @@ def _fwd_kernel_splitK( q, kT, v, seq_pos, valid_mask, m_i, l_i, acc, pid_m, - q_descale, k_descale, v_descale, # FP8 scaling - IS_FP8, # FP8 flag + q_descale, k_descale, v_descale, + IS_FP8, BLOCK_M, BLOCK_N, N_CTX_Q, N_CTX_K_FINAL, USE_ALIBI, alibi_slope, USE_SLIDING_WINDOW, IS_CAUSAL, WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, + True, ) else: # Non-paged attention: process KV from cache @@ -556,12 +559,16 @@ def _fwd_kernel_splitK( v = tl.load(V_ptrs, mask=v_mask, other=0.0) # Use the same inner loop logic + # Precompute column validity mask for this tile (all True for full tiles). + # hi is the upper bound of the overall split range; start_n marks this tile's base. + col_valid_mask = offs_n < (hi - start_n) + m_i, l_i, acc = _attn_fwd_inner( - q, kT, v, start_n, + q, kT, v, start_n, col_valid_mask, m_i, l_i, acc, - pid_m, hi, - q_descale, k_descale, v_descale, # FP8 scaling - IS_FP8, # FP8 flag + pid_m, + q_descale, k_descale, v_descale, + IS_FP8, BLOCK_M, BLOCK_N, N_CTX_Q, N_CTX_K_FINAL, USE_ALIBI, alibi_slope, From 4fc028c11cb7bf6c41d808cfc34721f7289fd9db Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 17 Sep 2025 16:35:00 -0500 Subject: [PATCH 7/7] unify _attn_fwd_inner --- .../flash_attn_triton_amd/fwd_decode.py | 122 ++---------------- 1 file changed, 9 insertions(+), 113 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index c8210576216..327967cebf7 100755 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -170,110 +170,6 @@ def _attn_fwd_inner( return m_i, l_i, acc -@triton.jit -def _attn_fwd_inner_paged( - q, kT, v, pos, col_mask, - m_i, l_i, acc, - pid_m, - q_descale, k_descale, v_descale, - IS_FP8: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - N_CTX_Q: tl.constexpr, - N_CTX_K_FINAL: tl.constexpr, - USE_ALIBI: tl.constexpr, - alibi_slope, - USE_SLIDING_WINDOW: tl.constexpr, - IS_CAUSAL: tl.constexpr, - WINDOW_SIZE_LEFT: tl.constexpr, - WINDOW_SIZE_RIGHT: tl.constexpr, - APPLY_COL_MASK: tl.constexpr, -): - """ - Specialized attention computation for paged KV cache. - - Key differences from _attn_fwd_inner (pre-unification): - - Takes a col_mask parameter to handle block boundaries / padding - - pos represents the absolute starting column position in the sequence - """ - # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - if IS_FP8: - qk += (tl.dot(q, kT) * q_descale * k_descale) # Apply FP8 scaling - else: - qk += tl.dot(q, kT) - - # Apply ALiBi if needed - if USE_ALIBI: - row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - col_idx = pos + tl.arange(0, BLOCK_N) - - # Compute relative positions - relative_pos = row_idx[:, None] + N_CTX_K_FINAL - (N_CTX_Q + col_idx[None, :]) - relative_pos = tl.abs(relative_pos) - - # Compute ALiBi bias - alibi_bias = -1 * alibi_slope * relative_pos - qk += (alibi_bias * 1.44269504) - - # Apply sliding window if needed - if USE_SLIDING_WINDOW: - row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - col_idx = pos + tl.arange(0, BLOCK_N) - row = row_idx[:, None] - col = col_idx[None, :] - - if IS_CAUSAL: - # -------- causal + window -------- - diag = N_CTX_K_FINAL - N_CTX_Q # sk-sq - causal_ok = col <= row + diag - if WINDOW_SIZE_LEFT < 0: # only right window - win_ok = col <= row + diag + WINDOW_SIZE_RIGHT - else: # both sides - win_ok = ((col >= row + diag - WINDOW_SIZE_LEFT) & - (col <= row + diag + WINDOW_SIZE_RIGHT)) - mask = ~(causal_ok & win_ok) # True ⇒ -inf - else: - # -------- non-causal window -------- - sk, sq = N_CTX_K_FINAL, N_CTX_Q - if WINDOW_SIZE_LEFT < 0: - mask = col > row + (sk - sq) + WINDOW_SIZE_RIGHT - else: - right = tl.minimum(row + (sk - sq) + WINDOW_SIZE_RIGHT, sk) - left = row + (sk - sq) - WINDOW_SIZE_LEFT - mask = (col > right) | (col < left) - qk = tl.where(mask, float("-inf"), qk) - elif IS_CAUSAL: - row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - col_idx = pos + tl.arange(0, BLOCK_N) - col_offset = N_CTX_K_FINAL - N_CTX_Q - causal_mask = row_idx[:, None] >= (col_idx[None, :] - col_offset) - qk = tl.where(causal_mask, qk, float("-inf")) - - # Mask out invalid positions (from block / range boundaries) if enabled - if APPLY_COL_MASK: - qk = tl.where(col_mask[None, :], qk, float("-inf")) - - # Compute new m and do softmax - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) - valid = m_i_new > float("-inf") - alpha = tl.where(valid, tl.math.exp2(m_i - m_i_new), 0.0) - qk = tl.where(valid[:, None], qk - m_i_new[:, None], float("-inf")) - p = tl.math.exp2(qk) - - # Update m_i and l_i - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - p = p.to(q.dtype) - - # Scale and update acc - acc *= alpha[:, None] - if IS_FP8: - acc += tl.dot(p.to(v.dtype), v) * v_descale # Apply FP8 scaling for V - else: - acc += tl.dot(p.to(v.dtype), v) - - return m_i, l_i, acc # @triton.autotune( # configs=fwd_auto_tune_configs, @@ -502,7 +398,7 @@ def _fwd_kernel_splitK( for offset in range(process_start, process_end, BLOCK_N): # Current position (may begin slightly before logical split range; masking fixes it) - seq_pos = block_start + offset + pos = block_start + offset # Proceed unconditionally; masking below enforces [lo, hi) # Calculate base addresses for K and V in this physical block k_base = K + physical_block * BLOCK_SIZE_K * stride_kn + hk_id * stride_kh + g_id * stride_kg @@ -515,15 +411,15 @@ def _fwd_kernel_splitK( # (1) global key length (seq_mask) # (2) block bounds (block_mask) # (3) current split range [lo, hi) - seq_mask = ((seq_pos + offs_n) < N_CTX_K_FINAL) + seq_mask = ((pos + offs_n) < N_CTX_K_FINAL) block_mask = (block_offs < BLOCK_SIZE_K) end_mask = (block_offs < process_end) - split_mask = ((seq_pos + offs_n) >= lo) & ((seq_pos + offs_n) < hi) - valid_mask = seq_mask & block_mask & end_mask & split_mask + split_mask = ((pos + offs_n) >= lo) & ((pos + offs_n) < hi) + col_mask = seq_mask & block_mask & end_mask & split_mask # Apply masks - kT_mask_final = kT_mask & valid_mask[None, :] - v_mask_final = v_mask & valid_mask[:, None] + kT_mask_final = kT_mask & col_mask[None, :] + v_mask_final = v_mask & col_mask[:, None] # Load K and V kT_ptrs = k_base + offs_d[:, None] * stride_kd + block_offs[None, :] * stride_kn @@ -532,9 +428,9 @@ def _fwd_kernel_splitK( kT = tl.load(kT_ptrs, mask=kT_mask_final, other=0.0) v = tl.load(v_ptrs, mask=v_mask_final, other=0.0) - # Use the specialized paged attention inner function - m_i, l_i, acc = _attn_fwd_inner_paged( - q, kT, v, seq_pos, valid_mask, + # Unified inner function handles both paged and contiguous + m_i, l_i, acc = _attn_fwd_inner( + q, kT, v, pos, col_mask, m_i, l_i, acc, pid_m, q_descale, k_descale, v_descale,