diff --git a/.github/workflows/amd_tests.yml b/.github/workflows/amd_tests.yml deleted file mode 100644 index 2e3f061c78d..00000000000 --- a/.github/workflows/amd_tests.yml +++ /dev/null @@ -1,65 +0,0 @@ -name: AMD Perf Kernel Tests - -on: - workflow_dispatch: - pull_request: - branches: [main_perf] - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - Integration-Tests-AMD: - runs-on: ${{ matrix.runner }} - strategy: - matrix: - runner: [linux-mi300-gpu-1] - fail-fast: false # disables failing the entire job when one matrix entry fails - timeout-minutes: 720 # self hosted runners can run jobs for longer than the default of 360 minutes - container: - image: rocm/pytorch:latest - options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --shm-size 16G --group-add video --user root - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Show Device Info - run: | - rocminfo | grep gfx - - - name: Uninstall Triton - run: | - pip uninstall -y triton - rm -rf ~/.triton - rm -rf ./triton/python/build - - - name: Install Triton - run: | - pip install triton==3.3.0 - - - name: Show Triton version - run: | - pip show triton - - - name: Build - run: | - FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install - - - name: Install dependencies for bench and misc - run: | - pip install matplotlib pandas tabulate - - - name: AMD Internal Tests - run: | - FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest flash_attn/flash_attn_triton_amd/test.py - - - name: Flash Attention Tests - run: | - FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0 pytest -n 8 tests/test_flash_attn_triton_amd.py - - - name: AMD Bench - run: | - python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_func - python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_varlen_func - python flash_attn/flash_attn_triton_amd/bench.py -benchmark_fn flash_attn_with_kvcache diff --git a/README.md b/README.md index 0c4929bfe84..ed4e05dbd29 100644 --- a/README.md +++ b/README.md @@ -160,7 +160,6 @@ Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` s ``` cd flash-attention -git checkout main_perf FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install ``` @@ -184,16 +183,17 @@ WORKDIR /workspace # install triton RUN pip install triton==3.3.0 -# install flash attention -ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" - -RUN git clone https://github.com/ROCm/flash-attention.git &&\ +# build flash attention with triton backend +RUN git clone https://github.com/Dao-AILab/flash-attention &&\ cd flash-attention &&\ - git checkout main_perf &&\ - python setup.py install + FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install # set working dir WORKDIR /workspace/flash-attention + +# set env variable to use triton backend +ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" + ``` To build the docker file diff --git a/flash_attn/flash_attn_triton_amd/.gitignore b/flash_attn/flash_attn_triton_amd/.gitignore deleted file mode 100644 index 21538fc4e4a..00000000000 --- a/flash_attn/flash_attn_triton_amd/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -bwd_prefill_fused.py -bwd_prefill_onekernel.py \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/Dockerfile b/flash_attn/flash_attn_triton_amd/Dockerfile deleted file mode 100644 index 8df939a0886..00000000000 --- a/flash_attn/flash_attn_triton_amd/Dockerfile +++ /dev/null @@ -1,17 +0,0 @@ -FROM rocm/pytorch:latest - -WORKDIR /workspace - -# install triton -RUN pip install triton==3.3.0 - -# install flash attention -ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" - -RUN git clone https://github.com/ROCm/flash-attention.git &&\ - cd flash-attention &&\ - git checkout main_perf &&\ - python setup.py install - -# set working dir -WORKDIR /workspace/flash-attention \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/README.md b/flash_attn/flash_attn_triton_amd/README.md deleted file mode 100644 index 87213c1883c..00000000000 --- a/flash_attn/flash_attn_triton_amd/README.md +++ /dev/null @@ -1,113 +0,0 @@ -Flash Attention Triton Kernel -=============== - -#### Introduction -The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress. - -It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes. - -These features are supported in Fwd and Bwd -1) Fwd and Bwd with causal masking -2) Variable sequence lengths -3) Arbitrary Q and KV sequence lengths -4) Arbitrary head sizes -5) Multi and grouped query attention -6) Dropout -7) Rotary embeddings -8) ALiBi - -We are working on the following things -1) Paged Attention -2) Sliding Window -3) FP8 -4) Performance Improvements - -##### Getting Started -To get started with the triton backend for AMD, follow the steps below. - -First install the recommended Triton version - -``` -pip install triton==3.3.0 -``` -Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`. - -``` -cd flash-attention -git checkout main_perf -FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install -``` - -To test that things are working, you can run our tests. These tests take hours so you don't need to run the full thing. -``` -FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" pytest tests/test_flash_attn_triton_amd.py -``` - -You can use autotune for better performance by using this flag `FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE"` -``` -FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" FLASH_ATTENTION_TRITON_AMD_AUTOTUNE="TRUE" python $PATH_TO_CODE -``` - -###### Docker -You can also use the Dockerfile below which does the above steps on top of the latest rocm/pytorch image. -``` -FROM rocm/pytorch:latest - -WORKDIR /workspace - -# install triton -RUN pip install triton==3.3.0 - -# install flash attention -ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" - -RUN git clone https://github.com/ROCm/flash-attention.git &&\ - cd flash-attention &&\ - git checkout main_perf &&\ - python setup.py install - -# set working dir -WORKDIR /workspace/flash-attention -``` - -To build the docker file -``` -docker build -t fa_triton . -``` - -To run the docker image -``` -docker run -it --network=host --user root --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --ipc=host --shm-size 16G --device=/dev/kfd --device=/dev/dri fa_triton -``` - -###### FP8 -In our fork We have created the following api functions that use fp8 internally to compute their values. These functions are `flash_attn_fp8_func`, `flash_attn_varlen_fp8_func`, `flash_attn_qkvpacked_fp8_func` and `flash_attn_varlen_qkvpacked_fp8_func`. Here is a usage example - -``` -from flash_attn.flash_attn_triton_amd.fp8 import flash_attn_qkvpacked_fp8_func - -# forward pass -out, lse, S_dmask = flash_attn_qkvpacked_fp8_func( - qkv, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - -# backward pass -do = torch.randn_like(out) -dqkv = torch.autograd.grad(out, (qkv), do) -``` - -You can use the other api functions in a similar way. - - - -##### Credits -AMD Triton kernels team - -OpenAI kernel team diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py old mode 100644 new mode 100755 index e969a3770b8..51e53daedc2 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py @@ -1,7 +1,7 @@ import torch import triton import triton.language as tl -from flash_attn.flash_attn_triton_amd.utils import compute_fp8_scaling_factors +from flash_attn.flash_attn_triton_amd.utils import compute_fp8_scaling_factors, DEBUG, is_fp8 from typing import Optional, Tuple @@ -1503,11 +1503,17 @@ def attention_prefill_backward_triton_fused_atomics_impl( descale_v: Optional[torch.Tensor] = None, descale_do: Optional[torch.Tensor] = None, fused: bool = False, + # seqused for FA v3 (currently ignored in this implementation) + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, ): IS_FP8 = is_fp8(q) if IS_FP8: FP8_MAX = torch.finfo(q.dtype).max descale_strides = (descale_q.stride(0),descale_k.stride(0),descale_v.stride(0),descale_do.stride(0) ) + + if DEBUG: + print(f"FP8 path triggered in bwd_prefill_fused_atomics.py") else: FP8_MAX = None stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_do_z = None diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py old mode 100644 new mode 100755 index 5b2f8858d11..0d3b3a6fdf4 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_fused_no_atomics.py @@ -3,10 +3,9 @@ import triton # type: ignore import triton.language as tl # type: ignore from typing import Literal, Optional -from .utils import AUTOTUNE, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, \ +from .utils import AUTOTUNE, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, DEBUG, compute_fp8_scaling_factors, \ create_dropout_mask, create_dropout_mask_varlen, is_cdna, is_fp8, is_rdna, round_multiple -DEBUG= False # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) @@ -29,7 +28,7 @@ def get_autotune_configs(): ] preprocess_autotune_keys = [ "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", ] causal_autotune_configs = [ triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config @@ -39,7 +38,7 @@ def get_autotune_configs(): ] causal_autotune_keys = [ "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", ] noncausal_autotune_configs = [ triton.Config({"BLOCK_M1": 32, "BLOCK_N1": 128, "BLOCK_M2": 128, "BLOCK_N2": 32, "BLK_SLICE_FACTOR": 2, "waves_per_eu": WAVES_PER_EU, "matrix_instr_nonkdim": MATRIX_INSTR_NONKDIM}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), # og config @@ -49,7 +48,7 @@ def get_autotune_configs(): ] noncausal_autotune_keys = [ "IS_CAUSAL", "dropout_p", "MAX_SEQLENS_Q", "MAX_SEQLENS_K", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", ] return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) @@ -72,21 +71,21 @@ def get_autotune_configs(): ] preprocess_autotune_keys = [ "max_seqlen_q", - "ACTUAL_HEAD_DIM", "IS_VARLEN", + "ACTUAL_HEAD_DIM_V", "IS_VARLEN", ] causal_autotune_configs = [ triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), ] causal_autotune_keys = [ "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", ] noncausal_autotune_configs = [ triton.Config({"BLOCK_M1": BLOCK_M1, "BLOCK_N1": BLOCK_N1, "BLOCK_M2": BLOCK_M2, "BLOCK_N2": BLOCK_N2, "BLK_SLICE_FACTOR": BLK_SLICE_FACTOR, "waves_per_eu": WAVES_PER_EU}, num_stages=NUM_STAGES, num_warps=NUM_WARPS), ] noncausal_autotune_keys = [ "dropout_p", "max_seqlen_q", "max_seqlen_k", - "ACTUAL_HEAD_DIM", "IS_VARLEN", "HQ", "HK", + "ACTUAL_HEAD_DIM_QK", "ACTUAL_HEAD_DIM_V", "IS_VARLEN", "HQ", "HK", ] return (preprocess_autotune_configs, preprocess_autotune_keys), (causal_autotune_configs, causal_autotune_keys), (noncausal_autotune_configs, noncausal_autotune_keys) @@ -117,8 +116,8 @@ def _bwd_preprocess( cu_seqlens_q, max_seqlen_q, Descale_do, PRE_BLOCK: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, IS_VARLEN: tl.constexpr, IS_FP8: tl.constexpr ): @@ -136,7 +135,7 @@ def _bwd_preprocess( # Compute offsets offs_m = pid_m * PRE_BLOCK + tl.arange(0, PRE_BLOCK) - offs_d = tl.arange(0, HEAD_DIM) + offs_d = tl.arange(0, HEAD_DIM_V) # pointer offsets for O & DO off_o = ( bid * stride_ob + hid * stride_oh @@ -152,9 +151,9 @@ def _bwd_preprocess( # create masks mask_m = offs_m < seqlen_q mask_md = mask_m[:, None] - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - if PADDED_HEAD: - mask_md &= offs_d[None, :] < ACTUAL_HEAD_DIM + PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) + if PADDED_HEAD_V: + mask_md &= offs_d[None, :] < ACTUAL_HEAD_DIM_V # load o = tl.load(O + off_o, mask=mask_md, other=0.0) do = tl.load(DO + off_do, mask=mask_md, other=0.0) @@ -185,8 +184,10 @@ def _bwd_dkdv_inner( stride_lse_m, stride_delta_m, BLOCK_M: tl.constexpr, # 16 BLOCK_N: tl.constexpr, # 128 - HEAD_DIM: tl.constexpr, # - ACTUAL_HEAD_DIM: tl.constexpr, # + HEAD_DIM_QK: tl.constexpr, # + HEAD_DIM_V: tl.constexpr, # + ACTUAL_HEAD_DIM_QK: tl.constexpr, # + ACTUAL_HEAD_DIM_V: tl.constexpr, # dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k @@ -203,18 +204,20 @@ def _bwd_dkdv_inner( DEBUG_TRITON_DETAIL: tl.constexpr, ): # if HEAD_DIM is padded - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + PADDED_HEAD_QK: tl.constexpr = (ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK) + PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) delta_qk = seqlen_q - seqlen_k offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) - offs_k = tl.arange(0, HEAD_DIM) + offs_k_qk = tl.arange(0, HEAD_DIM_QK) + offs_k_v = tl.arange(0, HEAD_DIM_V) # mask to make sure not OOB of seqlen_q mask_n = offs_n < seqlen_k # Q and DO are (seqlen_q, head_dim) - # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM, 1), transpose of q - qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k[:, None] * stride_qk - # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM), NOT transposed - do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM_QK, 1), transpose of q + qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k_qk[:, None] * stride_qk + # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM_V), NOT transposed + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k_v[None, :] * stride_dok # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. tl.static_assert(BLOCK_N % BLOCK_M == 0) curr_m = start_m @@ -231,9 +234,10 @@ def _bwd_dkdv_inner( mask_qT = mask_m[None, :] mask_do = mask_m[:, None] mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) - if PADDED_HEAD: - mask_qT &= offs_k[:, None] < ACTUAL_HEAD_DIM - mask_do &= offs_k[None, :] < ACTUAL_HEAD_DIM + if PADDED_HEAD_QK: + mask_qT &= offs_k_qk[:, None] < ACTUAL_HEAD_DIM_QK + if PADDED_HEAD_V: + mask_do &= offs_k_v[None, :] < ACTUAL_HEAD_DIM_V qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) # generate dropout mask if ENABLE_DROPOUT: @@ -341,8 +345,10 @@ def _bwd_dq_inner( seqlen_q, seqlen_k, # BLOCK_M2: tl.constexpr, # BLOCK_N2: tl.constexpr, # - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, # + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, # dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, # Filled in by the wrapper. @@ -358,17 +364,19 @@ def _bwd_dq_inner( DEBUG_TRITON_DETAIL: tl.constexpr, ): # if HEAD_DIM is padded - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) + PADDED_HEAD_QK: tl.constexpr = (ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK) + PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) delta_qk = seqlen_q - seqlen_k offs_m = start_m + tl.arange(0, BLOCK_M2) offs_n = start_n + tl.arange(0, BLOCK_N2) - offs_k = tl.arange(0, HEAD_DIM) + offs_k_qk = tl.arange(0, HEAD_DIM_QK) + offs_k_v = tl.arange(0, HEAD_DIM_V) # mask to make sure not OOB of seqlen_q mask_m = offs_m < seqlen_q - kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk - vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k_qk[:, None] * stride_kk + vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k_v[:, None] * stride_vk # D (= delta) is pre-divided by ds_scale. Di = tl.load(Delta + offs_m * stride_delta_m, mask=mask_m, other=0.0) # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. @@ -387,12 +395,15 @@ def _bwd_dq_inner( if DEBUG_TRITON_DETAIL: print(f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}") # noqa: E701 if DEBUG_TRITON_DETAIL: print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 mask_kT = mask_n[None, :] + mask_vT = mask_n[None, :] mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) - if PADDED_HEAD: - mask_kT &= offs_k[:, None] < ACTUAL_HEAD_DIM + if PADDED_HEAD_QK: + mask_kT &= offs_k_qk[:, None] < ACTUAL_HEAD_DIM_QK + if PADDED_HEAD_V: + mask_vT &= offs_k_v[:, None] < ACTUAL_HEAD_DIM_V kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) - vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) + vT = tl.load(vT_ptrs, mask=mask_vT, other=0.0) if ENABLE_DROPOUT: # NOTE: dropout is transposed because it is used to mask pT @@ -477,6 +488,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_az, stride_ah, HQ, HK, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Add seqused parameters max_seqlen_q, max_seqlen_k, Dropout_mask, dropout_p, philox_seed, philox_offset_base, Alibi_slopes, @@ -486,8 +498,10 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, IS_VARLEN: tl.constexpr, USE_ALIBI: tl.constexpr, @@ -495,6 +509,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr, + USE_SEQUSED: tl.constexpr, # Add flag for seqused DEBUG_TRITON: tl.constexpr, DEBUG_TRITON_DETAIL: tl.constexpr, ): @@ -514,21 +529,31 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba q_end = tl.load(cu_seqlens_q + bid + 1) k_start = tl.load(cu_seqlens_k + bid) k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start + + # If seqused is provided, use it to limit the actual sequence length + if USE_SEQUSED: + actual_seqlen_q = tl.load(seqused_q + bid) if seqused_q is not None else q_end - q_start + seqlen_q = tl.minimum(actual_seqlen_q, q_end - q_start) + actual_seqlen_k = tl.load(seqused_k + bid) if seqused_k is not None else k_end - k_start + seqlen_k = tl.minimum(actual_seqlen_k, k_end - k_start) + else: + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start delta_qk = seqlen_q - seqlen_k if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") # noqa: E701 - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - offs_d = tl.arange(0, HEAD_DIM) + PADDED_HEAD_QK: tl.constexpr = (ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK) + PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) + offs_d_qk = tl.arange(0, HEAD_DIM_QK) + offs_d_v = tl.arange(0, HEAD_DIM_V) GROUP_SIZE: tl.constexpr = HQ // HK # align the delta_qk start_n = pid * BLOCK_N1 if start_n < seqlen_k: # This section does dk and dv - dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, HEAD_DIM_QK], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM_V], dtype=tl.float32) # q > k: diretcly skip all the way until the start of causal block start_delta_q_gt_k = delta_qk @@ -548,17 +573,21 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba offs_n = start_n + tl.arange(0, BLOCK_N1) # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - if PADDED_HEAD: - mask_d = offs_d < ACTUAL_HEAD_DIM - mask_kv &= mask_d[None, :] + mask_k = offs_n[:, None] < seqlen_k + mask_v = offs_n[:, None] < seqlen_k + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_k &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_v &= mask_d_v[None, :] # K/V tensors not changed for the group - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d_qk[None, :] * stride_kd + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d_v[None, :] * stride_vd # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k, mask=mask_kv, other=0.0) - v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + k = tl.load(K + adj_k, mask=mask_k, other=0.0) + v = tl.load(V + adj_v, mask=mask_v, other=0.0) # If MQA / GQA, set the K and V head offsets appropriately. # hqid = hkid for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): @@ -628,7 +657,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_dropoutm, stride_dropoutn, # strides for dropout stride_lse_m, stride_delta_m, MASK_BLOCK_M1, BLOCK_N1, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, # head dim dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k @@ -658,7 +687,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_dropoutm, stride_dropoutn, # strides for dropout stride_lse_m, stride_delta_m, BLOCK_M1, BLOCK_N1, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, # head dim dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k @@ -676,13 +705,13 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba # end of GQA/MQA of dkdv # Write back dV adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn - offs_dv = offs_n[:, None] * stride_dvn + offs_d[None, :] * stride_dvd - tl.store(DV + adj_dv + offs_dv, dv, mask=mask_kv) + offs_dv = offs_n[:, None] * stride_dvn + offs_d_v[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_v) # write back dk adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn - offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd + offs_dk = offs_n[:, None] * stride_dkn + offs_d_qk[None, :] * stride_dkd dk *= sm_scale - tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_k) # This part does dq start_m = pid * BLOCK_M2 @@ -696,11 +725,15 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba offs_m = start_m + tl.arange(0, BLOCK_M2) # Mask for loading K and V mask_q = offs_m[:, None] < seqlen_q - if PADDED_HEAD: - mask_d = offs_d < ACTUAL_HEAD_DIM - mask_q &= mask_d[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd - offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + mask_do = offs_m[:, None] < seqlen_q + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_q &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_do &= mask_d_v[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d_v[None, :] * stride_dod # NOTE: don't assume that the strides for k and v are the same! K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn @@ -739,7 +772,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba dropout_offset = \ Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_do, other=0.0) m = tl.load(M + adj_m + offs_m * stride_lse_m, mask=offs_m < seqlen_q) m = m[:, None] @@ -757,7 +790,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba else: descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 - dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, HEAD_DIM_QK], dtype=tl.float32) dq = _bwd_dq_inner( dq, q, K, V, do, m, Delta_ptr, sm_scale, @@ -767,7 +800,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_delta_m, seqlen_q, seqlen_k, BLOCK_M2, MASK_BLOCK_N2, - HEAD_DIM, ACTUAL_HEAD_DIM, + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, start_m, start_n, end_n, num_steps, @@ -794,7 +827,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba stride_delta_m, seqlen_q, seqlen_k, BLOCK_M2, BLOCK_N2, - HEAD_DIM, ACTUAL_HEAD_DIM, + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, start_m, start_n, end_n, num_steps, @@ -810,7 +843,7 @@ def bwd_kernel_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), ba ) # Write back dQ. adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd + offs_dq = offs_m[:, None] * stride_dqm + offs_d_qk[None, :] * stride_dqd dq *= sm_scale tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) # end of GQA/MQA of dq @@ -838,6 +871,7 @@ def bwd_kernel_noncausal( stride_az, stride_ah, HQ, HK, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Add seqused parameters max_seqlen_q, max_seqlen_k, Dropout_mask, dropout_p, philox_seed, philox_offset_base, Alibi_slopes, @@ -847,8 +881,10 @@ def bwd_kernel_noncausal( BLOCK_M2: tl.constexpr, # 128 BLOCK_N2: tl.constexpr, # 32 BLK_SLICE_FACTOR: tl.constexpr, - HEAD_DIM: tl.constexpr, - ACTUAL_HEAD_DIM: tl.constexpr, + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, IS_VARLEN: tl.constexpr, USE_ALIBI: tl.constexpr, @@ -856,6 +892,7 @@ def bwd_kernel_noncausal( IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr, + USE_SEQUSED: tl.constexpr, # Add flag for seqused DEBUG_TRITON: tl.constexpr, DEBUG_TRITON_DETAIL: tl.constexpr, ): @@ -875,31 +912,45 @@ def bwd_kernel_noncausal( q_end = tl.load(cu_seqlens_q + bid + 1) k_start = tl.load(cu_seqlens_k + bid) k_end = tl.load(cu_seqlens_k + bid + 1) - seqlen_q = q_end - q_start - seqlen_k = k_end - k_start + + # If seqused is provided, use it to limit the actual sequence length + if USE_SEQUSED: + actual_seqlen_q = tl.load(seqused_q + bid) if seqused_q is not None else q_end - q_start + seqlen_q = tl.minimum(actual_seqlen_q, q_end - q_start) + actual_seqlen_k = tl.load(seqused_k + bid) if seqused_k is not None else k_end - k_start + seqlen_k = tl.minimum(actual_seqlen_k, k_end - k_start) + else: + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start - PADDED_HEAD: tl.constexpr = (ACTUAL_HEAD_DIM != HEAD_DIM) - offs_d = tl.arange(0, HEAD_DIM) + PADDED_HEAD_QK: tl.constexpr = (ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK) + PADDED_HEAD_V: tl.constexpr = (ACTUAL_HEAD_DIM_V != HEAD_DIM_V) + offs_d_qk = tl.arange(0, HEAD_DIM_QK) + offs_d_v = tl.arange(0, HEAD_DIM_V) GROUP_SIZE: tl.constexpr = HQ // HK start_n = pid * BLOCK_N1 if start_n < seqlen_k: - dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) - dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, HEAD_DIM_QK], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM_V], dtype=tl.float32) offs_n = start_n + tl.arange(0, BLOCK_N1) # Mask for loading K and V - mask_kv = offs_n[:, None] < seqlen_k - if PADDED_HEAD: - mask_d = offs_d < ACTUAL_HEAD_DIM - mask_kv &= mask_d[None, :] + mask_k = offs_n[:, None] < seqlen_k + mask_v = offs_n[:, None] < seqlen_k + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_k &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_v &= mask_d_v[None, :] # NOTE: don't assume that the strides for k and v are the same! # K/V tensors not changed for the group - adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd - adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + offs_n[:, None] * stride_kn + offs_d_qk[None, :] * stride_kd + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + offs_n[:, None] * stride_vn + offs_d_v[None, :] * stride_vd # load K and V: they stay in SRAM throughout the inner loop. - k = tl.load(K + adj_k, mask=mask_kv, other=0.0) - v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + k = tl.load(K + adj_k, mask=mask_k, other=0.0) + v = tl.load(V + adj_v, mask=mask_v, other=0.0) # If MQA / GQA, set the K and V head offsets appropriately. for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): # offset input and output tensor by batch and Q/K heads @@ -948,7 +999,7 @@ def bwd_kernel_noncausal( stride_lse_m, stride_delta_m, BLOCK_M1, BLOCK_N1, # block dim - HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, # head dim dropout_p, philox_seed, batch_philox_offset, dropout_offset, # alibi_slope, seqlen_q, seqlen_k, # max sequence length for q and k @@ -966,13 +1017,13 @@ def bwd_kernel_noncausal( # Write back dV adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn - offs_dv = offs_n[:, None] * stride_dvn + offs_d[None, :] * stride_dvd - tl.store(DV + adj_dv + offs_dv, dv, mask=mask_kv) + offs_dv = offs_n[:, None] * stride_dvn + offs_d_v[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_v) # write back dk adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn - offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd + offs_dk = offs_n[:, None] * stride_dkn + offs_d_qk[None, :] * stride_dkd dk *= sm_scale - tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_k) # THIS PART DOES DQ start_m = pid * BLOCK_M2 @@ -980,11 +1031,15 @@ def bwd_kernel_noncausal( offs_m = start_m + tl.arange(0, BLOCK_M2) # Mask for loading K and V mask_q = offs_m[:, None] < seqlen_q - if PADDED_HEAD: - mask_d = offs_d < ACTUAL_HEAD_DIM - mask_q &= mask_d[None, :] - offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd - offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod + mask_do = offs_m[:, None] < seqlen_q + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_q &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_do &= mask_d_v[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d_v[None, :] * stride_dod K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn # If MQA / GQA, set the K and V head offsets appropriately. @@ -1016,7 +1071,7 @@ def bwd_kernel_noncausal( Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) - do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_do, other=0.0) m = tl.load(M + adj_m + offs_m * stride_lse_m, mask=offs_m < seqlen_q) m = m[:, None] @@ -1034,7 +1089,7 @@ def bwd_kernel_noncausal( end_n = seqlen_k num_steps = tl.cdiv(seqlen_k, BLOCK_N2) - dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, HEAD_DIM_QK], dtype=tl.float32) dq = _bwd_dq_inner( dq, q, K, V, do, m, Delta_ptr, sm_scale, @@ -1044,7 +1099,7 @@ def bwd_kernel_noncausal( stride_delta_m, seqlen_q, seqlen_k, BLOCK_M2, BLOCK_N2, - HEAD_DIM, ACTUAL_HEAD_DIM, + HEAD_DIM_QK, HEAD_DIM_V, ACTUAL_HEAD_DIM_QK, ACTUAL_HEAD_DIM_V, dropout_p, philox_seed, batch_philox_offset, dropout_offset, alibi_slope, start_m, start_n, end_n, num_steps, @@ -1060,7 +1115,7 @@ def bwd_kernel_noncausal( ) # Write back dQ. adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm - offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd + offs_dq = offs_m[:, None] * stride_dqm + offs_d_qk[None, :] * stride_dqd dq *= sm_scale tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) @@ -1106,6 +1161,9 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( descale_dq: Optional[torch.Tensor], descale_dk: Optional[torch.Tensor], descale_dv: Optional[torch.Tensor], + # seqused for FA v3 + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, ): # get params, strides and shape IS_VARLEN = layout == "thd" @@ -1135,13 +1193,13 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( assert max_seqlen_k is not None, "max_seqlen_k must be provided for varlen layout" # assert head dimensions - assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert head_size_q == head_size_k, f"head sizes must match: q={head_size_q}, k={head_size_k}" assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" assert nheads_lse == nheads_q, f"softmax_lse heads {nheads_lse} != q heads {nheads_q}" # assert output shapes - assert o.shape == (total_seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_q)}" + assert o.shape == (total_seqlen_q, nheads_q, head_size_v), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_v)}" assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" @@ -1157,7 +1215,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( # set vars batch = len(cu_seqlens_q) - 1 - head_size = head_size_q + head_size_qk = head_size_q # strides stride_qb, stride_qm, stride_qh, stride_qd = 0, q.stride(0), q.stride(1), q.stride(2) @@ -1180,7 +1238,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( assert batch_q == batch_k == batch_v, f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" # assert head dimensions - assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert head_size_q == head_size_k, f"head sizes must match: q={head_size_q}, k={head_size_k}" assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" @@ -1188,7 +1246,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( assert seqlen_k == seqlen_v, f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" # assert output shapes - assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected" + assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_v), f"o shape {o.shape} != expected" assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" @@ -1199,7 +1257,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( # set vars batch = batch_q - head_size = head_size_q + head_size_qk = head_size_q max_seqlen_q = seqlen_q max_seqlen_k = seqlen_k @@ -1233,6 +1291,9 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None + + if DEBUG: + print(f"FP8 path triggered in bwd_prefill_fused_no_atomics.py (FP8_OUTPUT={FP8_OUTPUT})") else: FP8_MAX = None FP8_OUTPUT = False @@ -1242,10 +1303,14 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( use_alibi, (stride_az, stride_ah) = (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) # get closest power of 2 over or equal to 32. - padded_d_model = 1 << (head_size - 1).bit_length() - padded_d_model = max(padded_d_model, 32) - HEAD_DIM = padded_d_model - ACTUAL_HEAD_DIM = head_size + padded_d_model_qk = 1 << (head_size_qk - 1).bit_length() + padded_d_model_qk = max(padded_d_model_qk, 32) + padded_d_model_v = 1 << (head_size_v - 1).bit_length() + padded_d_model_v = max(padded_d_model_v, 32) + HEAD_DIM_QK = padded_d_model_qk + HEAD_DIM_V = padded_d_model_v + ACTUAL_HEAD_DIM_QK = head_size_qk + ACTUAL_HEAD_DIM_V = head_size_v # init delta if OLD_LSE: @@ -1280,13 +1345,13 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_descale_do_z, cu_seqlens_q, max_seqlen_q, descale_do, - HEAD_DIM=HEAD_DIM, - ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, IS_VARLEN=IS_VARLEN, IS_FP8=IS_FP8 ) - if DEBUG: + if False: print("delta:", delta, delta.shape) # dropout mask tensor for debugging. We dump the dropout mask created in @@ -1337,12 +1402,15 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_az, stride_ah, nheads_q, nheads_k, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Pass seqused tensors max_seqlen_q, max_seqlen_k, dropout_mask, dropout_p, philox_seed, philox_offset, alibi_slopes, descale_q, descale_k, descale_v, descale_do, - HEAD_DIM=HEAD_DIM, - ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + HEAD_DIM_QK=HEAD_DIM_QK, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, USE_ALIBI=use_alibi, @@ -1350,6 +1418,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT, + USE_SEQUSED=(seqused_q is not None or seqused_k is not None), # Add flag for seqused DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) @@ -1371,12 +1440,15 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( stride_az, stride_ah, nheads_q, nheads_k, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Pass seqused tensors max_seqlen_q, max_seqlen_k, dropout_mask, dropout_p, philox_seed, philox_offset, alibi_slopes, descale_q, descale_k, descale_v, descale_do, - HEAD_DIM=HEAD_DIM, - ACTUAL_HEAD_DIM=ACTUAL_HEAD_DIM, + HEAD_DIM_QK=HEAD_DIM_QK, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, USE_ALIBI=use_alibi, @@ -1384,6 +1456,7 @@ def attention_prefill_backward_triton_split_fused_no_atomics_impl( IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT, + USE_SEQUSED=(seqused_q is not None or seqused_k is not None), # Add flag for seqused DEBUG_TRITON=DEBUG_TRITON, DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, ) diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py old mode 100644 new mode 100755 index 56187ea71f0..9ffdc9dea1a --- a/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill_split.py @@ -2,11 +2,9 @@ import triton # type: ignore import triton.language as tl # type: ignore from typing import Literal, Optional -from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, compute_fp8_scaling_factors, get_shapes_from_layout, \ +from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, DEBUG, compute_fp8_scaling_factors, get_shapes_from_layout, \ get_strides_from_layout, create_dropout_mask, create_dropout_mask_varlen, is_fp8 -DEBUG = False - # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) @@ -1092,6 +1090,9 @@ def attention_prefill_backward_triton_split_impl( descale_dq: Optional[torch.Tensor], descale_dk: Optional[torch.Tensor], descale_dv: Optional[torch.Tensor], + # seqused for FA v3 (currently ignored in this implementation) + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, ): # debug DEBUG_TRITON: bool = False @@ -1117,6 +1118,9 @@ def attention_prefill_backward_triton_split_impl( stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None stride_descale_do_z = descale_do.stride(0) if descale_do is not None else None + + if DEBUG: + print(f"FP8 path triggered in bwd_prefill_split.py (FP8_OUTPUT={FP8_OUTPUT})") else: FP8_MAX = None FP8_OUTPUT = False diff --git a/flash_attn/flash_attn_triton_amd/fp8.py b/flash_attn/flash_attn_triton_amd/fp8.py deleted file mode 100644 index df79c7926b2..00000000000 --- a/flash_attn/flash_attn_triton_amd/fp8.py +++ /dev/null @@ -1,716 +0,0 @@ -from typing import Optional, Sequence, Tuple, Union -import torch -import torch.nn as nn -from .utils import cast_to_fp8, is_fp8 -from . import interface_fa as flash_attn_gpu - - -def maybe_contiguous(x): - return x.contiguous() if x is not None and x.stride(-1) != 1 else x - -class FlashAttnFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # figure out fwd parameters - if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" - q_fp8 = q - k_fp8 = k - v_fp8 = v - out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "bshd") - k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "bshd") - v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "bshd") - out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None - - q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] - _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - alibi_slopes, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_o=descale_o - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_fp8[..., :head_size_og] # NOTE: this used to be out_padded. It might cause issue doing an empty - - # check output type - assert out.dtype == q.dtype, "Input and output type must match otherwise there will be implicit casting by autograd" - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - # figure out bwd parameters - if is_fp8(dout): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_do is not None), f"You need to pass descale factors for do" - dout_padded_fp8 = dout_padded - dq, descale_dq = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - dk, descale_dk = torch.zeros_like(k_fp8), torch.zeros_like(descale_k) - dv, descale_dv = torch.zeros_like(v_fp8), torch.zeros_like(descale_v) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "bshd") - dq, descale_dq = torch.zeros_like(q_fp8, dtype=torch.float32), None - dk, descale_dk = torch.zeros_like(k_fp8, dtype=torch.float32), None - dv, descale_dv = torch.zeros_like(v_fp8, dtype=torch.float32), None - - # dq, dk, dv are allocated by us so they should already be contiguous - dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] - flash_attn_gpu.bwd( - dout_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - softmax_lse, - dq, - dk, - dv, - ctx.alibi_slopes, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.deterministic, - None, # gen_ - rng_state, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - descale_dq, - descale_dk, - descale_dv, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None - -def flash_attn_fp8_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None -): - return FlashAttnFP8Func.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - descale_q, - descale_k, - descale_v, - descale_do - ) - -class FlashAttnVarlenFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - block_table, - is_grad_enabled, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None - ): - is_grad = is_grad_enabled and any( - x.requires_grad for x in [q, k, v] - ) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # figure out fwd parameters - if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" - q_fp8 = q - k_fp8 = k - v_fp8 = v - out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_q, max_seqlen=max_seqlen_q) - k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_k, max_seqlen=max_seqlen_k) - v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_k, max_seqlen=max_seqlen_k) - out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None - - q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] - _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - cu_seqlens_q, - cu_seqlens_k, - None, - None, - block_table, - alibi_slopes, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - False, - causal, - window_size[0], - window_size[1], - softcap, - return_softmax, - None, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_o=descale_o - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) - ctx.dropout_p = dropout_p - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_fp8[..., :head_size_og] # NOTE: this used to be out_padded. It might cause issue doing an empty - - # check output type - assert out.dtype == q.dtype, "Input and output type must match otherwise there will be implicit casting by autograd" - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - # figure out bwd parameters - if is_fp8(dout_padded): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_do is not None), f"You need to pass descale factors for do" - dout_padded_fp8 = dout_padded - dq, descale_dq = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - dk, descale_dk = torch.zeros_like(k_fp8), torch.zeros_like(descale_k) - dv, descale_dv = torch.zeros_like(v_fp8), torch.zeros_like(descale_v) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens_q, max_seqlen=ctx.max_seqlen_q) - dq, descale_dq = torch.zeros_like(q_fp8, dtype=torch.float32), None - dk, descale_dk = torch.zeros_like(k_fp8, dtype=torch.float32), None - dv, descale_dv = torch.zeros_like(v_fp8, dtype=torch.float32), None - - # dq, dk, dv are allocated by us so they should already be contiguous - dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] - flash_attn_gpu.varlen_bwd( - dout_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_k, - ctx.alibi_slopes, - ctx.max_seqlen_q, - ctx.max_seqlen_k, - ctx.dropout_p, - ctx.softmax_scale, - False, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.deterministic, - None, - rng_state, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - descale_dq, - descale_dk, - descale_dv, - ) - dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., : dout.shape[-1]] - dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None - - -def flash_attn_varlen_fp8_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, - block_table=None -): - return FlashAttnVarlenFP8Func.apply( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - block_table, - torch.is_grad_enabled() - ) - -class FlashAttnQKVPackedFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None - ): - is_grad = is_grad_enabled and qkv.requires_grad - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - q, k, v = qkv[:, :, 0].detach(), qkv[:, :, 1].detach(), qkv[:, :, 2].detach() - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # figure out fwd parameters - if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" - q_fp8 = q - k_fp8 = k - v_fp8 = v - out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "bshd") - k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "bshd") - v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "bshd") - out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None - - q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] - _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - alibi_slopes, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_o=descale_o, - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_fp8[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors - qkv_shape = q_fp8.shape[:-2] + (3, *q_fp8.shape[-2:]) - head_size_og = dout.size(3) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - # figure out bwd parameters - if is_fp8(dout): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_do is not None), f"You need to pass descale factors for do" - dout_padded_fp8 = dout_padded - dqkv, descale_dqkv = torch.zeros(qkv_shape, device=q_fp8.device), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "bshd") - dqkv, descale_dqkv = torch.zeros(qkv_shape, dtype=torch.float32, device=q_fp8.device), None - - - # dq, dk, dv are allocated by us so they should already be contiguous - dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] - flash_attn_gpu.bwd( - dout_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - softmax_lse, - dqkv[:, :, 0], - dqkv[:, :, 1], - dqkv[:, :, 2], - ctx.alibi_slopes, - ctx.dropout_p, - ctx.softmax_scale, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.deterministic, - None, # gen_ - rng_state, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - None, - None, - None, - ) - dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None - - -def flash_attn_qkvpacked_fp8_func( - qkv, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # <=0.0 means deactivate - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - return FlashAttnQKVPackedFP8Func.apply( - qkv, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) - - -class FlashAttnVarlenQKVPackedFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_softmax, - is_grad_enabled, - descale_q: Optional[torch.Tensor] = None, - descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_do: Optional[torch.Tensor] = None - ): - is_grad = is_grad_enabled and qkv.requires_grad - if softmax_scale is None: - softmax_scale = qkv.shape[-1] ** (-0.5) - q, k, v = qkv[:, 0].detach(), qkv[:, 1].detach(), qkv[:, 2].detach() - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # figure out fwd parameters - if is_fp8(q) or is_fp8(k) or is_fp8(v): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"You need to pass descale factors for q, k and v" - q_fp8 = q - k_fp8 = k - v_fp8 = v - out_fp8, descale_o = torch.zeros_like(q_fp8), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_q is None) and (descale_k is None) and (descale_v is None), f"Found {q.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - q_fp8, descale_q = cast_to_fp8(q, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - k_fp8, descale_k = cast_to_fp8(k, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - v_fp8, descale_v = cast_to_fp8(v, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) - out_fp8, descale_o = torch.zeros_like(q_fp8, dtype=torch.float32), None - - q_fp8, k_fp8, v_fp8 = [maybe_contiguous(x) for x in (q_fp8, k_fp8, v_fp8)] - _, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - cu_seqlens, - cu_seqlens, - None, - None, - None, - alibi_slopes, - max_seqlen, - max_seqlen, - dropout_p, - softmax_scale, - False, - causal, - window_size[0], - window_size[1], - softcap, - return_softmax, - None, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_o=descale_o - ) - if is_grad: - ctx.save_for_backward(q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do) - ctx.dropout_p = dropout_p - ctx.max_seqlen = max_seqlen - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.softcap = softcap - ctx.alibi_slopes = alibi_slopes - ctx.deterministic = deterministic - out = out_fp8[..., :head_size_og] - return out if not return_softmax else (out, softmax_lse, S_dmask) - - @staticmethod - def backward(ctx, dout, *args): - q_fp8, k_fp8, v_fp8, out_fp8, softmax_lse, cu_seqlens, rng_state, descale_q, descale_k, descale_v, descale_o, descale_do = ctx.saved_tensors - qkv_shape = q_fp8.shape[:-2] + (3, *q_fp8.shape[-2:]) - head_size_og = dout.size(2) - dout_padded = dout - if head_size_og % 8 != 0: - dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) - - # figure out bwd parameters - if is_fp8(dout_padded): # fp8 input and output - raise ValueError("fp8 input and out not supported yet for this function.") - assert (descale_do is not None), f"You need to pass descale factors for do" - dout_padded_fp8 = dout_padded - dqkv, descale_dqkv = torch.zeros(qkv_shape, device=q_fp8.device), torch.zeros_like(descale_q) - else: # cast to fp8 and return output in the fp32. (accumulator type) - assert (descale_do is None), f"Found {dout.dtype} input tensor with descale factors. In this case, we cast to fp8 and compute the descale factors. You can pass an fp8 tensor with its descale factors if desired." - dout_padded_fp8, descale_do = cast_to_fp8(dout_padded, torch.float8_e4m3fnuz, "thd", cu_seqlens=cu_seqlens, max_seqlen=ctx.max_seqlen) - dqkv, descale_dqkv = torch.zeros(qkv_shape, dtype=torch.float32, device=q_fp8.device), None - - # dq, dk, dv are allocated by us so they should already be contiguous - dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8 = [maybe_contiguous(x) for x in (dout_padded_fp8, q_fp8, k_fp8, v_fp8, out_fp8)] - flash_attn_gpu.varlen_bwd( - dout_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - softmax_lse, - dqkv[:, 0], - dqkv[:, 1], - dqkv[:, 2], - cu_seqlens, - cu_seqlens, - ctx.alibi_slopes, - ctx.max_seqlen, - ctx.max_seqlen, - ctx.dropout_p, - ctx.softmax_scale, - False, - ctx.causal, - ctx.window_size[0], - ctx.window_size[1], - ctx.softcap, - ctx.deterministic, - None, - rng_state, - descale_q, - descale_k, - descale_v, - descale_o, - descale_do, - None, - None, - None, - ) - dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None, None, None - - -def flash_attn_varlen_qkvpacked_fp8_func( - qkv, - cu_seqlens, - max_seqlen, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - alibi_slopes=None, - deterministic=False, - return_attn_probs=False, -): - return FlashAttnVarlenQKVPackedFP8Func.apply( - qkv, - cu_seqlens, - max_seqlen, - dropout_p, - softmax_scale, - causal, - window_size, - softcap, - alibi_slopes, - deterministic, - return_attn_probs, - torch.is_grad_enabled(), - ) diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py old mode 100644 new mode 100755 index d3f7f9c32b9..327967cebf7 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -2,7 +2,7 @@ import triton import triton.language as tl from typing import Literal, Optional, Union -from .utils import AUTOTUNE, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna +from .utils import AUTOTUNE, get_padded_headsize, get_shape_and_strides_from_layout, is_cdna, is_fp8 DEBUG = False @@ -59,6 +59,118 @@ def get_autotune_configs(): (fwd_auto_tune_configs, fwd_autotune_keys), (reduce_auto_tune_configs, reduce_autotune_keys) = get_autotune_configs() +@triton.jit +def _attn_fwd_inner( + q, kT, v, pos, col_mask, + m_i, l_i, acc, + pid_m, + q_descale, k_descale, v_descale, # FP8 scaling factors + IS_FP8: tl.constexpr, # FP8 flag + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + N_CTX_Q: tl.constexpr, + N_CTX_K_FINAL: tl.constexpr, + USE_ALIBI: tl.constexpr, + alibi_slope, + USE_SLIDING_WINDOW: tl.constexpr, + IS_CAUSAL: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + APPLY_COL_MASK: tl.constexpr, # apply provided col_mask when True +): + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_FP8: + qk += (tl.dot(q, kT) * q_descale * k_descale) # Apply FP8 scaling + else: + qk += tl.dot(q, kT) # noqa: F821 + + if USE_ALIBI: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = pos + tl.arange(0, BLOCK_N) + + # Compute relative positions + relative_pos = row_idx[:, None] + N_CTX_K_FINAL - (N_CTX_Q + col_idx[None, :]) + relative_pos = tl.abs(relative_pos) + + # Compute ALiBi bias + alibi_bias = -1 * alibi_slope * relative_pos + qk += (alibi_bias * 1.44269504) + + # ------------------------------------------------------------------ + # masking + # ------------------------------------------------------------------ + if USE_SLIDING_WINDOW: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # q positions + col_idx = pos + tl.arange(0, BLOCK_N) # k positions + row = row_idx[:, None] # [M,1] + col = col_idx[None, :] # [1,N] + + if IS_CAUSAL: + # -------- causal + window -------- + diag = N_CTX_K_FINAL - N_CTX_Q # sk-sq + causal_ok = col <= row + diag + if WINDOW_SIZE_LEFT < 0: # only right window + win_ok = col <= row + diag + WINDOW_SIZE_RIGHT + else: # both sides + win_ok = ((col >= row + diag - WINDOW_SIZE_LEFT) & + (col <= row + diag + WINDOW_SIZE_RIGHT)) + mask = ~(causal_ok & win_ok) # True ⇒ -inf + else: + # -------- non-causal window -------- + sk, sq = N_CTX_K_FINAL, N_CTX_Q + if WINDOW_SIZE_LEFT < 0: + mask = col > row + (sk - sq) + WINDOW_SIZE_RIGHT + else: + right = tl.minimum(row + (sk - sq) + WINDOW_SIZE_RIGHT, sk) + left = row + (sk - sq) - WINDOW_SIZE_LEFT + mask = (col > right) | (col < left) + qk = tl.where(mask, float("-inf"), qk) + else: + if IS_CAUSAL: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = pos + tl.arange(0, BLOCK_N) + + # create a N_CTX_Q x kv_len causal mask + col_offset = N_CTX_K_FINAL - N_CTX_Q + causal_mask = row_idx[:, None] >= (col_idx[None, :] - col_offset) + + # Apply the mask + qk = tl.where(causal_mask, qk, float("-inf")) + + # Column mask (tail / variable-length). Instead of recomputing an arange each time, + # we accept a precomputed mask from the caller (col_valid_mask). + if APPLY_COL_MASK: + # Expect col_mask shape: [BLOCK_N]. True where column is within sequence. + qk = tl.where(col_mask[None, :], qk, float("-inf")) + + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) # per-row max so far + + # rows that are *all* -inf after masking + valid = m_i_new > float("-inf") + + # scale previous partial sums safely + alpha = tl.where(valid, tl.math.exp2(m_i - m_i_new), 0.0) + + # subtract the row max only on valid rows + qk = tl.where(valid[:, None], qk - m_i_new[:, None], float("-inf")) + p = tl.math.exp2(qk) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(q.dtype) + + # -- scale and update acc -- + acc *= alpha[:, None] + if IS_FP8: + acc += tl.dot(p.to(v.dtype), v) * v_descale # Apply FP8 scaling for V + else: + acc += tl.dot(p.to(v.dtype), v) + + return m_i, l_i, acc + + # @triton.autotune( # configs=fwd_auto_tune_configs, # key=fwd_autotune_keys, @@ -69,6 +181,9 @@ def _fwd_kernel_splitK( Q, K, V, + Q_Descale, # FP8 descale factors for Q + K_Descale, # FP8 descale factors for K + V_Descale, # FP8 descale factors for V sm_scale, Out_splitK, # [B*H*G, split_k, Mq, K] Metadata, # [B*H*G, 2, split_k, M_ceil] contains [mi, li] @@ -76,6 +191,7 @@ def _fwd_kernel_splitK( V_new, Cache_seqlens, Cache_batch_idx, + Block_table, Alibi_slopes, stride_qz, stride_qm, @@ -110,13 +226,22 @@ def _fwd_kernel_splitK( stride_vn_g, stride_vn_h, stride_vn_d, + stride_bt_b, + stride_bt_s, stride_az, stride_ah, + stride_q_descale_z, # FP8 descale strides + stride_q_descale_h, + stride_k_descale_z, + stride_k_descale_h, + stride_v_descale_z, + stride_v_descale_h, Z, N_CTX_Q, N_CTX_K, N_CTX_NEW, BLOCK_N_PER_SPLIT, + BLOCK_SIZE_K: tl.constexpr, H_q: tl.constexpr, H_kv: tl.constexpr, G_q: tl.constexpr, @@ -136,6 +261,8 @@ def _fwd_kernel_splitK( USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, + USE_BLOCK_TABLE: tl.constexpr, + IS_FP8: tl.constexpr, # FP8 flag ): # get program ids pid_m = tl.program_id(0) @@ -155,14 +282,24 @@ def _fwd_kernel_splitK( hk_id = hq_id hv_id = hq_id + # Load FP8 descale factors if needed + if IS_FP8: + if IS_GQA: + # For MQA/GQA, q_descale uses the same indexing as k/v (hk_id) + q_descale = tl.load(Q_Descale + z_id * stride_q_descale_z + hk_id * stride_q_descale_h) + else: + # For MHA, q_descale uses hq_id + q_descale = tl.load(Q_Descale + z_id * stride_q_descale_z + hq_id * stride_q_descale_h) + k_descale = tl.load(K_Descale + z_id * stride_k_descale_z + hk_id * stride_k_descale_h) + v_descale = tl.load(V_Descale + z_id * stride_v_descale_z + hv_id * stride_v_descale_h) + else: + q_descale, k_descale, v_descale = 1.0, 1.0, 1.0 + # figure out seqlens lo = pid_splitk * BLOCK_N_PER_SPLIT if USE_CACHE_SEQLENs: cache_seqlen_last_idx = tl.load(Cache_seqlens + z_id) - if NEW_KV: - N_CTX_K_FINAL = cache_seqlen_last_idx + N_CTX_NEW - else: - N_CTX_K_FINAL = cache_seqlen_last_idx + N_CTX_K_FINAL = cache_seqlen_last_idx else: N_CTX_K_FINAL = N_CTX_K hi = tl.minimum((pid_splitk + 1) * BLOCK_N_PER_SPLIT, N_CTX_K_FINAL) @@ -181,8 +318,15 @@ def _fwd_kernel_splitK( # compute ptrs q_offset = Q + hq_id * stride_qh + z_id * stride_qz + g_id * stride_qg q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd - k_offset = K + hk_id * stride_kh + cache_batch_idx * stride_kz + g_id * stride_kg - v_offset = V + hv_id * stride_vh + cache_batch_idx * stride_vz + g_id * stride_vg + + # Handle block table for paged attention + if USE_BLOCK_TABLE: + # K and V now point to paged cache + # Each batch has its own block table row + block_table_ptr = Block_table + z_id * stride_bt_b + else: + k_offset = K + hk_id * stride_kh + cache_batch_idx * stride_kz + g_id * stride_kg + v_offset = V + hv_id * stride_vh + cache_batch_idx * stride_vz + g_id * stride_vg # compute masks if PADDED_HEAD: @@ -212,160 +356,122 @@ def _fwd_kernel_splitK( else: alibi_slope = None - # Copy new Keys and Values into Cache - if NEW_KV: - knew_base = K_new + hk_id * stride_kn_h + z_id * stride_kn_z + g_id * stride_kn_g - - # Determine the starting position for new data in the cache - if USE_CACHE_SEQLENs: - start_idx = tl.load(Cache_seqlens + z_id) - else: - start_idx = N_CTX_K - N_CTX_NEW - - # Copy new Keys - for i in range(0, N_CTX_NEW, BLOCK_N): - # Load from K_new - k_new_block = tl.load( - knew_base + - tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kn_d + - (tl.arange(0, BLOCK_N) + i)[None, :] * stride_kn_n, - mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & - (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), - other=0 - ) - - # Store to K - tl.store( - k_offset + - tl.arange(0, BLOCK_DMODEL)[:, None] * stride_kd + - (tl.arange(0, BLOCK_N) + i + start_idx)[None, :] * stride_kn, - k_new_block, - mask=(tl.arange(0, BLOCK_N)[None, :] + i < N_CTX_NEW) & - (tl.arange(0, BLOCK_DMODEL)[:, None] < ACTUAL_BLOCK_DMODEL), - ) - - # Copy new Values - vnew_base = V_new + hv_id * stride_vn_h + z_id * stride_vn_z + g_id * stride_vn_g - for i in range(0, N_CTX_NEW, BLOCK_N): - # Load from V_new - v_new_block = tl.load( - vnew_base + - (tl.arange(0, BLOCK_N) + i)[:, None] * stride_vn_n + - tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vn_d, - mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & - (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), - other=0 - ) - - # Store to V - tl.store( - v_offset + - (tl.arange(0, BLOCK_N) + i + start_idx)[:, None] * stride_vn + - tl.arange(0, BLOCK_DMODEL)[None, :] * stride_vd, - v_new_block, - mask=(tl.arange(0, BLOCK_N)[:, None] + i < N_CTX_NEW) & - (tl.arange(0, BLOCK_DMODEL)[None, :] < ACTUAL_BLOCK_DMODEL), - ) - - # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # noqa: F821 - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - kT_ptrs = k_offset + offs_d[:, None] * stride_kd + (start_n + offs_n)[None, :] * stride_kn - V_ptrs = v_offset + (start_n + offs_n)[:, None] * stride_vn + offs_d[None, :] * stride_vd - - # load k - kT = tl.load(kT_ptrs, mask=kT_mask, other=0.0) - v = tl.load(V_ptrs, mask=v_mask, other=0.0) - - # -- compute qk --- - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, kT) # noqa: F821 - - if USE_ALIBI: - row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - col_idx = start_n + tl.arange(0, BLOCK_N) - - # Compute relative positions - relative_pos = row_idx[:, None] + N_CTX_K_FINAL - (N_CTX_Q + col_idx[None, :]) - relative_pos = tl.abs(relative_pos) + if USE_BLOCK_TABLE: + # Paged attention: process all KV blocks from cache + # Note: Cache should be updated externally before calling this kernel + num_kv_blocks = (N_CTX_K_FINAL + BLOCK_SIZE_K - 1) // BLOCK_SIZE_K + + for block_idx in range(num_kv_blocks): + # Calculate sequence range for this block + block_start = block_idx * BLOCK_SIZE_K + block_end = tl.minimum(block_start + BLOCK_SIZE_K, N_CTX_K_FINAL) - # Compute ALiBi bias - alibi_bias = -1 * alibi_slope * relative_pos - qk += (alibi_bias * 1.44269504) - - # ------------------------------------------------------------------ - # masking - # ------------------------------------------------------------------ - if USE_SLIDING_WINDOW: - row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # q positions - col_idx = start_n + tl.arange(0, BLOCK_N) # k positions - row = row_idx[:, None] # [M,1] - col = col_idx[None, :] # [1,N] - - if IS_CAUSAL: - # -------- causal + window -------- - diag = N_CTX_K_FINAL - N_CTX_Q # sk-sq - causal_ok = col <= row + diag - if WINDOW_SIZE_LEFT < 0: # only right window - win_ok = col <= row + diag + WINDOW_SIZE_RIGHT - else: # both sides - win_ok = ((col >= row + diag - WINDOW_SIZE_LEFT) & - (col <= row + diag + WINDOW_SIZE_RIGHT)) - mask = ~(causal_ok & win_ok) # True ⇒ -inf - else: - # -------- non-causal window -------- - sk, sq = N_CTX_K_FINAL, N_CTX_Q - if WINDOW_SIZE_LEFT < 0: - mask = col > row + (sk - sq) + WINDOW_SIZE_RIGHT + # Check if block overlaps with our split-k range [lo, hi) + if block_end > lo and block_start < hi: + # Load physical block number + physical_block = tl.load(block_table_ptr + block_idx * stride_bt_s) + + # Calculate the range within this block that overlaps with [lo, hi) + process_start = tl.maximum(lo - block_start, 0) + process_end = tl.minimum(hi - block_start, BLOCK_SIZE_K) + process_end = tl.minimum(process_end, block_end - block_start) + + # Instead of forcing a floor alignment to BLOCK_N (which can still skip + # part of the intended range if start falls mid-tile for small splits), + # start from the raw (possibly unaligned) process_start rounded *down* but + # allow the loop to begin earlier (at most BLOCK_N before) so that any + # partial tile overlapping [lo, hi) is covered. Masking below will remove + # columns < lo or >= hi ensuring numerically identical coverage without + # duplication. + aligned_start = (process_start // BLOCK_N) * BLOCK_N + if aligned_start > 0 and aligned_start + BLOCK_N > process_start: + # ensure we include the tile that contains process_start + process_start = aligned_start else: - right = tl.minimum(row + (sk - sq) + WINDOW_SIZE_RIGHT, sk) - left = row + (sk - sq) - WINDOW_SIZE_LEFT - mask = (col > right) | (col < left) - qk = tl.where(mask, float("-inf"), qk) - else: - if IS_CAUSAL: - row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - col_idx = start_n + tl.arange(0, BLOCK_N) - - # create a N_CTX_Q x kv_len causal mask - col_offset = N_CTX_Q - N_CTX_K_FINAL - causal_mask = row_idx[:, None] >= (col_offset + col_idx[None, :]) - - # Apply the mask - qk = tl.where(causal_mask, qk, float("-inf")) - - # TODO: This is slow, and only needed at the last iteration. - # Maybe we can unroll the last iteration instead? - if BOUNDS_CHECKS_N: - qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) - - m_i_new = tl.maximum(m_i, tl.max(qk, 1)) # per-row max so far - - # rows that are *all* -inf after masking - valid = m_i_new > float("-inf") - - # scale previous partial sums safely - alpha = tl.where(valid, tl.math.exp2(m_i - m_i_new), 0.0) - - # subtract the row max only on valid rows - qk = tl.where(valid[:, None], qk - m_i_new[:, None], float("-inf")) - p = tl.math.exp2(qk) - - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new - p = p.to(Q.dtype.element_ty) - - # -- scale and update acc -- - acc *= alpha[:, None] - acc += tl.dot(p.to(v.dtype), v) + process_start = aligned_start + + for offset in range(process_start, process_end, BLOCK_N): + # Current position (may begin slightly before logical split range; masking fixes it) + pos = block_start + offset + # Proceed unconditionally; masking below enforces [lo, hi) + # Calculate base addresses for K and V in this physical block + k_base = K + physical_block * BLOCK_SIZE_K * stride_kn + hk_id * stride_kh + g_id * stride_kg + v_base = V + physical_block * BLOCK_SIZE_K * stride_vn + hv_id * stride_vh + g_id * stride_vg + + # Offsets within the current block + block_offs = offset + offs_n + + # Masks for valid data respecting: + # (1) global key length (seq_mask) + # (2) block bounds (block_mask) + # (3) current split range [lo, hi) + seq_mask = ((pos + offs_n) < N_CTX_K_FINAL) + block_mask = (block_offs < BLOCK_SIZE_K) + end_mask = (block_offs < process_end) + split_mask = ((pos + offs_n) >= lo) & ((pos + offs_n) < hi) + col_mask = seq_mask & block_mask & end_mask & split_mask + + # Apply masks + kT_mask_final = kT_mask & col_mask[None, :] + v_mask_final = v_mask & col_mask[:, None] + + # Load K and V + kT_ptrs = k_base + offs_d[:, None] * stride_kd + block_offs[None, :] * stride_kn + v_ptrs = v_base + block_offs[:, None] * stride_vn + offs_d[None, :] * stride_vd + + kT = tl.load(kT_ptrs, mask=kT_mask_final, other=0.0) + v = tl.load(v_ptrs, mask=v_mask_final, other=0.0) + + # Unified inner function handles both paged and contiguous + m_i, l_i, acc = _attn_fwd_inner( + q, kT, v, pos, col_mask, + m_i, l_i, acc, + pid_m, + q_descale, k_descale, v_descale, + IS_FP8, + BLOCK_M, BLOCK_N, + N_CTX_Q, N_CTX_K_FINAL, + USE_ALIBI, alibi_slope, + USE_SLIDING_WINDOW, IS_CAUSAL, + WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, + True, + ) + else: + # Non-paged attention: process KV from cache + # Note: Cache should be updated externally before calling this kernel + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + kT_ptrs = k_offset + offs_d[:, None] * stride_kd + (start_n + offs_n)[None, :] * stride_kn + V_ptrs = v_offset + (start_n + offs_n)[:, None] * stride_vn + offs_d[None, :] * stride_vd + + # load k + kT = tl.load(kT_ptrs, mask=kT_mask, other=0.0) + v = tl.load(V_ptrs, mask=v_mask, other=0.0) + + # Use the same inner loop logic + # Precompute column validity mask for this tile (all True for full tiles). + # hi is the upper bound of the overall split range; start_n marks this tile's base. + col_valid_mask = offs_n < (hi - start_n) + + m_i, l_i, acc = _attn_fwd_inner( + q, kT, v, start_n, col_valid_mask, + m_i, l_i, acc, + pid_m, + q_descale, k_descale, v_descale, + IS_FP8, + BLOCK_M, BLOCK_N, + N_CTX_Q, N_CTX_K_FINAL, + USE_ALIBI, alibi_slope, + USE_SLIDING_WINDOW, IS_CAUSAL, + WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, + BOUNDS_CHECKS_N, + ) # write back O osk_offset = Out_splitK + pid_zhg * stride_osk_zhg + pid_splitk * stride_osk_s @@ -598,7 +704,71 @@ def attention_decode_forward_triton_impl( layout: Literal["bshd"], cache_seqlens: Optional[torch.Tensor], cache_batch_idx: Optional[torch.Tensor], + block_table: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, ): + # handle cache updates + if k_new is not None and v_new is not None: + # Update cache with new KV values + if block_table is None: + # Non-paged attention: update cache directly + batch_size = k_new.shape[0] + seqlen_new = k_new.shape[1] + + if cache_seqlens is not None: + # Use cache_seqlens to determine where to insert new KV + for b in range(batch_size): + start_idx = int(cache_seqlens[b].item()) + end_idx = start_idx + seqlen_new + k_cache[b, start_idx:end_idx] = k_new[b] + v_cache[b, start_idx:end_idx] = v_new[b] + cache_seqlens[b] = end_idx + else: + # Append at the end of existing cache + seqlen_cache = k_cache.shape[1] + k_cache[:, seqlen_cache - seqlen_new:] = k_new + v_cache[:, seqlen_cache - seqlen_new:] = v_new + else: + # Paged attention: update cache using block table + batch_size = k_new.shape[0] + seqlen_new = k_new.shape[1] + block_size = k_cache.shape[1] # k_cache shape: [num_blocks, block_size, nheads, head_dim] + + # Update cache for each batch element + for b in range(batch_size): + if cache_seqlens is not None: + start_idx = int(cache_seqlens[b].item()) + else: + # If no cache_seqlens, assume we're appending at the end + # Find the last used position from block table + start_idx = 0 + for block_idx in range(block_table.shape[1]): + if block_table[b, block_idx] >= 0: + start_idx = (block_idx + 1) * block_size + else: + start_idx = block_idx * block_size + break + + # Copy new KV values into the paged cache + for i in range(seqlen_new): + pos = start_idx + i + block_idx = pos // block_size + within_block_idx = pos % block_size + + # Get the physical block number from block table + if block_idx < block_table.shape[1]: + physical_block = int(block_table[b, block_idx].item()) + + # Update k_cache and v_cache at the physical block location + k_cache[physical_block, within_block_idx] = k_new[b, i] + v_cache[physical_block, within_block_idx] = v_new[b, i] + + # Update cache_seqlens if provided + if cache_seqlens is not None: + cache_seqlens[b] = start_idx + seqlen_new + # triton configs BLOCK_M = 16 BLOCK_N = 64 @@ -607,17 +777,45 @@ def attention_decode_forward_triton_impl( num_warps_reduce = 4 # kernel_configs - is_new_kv = True if k_new is not None and v_new is not None else False + is_new_kv = False # Cache has been updated, so no new KV in kernel use_alibi, (stride_az, stride_ah) = True if alibi_slopes is not None else False, alibi_slopes.stride() if alibi_slopes is not None else (None, None) use_cache_seqlens = cache_seqlens is not None use_sliding_window = window_size_left != -1 or window_size_right != -1 + use_block_table = block_table is not None SPLIT_K = None NUM_QUANT_GROUPS = 1 # get shapes and strides (batch_size, seqlen_q, nheads_q, dim_q), (stride_qz, stride_qh, stride_qm, stride_qd) = get_shape_and_strides_from_layout(q, layout) - (_, seqlen_kc, nheads_kc, dim_kc), (stride_kc_z, stride_kc_h, stride_kc_n, stride_kc_d) = get_shape_and_strides_from_layout(k_cache, layout) - (_, seqlen_vc, nheads_vc, dim_vc), (stride_vc_z, stride_vc_h, stride_vc_n, stride_vc_d) = get_shape_and_strides_from_layout(v_cache, layout) + + # Handle paged KV cache layout + if use_block_table: + # For paged attention, k_cache and v_cache have shape [num_blocks, block_size, nheads, head_dim] + num_blocks_kc, block_size_k, nheads_kc, dim_kc = k_cache.shape + num_blocks_vc, block_size_v, nheads_vc, dim_vc = v_cache.shape + # Get the actual sequence length from cache_seqlens or block_table + if cache_seqlens is not None: + seqlen_kc = int(cache_seqlens.max().item()) + else: + # Infer from block_table shape [batch_size, num_blocks_per_seq] + num_blocks_per_seq = block_table.shape[1] + seqlen_kc = num_blocks_per_seq * block_size_k + seqlen_vc = seqlen_kc + + # Strides for paged layout + stride_kc_z = 0 # No batch dimension in paged cache + stride_kc_n = k_cache.stride(1) # Sequence stride + stride_kc_h = k_cache.stride(2) # Head stride + stride_kc_d = k_cache.stride(3) # Dim stride + + stride_vc_z = 0 + stride_vc_n = v_cache.stride(1) + stride_vc_h = v_cache.stride(2) + stride_vc_d = v_cache.stride(3) + else: + (_, seqlen_kc, nheads_kc, dim_kc), (stride_kc_z, stride_kc_h, stride_kc_n, stride_kc_d) = get_shape_and_strides_from_layout(k_cache, layout) + (_, seqlen_vc, nheads_vc, dim_vc), (stride_vc_z, stride_vc_h, stride_vc_n, stride_vc_d) = get_shape_and_strides_from_layout(v_cache, layout) + block_size_k = 0 # Not used if is_new_kv: ( _, seqlen_kn, nheads_kn, dim_kn), (stride_kn_z, stride_kn_h, stride_kn_n, stride_kn_d) = get_shape_and_strides_from_layout(k_new, layout) (_, seqlen_vn, nheads_vn, dim_vn), (stride_vn_z, stride_vn_h, stride_vn_n, stride_vn_d) = get_shape_and_strides_from_layout(v_new, layout) @@ -657,7 +855,12 @@ def attention_decode_forward_triton_impl( split_k = SPLIT_K else: # Use heuristics - split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_kc) # NOTE: should the split think about seqlens? + if use_block_table: + # For paged attention, use the actual sequence length from cache_seqlens + max_seqlen = int(cache_seqlens.max().item()) if cache_seqlens is not None else block_size_k + split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, max_seqlen) + else: + split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_kc) split_size = (seqlen_kc + split_k - 1) // split_k # setup grid @@ -673,6 +876,39 @@ def attention_decode_forward_triton_impl( stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d = out_splitk.stride() stride_mzhg, stride_m2, stride_ms, stride_mm = metadata.stride() stride_lse_zhg, stride_lse_m = lse.stride() + + # Block table strides + if use_block_table: + stride_bt_b, stride_bt_s = block_table.stride() + else: + stride_bt_b, stride_bt_s = 0, 0 + + # FP8 support + IS_FP8 = is_fp8(q) + if IS_FP8: + if (q_descale is None) or (k_descale is None) or (v_descale is None): + import warnings + warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) + # Create default descale tensors if not provided + if q_descale is None: + q_descale = torch.ones(batch_size, nheads_q, dtype=torch.float32, device=q.device) + if k_descale is None: + k_descale = torch.ones(batch_size, nheads_kc, dtype=torch.float32, device=q.device) + if v_descale is None: + v_descale = torch.ones(batch_size, nheads_vc, dtype=torch.float32, device=q.device) + stride_q_descale_z, stride_q_descale_h = q_descale.stride() + stride_k_descale_z, stride_k_descale_h = k_descale.stride() + stride_v_descale_z, stride_v_descale_h = v_descale.stride() + else: + q_descale = None + k_descale = None + v_descale = None + stride_q_descale_z = 0 + stride_q_descale_h = 0 + stride_k_descale_z = 0 + stride_k_descale_h = 0 + stride_v_descale_z = 0 + stride_v_descale_h = 0 if DEBUG: print("batch_size, seqlen_q, nheads_q, dim_q", (batch_size, seqlen_q, nheads_q, dim_q)) @@ -689,18 +925,21 @@ def attention_decode_forward_triton_impl( print("stride_mzhg, stride_m2, stride_ms, stride_mm", (stride_mzhg, stride_m2, stride_ms, stride_mm)) print("stride_lse_zhg, stride_lse_m", (stride_lse_zhg, stride_lse_m)) - # TODO: enable quantization _fwd_kernel_splitK[grid]( Q=q, K=k_cache, V=v_cache, + Q_Descale=q_descale, + K_Descale=k_descale, + V_Descale=v_descale, sm_scale=sm_scale, Out_splitK=out_splitk, Metadata=metadata, - K_new=k_new, - V_new=v_new, + K_new=None, + V_new=None, Cache_seqlens=cache_seqlens, Cache_batch_idx=cache_batch_idx, + Block_table=block_table, Alibi_slopes=alibi_slopes, # q strides stride_qz=stride_qz, @@ -742,17 +981,28 @@ def attention_decode_forward_triton_impl( stride_vn_g=stride_vn_g, stride_vn_h=stride_vn_h, stride_vn_d=stride_vn_d, + # block table strides + stride_bt_b=stride_bt_b, + stride_bt_s=stride_bt_s, # alibi strides stride_az=stride_az, stride_ah=stride_ah, + # FP8 descale strides + stride_q_descale_z=stride_q_descale_z, + stride_q_descale_h=stride_q_descale_h, + stride_k_descale_z=stride_k_descale_z, + stride_k_descale_h=stride_k_descale_h, + stride_v_descale_z=stride_v_descale_z, + stride_v_descale_h=stride_v_descale_h, Z=batch_size, H_q=heads_per_group_q, H_kv=heads_per_group_k, G_q=n_group_q, N_CTX_Q=seqlen_q, N_CTX_K=seqlen_kc, - N_CTX_NEW=seqlen_kn, + N_CTX_NEW=0, # No new KV, cache already updated BLOCK_N_PER_SPLIT=split_size, + BLOCK_SIZE_K=block_size_k if use_block_table else 256, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=dim_padded, @@ -760,7 +1010,7 @@ def attention_decode_forward_triton_impl( BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens, USE_CACHE_SEQLENs=use_cache_seqlens, USE_CACHE_BATCH_IDX=cache_batch_idx is not None, - NEW_KV=is_new_kv, + NEW_KV=False, # Cache already updated IS_GQA=is_gqa, IS_CAUSAL=causal, USE_ALIBI=use_alibi, @@ -769,6 +1019,8 @@ def attention_decode_forward_triton_impl( USE_SLIDING_WINDOW=use_sliding_window, WINDOW_SIZE_LEFT=window_size_left, WINDOW_SIZE_RIGHT=window_size_right, + USE_BLOCK_TABLE=use_block_table, + IS_FP8=IS_FP8, num_warps=num_warps_fwd, num_stages=num_stages, ) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py old mode 100644 new mode 100755 index 3a2bd56fda4..bb0301c7700 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -3,14 +3,99 @@ import triton import triton.language as tl from typing import Literal, Optional, Union -from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_arch, is_cdna, is_fp8, is_rdna, create_dropout_mask, get_fwd_prefill_autotune_configs - -DEBUG = False +from .utils import DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, compute_alibi_block, compute_fp8_scaling_factors, get_arch, is_cdna, is_fp8, is_rdna, create_dropout_mask, DEBUG # NOTE: triton fails to import tl.constexprs so create them here for the file tl_DROPOUT_USE_PYTORCH: tl.constexpr = triton.language.constexpr(DROPOUT_USE_PYTORCH) tl_DROPOUT_DUMP: tl.constexpr = triton.language.constexpr(DROPOUT_DUMP) +# ------------------------------- +# Autotune +# ------------------------------- +def get_fwd_prefill_cdna_autotune_configs(): + return [ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=4), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL_QK', 'ACTUAL_BLOCK_DMODEL_V', 'IS_VARLEN', 'HQ', 'HK'] + + +def get_fwd_prefill_rdna_autotune_configs(): + return [ + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, + num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL_QK', 'ACTUAL_BLOCK_DMODEL_V', 'IS_VARLEN', 'HQ', 'HK'] + + +def get_fwd_prefill_autotune_configs(): + if AUTOTUNE: + if is_rdna(): + return get_fwd_prefill_rdna_autotune_configs() + elif is_cdna(): + return get_fwd_prefill_cdna_autotune_configs() + else: + raise ValueError("Unknown Device Type") + else: + arch = get_arch() + if arch == "gfx950": + default_config = triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + elif arch == "gfx942" and False: # Disabled due shared mem oom in CI when using triton==3.3.0 when using top of tree everything seems fine. + default_config = triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + else: + default_config = triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ) + + return [ + default_config + ], [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_BLOCK_DMODEL_QK", + "ACTUAL_BLOCK_DMODEL_V", + "IS_VARLEN", + "HQ", + "HK", + ] + + fwd_prefill_autotune_configs, fwd_prefill_autotune_keys = get_fwd_prefill_autotune_configs() @triton.jit @@ -20,13 +105,14 @@ def _attn_fwd_no_mask(acc, l_i, m_i, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_base_ptrs, sd_mask_base_ptrs, dropout_mask_base_ptrs, - offs_m, offs_n, offs_d, + offs_m, offs_n, offs_d_qk, offs_d_v, block_min, block_max, alibi_slope, - descale_q, descale_k, descale_v, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + q_descale, k_descale, v_descale, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_P_DESCALE: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_DMODEL_QK: tl.constexpr, BLOCK_DMODEL_V: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD_QK: tl.constexpr, PADDED_HEAD_V: tl.constexpr, + ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, ACTUAL_BLOCK_DMODEL_V: tl.constexpr, + SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, RETURN_SCORES: tl.constexpr, ACCUMULATOR_TYPE): if USE_EXP2: RCP_LN2: tl.constexpr = 1.4426950408889634 @@ -38,21 +124,27 @@ def _attn_fwd_no_mask(acc, l_i, m_i, v_ptrs = v_base_ptrs + start_n * stride_vk kv_offs_n = start_n + tl.arange(0, BLOCK_N) - if PADDED_HEAD: - k_mask, k_mask_other = (offs_d[:, None] < ACTUAL_BLOCK_DMODEL), 0.0 - v_mask, v_mask_other = (offs_d[None, :] < ACTUAL_BLOCK_DMODEL), 0.0 + if PADDED_HEAD_QK: + k_mask, k_mask_other = (offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK), 0.0 + else: + k_mask, k_mask_other = None, None + + if PADDED_HEAD_V: + v_mask, v_mask_other = (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V), 0.0 + else: + v_mask, v_mask_other = None, None # load k and if preload_v then v - k = tl.load(k_ptrs, mask=k_mask, other=k_mask_other) if PADDED_HEAD else tl.load(k_ptrs) + k = tl.load(k_ptrs, mask=k_mask, other=k_mask_other) if PADDED_HEAD_QK else tl.load(k_ptrs) if PRE_LOAD_V: - v = tl.load(v_ptrs, mask=v_mask, other=v_mask_other) if PADDED_HEAD else tl.load(v_ptrs) + v = tl.load(v_ptrs, mask=v_mask, other=v_mask_other) if PADDED_HEAD_V else tl.load(v_ptrs) # setup qk accumlator qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE) # -- compute qk ---- if IS_FP8 : - qk += (tl.dot(q, k) * descale_q * descale_k) + qk += (tl.dot(q, k) * q_descale * k_descale) else: qk += tl.dot(q, k) qk_scaled = qk * SM_SCALE @@ -124,15 +216,18 @@ def _attn_fwd_no_mask(acc, l_i, m_i, alpha = tl.math.exp(m_diff) acc = acc * alpha[:, None] if not PRE_LOAD_V: - v = tl.load(v_ptrs, mask=v_mask, other=v_mask_other) if PADDED_HEAD else tl.load(v_ptrs) + v = tl.load(v_ptrs, mask=v_mask, other=v_mask_other) if PADDED_HEAD_V else tl.load(v_ptrs) # -- update m_i and l_i l_i = l_i * alpha + l_ij m_i = m_ij if IS_FP8: - scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) - acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v) + if FP8_P_DESCALE: + scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) + acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * v_descale) + else: + acc += tl.dot(p.to(v.type.element_ty), v) * v_descale else: acc += tl.dot(p.to(v.type.element_ty), v) @@ -145,13 +240,14 @@ def _attn_fwd_mask(acc, l_i, m_i, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_base_ptrs, sd_mask_base_ptrs, dropout_mask_base_ptrs, - offs_m, offs_n, offs_d, + offs_m, offs_n, offs_d_qk, offs_d_v, block_min, block_max, n_extra_tokens, alibi_slope, - descale_q, descale_k, descale_v, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, - IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + q_descale, k_descale, v_descale, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_P_DESCALE: tl.constexpr, + IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL_QK: tl.constexpr, BLOCK_DMODEL_V: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD_QK: tl.constexpr, PADDED_HEAD_V: tl.constexpr, + ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, ACTUAL_BLOCK_DMODEL_V: tl.constexpr, + SM_SCALE: tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, RETURN_SCORES: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, ACCUMULATOR_TYPE): @@ -172,9 +268,10 @@ def _attn_fwd_mask(acc, l_i, m_i, kv_offs_n = start_n + tl.arange(0, BLOCK_N) k_mask = (kv_offs_n[None, :] < seqlen_k) v_mask = (kv_offs_n[:, None] < seqlen_k) - if PADDED_HEAD: - k_mask = k_mask & (offs_d[:, None] < ACTUAL_BLOCK_DMODEL) - v_mask = v_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + if PADDED_HEAD_QK: + k_mask = k_mask & (offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK) + if PADDED_HEAD_V: + v_mask = v_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) # load k and if preload_v then v k = tl.load(k_ptrs, mask=k_mask, other = 0.0) @@ -200,7 +297,7 @@ def _attn_fwd_mask(acc, l_i, m_i, # -- compute qk ---- if IS_FP8 : - qk += (tl.dot(q, k) * descale_q * descale_k) + qk += (tl.dot(q, k) * q_descale * k_descale) else: qk += tl.dot(q, k) qk_scaled = qk * SM_SCALE @@ -373,8 +470,11 @@ def _attn_fwd_mask(acc, l_i, m_i, m_i = m_ij if IS_FP8: - scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) - acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * descale_v) + if FP8_P_DESCALE: + scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) + acc += (tl.dot((p * scale_p).to(v.type.element_ty), v) * descale_p * v_descale) + else: + acc += tl.dot(p.to(v.type.element_ty), v) * v_descale else: acc += tl.dot(p.to(v.type.element_ty), v) @@ -626,18 +726,19 @@ def compute_block_masking(seqlen_k, seqlen_q, start_m, ) @triton.jit def attn_fwd(Q, K, V, bias, - Descale_Q, Descale_K, Descale_V, Descale_O, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, + Q_Descale, K_Descale, V_Descale, stride_q_descale_z, stride_k_descale_z, stride_v_descale_z, SM_SCALE: tl.constexpr, LSE, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Add seqused parameters dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr, - HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, + HK: tl.constexpr, ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, ACTUAL_BLOCK_DMODEL_V: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, IS_VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, + BLOCK_DMODEL_QK: tl.constexpr, BLOCK_DMODEL_V: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_SCORES: tl.constexpr, NEEDS_SDMASK : tl.constexpr, USE_ALIBI: tl.constexpr, USE_EXP2: tl.constexpr, - IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_OUTPUT: tl.constexpr): + IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, FP8_P_DESCALE: tl.constexpr, USE_SEQUSED: tl.constexpr): # set params ACCUMULATOR_TYPE = tl.float32 @@ -652,24 +753,38 @@ def attn_fwd(Q, K, V, bias, else: off_h_k = off_h_q # Determine if we need to mask the heads - PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + PADDED_HEAD_QK: tl.constexpr = (ACTUAL_BLOCK_DMODEL_QK != BLOCK_DMODEL_QK) + PADDED_HEAD_V: tl.constexpr = (ACTUAL_BLOCK_DMODEL_V != BLOCK_DMODEL_V) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_d_qk = tl.arange(0, BLOCK_DMODEL_QK) + offs_d_v = tl.arange(0, BLOCK_DMODEL_V) # handle seqlen if IS_VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) - seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + + # If seqused is provided, use it to limit the actual sequence length + if USE_SEQUSED: + actual_seqlen_q = tl.load(seqused_q + off_z) if seqused_q is not None else cu_seqlens_q_end - cu_seqlens_q_start + seqlen_q = tl.minimum(actual_seqlen_q, cu_seqlens_q_end - cu_seqlens_q_start) + else: + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start # we have a one-size-fits-all grid in id(0). Some seqlens might be too small for all start_m so for those we return early. if start_m * BLOCK_M > seqlen_q: return cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) - seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + + # If seqused is provided, use it to limit the actual sequence length for keys + if USE_SEQUSED: + actual_seqlen_k = tl.load(seqused_k + off_z) if seqused_k is not None else cu_seqlens_k_end - cu_seqlens_k_start + seqlen_k = tl.minimum(actual_seqlen_k, cu_seqlens_k_end - cu_seqlens_k_start) + else: + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start else: cu_seqlens_q_start = 0 cu_seqlens_k_start = 0 @@ -678,11 +793,16 @@ def attn_fwd(Q, K, V, bias, # Load scale factors if IS_FP8. if IS_FP8: - descale_q = tl.load(Descale_Q + off_z * stride_descale_q_z + off_h_q) - descale_k = tl.load(Descale_K + off_z * stride_descale_k_z + off_h_k) - descale_v = tl.load(Descale_V + off_z * stride_descale_v_z + off_h_k) + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (off_h_k) + # For MHA (GROUP_SIZE == 1), q_descale uses off_h_q (same as off_h_k) + if GROUP_SIZE != 1: + q_descale = tl.load(Q_Descale + off_z * stride_q_descale_z + off_h_k) # MQA/GQA: broadcast using k/v head index + else: + q_descale = tl.load(Q_Descale + off_z * stride_q_descale_z + off_h_q) # MHA: use q head index + k_descale = tl.load(K_Descale + off_z * stride_k_descale_z + off_h_k) + v_descale = tl.load(V_Descale + off_z * stride_v_descale_z + off_h_k) else: - descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + q_descale, k_descale, v_descale = 1.0, 1.0, 1.0 # figure out masking pattern @@ -701,11 +821,11 @@ def attn_fwd(Q, K, V, bias, """ # Write zeros to output o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om - o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d_v[None, :] * stride_on o_mask = (offs_m[:, None] < seqlen_q) - if PADDED_HEAD: - o_mask = o_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) - tl.store(o_ptrs, tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty), mask=o_mask) + if PADDED_HEAD_V: + o_mask = o_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) + tl.store(o_ptrs, tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=Out.type.element_ty), mask=o_mask) # Write zeros to LSE l_ptrs = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m + offs_m * stride_lse_m @@ -723,11 +843,11 @@ def attn_fwd(Q, K, V, bias, # Initialize for processing # Compute pointers for all the tensors used in this kernel. q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm - q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qk k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn - k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn + k_ptrs = k_offset + offs_d_qk[:, None] * stride_kk + offs_n[None, :] * stride_kn v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk - v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn + v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d_v[None, :] * stride_vn if USE_BIAS: # Note: this might get large enough to overflow on some configs bias_offset = off_h_q * stride_bh @@ -759,12 +879,12 @@ def attn_fwd(Q, K, V, bias, # initialize pointer to m and l m_i = tl.full([BLOCK_M], float("-inf"), dtype=ACCUMULATOR_TYPE) l_i = tl.full([BLOCK_M], 1.0, dtype=ACCUMULATOR_TYPE) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=ACCUMULATOR_TYPE) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=ACCUMULATOR_TYPE) # Q is loaded once at the beginning and shared by all N blocks. q_ptrs_mask = offs_m[:, None] < seqlen_q - if PADDED_HEAD: - q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + if PADDED_HEAD_QK: + q_ptrs_mask = q_ptrs_mask & (offs_d_qk[None, :] < ACTUAL_BLOCK_DMODEL_QK) q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) @@ -781,17 +901,17 @@ def attn_fwd(Q, K, V, bias, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, - offs_m, offs_n, offs_d, + offs_m, offs_n, offs_d_qk, offs_d_v, block_min, # Start of front masked blocks block_max, # End of front masked blocks 0, # n_extra_tokens (0 for front blocks, only relevant for last block) alibi_slope, - descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, + q_descale, k_descale, v_descale, IS_FP8, FP8_MAX, FP8_P_DESCALE, IS_CAUSAL, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, + BLOCK_M, BLOCK_DMODEL_QK, BLOCK_DMODEL_V, BLOCK_N, PRE_LOAD_V, - ENABLE_DROPOUT, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL, SM_SCALE, + ENABLE_DROPOUT, PADDED_HEAD_QK, PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK, ACTUAL_BLOCK_DMODEL_V, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, @@ -812,15 +932,15 @@ def attn_fwd(Q, K, V, bias, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, - offs_m, offs_n, offs_d, + offs_m, offs_n, offs_d_qk, offs_d_v, block_min, # Start of range: 0 block_max, # End of range: n_full_blocks * BLOCK_N alibi_slope, - descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, + q_descale, k_descale, v_descale, IS_FP8, FP8_MAX, FP8_P_DESCALE, + BLOCK_M, BLOCK_DMODEL_QK, BLOCK_DMODEL_V, BLOCK_N, PRE_LOAD_V, - ENABLE_DROPOUT, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL, SM_SCALE, + ENABLE_DROPOUT, PADDED_HEAD_QK, PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK, ACTUAL_BLOCK_DMODEL_V, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, ACCUMULATOR_TYPE=ACCUMULATOR_TYPE ) @@ -837,17 +957,17 @@ def attn_fwd(Q, K, V, bias, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs, - offs_m, offs_n, offs_d, + offs_m, offs_n, offs_d_qk, offs_d_v, block_min, # Start of range: n_full_blocks * BLOCK_N block_max, # End of range: n_visible_k_blocks * BLOCK_N n_extra_tokens, # Padding tokens in last block alibi_slope, - descale_q, descale_k, descale_v, IS_FP8, FP8_MAX, + q_descale, k_descale, v_descale, IS_FP8, FP8_MAX, FP8_P_DESCALE, IS_CAUSAL, # Use actual causal flag - BLOCK_M, BLOCK_DMODEL, BLOCK_N, + BLOCK_M, BLOCK_DMODEL_QK, BLOCK_DMODEL_V, BLOCK_N, PRE_LOAD_V, - ENABLE_DROPOUT, PADDED_HEAD, - ACTUAL_BLOCK_DMODEL, SM_SCALE, + ENABLE_DROPOUT, PADDED_HEAD_QK, PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK, ACTUAL_BLOCK_DMODEL_V, SM_SCALE, USE_ALIBI=USE_ALIBI, USE_EXP2=USE_EXP2, RETURN_SCORES=RETURN_SCORES, USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, @@ -933,7 +1053,7 @@ def attn_fwd(Q, K, V, bias, end_m_idx = (start_m + 1) * BLOCK_M if causal_start_idx < end_m_idx: # This block contains the boundary - need to mask acc - out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) + out_mask_boundary = tl.full((BLOCK_DMODEL_V, ), causal_start_idx, dtype=tl.int32) out_ptrs_mask = row_indices[:, None] >= out_mask_boundary[None, :] z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) @@ -958,19 +1078,14 @@ def attn_fwd(Q, K, V, bias, # write back O o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om - o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on - o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d_v[None, :] * stride_on + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL_V], 1, dtype=tl.int1) if overflow_size > 0: o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) - if PADDED_HEAD: - o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + if PADDED_HEAD_V: + o_ptrs_mask = o_ptrs_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) - if FP8_OUTPUT: - scale_acc, descale_acc = compute_fp8_scaling_factors(acc, FP8_MAX) - tl.store(Descale_O + off_z * stride_descale_o_z + off_h_q, descale_acc) - tl.store(o_ptrs, (acc * scale_acc).to(Out.type.element_ty), mask=o_ptrs_mask) - else: - tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) def attention_prefill_forward_triton_impl( q: torch.Tensor, @@ -997,10 +1112,12 @@ def attention_prefill_forward_triton_impl( return_softmax: bool, use_exp2: bool, # fp8 - descale_q: Optional[torch.Tensor], - descale_k: Optional[torch.Tensor], - descale_v: Optional[torch.Tensor], - descale_o: Optional[torch.Tensor], + q_descale: Optional[torch.Tensor], + k_descale: Optional[torch.Tensor], + v_descale: Optional[torch.Tensor], + # seqused for FA v3 + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, ): # get params, strides and shape IS_VARLEN = layout == "thd" @@ -1027,12 +1144,12 @@ def attention_prefill_forward_triton_impl( assert max_seqlens_k is not None and max_seqlens_k > 0, "max_seqlens_k must be provided and positive for varlen layout" # assert head dimensions - assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert head_size_q == head_size_k, f"head sizes must match: q={head_size_q}, k={head_size_k}" assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" # assert output shapes - assert o.shape == (total_seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_q)}" + assert o.shape == (total_seqlen_q, nheads_q, head_size_v), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_v)}" # assert cu_seqlens assert cu_seqlens_q.dtype == torch.int32, f"cu_seqlens_q must be int32, got {cu_seqlens_q.dtype}" @@ -1044,7 +1161,7 @@ def attention_prefill_forward_triton_impl( # set vars batch = len(cu_seqlens_q) - 1 - head_size = head_size_q + head_size_qk = head_size_q # softmax_lse shape softmax_lse = torch.zeros((nheads_q, total_seqlen_q), device=q.device, dtype=torch.float32) @@ -1065,7 +1182,7 @@ def attention_prefill_forward_triton_impl( assert batch_q == batch_k == batch_v, f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" # assert head dimensions - assert head_size_q == head_size_k == head_size_v, f"head sizes must match: q={head_size_q}, k={head_size_k}, v={head_size_v}" + assert head_size_q == head_size_k, f"head sizes must match: q={head_size_q}, k={head_size_k}" assert nheads_k == nheads_v, f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" assert nheads_q % nheads_k == 0, f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" @@ -1073,11 +1190,11 @@ def attention_prefill_forward_triton_impl( assert seqlen_k == seqlen_v, f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" # assert output shapes - assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_q), f"o shape {o.shape} != expected {(batch_q, seqlen_q, nheads_q, head_size_q)}" + assert o.shape == (batch_q, seqlen_q, nheads_q, head_size_v), f"o shape {o.shape} != expected {(batch_q, seqlen_q, nheads_q, head_size_v)}" # set vars batch = batch_q - head_size = head_size_q + head_size_qk = head_size_q max_seqlens_q = seqlen_q max_seqlens_k = seqlen_k @@ -1098,29 +1215,32 @@ def attention_prefill_forward_triton_impl( FP8_MAX = torch.finfo(q.dtype).max - # Check descale tensors - assert descale_q is not None, "descale_q must be provided when using fp8" - assert descale_k is not None, "descale_k must be provided when using fp8" - assert descale_v is not None, "descale_v must be provided when using fp8" + # Check and create default descale tensors if not provided + if (q_descale is None) or (k_descale is None) or (v_descale is None): + import warnings + warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) + # Create default descale tensors if not provided + if q_descale is None: + q_descale = torch.ones(batch, nheads_q, dtype=torch.float32, device=q.device) + if k_descale is None: + k_descale = torch.ones(batch, nheads_k, dtype=torch.float32, device=q.device) + if v_descale is None: + v_descale = torch.ones(batch, nheads_k, dtype=torch.float32, device=q.device) - if is_fp8(o): - FP8_OUTPUT = True - assert descale_o is not None, f"descale_o is None. In fp8, you need to pass a tensor for descale_o along with a tensor o." - else: - FP8_OUTPUT = False - # o should be fp32 or fp16/bf16 - assert o.dtype in [torch.float16, torch.bfloat16, torch.float32], \ - f"Output tensor o must be fp16, bf16, or fp32 when using fp8, got {o.dtype}" - - stride_descale_q_z = descale_q.stride(0) if descale_q is not None else None - stride_descale_k_z = descale_k.stride(0) if descale_k is not None else None - stride_descale_v_z = descale_v.stride(0) if descale_v is not None else None - stride_descale_o_z = descale_o.stride(0) if descale_o is not None else None + # o should be fp32 or fp16/bf16 + assert o.dtype in [torch.float16, torch.bfloat16, torch.float32], \ + f"Output tensor o must be fp16, bf16, or fp32 when using fp8, got {o.dtype}" + + stride_q_descale_z = q_descale.stride(0) if q_descale is not None else 0 + stride_k_descale_z = k_descale.stride(0) if k_descale is not None else 0 + stride_v_descale_z = v_descale.stride(0) if v_descale is not None else 0 + + if DEBUG: + print(f"FP8 path triggered in fwd_prefill.py") else: FP8_MAX = None - FP8_OUTPUT = False - descale_q = descale_k = descale_v = descale_o = None - stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_o_z = None + q_descale = k_descale = v_descale = None + stride_q_descale_z = stride_k_descale_z = stride_v_descale_z = None # check output dtype matches input dtype when not using fp8 assert o.dtype == q.dtype, f"Output dtype {o.dtype} must match input dtype {q.dtype} when not using fp8" @@ -1132,11 +1252,13 @@ def attention_prefill_forward_triton_impl( if (bias is not None): assert (bias.numel() < 2**31) - # Get closest power of 2 over or equal to 32. - padded_d_model = 1 << (head_size - 1).bit_length() + # Get closest power of 2 over or equal to 32 for both QK and V dimensions + padded_d_model_qk = 1 << (head_size_qk - 1).bit_length() + padded_d_model_v = 1 << (head_size_v - 1).bit_length() # Smallest head_dim supported is 16. If smaller, the tile in the # kernel is padded - there is no padding in memory for any dims. - padded_d_model = max(padded_d_model, 16) + padded_d_model_qk = max(padded_d_model_qk, 16) + padded_d_model_v = max(padded_d_model_v, 16) # sd_mask is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according @@ -1166,7 +1288,7 @@ def attention_prefill_forward_triton_impl( # launch kernel grid = lambda META: (batch, nheads_q, triton.cdiv(max_seqlens_q, META['BLOCK_M'])) attn_fwd[grid](q, k, v, bias, - descale_q, descale_k, descale_v, descale_o, stride_descale_q_z, stride_descale_k_z, stride_descale_v_z, stride_descale_o_z, + q_descale, k_descale, v_descale, stride_q_descale_z, stride_k_descale_z, stride_v_descale_z, sm_scale, softmax_lse, o, stride_qb, stride_qh, stride_qm, stride_qd, stride_kb, stride_kh, stride_kn, stride_kd, @@ -1177,13 +1299,15 @@ def attention_prefill_forward_triton_impl( stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, + seqused_q, seqused_k, # Pass seqused tensors dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, - HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, + HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL_QK=head_size_qk, ACTUAL_BLOCK_DMODEL_V=head_size_v, MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, USE_SLIDING_WINDOW=use_sliding_window, WINDOW_SIZE_LEFT=window_size_left, WINDOW_SIZE_RIGHT=window_size_right, IS_VARLEN=IS_VARLEN, - BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, + BLOCK_DMODEL_QK=padded_d_model_qk, BLOCK_DMODEL_V=padded_d_model_v, USE_BIAS=False if bias is None else True, USE_ALIBI=use_alibi, ENABLE_DROPOUT=dropout_p > 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, NEEDS_SDMASK=NEEDS_SDMASK, - IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_OUTPUT=FP8_OUTPUT) + IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, FP8_P_DESCALE=False, + USE_SEQUSED=(seqused_q is not None or seqused_k is not None)) # Add flag for seqused return softmax_lse, sd_mask if return_softmax else None \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/fwd_ref.py b/flash_attn/flash_attn_triton_amd/fwd_ref.py old mode 100644 new mode 100755 index 2265af12096..a8ca54a7ec3 --- a/flash_attn/flash_attn_triton_amd/fwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/fwd_ref.py @@ -9,7 +9,7 @@ def attention_forward_core_ref_impl( q, k, v, sm_scale, causal, window_size_left, window_size_right, dropout_p, philox_seed, philox_offset, alibi_slopes, use_exp2, - cache_seqlens=None + cache_seqlens=None, block_table=None, paged_kv_block_size=None ): if DEBUG_CORE: print() @@ -26,17 +26,99 @@ def attention_forward_core_ref_impl( print("philox_offset:", philox_offset) print("use_exp2:", use_exp2) print("cache_seqlens:", cache_seqlens) + print("block_table:", block_table) + print("paged_kv_block_size:", paged_kv_block_size) # cast to float32 q = q.to(torch.float32) - k = k.to(torch.float32) - v = v.to(torch.float32) - - # get seqlens - L_q, L_k = q.shape[1], k.shape[1] - # Compute attention scores - attention_scores = torch.matmul(q, k.transpose(-2, -1)) + # Check if we're in paged KV mode + is_paged = block_table is not None and paged_kv_block_size is not None + + if False: # Debug paged attention (disabled for production) + print(f"\n=== attention_forward_core_ref_impl DEBUG ===") + print(f"is_paged: {is_paged}") + print(f"block_table: {block_table.shape if block_table is not None else None}") + print(f"paged_kv_block_size: {paged_kv_block_size}") + if is_paged: + print(f"k shape (paged): {k.shape}") + print(f"v shape (paged): {v.shape}") + print(f"cache_seqlens: {cache_seqlens}") + + if is_paged: + # In paged mode, k and v are [num_blocks, block_size, nheads_k, head_dim] + # We'll compute attention on-the-fly without reconstructing + nheads_q = q.shape[0] + L_q = q.shape[1] + head_dim = q.shape[2] + + # Get number of KV heads from the cache + nheads_k = k.shape[2] # k shape: [num_blocks, block_size, nheads_k, head_dim] + + # Handle MQA/GQA + assert nheads_q % nheads_k == 0, f"nheads_q ({nheads_q}) must be divisible by nheads_k ({nheads_k})" + group_size = nheads_q // nheads_k + + # Determine the actual KV sequence length from cache_seqlens + L_k = cache_seqlens if isinstance(cache_seqlens, int) else cache_seqlens.item() + + if False: # Debug disabled + print(f"L_q: {L_q}, L_k: {L_k}, nheads_q: {nheads_q}, nheads_k: {nheads_k}, group_size: {group_size}, head_dim: {head_dim}") + print(f"block_table contents: {block_table if block_table is not None else 'None'}") + + # Initialize attention scores + attention_scores = torch.zeros((nheads_q, L_q, L_k), dtype=torch.float32, device=q.device) + + # Compute attention scores on-the-fly by accessing blocks directly + for kv_pos in range(L_k): + # Calculate which block and position within block + block_idx = kv_pos // paged_kv_block_size + within_block_idx = kv_pos % paged_kv_block_size + + # Get the physical block number from block_table + # block_table is [1, num_blocks] for single batch in core function + if block_table.dim() == 2: + physical_block = block_table[0, block_idx].item() + else: + physical_block = block_table[block_idx].item() + + # Debug output disabled + # if kv_pos == 0: + # print(f"First KV access: block_idx={block_idx}, within_block={within_block_idx}, physical_block={physical_block}") + # print(f"k_vec shape will be: {k[physical_block, within_block_idx, :, :].shape}") + + # Access k values directly from paged cache + # k shape: [num_blocks, block_size, nheads_k, head_dim] + k_vec = k[physical_block, within_block_idx, :, :].to(torch.float32) # [nheads_k, head_dim] + + # For GQA/MQA, we need to repeat k_vec for each group + if group_size > 1: + # Expand k_vec to match query heads + # k_vec: [nheads_k, head_dim] -> [nheads_q, head_dim] + k_vec = k_vec.repeat_interleave(group_size, dim=0) + + # Compute dot product with all query positions + # q is [nheads_q, L_q, head_dim], k_vec is [nheads_q, head_dim] + # Result should be [nheads_q, L_q] for this kv_pos + attention_scores[:, :, kv_pos] = torch.sum(q * k_vec.unsqueeze(1), dim=-1) + + # Keep k and v in original format for later v computation + k_paged = k + v_paged = v + + # Debug output disabled + # print(f"attention_scores computed shape: {attention_scores.shape}") + # print(f"attention_scores sample values: {attention_scores[0, 0, :5]}") + else: + # Standard non-paged mode + k = k.to(torch.float32) + v = v.to(torch.float32) + + # get seqlens + L_q, L_k = q.shape[1], k.shape[1] + + # Compute attention scores + attention_scores = torch.matmul(q, k.transpose(-2, -1)) if DEBUG_CORE: print("attention_scores:", attention_scores, attention_scores.shape) @@ -256,7 +338,55 @@ def attention_forward_core_ref_impl( print("softmax_lse:", softmax_lse, softmax_lse.shape) # Compute output - o = torch.matmul(p, v) + if is_paged: + # Compute output on-the-fly using paged v cache + nheads_q = p.shape[0] + L_q = p.shape[1] + nheads_v = v_paged.shape[2] # [num_blocks, block_size, nheads_v, head_dim] + head_dim = v_paged.shape[3] + + # Handle MQA/GQA for v + assert nheads_q % nheads_v == 0, f"nheads_q ({nheads_q}) must be divisible by nheads_v ({nheads_v})" + v_group_size = nheads_q // nheads_v + + o = torch.zeros((nheads_q, L_q, head_dim), dtype=torch.float32, device=p.device) + + # Accumulate weighted v values + for kv_pos in range(L_k): + # Calculate which block and position within block + block_idx = kv_pos // paged_kv_block_size + within_block_idx = kv_pos % paged_kv_block_size + + # Get the physical block number from block_table + if block_table.dim() == 2: + physical_block = block_table[0, block_idx].item() + else: + physical_block = block_table[block_idx].item() + + # Access v values directly from paged cache + # v_paged shape: [num_blocks, block_size, nheads_v, head_dim] + v_vec = v_paged[physical_block, within_block_idx, :, :].to(torch.float32) # [nheads_v, head_dim] + + # For GQA/MQA, we need to repeat v_vec for each group + if v_group_size > 1: + # Expand v_vec to match query heads + # v_vec: [nheads_v, head_dim] -> [nheads_q, head_dim] + v_vec = v_vec.repeat_interleave(v_group_size, dim=0) + + # Weight by attention probabilities + # p is [nheads_q, L_q, L_k], we need p[:, :, kv_pos] which is [nheads_q, L_q] + # v_vec is [nheads_q, head_dim] + # We want to add p[:, :, kv_pos] * v_vec to each query position + weights = p[:, :, kv_pos].unsqueeze(-1) # [nheads_q, L_q, 1] + o += weights * v_vec.unsqueeze(1) # [nheads_q, L_q, head_dim] + else: + o = torch.matmul(p, v) + + # Debug output disabled + # if False: + # print(f"Output o shape: {o.shape}") + # print(f"Output o sample values: {o[0, 0, :5]}") + if DEBUG_CORE: print("o:", o, o.shape) @@ -439,7 +569,7 @@ def attention_varlen_forward_pytorch_ref_impl( return o, softmax_lse, sd_mask -def attention_forward_pytorch_ref_impl( +def attention_prefill_forward_ref_impl( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -516,12 +646,61 @@ def attention_decode_forward_ref_impl( layout: Literal["bshd"], cache_seqlens: Optional[torch.Tensor], cache_batch_idx: Optional[torch.Tensor], + block_table: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, ): """Compute reference output for decode attention using PyTorch's built-in functions""" + if False: # Permanently disabled old debug output + pass + # print("\n========== attention_decode_forward_ref_impl inputs ==========") + # print(f"q shape: {q.shape}, dtype: {q.dtype}, device: {q.device}") + print(f"q values:\n{q}") + print(f"\nk_cache shape: {k_cache.shape}, dtype: {k_cache.dtype}, device: {k_cache.device}") + print(f"k_cache values:\n{k_cache}") + print(f"\nv_cache shape: {v_cache.shape}, dtype: {v_cache.dtype}, device: {v_cache.device}") + print(f"v_cache values:\n{v_cache}") + print(f"\nk_new: {k_new.shape if k_new is not None else None}, dtype: {k_new.dtype if k_new is not None else None}") + if k_new is not None: + print(f"k_new values:\n{k_new}") + print(f"\nv_new: {v_new.shape if v_new is not None else None}, dtype: {v_new.dtype if v_new is not None else None}") + if v_new is not None: + print(f"v_new values:\n{v_new}") + print(f"\nout shape: {out.shape}, dtype: {out.dtype}, device: {out.device}") + print(f"out values:\n{out}") + print(f"\nsm_scale: {sm_scale}") + print(f"causal: {causal}") + print(f"window_size_left: {window_size_left}") + print(f"window_size_right: {window_size_right}") + print(f"\nalibi_slopes: {alibi_slopes.shape if alibi_slopes is not None else None}, dtype: {alibi_slopes.dtype if alibi_slopes is not None else None}") + if alibi_slopes is not None: + print(f"alibi_slopes values:\n{alibi_slopes}") + print(f"\nlayout: {layout}") + print(f"cache_seqlens: {cache_seqlens}") + if cache_seqlens is not None and torch.is_tensor(cache_seqlens): + print(f"cache_seqlens values: {cache_seqlens}") + print(f"cache_batch_idx: {cache_batch_idx}") + if cache_batch_idx is not None: + print(f"cache_batch_idx values: {cache_batch_idx}") + print(f"\nblock_table: {block_table.shape if block_table is not None else None}, dtype: {block_table.dtype if block_table is not None else None}") + if block_table is not None: + print(f"block_table values:\n{block_table}") + print("=" * 60) + # get batch size before any layout conversion batch_size = q.shape[0] + # Determine if we're in paged KV mode + is_paged = block_table is not None + if is_paged: + # Infer block size from cache shape + # k_cache shape for paged: [num_blocks, block_size, nheads, head_dim] + paged_kv_block_size = k_cache.shape[1] + else: + paged_kv_block_size = None + # handle cache_batch_idx if cache_batch_idx is not None: # remap batch indices for cache access @@ -531,39 +710,79 @@ def attention_decode_forward_ref_impl( # copy new keys and values into cache if provided (before any layout conversion) if k_new is not None and v_new is not None: - _, seq_len_new, _, _ = k_new.shape # shape is [batch, seq_len, nheads, head_dim] for bshd layout - - for b in range(batch_size): - cache_idx = batch_indices[b].item() if torch.is_tensor(batch_indices) else batch_indices + if is_paged: + # For paged KV cache, we need to update the blocks with new k/v values + _, seq_len_new, _, _ = k_new.shape # shape is [batch, seq_len, nheads, head_dim] for bshd layout - # determine where to place new k/v in cache - if cache_seqlens is not None: - if torch.is_tensor(cache_seqlens): - start_pos = cache_seqlens[b].item() + for b in range(batch_size): + # Determine where to place new k/v in cache + if cache_seqlens is not None: + if torch.is_tensor(cache_seqlens): + start_pos = cache_seqlens[b].item() + else: + start_pos = cache_seqlens else: - start_pos = cache_seqlens - else: - # if no cache_seqlens, assume we're filling from the beginning - start_pos = 0 - - end_pos = start_pos + seq_len_new + start_pos = 0 + + # For each new position, find the corresponding block and update it + for pos_offset in range(seq_len_new): + kv_pos = start_pos + pos_offset + + # Calculate which block and position within block + block_idx = kv_pos // paged_kv_block_size + within_block_idx = kv_pos % paged_kv_block_size + + # Get the physical block number from block_table + physical_block = block_table[b, block_idx].item() + + # Update the k and v values in the paged cache + # k_cache shape: [num_blocks, block_size, nheads, head_dim] + # k_new shape: [batch, seq_len, nheads, head_dim] + k_cache[physical_block, within_block_idx, :, :] = k_new[b, pos_offset, :, :] + v_cache[physical_block, within_block_idx, :, :] = v_new[b, pos_offset, :, :] + else: + _, seq_len_new, _, _ = k_new.shape # shape is [batch, seq_len, nheads, head_dim] for bshd layout - # copy new keys and values into cache (both are in bshd layout) - k_cache[cache_idx, start_pos:end_pos, :, :] = k_new[b, :, :, :] - v_cache[cache_idx, start_pos:end_pos, :, :] = v_new[b, :, :, :] + for b in range(batch_size): + cache_idx = batch_indices[b].item() if torch.is_tensor(batch_indices) else batch_indices + + # determine where to place new k/v in cache + if cache_seqlens is not None: + if torch.is_tensor(cache_seqlens): + start_pos = cache_seqlens[b].item() + else: + start_pos = cache_seqlens + else: + # if no cache_seqlens, assume we're filling from the beginning + start_pos = 0 + + end_pos = start_pos + seq_len_new + + # copy new keys and values into cache (both are in bshd layout) + k_cache[cache_idx, start_pos:end_pos, :, :] = k_new[b, :, :, :] + v_cache[cache_idx, start_pos:end_pos, :, :] = v_new[b, :, :, :] # ensure the layout is 'bhsd' if layout == "bshd": q = q.transpose(1, 2).contiguous() - k_cache = k_cache.transpose(1, 2).contiguous() - v_cache = v_cache.transpose(1, 2).contiguous() + if not is_paged: + k_cache = k_cache.transpose(1, 2).contiguous() + v_cache = v_cache.transpose(1, 2).contiguous() elif layout != "bhsd": raise ValueError(f"Unknown layout {layout}") # prepare tensors batch_size_q, nheads_q, seq_len_q, head_dim = q.shape - batch_size_cache, nheads_k, max_cache_len, head_dim_k = k_cache.shape - _, nheads_v, _, head_dim_v = v_cache.shape + + if is_paged: + # For paged cache: [num_blocks, block_size, nheads, head_dim] + num_blocks, block_size, nheads_k, head_dim_k = k_cache.shape + _, _, nheads_v, head_dim_v = v_cache.shape + max_cache_len = None # Not directly available in paged mode + batch_size_cache = None # Not applicable in paged mode + else: + batch_size_cache, nheads_k, max_cache_len, head_dim_k = k_cache.shape + _, nheads_v, _, head_dim_v = v_cache.shape # validate dimensions assert head_dim == head_dim_k == head_dim_v, f"Head dimensions must match: {head_dim}, {head_dim_k}, {head_dim_v}" @@ -586,7 +805,8 @@ def attention_decode_forward_ref_impl( # process each batch element for b in range(batch_size): - cache_idx = batch_indices[b].item() if torch.is_tensor(batch_indices) else batch_indices + if not is_paged: + cache_idx = batch_indices[b].item() if torch.is_tensor(batch_indices) else batch_indices # determine valid cache length for this batch element if cache_seqlens is not None: @@ -601,25 +821,41 @@ def attention_decode_forward_ref_impl( _, seq_len_new, _, _ = k_new.shape cache_len += seq_len_new else: - cache_len = max_cache_len - - # CHANGE: Extract the full cache, not just valid portion - # This matches what the test does - it uses full k_cache_rep/v_cache_rep - k_b = k_cache[cache_idx, :, :, :] # [nheads_k, max_cache_len, head_dim] - v_b = v_cache[cache_idx, :, :, :] # [nheads_v, max_cache_len, head_dim] - q_b = q[b:b+1, :, :, :] # [1, nheads_q, seq_len_q, head_dim] + if is_paged: + # For paged mode, we need cache_seqlens to know the valid length + raise ValueError("cache_seqlens must be provided for paged KV cache") + else: + cache_len = max_cache_len - # handle MQA/GQA by expanding k and v - if group_size != 1: - # expand k and v to match q's number of heads - k_b = k_b.unsqueeze(1).expand(-1, group_size, -1, -1) - k_b = k_b.reshape(nheads_q, max_cache_len, head_dim) + if is_paged: + # For paged KV cache, pass the cache and block table directly + # Extract block table for this batch element + block_table_b = block_table[b:b+1, :] # [1, num_blocks] + k_b = k_cache # Pass entire paged cache + v_b = v_cache # Pass entire paged cache + q_b = q[b:b+1, :, :, :] # [1, nheads_q, seq_len_q, head_dim] - v_b = v_b.unsqueeze(1).expand(-1, group_size, -1, -1) - v_b = v_b.reshape(nheads_q, max_cache_len, head_dim) - - # reshape for attention_forward_core_ref_impl - q_b = q_b.reshape(nheads_q, seq_len_q, head_dim) + # For paged mode with MQA/GQA, we handle expansion in the core function + # Just reshape q for now + q_b = q_b.reshape(nheads_q, seq_len_q, head_dim) + else: + # Standard non-paged mode + k_b = k_cache[cache_idx, :, :, :] # [nheads_k, max_cache_len, head_dim] + v_b = v_cache[cache_idx, :, :, :] # [nheads_v, max_cache_len, head_dim] + q_b = q[b:b+1, :, :, :] # [1, nheads_q, seq_len_q, head_dim] + block_table_b = None + + # handle MQA/GQA by expanding k and v + if group_size != 1: + # expand k and v to match q's number of heads + k_b = k_b.unsqueeze(1).expand(-1, group_size, -1, -1) + k_b = k_b.reshape(nheads_q, max_cache_len, head_dim) + + v_b = v_b.unsqueeze(1).expand(-1, group_size, -1, -1) + v_b = v_b.reshape(nheads_q, max_cache_len, head_dim) + + # reshape for attention_forward_core_ref_impl + q_b = q_b.reshape(nheads_q, seq_len_q, head_dim) # handle alibi slopes for this batch alibi_slopes_b = None @@ -635,6 +871,8 @@ def attention_decode_forward_ref_impl( dropout_p=0.0, philox_seed=None, philox_offset=None, alibi_slopes=alibi_slopes_b, use_exp2=True, cache_seqlens=cache_len, # Pass valid cache length + block_table=block_table_b, # Pass block table for paged mode + paged_kv_block_size=paged_kv_block_size, # Pass block size for paged mode ) # store outputs diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index a6b83f64365..3dc443abf67 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -5,7 +5,7 @@ from .bwd_prefill_fused_atomics import attention_prefill_backward_triton_fused_atomics_impl from .bwd_prefill_fused_no_atomics import attention_prefill_backward_triton_split_fused_no_atomics_impl from .fwd_decode import attention_decode_forward_triton_impl -from .fwd_ref import attention_forward_pytorch_ref_impl, attention_decode_forward_ref_impl +from .fwd_ref import attention_prefill_forward_ref_impl, attention_decode_forward_ref_impl from .bwd_ref import attention_backward_pytorch_ref_impl from .utils import DEBUG, USE_REF, MetaData, is_fp8 from einops import rearrange, repeat @@ -31,8 +31,7 @@ def fwd(q: torch.Tensor, gen_: Optional[torch.Tensor] = None, descale_q: Optional[torch.Tensor] = None, descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_o: Optional[torch.Tensor] = None + descale_v: Optional[torch.Tensor] = None ): if DEBUG: @@ -53,11 +52,12 @@ def fwd(q: torch.Tensor, print("descale_q:", descale_q, descale_q.shape if descale_q is not None else None) print("descale_k:", descale_k, descale_k.shape if descale_k is not None else None) print("descale_v:", descale_v, descale_v.shape if descale_v is not None else None) - print("descale_o:", descale_o, descale_o.shape if descale_o is not None else None) if is_fp8(q): assert out is not None, "fp8 output tensor should be passed in." - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"For fp8, you need to pass descale factors for q, k and v" + if (descale_q is None) or (descale_k is None) or (descale_v is None): + import warnings + warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) else: out = torch.zeros_like(q) if out is None else out.zero_() @@ -93,7 +93,7 @@ def fwd(q: torch.Tensor, if USE_REF: if DEBUG: print("Using reference implementation") - softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( + softmax_lse_ref, sd_mask_ref = attention_prefill_forward_ref_impl( q, k, v, @@ -140,16 +140,13 @@ def fwd(q: torch.Tensor, USE_EXP2, descale_q, descale_k, - descale_v, - descale_o) + descale_v) softmax_lse=softmax_lse_triton sd_mask=sd_mask_triton if DEBUG: print("flash_attn_triton_amd.py::fwd outputs") print("o:", out, out.shape) - if is_fp8(out): - print("descale_o:", descale_o, descale_o.shape if descale_o is not None else None) print("softmax_lse:", softmax_lse, softmax_lse.shape) print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None ) print("rng_state:", rng_state) @@ -405,8 +402,7 @@ def varlen_fwd( gen_: Optional[torch.Tensor] = None, descale_q: Optional[torch.Tensor] = None, descale_k: Optional[torch.Tensor] = None, - descale_v: Optional[torch.Tensor] = None, - descale_o: Optional[torch.Tensor] = None + descale_v: Optional[torch.Tensor] = None ): if DEBUG: @@ -432,7 +428,9 @@ def varlen_fwd( if is_fp8(q): assert out is not None, "fp8 output tensor should be passed in." - assert (descale_q is not None) and (descale_k is not None) and (descale_v is not None), f"For fp8, you need to pass descale factors for q, k and v" + if (descale_q is None) or (descale_k is None) or (descale_v is None): + import warnings + warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) else: out = torch.zeros_like(q) if out is None else out.zero_() @@ -468,7 +466,7 @@ def varlen_fwd( if USE_REF: if DEBUG: print("Using reference implementation") - softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl( + softmax_lse_ref, sd_mask_ref = attention_prefill_forward_ref_impl( q, k, v, @@ -515,8 +513,7 @@ def varlen_fwd( USE_EXP2, descale_q, descale_k, - descale_v, - descale_o) + descale_v) softmax_lse=softmax_lse_triton sd_mask=sd_mask_triton @@ -899,6 +896,7 @@ def fwd_kvcache( metadata.layout, metadata.cache_seqlens, metadata.cache_batch_idx, + block_table, ) softmax_lse=softmax_lse_ref else: @@ -919,6 +917,7 @@ def fwd_kvcache( metadata.layout, metadata.cache_seqlens, metadata.cache_batch_idx, + block_table, ) softmax_lse = softmax_lse_triton diff --git a/flash_attn/flash_attn_triton_amd/interface_fa_v3.py b/flash_attn/flash_attn_triton_amd/interface_fa_v3.py new file mode 100755 index 00000000000..be8e2d3cbeb --- /dev/null +++ b/flash_attn/flash_attn_triton_amd/interface_fa_v3.py @@ -0,0 +1,660 @@ +import torch +import os +from .fwd_prefill import attention_prefill_forward_triton_impl +from .bwd_prefill_split import attention_prefill_backward_triton_split_impl +from .bwd_prefill_fused_atomics import attention_prefill_backward_triton_fused_atomics_impl +from .bwd_prefill_fused_no_atomics import attention_prefill_backward_triton_split_fused_no_atomics_impl +from .fwd_decode import attention_decode_forward_triton_impl +from .fwd_ref import attention_prefill_forward_ref_impl, attention_decode_forward_ref_impl +from .bwd_ref import attention_backward_pytorch_ref_impl +from .utils import DEBUG, USE_REF, MetaData, is_fp8 +from einops import rearrange, repeat +from flash_attn.layers.rotary import apply_rotary_emb +from typing import Optional, Union, Tuple + +USE_EXP2 = True +BWD_MODE = os.environ.get('BWD_MODE', 'fused_no_atomics').lower() +USE_DECODE_PATH = os.environ.get('FLASH_ATTENTION_V3_USE_DECODE', '0') == '1' + +def fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_new: Optional[torch.Tensor], + v_new: Optional[torch.Tensor], + qv: Optional[torch.Tensor], + out: Optional[torch.Tensor], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + cu_seqlens_k_new: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + page_table: Optional[torch.Tensor], + kv_batch_idx: Optional[torch.Tensor], + leftpad_k: Optional[torch.Tensor], + rotary_cos: Optional[torch.Tensor], + rotary_sin: Optional[torch.Tensor], + seqlens_rotary: Optional[torch.Tensor], + q_descale: Optional[torch.Tensor], + k_descale: Optional[torch.Tensor], + v_descale: Optional[torch.Tensor], + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + attention_chunk: int, + softcap: float, + rotary_interleaved: bool, + scheduler_metadata=None, + num_splits: int = 1, + pack_gqa=None, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Flash Attention v3 forward pass compatible interface for AMD Triton implementation. + + This function maps v3 parameters to the existing AMD Triton implementation. + """ + + if DEBUG: + print() + print("interface_fa_v3.py::fwd inputs") + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("k_new:", k_new, k_new.shape if k_new is not None else None) + print("v_new:", v_new, v_new.shape if v_new is not None else None) + print("qv:", qv, qv.shape if qv is not None else None) + print("out:", out, out.shape if out is not None else None) + print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape if cu_seqlens_q is not None else None) + print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape if cu_seqlens_k is not None else None) + print("cu_seqlens_k_new:", cu_seqlens_k_new, cu_seqlens_k_new.shape if cu_seqlens_k_new is not None else None) + print("seqused_q:", seqused_q, seqused_q.shape if seqused_q is not None else None) + print("seqused_k:", seqused_k, seqused_k.shape if seqused_k is not None else None) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("page_table:", page_table, page_table.shape if page_table is not None else None) + print("kv_batch_idx:", kv_batch_idx, kv_batch_idx.shape if kv_batch_idx is not None else None) + print("leftpad_k:", leftpad_k, leftpad_k.shape if leftpad_k is not None else None) + print("rotary_cos:", rotary_cos, rotary_cos.shape if rotary_cos is not None else None) + print("rotary_sin:", rotary_sin, rotary_sin.shape if rotary_sin is not None else None) + print("seqlens_rotary:", seqlens_rotary, seqlens_rotary.shape if seqlens_rotary is not None else None) + print("q_descale:", q_descale, q_descale.shape if q_descale is not None else None) + print("k_descale:", k_descale, k_descale.shape if k_descale is not None else None) + print("v_descale:", v_descale, v_descale.shape if v_descale is not None else None) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("attention_chunk:", attention_chunk) + print("softcap:", softcap) + print("rotary_interleaved:", rotary_interleaved) + print("scheduler_metadata:", scheduler_metadata) + print("num_splits:", num_splits) + print("pack_gqa:", pack_gqa) + print("sm_margin:", sm_margin) + + # Handle qv packed input + if qv is not None: + raise NotImplementedError("QV packed input is not yet supported in the AMD Triton backend") + + + # Handle softcap + if softcap != 0.0: + raise NotImplementedError(f"Softcap is not yet supported in the AMD Triton backend (got softcap={softcap}, expected 0.0)") + + # Handle attention_chunk + if attention_chunk != 0 and attention_chunk != 1: + raise NotImplementedError(f"attention_chunk is not yet supported in the AMD Triton backend (got attention_chunk={attention_chunk})") + + + # Handle scheduler metadata + if scheduler_metadata is not None: + raise NotImplementedError("Scheduler metadata is not yet supported in the AMD Triton backend") + + # Handle pack_gqa + if pack_gqa is not None and pack_gqa is not False: + raise NotImplementedError(f"pack_gqa is not yet supported in the AMD Triton backend (got pack_gqa={pack_gqa})") + + # Handle num_splits + if num_splits != 1: + raise NotImplementedError(f"Split attention (num_splits > 1) is not yet supported in the AMD Triton backend (got num_splits={num_splits})") + + # Handle sm_margin + if sm_margin != 0: + raise NotImplementedError(f"sm_margin is not yet supported in the AMD Triton backend (got sm_margin={sm_margin}, expected 0)") + + # Handle leftpad_k + if leftpad_k is not None: + raise NotImplementedError("Left padding (leftpad_k) is not yet supported in the AMD Triton backend") + + # Handle cu_seqlens_k_new + if cu_seqlens_k_new is not None: + raise NotImplementedError("cu_seqlens_k_new is not yet supported in the AMD Triton backend") + + # if seqlens_rotary is not None: + # raise NotImplementedError("seqlens_rotary is not yet supported in the AMD Triton backend") + + # Setup metadata + metadata = MetaData(sm_scale=softmax_scale) + + + # Handle variable length sequences first to determine layout + # Determine layout based on tensor dimensions and cu_seqlens presence + if cu_seqlens_q is not None: + # Q has variable length - check tensor dimensions to confirm + if len(q.shape) == 3: # [total_seqlen, nheads, head_dim] + metadata.layout = "thd" + metadata.varlen = True + metadata.cu_seqlens_q = cu_seqlens_q + metadata.max_seqlens_q = max_seqlen_q + + # K might be varlen or batch mode + if cu_seqlens_k is not None: + metadata.cu_seqlens_k = cu_seqlens_k + metadata.max_seqlens_k = max_seqlen_k + else: + # K is in batch mode while Q is varlen (KV cache scenario) + metadata.cu_seqlens_k = None + metadata.max_seqlens_k = k.shape[1] if len(k.shape) == 4 else max_seqlen_k + else: + raise ValueError(f"cu_seqlens_q provided but q has shape {q.shape}, expected 3D tensor for varlen") + else: + # Regular batch mode + metadata.layout = "bshd" + metadata.varlen = False + metadata.cu_seqlens_q = None + metadata.cu_seqlens_k = None + metadata.max_seqlens_q = q.shape[1] if max_seqlen_q is None else max_seqlen_q + metadata.max_seqlens_k = k.shape[1] if max_seqlen_k is None else max_seqlen_k + + # Now determine if we should use decode or prefill kernel + # Decode kernel should be used for KV cache scenarios where: + # 1. k_new/v_new are provided - incremental KV cache update (primary KV cache indicator) + # 2. kv_batch_idx is provided - KV cache batch indexing (primary KV cache indicator) + # 3. seqused_k without seqused_q - indicates KV cache fill levels (not varlen masking) + # Note: In varlen, both seqused_q and seqused_k are used for sequence masking + # In KV cache, only seqused_k is used to track cache fill levels + if USE_DECODE_PATH: + # Force decode path + use_decode = True + else: + # Detect KV cache scenarios: + # - Clear KV cache indicators (k_new, v_new, kv_batch_idx) + # - OR seqused_k without seqused_q (KV cache fill tracking, not varlen masking) + use_decode = ( + k_new is not None or # Have new KV to append (KV cache indicator) + v_new is not None or # Have new KV to append (KV cache indicator) + kv_batch_idx is not None or # Have KV cache batch indexing (KV cache indicator) + (seqused_k is not None and seqused_q is None) # KV cache fill levels (not varlen) + ) + + # Check for unsupported features with decode kernel + if use_decode: + if metadata.layout == "thd": + raise NotImplementedError("Varlen is not yet supported with the decode kernel in the AMD Triton backend") + if kv_batch_idx is not None: + raise NotImplementedError("kv_batch_idx is not yet supported with the decode kernel in the AMD Triton backend") + + + if out is None: + out_dtype = torch.float32 if is_fp8(q) else q.dtype + if metadata.layout == "bshd": + out = torch.zeros(q.shape[0], q.shape[1], q.shape[2], v.shape[-1], dtype=out_dtype, device=q.device) + elif metadata.layout == "thd": + out = torch.zeros(q.shape[0], q.shape[1], v.shape[-1], dtype=out_dtype, device=q.device) + else: + raise ValueError(f"Unsupported layout: {metadata.layout}. Only 'bshd' and 'thd' layouts are supported.") + else: + out = out.zero_() + + if is_fp8(q): + if (q_descale is None) or (k_descale is None) or (v_descale is None): + import warnings + warnings.warn("FP8 tensors detected but descale factors not provided. Using default scale of 1.0", UserWarning) + + # Get shape + if metadata.layout == "bshd": + batch, _, nheads_q, _ = q.shape + else: # "thd" layout for varlen + _, nheads_q, _ = q.shape + batch = len(cu_seqlens_q) - 1 if cu_seqlens_q is not None else 1 + + # Handle causal mask + if causal: + metadata.need_causal(True) + + # Handle alibi slopes (not directly supported in v3 interface, but we'll keep the logic) + alibi_slopes = None # V3 doesn't have alibi_slopes in the signature + if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError(f"Alibi can be (nheads,) or (batch_size, nheads). Given tensor with shape {alibi_slopes.shape}") + metadata.need_alibi(alibi_slopes, batch, nheads_q) + + # Handle dropout (v3 doesn't have dropout in forward) + dropout_p = 0.0 + return_softmax = False + metadata.need_dropout(dropout_p, return_softmax) + + # Handle rotary embeddings + if rotary_cos is not None and rotary_sin is not None: + metadata.need_rotary(rotary_sin, rotary_cos, rotary_interleaved) + + # Apply rotary embeddings if provided + if metadata.causal or window_size_left != -1 or window_size_right != -1: + q_rot = apply_rotary_emb( + q, + rotary_cos, + rotary_sin, + seqlen_offsets=seqlens_rotary, + interleaved=rotary_interleaved, + ) + q = q_rot.to(q.dtype) + + if k_new is not None: + k_rot = apply_rotary_emb( + k_new, + rotary_cos, + rotary_sin, + seqlen_offsets=seqlens_rotary, + interleaved=rotary_interleaved, + ) + k_new = k_rot.to(k.dtype) + + # Store RNG state + rng_state = torch.as_tensor([metadata.philox_seed, metadata.philox_offset]) + + # Call implementation + if USE_REF: + if DEBUG: + print("Using reference implementation") + + if use_decode: + if DEBUG: + print(f"Using decode reference implementation ( layout={metadata.layout}, cache_seqlens={seqused_k is not None}, k_new={k_new is not None}, v_new={v_new is not None}, kv_batch_idx={kv_batch_idx is not None})") + # Use decode reference implementation + softmax_lse = attention_decode_forward_ref_impl( + q, + k, # k_cache + v, # v_cache + k_new, + v_new, + out, + metadata.sm_scale, + metadata.causal, + window_size_left, + window_size_right, + metadata.alibi_slopes, + metadata.layout, + seqused_k, # cache_seqlens + kv_batch_idx, # cache_batch_idx + page_table, # block_table + q_descale, + k_descale, + v_descale, + ) + else: + if DEBUG: + print("Using prefill reference implementation") + # Use prefill reference implementation + softmax_lse_ref, sd_mask_ref = attention_prefill_forward_ref_impl( + q, k, v, out, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + window_size_left, + window_size_right, + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + USE_EXP2 + ) + softmax_lse = softmax_lse_ref + else: + if DEBUG: + print("Using Triton implementation") + + if use_decode: + if DEBUG: + print(f"Using Decode Triton implementation (cache_seqlens={seqused_k is not None}, k_new={k_new is not None}, v_new={v_new is not None}, kv_batch_idx={kv_batch_idx is not None})") + + # Use decode kernel for KV cache scenarios + # Note: seqused_k can serve as cache_seqlens in v3 + softmax_lse = attention_decode_forward_triton_impl( + q, + k, # k_cache in v2 terminology + v, # v_cache in v2 terminology + k_new, # New KV values to append to cache + v_new, # New KV values to append to cache + out, + metadata.sm_scale, + metadata.causal, + window_size_left, + window_size_right, + metadata.alibi_slopes, + metadata.layout, + seqused_k, # cache_seqlens + kv_batch_idx, # cache_batch_idx + page_table, # block_table for paged attention + q_descale, + k_descale, + v_descale, + ) + # Decode kernel returns only softmax_lse, not sd_mask + sd_mask_triton = None + else: + if DEBUG: + print("Using prefill Triton implementation") + # Use prefill kernel + softmax_lse_triton, sd_mask_triton = attention_prefill_forward_triton_impl( + q, k, v, out, + metadata.sm_scale, + metadata.alibi_slopes, + metadata.causal, + window_size_left, + window_size_right, + None, # block_table + metadata.layout, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, + metadata.max_seqlens_q, + metadata.max_seqlens_k, + metadata.dropout_p, + metadata.philox_seed, + metadata.philox_offset, + metadata.return_softmax, + USE_EXP2, + q_descale, + k_descale, + v_descale, + seqused_q, + seqused_k, + ) + softmax_lse = softmax_lse_triton + + if DEBUG: + print("interface_fa_v3.py::fwd outputs") + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + + # Return format compatible with v3 + # V3 returns (out, softmax_lse, *rest) where rest can be empty or contain additional outputs + return out, softmax_lse + + +def bwd( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + deterministic: bool, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Flash Attention v3 backward pass compatible interface for AMD Triton implementation. + + This function maps v3 parameters to the existing AMD Triton implementation. + """ + + if DEBUG: + print() + print("interface_fa_v3.py::bwd inputs") + print("dout:", dout, dout.shape) + print("q:", q, q.shape) + print("k:", k, k.shape) + print("v:", v, v.shape) + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("dq:", dq, dq.shape if dq is not None else None) + print("dk:", dk, dk.shape if dk is not None else None) + print("dv:", dv, dv.shape if dv is not None else None) + print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape if cu_seqlens_q is not None else None) + print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape if cu_seqlens_k is not None else None) + print("seqused_q:", seqused_q, seqused_q.shape if seqused_q is not None else None) + print("seqused_k:", seqused_k, seqused_k.shape if seqused_k is not None else None) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("softcap:", softcap) + print("deterministic:", deterministic) + print("sm_margin:", sm_margin) + + # Check for unsupported features in backward pass + + # Handle softcap + if softcap != 0.0: + raise NotImplementedError(f"Softcap is not yet supported in the AMD Triton backend backward pass (got softcap={softcap}, expected 0.0)") + + # Handle sm_margin + if sm_margin != 0: + raise NotImplementedError(f"sm_margin is not yet supported in the AMD Triton backend backward pass (got sm_margin={sm_margin}, expected 0)") + + # Initialize gradient tensors if not provided + dq = torch.zeros_like(q) if dq is None else dq.zero_() + dk = torch.zeros_like(k) if dk is None else dk.zero_() + dv = torch.zeros_like(v) if dv is None else dv.zero_() + + # Determine layout based on cu_seqlens + if cu_seqlens_q is not None and cu_seqlens_k is not None: + # Variable length sequence mode + layout = "thd" + batch = len(cu_seqlens_q) - 1 + _, nheads_q, _ = q.shape + else: + # Regular batch mode + layout = "bshd" + batch, _, nheads_q, _ = q.shape + max_seqlen_q = q.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlen_k = k.shape[1] if max_seqlen_k is None else max_seqlen_k + + # V3 backward doesn't have dropout or alibi slopes + dropout_p = 0.0 + philox_seed, philox_offset = None, None + alibi_slopes = None + + # For fp8, we would need descale factors, but v3 interface doesn't expose them + # So we'll pass None for now + descale_q = None + descale_k = None + descale_v = None + descale_o = None + descale_do = None + descale_dq = None + descale_dk = None + descale_dv = None + + # Call implementation + if USE_REF: + if DEBUG: + print("Using reference implementation") + delta_ref = attention_backward_pytorch_ref_impl( + dout, q, k, v, out, softmax_lse, + dq, dk, dv, + softmax_scale, + alibi_slopes, + causal, + window_size_left, + window_size_right, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + USE_EXP2, + ) + delta = delta_ref + else: + if DEBUG: + print("Using Triton implementation") + + if BWD_MODE == "split": + delta_triton = attention_prefill_backward_triton_split_impl( + dout, q, k, v, out, softmax_lse, + dq, dk, dv, + softmax_scale, + alibi_slopes, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + USE_EXP2, + descale_q, descale_k, descale_v, descale_o, + descale_do, descale_dq, descale_dk, descale_dv, + seqused_q, seqused_k, + ) + delta = delta_triton + elif BWD_MODE == "fused_atomics": + delta_triton = attention_prefill_backward_triton_fused_atomics_impl( + dout, q, k, v, out, softmax_lse, + dq, dk, dv, + softmax_scale, + alibi_slopes, + causal, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + descale_q, descale_k, descale_v, descale_o, + True, + ) + delta = delta_triton + elif BWD_MODE == "fused_no_atomics": + delta_triton = attention_prefill_backward_triton_split_fused_no_atomics_impl( + dout, q, k, v, out, softmax_lse, + dq, dk, dv, + softmax_scale, + alibi_slopes, + causal, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + USE_EXP2, + descale_q, descale_k, descale_v, descale_o, + descale_do, descale_dq, descale_dk, descale_dv, + seqused_q, seqused_k, + ) + delta = delta_triton + else: + raise ValueError(f"Unknown bwd mode {BWD_MODE}") + + if DEBUG: + print("interface_fa_v3.py::bwd outputs") + print("dq:", dq, dq.shape) + print("dk:", dk, dk.shape) + print("dv:", dv, dv.shape) + print("delta:", delta, delta.shape if delta is not None else None) + + # V3 expects (dq, dk, dv, softmax_d, *rest) + # delta is the softmax_d in this case + return dq, dk, dv, delta + + +def fwd_combine( + out_partial: torch.Tensor, + lse_partial: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """ + Combine partial outputs from split attention computation. + + This is used when num_splits > 1 to combine the partial results. + + Args: + out_partial: Partial output tensor from split computation + lse_partial: Partial log-sum-exp tensor + out: Optional output tensor to write to + out_dtype: Optional dtype for output + + Returns: + Combined output tensor + """ + raise NotImplementedError("fwd_combine is not yet implemented in the AMD Triton backend") + + +def get_scheduler_metadata( + batch_size: int, + max_seqlen_q: int, + max_seqlen_k: int, + num_heads_q: int, + num_heads_kv: int, + headdim: int, + headdim_v: int, + qkv_dtype: torch.dtype, + cache_seqlens: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + max_seqlen_k_new: int = 0, + causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + attention_chunk: int = 0, + has_softcap: bool = False, + num_splits: int = 0, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, +): + """ + Get scheduler metadata for optimized kernel selection. + + This function is used to precompute metadata for kernel scheduling in FA3. + The AMD Triton backend currently doesn't use scheduler metadata, so this + raises an error. + + Args: + Various attention parameters used for scheduling decisions + + Returns: + None - scheduler metadata is not used in AMD Triton backend + """ + raise NotImplementedError("get_scheduler_metadata is not supported in the AMD Triton backend yet.") \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py deleted file mode 100644 index 15fa8f2f07a..00000000000 --- a/flash_attn/flash_attn_triton_amd/test.py +++ /dev/null @@ -1,777 +0,0 @@ -import os -import glob -import shutil -import time -import torch -import pytest -import logging -import numpy as np -from pathlib import Path -import triton -import flash_attn -import flash_attn.flash_attn_triton_amd.fp8 - -from .utils import generate_bshd_kv_packed, generate_bshd_qkv_packed, generate_bshd_tensor, generate_varlen_kv_packed, generate_varlen_qkv_packed, input_helper, arch_supports_fp8, generate_varlen_tensor - -# debugging - -logging.basicConfig(level=logging.INFO, format='%(message)s', force=True) -logger = logging.getLogger(__name__) -DEBUG = False - -# defailt fp16 tolerance is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html -ATOL, RTOL = 1e-2, 1e-2 # old standard. maybe to lose. -# ATOL, RTOL = 1e-3, 1e-3 # catchs fa mismatch issues -# ATOL, RTOL = 1e-4, 1e-3 # to strict. there will be small diffs -# ATOL, RTOL = 1e-5, 1e-3 # # default fp16. there will be small diffs -# ATOL_fp8, RTOL_fp8 = 1e-1, 1e-1 # to strict for larger tensors in fp8 -ATOL_fp8, RTOL_fp8 = 2.5e-1, 2.5e-1 # fp8 -# ATOL_fp8, RTOL_fp8 = 2e-2, 2e-2 # fp8 -EQUAL_NAN = True - -def fp8_assert_close(tensor_a, tensor_b, atol=ATOL_fp8, rtol=RTOL_fp8, max_diff_percentage=0.5): - """Assert tensors are close with tolerance for small percentage of elements""" - # standard comparison - abs_diff = torch.abs(tensor_a - tensor_b) - rel_diff = abs_diff / torch.abs(tensor_b.clamp(min=1e-6)) - - # calculate elements that exceed tolerance - abs_check = abs_diff > atol - rel_check = rel_diff > rtol - failed_check = torch.logical_and(abs_check, rel_check) - - # calculate percentage of failed elements - failed_percentage = failed_check.sum().item() / failed_check.numel() * 100 - - # if percentage is small enough, test passes - if failed_percentage <= max_diff_percentage: - return True - - # Otherwise, provide diagnostic information - max_abs_idx = torch.argmax(abs_diff).item() - max_rel_idx = torch.argmax(rel_diff).item() - - flat_to_idx = lambda flat_idx, shape: np.unravel_index(flat_idx, shape) - - max_abs_pos = flat_to_idx(max_abs_idx, tensor_a.shape) - max_rel_pos = flat_to_idx(max_rel_idx, tensor_a.shape) - - max_abs_diff = abs_diff.flatten()[max_abs_idx].item() - max_rel_diff = rel_diff.flatten()[max_rel_idx].item() - - raise AssertionError( - f"Tensors not close enough! {failed_percentage:.6f}% elements exceed tolerance.\n" - f"Greatest absolute difference: {max_abs_diff} at index {max_abs_pos} (up to {atol} allowed)\n" - f"Greatest relative difference: {max_rel_diff} at index {max_rel_pos} (up to {rtol} allowed)" - ) - -@pytest.mark.parametrize( - "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", - [ - # seqlen q == k - (1, 1, 1, 1, 1, 1), - (1, 1, 1, 2, 2, 2), # small enough to debug - (1, 1, 1, 4, 4, 16), - (1, 2, 2, 4, 4, 16), - (2, 1, 1, 4, 4, 16), - (2, 2, 2, 4, 4, 16), - (1, 1, 1, 128, 128, 32), # only one block - (3, 3, 3, 128, 128, 64), - (1, 1, 1, 127, 127, 32), # only one block but with masking - # (1, 1, 1, 129, 129, 1), # two blocks with 2nd block small enough to debug # fails - (1, 2, 2, 129, 129, 32), # two blocks with 2nd block small enough to debug - (1, 1, 1, 350, 350, 32), # two blocks with 2nd block small enough to debug - (1, 1, 1, 350, 350, 68), # generic masking on q, k and head - (4, 1, 1, 512, 512, 128), # batch > 1 - (4, 2, 2, 512, 512, 128), - (4, 2, 2, 512, 512, 68), - (4, 2, 2, 500, 500, 68), - (2, 4, 4, 1024, 1024, 64), - (4, 8, 8, 2048, 2048, 128), - (2, 8, 8, 4096, 4096, 64), - (2, 4, 4, 8192, 8192, 32), - # seqlen q > k - (1, 1, 1, 4, 2, 16), - (1, 1, 1, 64, 32, 8), - (1, 1, 1, 128, 64, 16), - (1, 1, 1, 192, 128, 32), - (1, 2, 2, 1024, 512, 68), - (1, 4, 4, 729, 516, 68), - (2, 4, 4, 2753, 1528, 68), # a comprehensive seqlen_q > seqlen_k - # seqlen q < k - (1, 1, 1, 2, 4, 16), - (1, 2, 2, 2, 4, 16), - (1, 4, 1, 2, 4, 16), - (1, 4, 2, 2, 4, 16), - (2, 2, 2, 2, 128, 1), - (2, 3, 3, 2, 128, 16), - (1, 1, 1, 32, 64, 8), - (1, 1, 1, 128, 192, 32), - (4, 6, 6, 108, 256, 32), - (3, 2, 2, 256, 512, 16), - (2, 2, 2, 512, 1024, 68), - (1, 1, 1, 200, 413, 32), - (1, 1, 1, 782, 1546, 32), - # gqa/mqa # mismatch issue on varlen - (4, 8, 2, 500, 500, 68), - (4, 8, 2, 512, 512, 68), - (4, 8, 2, 512, 512, 128), - (4, 8, 2, 512, 1024, 68), - (4, 8, 2, 1024, 512, 64), - (4, 16, 4, 1528, 2753, 68), - # fa configs - (2, 4, 1, 113, 203, 64), - (2, 4, 2, 128, 217, 64), - (2, 6, 2, 113, 211, 128), - (2, 6, 2, 108, 256, 128), - (2, 6, 2, 256, 512, 64), - (2, 6, 2, 512, 256, 64), - (2, 6, 2, 1024, 1024, 32), - (2, 6, 2, 1023, 1024, 32), - (2, 6, 6, 1024, 1023, 32), - (2, 6, 6, 2048, 2048, 32), - ], -) -@pytest.mark.parametrize('causal', [False, True]) -@pytest.mark.parametrize('dropout_p', [0.0]) -@pytest.mark.parametrize('layout', ["bshd", "thd"]) -@pytest.mark.parametrize('packing', [None, "qkv"]) -@pytest.mark.parametrize('DEBUG_INPUT', [False]) -@pytest.mark.flaky(reruns=3, reason="Retry failures") -@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") -def test_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, packing, DEBUG_INPUT): - torch.manual_seed(20) - test_backward = True - device = "cuda" - window_size = (-1, -1) - softcap = 0.0 - alibi_slopes = None - deterministic = False - ref_dtype = torch.float32 - is_varlen = True if layout == "thd" else False - - # skip QKV packing tests for uneven sequence lengths and head sizes - if packing == 'qkv': - if N_CTX_Q != N_CTX_K: - pytest.skip("QKV packing requires N_CTX_Q == N_CTX_K") - if HQ != HK: - pytest.skip("QKV packing requires HQ == HK") - - # test apis - if packing == 'qkv': - # generate inputs - qkv, do, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout, packing=packing, device=device) - - # ---------------------------------------------------------------- - # --- FP8 --- - # ---------------------------------------------------------------- - qkv_fp8 = qkv.clone() - do_fp8= do.clone() - - if is_varlen: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_triton_amd.fp8.flash_attn_varlen_qkvpacked_fp8_func( - qkv_fp8, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_triton_amd.fp8.flash_attn_qkvpacked_fp8_func( - qkv_fp8, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # ---------------------------------------------------------------- - # --- Reference --- - # ---------------------------------------------------------------- - # reference forward pass - qkv_ref = qkv.clone() - do_ref= do.clone() - - if is_varlen: - out_ref, lse_ref, S_dmask_ref = flash_attn.flash_attn_varlen_qkvpacked_func( - qkv_ref, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out_ref, lse_ref, S_dmask_ref = flash_attn.flash_attn_qkvpacked_func( - qkv_ref, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # ---------------------------------------------------------------- - # --- Compare --- - # ---------------------------------------------------------------- - # compare forward - if DEBUG: - print() - print(f"Compare fp8 against ref with dtype {ref_dtype}") - - if DEBUG: - print("out_ref:", out_ref, out_ref.shape) - print("out_fp8:", out_fp8, out_fp8.shape) - fp8_assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - - - if DEBUG: - print("lse_ref:", lse_ref, lse_ref.shape) - print("lse_fp8:", lse_fp8, lse_fp8.shape) - fp8_assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - - if dropout_p > 0.0: - if DEBUG: - print("S_dmask_ref:", S_dmask_ref, S_dmask_ref.shape) - print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape) - fp8_assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - if not test_backward: - return - - # fp8 backward pass - dqkv_fp8, = torch.autograd.grad(out_fp8, (qkv_fp8), do_fp8) - - # ref backward pass - dqkv_ref, = torch.autograd.grad(out_ref, (qkv_ref), do_ref) - - # compare backward gradients - if DEBUG: - print("dqkv_ref:", dqkv_ref, dqkv_ref.shape) - print("dqkv_fp8:", dqkv_fp8, dqkv_fp8.shape) - fp8_assert_close(dqkv_ref, dqkv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - elif packing is None: - # generate inputs - q, k, v, do, metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout, device=device) - - # ---------------------------------------------------------------- - # --- FP8 --- - # ---------------------------------------------------------------- - if DEBUG: - print() - print(f"Compute Fp8 Forward") - q_fp8 = q.clone() - k_fp8 = k.clone() - v_fp8 = v.clone() - do_fp8= do.clone() - - if is_varlen: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_triton_amd.fp8.flash_attn_varlen_fp8_func( - q_fp8, - k_fp8, - v_fp8, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out_fp8, lse_fp8, S_dmask_fp8 = flash_attn.flash_attn_triton_amd.fp8.flash_attn_fp8_func( - q_fp8, - k_fp8, - v_fp8, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # ---------------------------------------------------------------- - # --- Reference --- - # ---------------------------------------------------------------- - if DEBUG: - print() - print(f"Compute Reference Forward") - # reference forward pass - q_ref = q.clone() - k_ref = k.clone() - v_ref = v.clone() - do_ref = do.clone() - - if is_varlen: - out_ref, lse_ref, S_dmask_ref = flash_attn.flash_attn_varlen_func( - q_ref, - k_ref, - v_ref, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out_ref, lse_ref, S_dmask_ref = flash_attn.flash_attn_func( - q_ref, - k_ref, - v_ref, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # ---------------------------------------------------------------- - # --- Compare --- - # ---------------------------------------------------------------- - # compare forward - if DEBUG: - print() - print(f"Compare fp8 against ref with dtype {ref_dtype}") - - if DEBUG: - print("out_ref:", out_ref, out_ref.shape) - print("out_fp8:", out_fp8, out_fp8.shape) - # torch.testing.assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - fp8_assert_close(out_ref, out_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - - - if DEBUG: - print("lse_ref:", lse_ref, lse_ref.shape) - print("lse_fp8:", lse_fp8, lse_fp8.shape) - # torch.testing.assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - fp8_assert_close(lse_ref, lse_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - - if dropout_p > 0.0: - if DEBUG: - print("S_dmask_ref:", S_dmask_ref, S_dmask_ref.shape) - print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape) - # torch.testing.assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - fp8_assert_close(S_dmask_ref, S_dmask_fp8, atol=ATOL_fp8, rtol=RTOL_fp8) - - if not test_backward: - return - - if DEBUG: - print() - print(f"Compute Fp8 Backward") - # fp8 backward pass - dq_fp8, dk_fp8, dv_fp8 = torch.autograd.grad(out_fp8, (q_fp8, k_fp8, v_fp8), do_fp8) - - if DEBUG: - print() - print(f"Compute Reference Backward") - # ref backward pass - dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), do_ref) - - # compare backward gradients - if DEBUG: - print("dv_ref:", dv_ref, dv_ref.shape) - print("dv_fp8:", dv_fp8, dv_fp8.shape) - # torch.testing.assert_close(dv_ref, dv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) - fp8_assert_close(dv_ref, dv_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - - if DEBUG: - print("dk_ref:", dk_ref, dk_ref.shape) - print("dk_fp8:", dk_fp8, dk_fp8.shape) - # torch.testing.assert_close(dk_ref, dk_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) - fp8_assert_close(dk_ref, dk_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - - if DEBUG: - print("dq_ref:", dq_ref, dq_ref.shape) - print("dq_fp8:", dq_fp8, dq_fp8.shape) - # torch.testing.assert_close(dq_ref, dq_fp8, atol=ATOL_fp8, rtol=RTOL_fp8, equal_nan=EQUAL_NAN) - fp8_assert_close(dq_ref, dq_fp8, atol=ATOL_fp8, rtol=RTOL_fp8 ) - -@pytest.mark.parametrize( - "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", - [ - (2, 4, 4, 512, 512, 128), - ], -) -@pytest.mark.parametrize('causal', [False, True]) -@pytest.mark.parametrize('dropout_p', [0.0, 0.1]) -@pytest.mark.parametrize('layout', ['bshd']) -@pytest.mark.parametrize('packing', [None]) -@pytest.mark.parametrize('test_backward', [False, True]) -@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") -@pytest.mark.skip("Breaks on CI but works locally") -def test_ir(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, packing, test_backward): # Don't run this test in parallel. It clears the cache so it doesnot work properly if run in parallel. - torch.manual_seed(20) - device = "cuda" - window_size = (-1, -1) - softcap = 0.0 - alibi_slopes = None - deterministic = False - ref_dtype = torch.float32 - is_varlen = True if layout == "thd" else False - - # remove cache - cache_path = Path(os.path.expanduser("~/.triton/cache")) - if cache_path.exists(): - shutil.rmtree(cache_path) - os.makedirs(cache_path) - - # inputs - q, k, v, do, metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, ref_dtype, layout=layout, packing=packing, device=device) - - if packing == None: - # fp8 forward pass - if is_varlen: - out, lse, S_dmask = flash_attn.flash_attn_triton_amd.fp8.flash_attn_varlen_fp8_func( - q, - k, - v, - metadata.cu_seqlens_q, - metadata.cu_seqlens_k, - metadata.max_seqlens_q, - metadata.max_seqlens_k, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out, lse, S_dmask = flash_attn.flash_attn_triton_amd.fp8.flash_attn_fp8_func( - q, - k, - v, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # fp8 backward pass - if test_backward: - dq, dk, dv = torch.autograd.grad(out, (q, k, v), do) - elif packing == "qkv": - # qkv packing path - # pack input tensors (use dim=1 for varlen, else dim=2) - if is_varlen: - qkv = torch.stack([q, k, v], dim=1) - else: - qkv = torch.stack([q, k, v], dim=2) - - # fp8 forward pass for qkv-packed input - if is_varlen: - out, lse, S_dmask = flash_attn.flash_attn_triton_amd.fp8.flash_attn_varlen_qkvpacked_fp8_func( - qkv, - metadata.cu_seqlens_q, - metadata.max_seqlens_q, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - else: - out, lse, S_dmask = flash_attn.flash_attn_triton_amd.fp8.flash_attn_qkvpacked_fp8_func( - qkv, - dropout_p, - causal=causal, - window_size=window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - deterministic=deterministic, - return_attn_probs=True, - ) - - # fp8 backward pass for qkv-packed input - if test_backward: - dqkv, = torch.autograd.grad(out, (qkv,), do) - else: - raise ValueError(f"unknown packing type {packing}") - - # search for .ttir files - max_retries = 5 - retry_delay = 0.5 - ttir_files = [] - logging.info(f"Checking for .ttir files in {cache_path}...") - for attempt in range(max_retries): - # search for .ttir files recursively within the cache path - ttir_files = glob.glob(str(cache_path) + "/**/*.ttir", recursive=True) - - if ttir_files: - # Files found, log success and exit the loop - logging.info(f"Found {len(ttir_files)} .ttir files on attempt {attempt + 1}.") - break - else: - # Files not found yet - if attempt < max_retries - 1: - # If not the last attempt, wait and log before retrying - logging.warning( - f"No .ttir files found on attempt {attempt + 1}. " - f"Retrying in {retry_delay}s..." - ) - time.sleep(retry_delay) - else: - pytest.fail( - f"FATAL: No .ttir files found in cache {cache_path} " - f"after {max_retries} attempts." - ) - - # check if there is fp8 - ttir_files_fp8_found_status = {} - fp8_types = ['f8E4M3', 'f8E5M2'] - for ttir_file in ttir_files: - base_name = os.path.basename(ttir_file) - with open(ttir_file, 'r') as f: - content = f.read() - - # check content for fp8 - fp8_found = False - for f8_type in fp8_types: - if f8_type in content: - fp8_found = True - ttir_files_fp8_found_status[base_name] = fp8_found - - for file, fp8_found in ttir_files_fp8_found_status.items(): - assert fp8_found, f"{fp8_types} not found in {file}" - - -def clear_compile_cache(): - """Clear torch compile caches to prevent graph merging""" - if hasattr(torch._dynamo, 'reset'): - torch._dynamo.reset() - torch.cuda.synchronize() - - -@pytest.mark.parametrize( - "BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", - [ - # (4, 8, 8, 128, 128, 64), # small test - (32, 32, 32, 531, 531, 128), # original test - # (16, 48, 16, 256, 512, 64), # MQA test (HQ > HK) - ], -) -@pytest.mark.skip() # this works or not depending on the torch and triton version compatibility. Keep the test but use it in docker containers where things are in sync -def test_torch_compile(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD): - print(f"\n\nTesting with BATCH={BATCH}, HQ={HQ}, HK={HK}, N_CTX_Q={N_CTX_Q}, N_CTX_K={N_CTX_K}, D_HEAD={D_HEAD}") - - try: - # Test 1: flash_attn_func - print("\n1. Testing flash_attn_func...") - clear_compile_cache() - - q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD) - k = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD) - v = generate_bshd_tensor(BATCH, N_CTX_K, HK, D_HEAD) - - flash_attn_func_compiled = torch.compile(flash_attn.flash_attn_func) - o = flash_attn_func_compiled(q, k, v, causal=True) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_func SUCCESS") - - # cleanup - del q, k, v, o - torch.cuda.empty_cache() - - - # Test 2: flash_attn_varlen_func - print("\n2. Testing flash_attn_varlen_func...") - clear_compile_cache() - - q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(BATCH * N_CTX_Q, HQ, D_HEAD, BATCH) - k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(BATCH * N_CTX_K, HK, D_HEAD, BATCH) - v, _, _ = generate_varlen_tensor(BATCH * N_CTX_K, HK, D_HEAD, BATCH) - - flash_attn_varlen_func_compiled = torch.compile(flash_attn.flash_attn_varlen_func) - o = flash_attn_varlen_func_compiled( - q, k, v, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, causal=True - ) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_varlen_func SUCCESS") - - # cleanup - del q, k, v, o, cu_seqlens_q, cu_seqlens_k - torch.cuda.empty_cache() - - - # Test 3: flash_attn_qkvpacked_func - print("\n3. Testing flash_attn_qkvpacked_func...") - clear_compile_cache() - - qkv = generate_bshd_qkv_packed(BATCH, N_CTX_Q, HQ, D_HEAD) - - flash_attn_qkvpacked_func_compiled = torch.compile(flash_attn.flash_attn_qkvpacked_func) - o = flash_attn_qkvpacked_func_compiled(qkv, causal=True) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_qkvpacked_func SUCCESS") - - # cleanup - del qkv, o - torch.cuda.empty_cache() - - - # Test 4: flash_attn_varlen_qkvpacked_func - print("\n4. Testing flash_attn_varlen_qkvpacked_func...") - clear_compile_cache() - - total_q = BATCH * N_CTX_Q - qkv, cu_seqlens, max_seqlen = generate_varlen_qkv_packed(total_q, HQ, D_HEAD, BATCH) - - flash_attn_varlen_qkvpacked_func_compiled = torch.compile(flash_attn.flash_attn_varlen_qkvpacked_func) - o = flash_attn_varlen_qkvpacked_func_compiled( - qkv, cu_seqlens, max_seqlen, causal=True - ) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_varlen_qkvpacked_func SUCCESS") - - # cleanup - del qkv, o, cu_seqlens - torch.cuda.empty_cache() - - - # Test 5: flash_attn_kvpacked_func - print("\n5. Testing flash_attn_kvpacked_func...") - clear_compile_cache() - - q = generate_bshd_tensor(BATCH, N_CTX_Q, HQ, D_HEAD) - kv = generate_bshd_kv_packed(BATCH, N_CTX_K, HK, D_HEAD) - - flash_attn_kvpacked_func_compiled = torch.compile(flash_attn.flash_attn_kvpacked_func) - o = flash_attn_kvpacked_func_compiled(q, kv, causal=True) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_kvpacked_func SUCCESS") - - # cleanup - del q, kv, o - torch.cuda.empty_cache() - - - # Test 6: flash_attn_varlen_kvpacked_func - print("\n6. Testing flash_attn_varlen_kvpacked_func...") - clear_compile_cache() - - q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(BATCH * N_CTX_Q, HQ, D_HEAD, BATCH) - kv, cu_seqlens_k, max_seqlen_k = generate_varlen_kv_packed(BATCH * N_CTX_K, HK, D_HEAD, BATCH) - - flash_attn_varlen_kvpacked_func_compiled = torch.compile(flash_attn.flash_attn_varlen_kvpacked_func) - o = flash_attn_varlen_kvpacked_func_compiled( - q, kv, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, causal=True - ) - print(f"Output shape: {o.shape}, dtype: {o.dtype}") - o.sum().backward() - print("✓ flash_attn_varlen_kvpacked_func SUCCESS") - - # cleanup - del q, kv, o, cu_seqlens_q, cu_seqlens_k - torch.cuda.empty_cache() - - - # Test 7: flash_attn_with_kvcache - print("\n7. Testing flash_attn_with_kvcache...") - clear_compile_cache() - - # setup cache dimensions - CACHE_SEQLEN = 1024 # max cache size - NEW_SEQLEN = 1 # for incremental decoding, usually 1 token at a time - - # create query for new tokens - q = generate_bshd_tensor(BATCH, NEW_SEQLEN, HQ, D_HEAD, dtype=torch.float16) - - # create kv cache using generators - k_cache = generate_bshd_tensor(BATCH, CACHE_SEQLEN, HK, D_HEAD, dtype=torch.float16) - v_cache = generate_bshd_tensor(BATCH, CACHE_SEQLEN, HK, D_HEAD, dtype=torch.float16) - - # cache sequence lengths - cache_seqlens = torch.full((BATCH,), 100, dtype=torch.int32, device='cuda') - - # new k,v to append to cache (optional) - k_new = generate_bshd_tensor(BATCH, NEW_SEQLEN, HK, D_HEAD, dtype=torch.float16) - v_new = generate_bshd_tensor(BATCH, NEW_SEQLEN, HK, D_HEAD, dtype=torch.float16) - - # Note: flash_attn_with_kvcache doesn't support backward pass - flash_attn_with_kvcache_compiled = torch.compile(flash_attn.flash_attn_with_kvcache) - - # Test with new k,v (append to cache and do attention) - with torch.no_grad(): - o = flash_attn_with_kvcache_compiled( - q, k_cache, v_cache, - k=k_new, v=v_new, - cache_seqlens=cache_seqlens, - causal=True - ) - print(f"Output shape (with new kv): {o.shape}, dtype: {o.dtype}") - - print("✓ flash_attn_with_kvcache SUCCESS") - - print("\n\n✅ ALL TESTS PASSED! ✅") - - except Exception as e: - print(f"\n❌ ERROR: {str(e)}") - # ensure we sync even on error to get proper error message - torch.cuda.synchronize() - raise e - finally: - # final cleanup - torch.cuda.empty_cache() - clear_compile_cache() - -# log env -if os.environ.get('PYTEST_XDIST_WORKER') in (None, 'gw0'): - logger.info("\n" + "="*80) - logger.info("ENVIRONMENT INFORMATION") - logger.info("="*80) - - # triton - logger.info(f" Triton Version: {triton.__version__}") - - # torch - logger.info(f" PyTorch Version: {torch.__version__}") - - # flash attention - logger.info(f" Flash Attention Version: {flash_attn.__version__}") - - logger.info("="*80 + "\n") diff --git a/flash_attn/flash_attn_triton_amd/train.py b/flash_attn/flash_attn_triton_amd/train.py deleted file mode 100644 index 8f39f627bfb..00000000000 --- a/flash_attn/flash_attn_triton_amd/train.py +++ /dev/null @@ -1,404 +0,0 @@ -import torch -import torch.nn as nn -import torch.optim as optim -from torch.utils.data import DataLoader, Dataset, random_split -import numpy as np -import pandas as pd -from tqdm import tqdm -import matplotlib.pyplot as plt -from datasets import load_dataset -import flash_attn -import flash_attn.flash_attn_triton_amd.fp8 - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -print(f"using device: {device}") - -# ------------------------------- -# Model -# ------------------------------- -class FlashAttention(nn.Module): - def __init__(self, dim, num_heads=8, causal=True, dropout=0.1, qkv_bias=True, use_fp8=False): - super().__init__() - self.use_fp8 = use_fp8 - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - self.causal = causal - self.dropout_p = dropout - - # qkv and output projections - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim) - - def forward(self, x): - b, n, c = x.shape - # project to qkv - qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim).permute(2, 0, 1, 3, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - - # reshape for flash attention function - qkv_packed = torch.stack([q, k, v], dim=2).reshape(b, n, 3, self.num_heads, self.head_dim) - - # use the appropriate flash attention function - if self.use_fp8: - context = flash_attn.flash_attn_triton_amd.fp8.flash_attn_qkvpacked_fp8_func( - qkv_packed, - dropout_p=self.dropout_p, - causal=self.causal - ) - else: - context = flash_attn.flash_attn_qkvpacked_func( - qkv_packed, - dropout_p=self.dropout_p, - causal=self.causal - ) - - # convert back to original shape and type - context = context.reshape(b, n, c) - - # output projection - x = self.proj(context) - - return x - -class TransformerBlock(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4.0, causal=True, dropout=0.1, use_fp8=False): - super().__init__() - self.norm1 = nn.LayerNorm(dim) - self.attn = FlashAttention(dim, num_heads=num_heads, causal=causal, dropout=dropout, use_fp8=use_fp8) - - self.norm2 = nn.LayerNorm(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = nn.Sequential( - nn.Linear(dim, mlp_hidden_dim), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(mlp_hidden_dim, dim), - nn.Dropout(dropout) - ) - - def forward(self, x): - x = x + self.attn(self.norm1(x)) - x = x + self.mlp(self.norm2(x)) - return x - -class FlashLM(nn.Module): - def __init__( - self, - vocab_size, - dim=256, - depth=6, - num_heads=8, - mlp_ratio=4.0, - causal=True, - dropout=0.1, - max_seq_len=256, - use_fp8=False - ): - super().__init__() - - # embedding layers - self.token_embedding = nn.Embedding(vocab_size, dim) - self.position_embedding = nn.Parameter(torch.zeros(1, max_seq_len, dim)) - self.dropout = nn.Dropout(dropout) - - # transformer blocks - self.blocks = nn.ModuleList([ - TransformerBlock(dim, num_heads, mlp_ratio, causal=causal, dropout=dropout, use_fp8=use_fp8) - for _ in range(depth) - ]) - - # lm head: project back to vocabulary dimension for each token - self.norm = nn.LayerNorm(dim) - self.lm_head = nn.Linear(dim, vocab_size) - - def forward(self, x): - b, n = x.shape - - # token + positional embedding - x = self.token_embedding(x) - x = x + self.position_embedding[:, :n, :] - x = self.dropout(x) - - # transformer blocks - for block in self.blocks: - x = block(x) - - # language modeling head - x = self.norm(x) - logits = self.lm_head(x) # shape: (b, n, vocab_size) - return logits - -# ------------------------------- -# Data -# ------------------------------- -class TextDataset(Dataset): - def __init__(self, sequences, max_len=None): - self.sequences = sequences - self.max_len = max_len - - def __len__(self): - return len(self.sequences) - - def __getitem__(self, idx): - seq = self.sequences[idx] - # input: all tokens except the last, target: all tokens except the first - return (torch.tensor(seq[:-1], dtype=torch.long), - torch.tensor(seq[1:], dtype=torch.long)) - -class VarLenTextDataset(Dataset): - def __init__(self, sequences, max_len=256): - self.sequences = sequences - self.max_len = max_len - - def __len__(self): - return len(self.sequences) - - def __getitem__(self, idx): - seq = self.sequences[idx] - # Ensure the sequence doesn't exceed max_len+1 - seq = seq[:self.max_len+1] - # input: all tokens except the last, target: all tokens except the first - return (torch.tensor(seq[:-1], dtype=torch.long), - torch.tensor(seq[1:], dtype=torch.long)) - -def prepare_dataset(batch_size, is_varlen=False, min_len=10, max_len=256, ratio_shorter=0.7): - # load the WikiText-2 - dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") - - # build vocabulary - corpus = " ".join([line for line in dataset["text"] if line.strip() != ""]) # join non-empty lines into a single corpus string - tokens = corpus.split() - vocab = sorted(set(tokens)) - word2idx = {word: idx for idx, word in enumerate(vocab)} - token_ids = [word2idx[word] for word in tokens] - - num_workers = 2 - if is_varlen: - # VARIABLE LENGTH: create sequences of different lengths - sequences = [] - for i in range(0, len(token_ids) - max_len, max_len // 2): # overlap to get more sequences - # Decide target length for this sequence - if np.random.random() < ratio_shorter: - # Shorter sequence - target_len = np.random.randint(min_len + 1, max_len + 1) - else: - # Full length sequence - target_len = max_len + 1 - - # Extract sequence up to target length or whatever's available - seq_end = min(i + target_len, len(token_ids)) - seq = token_ids[i:seq_end] - - # Only keep sequences that are long enough - if len(seq) > min_len + 1: # +1 because we need both input and target - sequences.append(seq) - - print(f"Created {len(sequences)} variable-length sequences") - - # Get some statistics - lens = [len(seq) for seq in sequences] - print(f"Sequence length stats: min={min(lens)}, max={max(lens)}, mean={np.mean(lens):.1f}") - - # split dataset - num_samples = len(sequences) - num_train = int(0.8 * num_samples) - num_val = num_samples - num_train - - # Use appropriate dataset class based on whether we need variable length - dataset_class = VarLenTextDataset - train_sequences = sequences[:num_train] - val_sequences = sequences[num_train:] - - train_dataset = dataset_class(train_sequences, max_len) - val_dataset = dataset_class(val_sequences, max_len) - - - # collate function - def collate_fn(batch): - """ - Collate function that creates a flat representation for variable length flash attention. - """ - # Separate inputs and targets - inputs, targets = zip(*batch) - - # Get sequence lengths - seq_lens = torch.tensor([len(x) for x in inputs], dtype=torch.int32) - - # Concatenate inputs and targets into single tensors - flat_inputs = torch.cat(inputs) - flat_targets = torch.cat(targets) - - # Create cumulative sequence lengths tensor - cu_seqlens = torch.zeros(len(seq_lens) + 1, dtype=torch.int32) - cu_seqlens[1:] = torch.cumsum(seq_lens, dim=0) - - # Calculate max sequence length for this batch - max_seqlen = seq_lens.max().item() - - return flat_inputs, flat_targets, seq_lens, cu_seqlens, max_seqlen - - # data loaders - train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn) - val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn) - else: - # FIXED LENGTH: create sequences of length max_len+1 - sequences = [] - for i in range(0, len(token_ids) - max_len, max_len): - seq = token_ids[i : i + max_len + 1] - if len(seq) == max_len + 1: - sequences.append(seq) - - # split dataset - num_samples = len(sequences) - num_train = int(0.8 * num_samples) - num_val = num_samples - num_train - train_dataset, val_dataset = random_split(TextDataset(sequences), [num_train, num_val]) - - # data loaders - train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) - val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) - - vocab_size = len(vocab) - print(f"vocab size: {vocab_size}, train samples: {len(train_dataset)}, validation samples: {len(val_dataset)}") - return train_dataloader, val_dataloader, vocab_size - -# ------------------------------- -# Training -# ------------------------------- -def train_lm(model, train_dataloader, val_dataloader, optimizer, criterion, num_epochs): - train_losses = [] - val_losses = [] - for epoch in range(num_epochs): - # Training phase - model.train() - epoch_train_loss = 0.0 - for inputs, targets in tqdm(train_dataloader, desc=f"epoch {epoch+1}/{num_epochs} [train]"): - inputs, targets = inputs.to(device), targets.to(device) - - optimizer.zero_grad() - logits = model(inputs) - loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) - loss.backward() - optimizer.step() - - epoch_train_loss += loss.item() - - epoch_train_loss /= len(train_dataloader) - train_losses.append(epoch_train_loss) - print(f"epoch {epoch+1}/{num_epochs} - train loss: {epoch_train_loss:.4f}") - - # Validation phase - model.eval() - epoch_val_loss = 0.0 - with torch.no_grad(): - for inputs, targets in tqdm(val_dataloader, desc=f"epoch {epoch+1}/{num_epochs} [validation]"): - inputs, targets = inputs.to(device), targets.to(device) - logits = model(inputs) - loss = criterion(logits.view(-1, logits.size(-1)), targets.view(-1)) - epoch_val_loss += loss.item() - epoch_val_loss /= len(val_dataloader) - val_losses.append(epoch_val_loss) - print(f"epoch {epoch+1}/{num_epochs} - validation loss: {epoch_val_loss:.4f}") - - return train_losses, val_losses - -# ------------------------------- -# Main -# ------------------------------- -def main(): - # hyperparameters - batch_size = 16 - num_epochs = 20 - learning_rate = 3e-4 - max_len = 128 # total length including both input and target tokens - is_varlen = False - causal=True - dropout=0.1 - - # prep data - print("Preparing Dataset") - train_dataloader, val_dataloader, vocab_size = prepare_dataset(batch_size, max_len=max_len, is_varlen=is_varlen) - - # create language models - print("Creating Models") - model_normal = FlashLM( - vocab_size=vocab_size, - dim=256, - depth=3, - num_heads=8, - causal=causal, - dropout=dropout, - max_seq_len=max_len, - ).to(device) - - model_fp8 = FlashLM( - vocab_size=vocab_size, - dim=256, - depth=3, - num_heads=8, - causal=causal, - dropout=dropout, - max_seq_len=max_len, - use_fp8=True - ).to(device) - - # Train Normal model - print("Starting training for Normal model...") - optimizer_normal = optim.AdamW(model_normal.parameters(), lr=learning_rate) - criterion = nn.CrossEntropyLoss() - normal_train_losses, normal_val_losses = train_lm( - model_normal, train_dataloader, val_dataloader, optimizer_normal, criterion, num_epochs - ) - torch.save(model_normal.state_dict(), 'flash_lm_normal.pth') - print("Normal model training complete and saved.") - - # Train FP8 model - print("Starting training for FP8 model...") - optimizer_fp8 = optim.AdamW(model_fp8.parameters(), lr=learning_rate) - fp8_train_losses, fp8_val_losses = train_lm( - model_fp8, train_dataloader, val_dataloader, optimizer_fp8, criterion, num_epochs - ) - torch.save(model_fp8.state_dict(), 'flash_lm_fp8.pth') - print("FP8 model training complete and saved.") - - # save losses to csv - epochs = range(1, num_epochs+1) - loss_data = { - "Epoch": epochs, - "Normal_Training_Loss": normal_train_losses, - "Normal_Validation_Loss": normal_val_losses, - "FP8_Training_Loss": fp8_train_losses, - "FP8_Validation_Loss": fp8_val_losses, - } - df_losses = pd.DataFrame(loss_data) - df_losses.to_csv("losses.csv", index=False) - print("Loss data saved to losses.csv") - - # plot Training Loss - plt.figure(figsize=(10, 6)) - plt.plot(epochs, normal_train_losses, label="Normal Training Loss", marker='o') - plt.plot(epochs, fp8_train_losses, label="FP8 Training Loss", marker='x') - plt.xlabel("Epoch") - plt.ylabel("Training Loss") - plt.title("Training Loss Comparison: Normal vs FP8 Flash Attention") - plt.legend() - plt.grid(True) - plt.savefig("training_loss.png") # Saves the training loss plot to disk - plt.show() - - # Plot Validation Loss - plt.figure(figsize=(10, 6)) - plt.plot(epochs, normal_val_losses, label="Normal Validation Loss", marker='o') - plt.plot(epochs, fp8_val_losses, label="FP8 Validation Loss", marker='x') - plt.xlabel("Epoch") - plt.ylabel("Validation Loss") - plt.title("Validation Loss Comparison: Normal vs FP8 Flash Attention") - plt.legend() - plt.grid(True) - plt.savefig("validation_loss.png") # Saves the validation loss plot to disk - plt.show() - - -if __name__ == "__main__": - main() diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 96d4f662567..44502785a35 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -139,11 +139,11 @@ def check_args(self, q, k, v, o): assert q.dim() == 4 assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 assert self.cu_seqlens_q is None and self.cu_seqlens_k is None - assert k.shape == v.shape - assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] # and q.shape[-1] == v.shape[-1] # TODO: Change assert if we support qkl f8 and v f16 assert q.dtype == k.dtype and q.dtype == v.dtype - assert o.shape == q.shape + assert o.shape[:-1] == q.shape[:-1] and o.shape[-1] == v.shape[-1] assert (nheads_q % nheads_k) == 0 assert self.layout is not None assert self.layout == 'thd' or not self.varlen @@ -1064,91 +1064,6 @@ def write_dropout_mask(x, tensor_name = "tensor"): else: writer.writerows(dropout_mask) -# ------------------------------- -# Autotune -# ------------------------------- -def get_fwd_prefill_cdna_autotune_configs(): - return [ - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=4), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] - - -def get_fwd_prefill_rdna_autotune_configs(): - return [ - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - # Fall-back config. - triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, - num_warps=2), - ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'IS_VARLEN', 'HQ', 'HK'] - - -def get_fwd_prefill_autotune_configs(): - if AUTOTUNE: - if is_rdna(): - return get_fwd_prefill_rdna_autotune_configs() - elif is_cdna(): - return get_fwd_prefill_cdna_autotune_configs() - else: - raise ValueError("Unknown Device Type") - else: - arch = get_arch() - if arch == "gfx950": - default_config = triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ) - elif arch == "gfx942" and False: # Disabled due shared mem oom in CI when using triton==3.3.0 when using top of tree everything seems fine. - default_config = triton.Config( - {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ) - else: - default_config = triton.Config( - {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, - num_stages=1, - num_warps=4, - ) - - return [ - default_config - ], [ - "IS_CAUSAL", - "dropout_p", - "MAX_SEQLENS_Q", - "MAX_SEQLENS_K", - "ACTUAL_BLOCK_DMODEL", - "IS_VARLEN", - "HQ", - "HK", - ] - # ------------------------------- # Runtime info # ------------------------------- diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py old mode 100644 new mode 100755 index 0e93f234aa3..644e86e4b13 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -2,16 +2,24 @@ from typing import Optional, Union +import os import torch import torch.nn as nn -# isort: off -# We need to import the CUDA kernels after importing torch -import flash_attn_3._C # Registers operators with PyTorch -# isort: on +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" +if USE_TRITON_ROCM: + import sys + sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from flash_attn.flash_attn_triton_amd import interface_fa_v3 as flash_attn_3_gpu +else: + # isort: off + # We need to import the CUDA kernels after importing torch + import flash_attn_3._C # Registers operators with PyTorch -flash_attn_3_cuda = torch.ops.flash_attn_3 + # isort: on + + flash_attn_3_gpu = torch.ops.flash_attn_3 def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x @@ -62,7 +70,7 @@ def _flash_attn_forward( ] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] seqlens_rotary = maybe_contiguous(seqlens_rotary) - out, softmax_lse, *rest = flash_attn_3_cuda.fwd( + out, softmax_lse, *rest = flash_attn_3_gpu.fwd( q, k, v, @@ -126,7 +134,7 @@ def _flash_attn_backward( ): # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - dq, dk, dv, softmax_d, *rest = flash_attn_3_cuda.bwd( + dq, dk, dv, softmax_d, *rest = flash_attn_3_gpu.bwd( dout, q, k, @@ -625,7 +633,7 @@ def flash_attn_varlen_func( def flash_attn_combine(out_partial, lse_partial, out=None, out_dtype=None): - return flash_attn_3_cuda.fwd_combine(out_partial, lse_partial, out, out_dtype) + return flash_attn_3_gpu.fwd_combine(out_partial, lse_partial, out, out_dtype) def flash_attn_with_kvcache( @@ -812,7 +820,7 @@ def get_scheduler_metadata( cache_seqlens = maybe_contiguous(cache_seqlens) if headdim_v is None: headdim_v = headdim - scheduler_metadata = flash_attn_3_cuda.get_scheduler_metadata( + scheduler_metadata = flash_attn_3_gpu.get_scheduler_metadata( batch_size, max_seqlen_q, max_seqlen_k, num_heads_q, num_heads_kv, headdim, headdim_v, qkv_dtype, cache_seqlens, diff --git a/hopper/setup.py b/hopper/setup.py old mode 100644 new mode 100755 index 10894252db0..437ddc8c9aa --- a/hopper/setup.py +++ b/hopper/setup.py @@ -43,6 +43,10 @@ SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" +# ROCm specific settings +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" +if USE_TRITON_ROCM: + SKIP_CUDA_BUILD = True DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" @@ -383,10 +387,10 @@ def nvcc_threads_args(): cmdclass = {} ext_modules = [] - # We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp # files included in the source distribution, in case the user compiles from source. -subprocess.run(["git", "submodule", "update", "--init", "../csrc/cutlass"]) +if not USE_TRITON_ROCM: + subprocess.run(["git", "submodule", "update", "--init", "../csrc/cutlass"]) if not SKIP_CUDA_BUILD: print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) diff --git a/hopper/test_flash_attn_triton_amd.py b/hopper/test_flash_attn_triton_amd.py new file mode 100755 index 00000000000..738ec1d8c13 --- /dev/null +++ b/hopper/test_flash_attn_triton_amd.py @@ -0,0 +1,1174 @@ +import os +import math +import itertools + +import pytest +import torch +import torch.nn.functional as F +from torch._C import parse_schema + +from einops import rearrange, repeat +try: + from flash_attn.layers.rotary import apply_rotary_emb +except ImportError: + apply_rotary_emb = None + +from padding import pad_input, unpad_input +from test_util import ( + attention_ref, + generate_qkv, + generate_random_padding_mask, +) + +from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_combine +from flash_attn_interface import flash_attn_with_kvcache, get_scheduler_metadata + + +DISABLE_BACKWARD = os.getenv("FLASH_ATTENTION_DISABLE_BACKWARD", "FALSE") == "TRUE" +DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "TRUE") == "TRUE" +DISABLE_PAGEDKV = os.getenv("FLASH_ATTENTION_DISABLE_PAGEDKV", "FALSE") == "TRUE" +DISABLE_APPENDKV = os.getenv("FLASH_ATTENTION_DISABLE_APPENDKV", "TRUE") == "TRUE" +DISABLE_LOCAL = os.getenv("FLASH_ATTENTION_DISABLE_LOCAL", "TRUE") == "TRUE" +DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "TRUE") == "TRUE" +DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "TRUE") == "TRUE" +DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" +DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" or torch.cuda.get_device_capability("cuda")[0] < 9 +DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE" +DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE" +DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" +DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" +DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" + +COMPILED_HDIMS = ( + [] + + ([64] if not DISABLE_HDIM64 else []) + + ([96] if not DISABLE_HDIM96 else []) + + ([128] if not DISABLE_HDIM128 else []) + + ([192] if not DISABLE_HDIM192 else []) + + ([256] if not DISABLE_HDIM256 else []) +) + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("V_colmajor", [False, True]) +@pytest.mark.parametrize("V_colmajor", [False]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", [64, 128, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128, 192]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (64, 128), + (128, 192), + (256, 256), + (239, 1), + (799, 3), + (113, 203), + (113, 128), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (4096, 4096), + (4224, 4224), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +def test_flash_attn_output( + seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype +): + if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): + pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") + device = "cuda" + # set seed + torch.random.manual_seed(0) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_k <= 2048 else 2 + # batch_size = 1 + nheads = 6 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4) + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)).tolist() + # window_size = (-1, -1) if not local else (16, 0) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().to(dtype).requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach().to(dtype).requires_grad_() if has_qv else None + if V_colmajor: + v = rearrange(rearrange(v.detach(), "b s h d -> b h d s").contiguous(), "b h d s -> b s h d").requires_grad_() + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + None, + None, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() + # if qv is not None: + # qk += torch.einsum('bshd,bthd->bhst', qv_ref, v_ref).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) + # lse_ref = torch.logsumexp(qk, dim=-1) + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out = flash_attn_func( + q, + k, + v, + causal=causal, + qv=qv, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits + ) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not V_colmajor + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g = torch.randn_like(out) + do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) + # import flash_attn_3_cuda + # dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum = flash_attn_3_cuda.bwd( + # g, + # q, + # k, + # v, + # out, + # lse, + # None, + # None, + # None, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.transpose(1, 2).unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float16] if not DISABLE_FP16 else []) + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("add_unused_qkv", [False, True]) +# @pytest.mark.parametrize("add_unused_qkv", [True]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128]) +@pytest.mark.parametrize("d", COMPILED_HDIMS) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (1, 3), + (2, 1), + (511, 1), + (3, 513), + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (307, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +def test_flash_attn_varlen_output( + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype +): + device = "cuda" + # set seed + torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_q <= 2048 else 2 + nheads = 6 + # batch_size = 2 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref).requires_grad_() + if has_qv: + qv_ref = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="random", zero_lengths=True + ) + + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + ( + q_unpad, + k_unpad, + v_unpad, + qv_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + qv, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, qv=qv, kvpacked=False, + query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask) + q_unpad, k_unpad, v_unpad = [x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out_unpad = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + causal=causal, + qv=qv_unpad, + q_descale=q_descale, + k_descale=k_descale, v_descale=v_descale, + window_size=window_size, + attention_chunk=attention_chunk, + softcap=softcap, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol + + + if ( + not DISABLE_BACKWARD + and dtype != torch.float8_e4m3fn + and not has_qv + and not dv > 256 + and not attention_chunk != 0 + ): + g_unpad = torch.randn_like(out_unpad) + do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + # import flash_attn_3_cuda + # dq_unpad, dk_unpad, dv_unpad, softmax_d, dq_accum, lse_log2 = flash_attn_3_cuda.bwd_varlen( + # g_unpad, + # q_unpad, + # k_unpad, + # v_unpad, + # out_unpad, + # lse, + # None, + # None, + # None, + # cu_seqlens_q, + # cu_seqlens_k, + # None, None, + # max_seqlen_q, + # max_seqlen_k, + # d ** (-0.5), + # causal, + # window_size[0], window_size[1], + # softcap, + # deterministic, + # 0, # sm_margin + # ) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) + + # qk = torch.einsum('bthd,bshd->bhts', q / (d ** 0.5), k).float() + # qk = torch.masked_fill(qk, rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - (g.float() * out.float()).sum(-1).transpose(1, 2).unsqueeze(-1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + # breakpoint() + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol + + +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize("dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) +# @pytest.mark.parametrize("new_kv", [True]) +@pytest.mark.parametrize("new_kv", [False, True]) +@pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) +# @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) +# @pytest.mark.parametrize("causal,local", [(False, False)]) +@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False]) +@pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) +# @pytest.mark.parametrize("rotary_interleaved", [True]) +@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) +# @pytest.mark.parametrize("rotary_fraction", [0.0]) +@pytest.mark.parametrize("page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else [])) +# @pytest.mark.parametrize("page_size", [None]) +@pytest.mark.parametrize("has_leftpad", [False]) +# @pytest.mark.parametrize("has_leftpad", [False]) +@pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("has_batch_idx", [False]) +@pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +@pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [192]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 128), + (1, 339), + (3, 1024), + (64, 800), + (64, 256), + (3, 799), + (64, 2048), + (16, 20000), + # (1, 128 * 1024), + # (16, 128 * 1024), + (128, 128), + (256, 512), # To test appending KV with more than 1 block + (2048, 3577), # Enough tile to test persistent scheduler + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +def test_flash_attn_kvcache( + seqlen_q, + seqlen_k, + d, + varlen_q, + has_batch_idx, + has_leftpad, + page_size, + rotary_fraction, + rotary_interleaved, + has_rotary_seqlens, + seqlen_new_eq_seqlen_q, + causal, + local, + new_kv, + mha_type, + dtype, +): + if page_size is not None and seqlen_k % page_size != 0: + pytest.skip() + if seqlen_q > seqlen_k and new_kv: + pytest.skip() + if not new_kv and rotary_fraction > 0.0: + pytest.skip() + if rotary_fraction == 0.0 and has_rotary_seqlens: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(0) + batch_size = 5 + # batch_size = 1 + batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 + nheads = 6 + # nheads = 1 + # rotary_dim must be a multiple of 16, and must be <= d + rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + assert nheads % nheads_k == 0 + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] + for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + has_qv = d == 64 and dv >= 256 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if has_qv: + qv = torch.randn(batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + else: + qv = None + if varlen_q: + query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(q, query_padding_mask) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + qv_unpad = rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None + else: + query_padding_mask = None + q_unpad = q + qv_unpad = qv + cu_seqlens_q, max_seqlen_q = None, None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + + seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() + cu_seqlens_k_new = None + key_new_padding_mask = None + if new_kv: + k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v = torch.randn(batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + if varlen_q: # k & v are also varlen + key_new_padding_mask = generate_random_padding_mask(seqlen_new, batch_size, device, mode="random") + k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(k, key_new_padding_mask) + v_unpad, *rest = unpad_input(v, key_new_padding_mask) + else: + k_unpad, v_unpad = k, v + else: + k, v, k_unpad, v_unpad = None, None, None, None + if page_size is None: + k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, dv, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) + page_table = None + else: + ( + k_cache, + v_cache, + page_table, + k_cache_paged, + v_cache_paged, + num_blocks, + ) = _generate_block_kvcache( + seqlen_k, page_size, batch_size_cache, nheads_k, d, dv, device, dtype, dtype_ref + ) + cache_seqlens = torch.randint( + 0 if new_kv else 1, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + ( + (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) + if new_kv + else (seqlen_k + 1) + ), + (batch_size,), + dtype=torch.int32, + device=device, + ) + if has_leftpad: + cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) + if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size)]) + else: + cache_leftpad = None + if has_batch_idx: + cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ + :batch_size + ] + else: + cache_batch_idx = None + arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + if not new_kv: + key_padding_mask = arange < cache_seqlens_expanded + else: + k_new_seqlens = key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new + key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) + ) + # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) + rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2 + if rotary_dim > 0: + angle = ( + torch.rand( + seqlen_k if page_size is None else num_blocks * page_size, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref) + if causal or local: + q_ro = apply_rotary_emb( + q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + ) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=seqlen_q, + ) + # q_ro = q + k_ro = apply_rotary_emb( + k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved + ) + else: + cos, sin = None, None + q_ro, k_ro = q, k + # k_cache[:, 64:] = -1 + k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone() + v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone() + if new_kv: + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + k_new_seqlens + ) + k_to_update = rearrange(k_ro, "b s ... -> (b s) ...") + v_to_update = rearrange(v, "b s ... -> (b s) ...") + if varlen_q: + k_to_update = k_to_update[indices_k] + v_to_update = v_to_update[indices_k] + k_cache_ref[update_mask] = k_to_update + v_cache_ref[update_mask] = v_to_update + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) + out_ref, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + attention_chunk=attention_chunk, + key_leftpad=cache_leftpad, + ) + out_pt, _ = attention_ref( + q_ro, + k_cache_rep, + v_cache_rep, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv, + window_size=window_size, + attention_chunk=attention_chunk, + upcast=False, + reorder_ops=True, + key_leftpad=cache_leftpad, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None + ) + q = q.to(dtype) + q_unpad = q_unpad.to(dtype) if varlen_q else None + k_cache = k_cache.to(dtype) + v_cache = v_cache.to(dtype) + k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None + v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None + k = k.to(dtype) if k is not None else None + v = v.to(dtype) if v is not None else None + k_unpad = k_unpad.to(dtype) if k_unpad is not None else None + v_unpad = v_unpad.to(dtype) if v_unpad is not None else None + qv = qv.to(dtype) if qv is not None else None + qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None + cos = cos.to(dtype) if cos is not None else None + sin = sin.to(dtype) if sin is not None else None + k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() + v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() + num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] + precompute_metadata_vals = [False] + for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): + if precompute_metadata: + scheduler_metadata = get_scheduler_metadata( + batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, + cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, + max_seqlen_k_new=seqlen_new, page_size=page_size, + causal=causal, window_size=window_size, attention_chunk=attention_chunk, + num_splits=num_splits + ) + else: + scheduler_metadata = None + # Repeat to test metadata reuse + for _ in range(1 if not precompute_metadata else 2): + if page_size is None: + k_cache.copy_(k_cache_saved) + v_cache.copy_(v_cache_saved) + else: + k_cache_paged.copy_(k_cache_saved) + v_cache_paged.copy_(v_cache_saved) + out, lse, *rest = flash_attn_with_kvcache( + q if not varlen_q else q_unpad, + k_cache if page_size is None else k_cache_paged, + v_cache if page_size is None else v_cache_paged, + k if not new_kv or not varlen_q else k_unpad, + v if not new_kv or not varlen_q else v_unpad, + qv=qv if not varlen_q else qv_unpad, + rotary_cos=cos, + rotary_sin=sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + page_table=page_table, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k_new, + max_seqlen_q=max_seqlen_q, + rotary_seqlens=rotary_seqlens, + causal=causal, + window_size=window_size, + attention_chunk=attention_chunk, + rotary_interleaved=rotary_interleaved, + scheduler_metadata=scheduler_metadata, + num_splits=num_splits, + return_softmax_lse=True + ) + if varlen_q: + out = output_pad_fn(out) + # out = flash_attn_with_kvcache( + # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size + # ) + # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) + # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) + # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) + # probs = torch.softmax(qk, dim=-1) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + if new_kv: + if page_size is None: + k_cache_select = ( + k_cache.to(dtype_ref) if not has_batch_idx else k_cache.to(dtype_ref)[cache_batch_idx] + ) + v_cache_select = ( + v_cache.to(dtype_ref) if not has_batch_idx else v_cache.to(dtype_ref)[cache_batch_idx] + ) + else: + k_cache_select = rearrange( + k_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + v_cache_select = rearrange( + v_cache_paged.to(dtype_ref)[(page_table if not has_batch_idx else page_table[cache_batch_idx]).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k].to(dtype_ref) + k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref) + v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref) + if dtype is not torch.float8_e4m3fn: + assert torch.equal(v_cache_select, v_cache_ref) + else: + assert torch.allclose(v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3) + # breakpoint() + # if rotary_dim == 0 and dtype is not torch.float8_e4m3fn: + if rotary_dim == 0: + assert torch.equal(k_cache_select, k_cache_ref) + else: + # if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3): + # breakpoint() + if dtype is not torch.float8_e4m3fn: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) + else: + assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1) + mult = 4 if dtype == torch.float8_e4m3fn else 2 + assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 + mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5 + assert (out - out_ref).abs().mean().item() <= mult_mean * (out_pt - out_ref).abs().mean().item() + + +def _generate_block_kvcache(seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref): + num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3 + k_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) + v_cache_paged = torch.randn( + num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref + ).to(dtype).to(dtype_ref) + page_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + k_cache = rearrange( + k_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + v_cache = rearrange( + v_cache_paged[page_table.flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize('d', [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (64, 8192), + ], +) +def test_flash_attn_cluster(seqlen_q, seqlen_k, d, causal, dtype): + device = "cuda" + torch.random.manual_seed(0) + batch_size = 2 + nheads = 16 + nheads_kv = 4 + # There was a bug where this would cause "unspecified launch failure" due to Cluster + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype) + for _ in range(100): + flash_attn_func(q, k, v, causal=causal) + + +# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize('causal', [False]) +@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128]) +# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [80]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 239), + (239, 1), + (3, 799), + (799, 3), + (1024, 128), + (97, 97), + (128, 128), + (200, 200), + (256, 256), + (257, 257), + (384, 384), + (512, 512), + (768, 768), + (1024, 1024), + (2048, 2048), + ], +) +@pytest.mark.skip(reason="Cannot be run in parallel with other tests due to memory usage") +def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, causal, dtype): + device = "cuda" + # set seed + torch.random.manual_seed(0) + # Simulate under memory load + dummy = torch.empty(70 * 1024 ** 3, dtype=torch.uint8, device=device) + batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger + nheads = 4 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) + torch.random.manual_seed(42) + out0 = flash_attn_func(q, k, v, causal=causal) + g = torch.randn_like(out0) + dq0, dk0, dv0 = torch.autograd.grad(out0, (q, k, v), g) + # Numerical error if we just do any arithmetic on dq + dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() + + for i in range(1000): + torch.random.manual_seed(42) + out = flash_attn_func(q, k, v, causal=causal) + assert torch.equal(out, out0) + # assert torch.equal(lse, lse0) + + dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_equal = torch.allclose(dq, dq0, atol=dq_atol) + if not dq_equal: + print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") + # breakpoint() + assert torch.equal(dv, dv0) + assert torch.equal(dk, dk0) + assert dq_equal + + +def attention_combine_ref(out_partial, lse_partial): + """ + out_partial: (num_splits, batch_size, seqlen, nheads, d) + lse_partial: (num_splits, batch_size, nheads, seqlen) + """ + lse = torch.logsumexp(lse_partial, dim=0) + scale = torch.exp(lse_partial - lse) + scale = torch.where(torch.isinf(scale) | torch.isnan(scale), torch.zeros_like(scale), scale) + out = (scale.unsqueeze(-1) * out_partial).sum(0) + return out, lse + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float32]) +# @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) +@pytest.mark.parametrize("d", [64, 96, 128, 192, 256, 512]) +# @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("seqlen", [1, 2, 3, 32, 64, 256, 113, 108, 640, 1024]) +# @pytest.mark.parametrize("seqlen", [12, 32, 64, 256, 112, 108, 640, 1024, 2048, 8192]) +# @pytest.mark.parametrize("seqlen", [15]) +@pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 17, 32, 55, 97, 133]) +# @pytest.mark.parametrize("num_splits", [1, 2, 3, 5, 11]) +# @pytest.mark.parametrize("num_splits", [128]) +def test_flash_attn_combine(num_splits, seqlen, d, dtype): + if DISABLE_SPLIT: + pytest.skip() + device = "cuda" + # set seed + torch.random.manual_seed(1) + batch_size = 5 + nheads = 16 + # batch_size = 1 + # nheads = 1 + out_partial = torch.randn(num_splits * 2, batch_size, nheads, seqlen, d, device=device, dtype=torch.float32).transpose(2, 3)[:num_splits] # To test non-contiguous tensor + lse_partial = torch.randn(num_splits, batch_size, nheads * 2, seqlen, device=device, dtype=torch.float32).transpose(-1, -2)[:, :, :, :nheads] # To test non-contiguous tensor + # To test short-circuiting based on num_splits + lse_partial[num_splits // 2:, :batch_size // 3] = -float("inf") + out, lse = flash_attn_combine(out_partial, lse_partial, out_dtype=dtype) + out_ref, lse_ref = attention_combine_ref(out_partial, lse_partial) + out_pt = out_ref.to(dtype) + + print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + print(f"LSE mean diff: {(lse - lse_ref).abs().mean().item()}") + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # breakpoint() + + assert torch.allclose(lse, lse_ref, atol=1e-5, rtol=1e-5) + multiple = 2 + assert ((out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item()) or torch.allclose(out, out_pt, atol=1e-5, rtol=1e-5) + + # from flash_attn.utils.benchmark import pytorch_profiler + # # pytorch_profiler(torch.sum, lse_partial) + # pytorch_profiler(flash_attn_combine, out_partial, lse_partial) + # pytorch_profiler(torch.sum, out_partial) + +@pytest.mark.skip(reason="AMD Triton backend doesn't use torch ops registration") +def test_flash3_bw_compatibility() -> None: + # Let's try to always stay backward compatible! This will make life easier + # for downstream libaries, users, and exported models. + # 1/ Instead of removing arguments, error out if their value is no longer supported + # 2/ When adding arguments, add them at the end with a default value + assert torch.ops.flash_attn_3.fwd.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::fwd(Tensor q, Tensor k, Tensor v, Tensor(k_new!)? k_new=None, " + "Tensor(v_new!)? v_new=None, Tensor? q_v=None, Tensor(out!)? out=None, " + "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, " + "Tensor? cu_seqlens_k_new=None, Tensor? seqused_q=None, Tensor? seqused_k=None, " + "int? max_seqlen_q=None, int? max_seqlen_k=None, Tensor? page_table=None, " + "Tensor? kv_batch_idx=None, Tensor? leftpad_k=None, Tensor? rotary_cos=None, Tensor? rotary_sin=None, " + "Tensor? seqlens_rotary=None, Tensor? q_descale=None, Tensor? k_descale=None, Tensor? v_descale=None, " + "float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " + "int attention_chunk=0, float softcap=0., bool is_rotary_interleaved=False, " + "Tensor? scheduler_metadata=None, int num_splits=0, bool? pack_gqa=None, int sm_margin=0) " + "-> (Tensor(out!), Tensor, Tensor, Tensor)" + )) + assert torch.ops.flash_attn_3.bwd.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::bwd(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, " + "Tensor(dq!)? dq=None, Tensor(dk!)? dk=None, Tensor(dv!)? dv=None, Tensor? cu_seqlens_q=None, " + "Tensor? cu_seqlens_k=None, Tensor? seqused_q=None, Tensor? seqused_k=None, int? max_seqlen_q=None, " + "int? max_seqlen_k=None, float? softmax_scale=None, bool is_causal=False, int window_size_left=-1, " + "int window_size_right=-1, float softcap=0., bool deterministic=False, int sm_margin=0) " + "-> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)" + )) + assert torch.ops.flash_attn_3.fwd_combine.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::fwd_combine(Tensor out_partial, Tensor lse_partial, Tensor(out!)? out=None, " + "ScalarType? out_dtype=None) -> (Tensor(out!), Tensor)" + )) + assert torch.ops.flash_attn_3.get_scheduler_metadata.default._schema.is_backward_compatible_with(parse_schema( + "flash_attn_3::get_scheduler_metadata(int batch_size, int max_seqlen_q, int max_seqlen_k, " + "int num_heads, int num_heads_k, int headdim, int headdim_v, ScalarType qkv_dtype, Tensor seqused_k, " + "Tensor? cu_seqlens_q=None, Tensor? cu_seqlens_k=None, Tensor? cu_seqlens_k_new=None, " + "Tensor? seqused_q=None, Tensor? leftpad_k=None, int? page_size=None, int max_seqlen_k_new=0, " + "bool is_causal=False, int window_size_left=-1, int window_size_right=-1, " + "int attention_chunk=0, bool has_softcap=False, int num_splits=0, bool? pack_gqa=None, " + "int sm_margin=0) -> Tensor" + )) diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index adebe088a79..ba1932438e2 100755 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -1921,7 +1921,7 @@ def test_flash_attn_splitkv( # @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) -@pytest.mark.parametrize("paged_kv_block_size", [None]) +@pytest.mark.parametrize("paged_kv_block_size", [None, 256]) # @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) # @pytest.mark.parametrize("paged_kv_block_size", [None]) @pytest.mark.parametrize("has_leftpad", [False])