diff --git a/.gitignore b/.gitignore index 9533a45b15..80ef837fd1 100644 --- a/.gitignore +++ b/.gitignore @@ -49,3 +49,7 @@ __pycache__ debug aiter_logs *.log + +# artifacts +aiter_meta +aiter/install_mode \ No newline at end of file diff --git a/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/__init__.py b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/__init__.py new file mode 100644 index 0000000000..78f85fb268 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/__init__.py @@ -0,0 +1,4 @@ +from . import interface_v2 as flash_attn_2 +from . import interface_v3 as flash_attn_3 + +__all__ = ["flash_attn_2", "flash_attn_3"] diff --git a/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/bwd.py b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/bwd.py new file mode 100755 index 0000000000..f75d9977f0 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/bwd.py @@ -0,0 +1,4941 @@ +import os +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +import warnings +from typing import Literal, Optional +from .utils import ( + DEBUG, + AUTOTUNE, + FP8_AUTO_DESCALE, + compute_fp8_scaling_factors, + get_cu_count, + is_cdna, + is_fp8, + get_arch, +) + + +def get_bwd_configs(autotune: bool): + # keys + preprocess_autotune_keys = [ + "max_seqlen_q", + "ACTUAL_HEAD_DIM", + "IS_VARLEN", + ] + + causal_autotune_keys = [ + "dropout_p", + "max_seqlen_q", + "max_seqlen_k", + "ACTUAL_HEAD_DIM", + "IS_VARLEN", + "HQ", + "HK", + ] + + noncausal_autotune_keys = [ + "dropout_p", + "max_seqlen_q", + "max_seqlen_k", + "ACTUAL_HEAD_DIM", + "IS_VARLEN", + "HQ", + "HK", + ] + + # default config + if not autotune: + arch = get_arch() + # configs for the kernels + if arch == "gfx942": + if get_cu_count() < 304: + preprocess_autotune_configs = [ + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 1}, num_stages=1, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 128, "waves_per_eu": 2}, num_stages=1, num_warps=4 + ), + ] + noncausal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 2, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=8, + ), + ] + causal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + ] + else: + preprocess_autotune_configs = [ + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 1}, num_stages=1, num_warps=4 + ), + ] + noncausal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 2, + }, + num_stages=1, + num_warps=4, + ), + ] + causal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + ] + elif arch == "gfx950": + preprocess_autotune_configs = [ + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=1, num_warps=8 + ), + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=4 + ), + ] + noncausal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 128, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 16, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 2, + }, + num_stages=1, + num_warps=4, + ), + ] + causal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M1": 64, + "BLOCK_N1": 64, + "BLOCK_M2": 64, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + ] + else: + preprocess_autotune_configs = [ + triton.Config( + {"PRE_BLOCK": 64, "waves_per_eu": 2}, num_stages=2, num_warps=8 + ), + ] + noncausal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + ] + causal_autotune_configs = [ + triton.Config( + { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 64, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + }, + num_stages=1, + num_warps=4, + ), + ] + + # assert constraints + for noncausal_cfg, causal_cfg in zip( + noncausal_autotune_configs, causal_autotune_configs + ): + assert ( + noncausal_cfg.all_kwargs()["BLOCK_N1"] + == noncausal_cfg.all_kwargs()["BLOCK_M2"] + ), f"BLOCK_N1 ({noncausal_cfg.all_kwargs()['BLOCK_N1']}) must equal BLOCK_M2 ({noncausal_cfg.all_kwargs()['BLOCK_M2']})" + assert ( + causal_cfg.all_kwargs()["BLOCK_N1"] + == causal_cfg.all_kwargs()["BLOCK_M2"] + ), f"BLOCK_N1 ({causal_cfg.all_kwargs()['BLOCK_N1']}) must equal BLOCK_M2 ({causal_cfg.all_kwargs()['BLOCK_M2']})" + + return ( + (preprocess_autotune_configs, preprocess_autotune_keys), + (causal_autotune_configs, causal_autotune_keys), + (noncausal_autotune_configs, noncausal_autotune_keys), + ) + + # param options + PRE_BLOCK_OPTIONS = [64, 128] # og: 128 + PRE_WAVES_PER_EU_OPTIONS = [1, 2] + PRE_NUM_STAGES_OPTIONS = [1, 2] + PRE_NUM_WARPS_OPTIONS = [4, 8] + NUM_STAGES_OPTIONS = [1, 2] # og: 1 + NUM_WARPS_OPTIONS = [4, 8] # og: 4 + WAVES_PER_EU_OPTIONS = [1, 2] # og: 1 + NON_CAUSAL_BLOCK_M1_OPTIONS = [16, 32, 64, 128] # og: 32 + NON_CAUSAL_BLOCK_N1_M2_OPTIONS = [32, 64, 128, 256] # og: 128 + NON_CAUSAL_BLOCK_N2_OPTIONS = [16, 32, 64, 128] # og: 32 + CAUSAL_BLOCK_M1_OPTIONS = [32, 64] # og: 32 + CAUSAL_BLOCK_N1_M2_OPTIONS = [32, 64, 128] # og: 128 + CAUSAL_BLOCK_N2_OPTIONS = [32, 64] # og: 32 + BLK_SLICE_FACTOR_OPTIONS = [2] # og: 2 + + # ==================== sweep configs ================================ + preprocess_autotune_configs = [] + for pre_num_warps in PRE_NUM_WARPS_OPTIONS: + for pre_num_stages in PRE_NUM_STAGES_OPTIONS: + for pre_waves in PRE_WAVES_PER_EU_OPTIONS: + for pre_block in PRE_BLOCK_OPTIONS: + preprocess_autotune_configs.append( + triton.Config( + { + "PRE_BLOCK": pre_block, + "waves_per_eu": pre_waves, + }, + num_stages=pre_num_stages, + num_warps=pre_num_warps, + ) + ) + + causal_autotune_configs = [] + for num_warps in NUM_WARPS_OPTIONS: + for num_stages in NUM_STAGES_OPTIONS: + for waves in WAVES_PER_EU_OPTIONS: + for m1 in CAUSAL_BLOCK_M1_OPTIONS: + for n1 in CAUSAL_BLOCK_N1_M2_OPTIONS: + m2 = n1 + for n2 in CAUSAL_BLOCK_N2_OPTIONS: + # Ensure constraint + assert ( + n1 == m2 + ), f"BLOCK_N1 ({n1}) must equal BLOCK_M2 ({m2})" + + # Skip configs where BLOCK_M2 % BLOCK_N2 != 0 + if m2 % n2 != 0: + continue + + # Skip configs where BLOCK_N1 % BLOCK_M1 != 0 + if n1 % m1 != 0: + continue + + for blk_slice in BLK_SLICE_FACTOR_OPTIONS: + causal_autotune_configs.append( + triton.Config( + { + "BLOCK_M1": m1, + "BLOCK_N1": n1, + "BLOCK_M2": m2, + "BLOCK_N2": n2, + "BLK_SLICE_FACTOR": blk_slice, + "waves_per_eu": waves, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + + noncausal_autotune_configs = [] + for num_warps in NUM_WARPS_OPTIONS: + for num_stages in NUM_STAGES_OPTIONS: + for waves in WAVES_PER_EU_OPTIONS: + for m1 in NON_CAUSAL_BLOCK_M1_OPTIONS: + for n1 in NON_CAUSAL_BLOCK_N1_M2_OPTIONS: + m2 = n1 + for n2 in NON_CAUSAL_BLOCK_N2_OPTIONS: + # Ensure constraint + assert ( + n1 == m2 + ), f"BLOCK_N1 ({n1}) must equal BLOCK_M2 ({m2})" + + # Skip configs where BLOCK_M2 % BLOCK_N2 != 0 + if m2 % n2 != 0: + continue + + # Skip configs where BLOCK_N1 % BLOCK_M1 != 0 + if n1 % m1 != 0: + continue + + for blk_slice in BLK_SLICE_FACTOR_OPTIONS: + noncausal_autotune_configs.append( + triton.Config( + { + "BLOCK_M1": m1, + "BLOCK_N1": n1, + "BLOCK_M2": m2, + "BLOCK_N2": n2, + "BLK_SLICE_FACTOR": blk_slice, + "waves_per_eu": waves, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + + return ( + (preprocess_autotune_configs, preprocess_autotune_keys), + (causal_autotune_configs, causal_autotune_keys), + (noncausal_autotune_configs, noncausal_autotune_keys), + ) + + +# os.environ["TRITON_PRINT_AUTOTUNING"] = "1" +( + (preprocess_autotune_configs, preprocess_autotune_keys), + (causal_autotune_configs, causal_autotune_keys), + (noncausal_autotune_configs, noncausal_autotune_keys), +) = get_bwd_configs(AUTOTUNE) + + +@triton.jit +def _bwd_dq_inner_split( + dq, + q, + K, + V, + do, + m, + Delta, + sm_scale, + stride_qm, + stride_qk, + stride_kn, + stride_kk, + stride_vn, + stride_vk, + stride_dropout_m, + stride_dropout_n, + stride_deltam, + seqlen_q, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + RCP_LN2: tl.constexpr = 1.4426950408889634 + + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # mask to make sure not OOB of seqlen_q + mask_m = offs_m < seqlen_q + + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk + vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk + + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) + + curr_n = start_n + step_n = BLOCK_N + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + for blk_idx in range(num_steps): + offs_n = curr_n + tl.arange(0, BLOCK_N) + # end_n is needed because the end of causal True might not be perfectly + # aligned with the end of the block + mask_n = offs_n < end_n + mask_kT = mask_n[None, :] + mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) + if PADDED_HEAD: + mask_kT &= offs_k[:, None] < BLOCK_D_MODEL + + kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) + vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) + + # dropout + if ENABLE_DROPOUT: + philox_offs = ( + curr_philox_offset + + offs_m[:, None] * stride_dropout_m + + offs_n[None, :] * stride_dropout_n + ) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1 / (1 - dropout_p) + + # qk + if IS_FP8: + qk = tl.dot(q, kT) * descale_q * descale_k + else: + qk = tl.dot(q, kT) + p = tl.math.exp2(qk * sm_scale * RCP_LN2 - m * RCP_LN2) + + if MASK: + causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] + mask = causal_mask * mask_mn + p = tl.where(mask, p, 0.0) + + # dp + if IS_FP8: + dp = tl.dot(do.to(vT.type.element_ty), vT) * descale_v + else: + dp = tl.dot(do, vT) + + if ENABLE_DROPOUT: + dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale + + # ds + delta_i = Di[:, None] + ds = p * (dp - delta_i) + + # dq + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + if IS_FP8: + # Rewrite dq += ds @ kT.T as dq += (kT @ ds.T).T + # This puts FP8 tensor (kT) on LHS of dot product + # Cast the transposed ds to FP8 to match kT's dtype + ds_transposed = tl.trans(ds).to(kT.type.element_ty) + dq += tl.trans(tl.dot(kT, ds_transposed)) * descale_k + else: + dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) + + curr_n += step_n + kT_ptrs += step_n * stride_kn + vT_ptrs += step_n * stride_vn + return dq + + +@triton.jit +def _bwd_dkdv_inner_split( + dk, + dv, + Q, + k, + v, + DO, + M, + D, + sm_scale, + stride_q_m, + stride_q_k, + stride_do_m, + stride_do_k, + stride_dropout_m, + stride_dropout_n, + stride_deltam, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + seqlen_q, + seqlen_k, + start_n, + start_m, + num_steps, + descale_q, + descale_k, + descale_v, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + qT_ptrs = ( + Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k + ) # [BLOCK_D_MODEL_POW2, BLOCK_M] + do_ptrs = DO + offs_m[:, None] * stride_do_m + offs_k[None, :] * stride_do_k + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 + + # Iterate over blocks(BLOCK_M size) of Q while calculating + # a fixed block(BLOCK_N) of dk and dv. Note, during backward + # pass P has to be recomputed. However, this kernel computes + # dV and dK, so we compute we need P^T and S^T. See backward pass + # equations + # + # From Flash Attention Paper: + # ForwardPass: S = QkT, P=softmax(S), O=PV + # + # BackwardPass equations + # dV = P^TdO + # dP = dOV^T + # dS = dsoftmax(dP) + # dQ = dSK + # dK = QdS^T + for blk_idx in range(num_steps): + offs_m = curr_m + tl.arange(0, BLOCK_M) + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + if PADDED_HEAD: + mask_qT &= offs_k[:, None] < BLOCK_D_MODEL + mask_do &= offs_k[None, :] < BLOCK_D_MODEL + + # load qT + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + + # dropout + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = ( + curr_philox_offset + + offs_m[None, :] * stride_dropout_m + + offs_n[:, None] * stride_dropout_n + ) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + + # Load M + m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + + # Compute qkT + if IS_FP8: + qkT = tl.dot(k, qT) * descale_q * descale_k + else: + qkT = tl.dot(k, qT) + + # Compute pT(use m and also apply sm_scale) + pT = tl.math.exp(qkT * sm_scale - m[None, :]) + + if MASK: + causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] + mask = causal_mask & mask_nm + pT = tl.where(mask, pT, 0.0) + + # load DO + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + + # dV + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + # Load delta + Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + + # Compute dP and dS + if IS_FP8: + dpT = tl.dot(v, tl.trans(do.to(v.type.element_ty))) * descale_v + else: + dpT = tl.dot(v, tl.trans(do)) + + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + + # compute dk + if IS_FP8: + # Rewrite dk += dsT @ qT.T as dk += (qT @ dsT.T).T + # This puts FP8 tensor (qT) on LHS of dot product + # Cast the transposed dsT to FP8 to match qT's dtype + dsT_transposed = tl.trans(dsT).to(qT.type.element_ty) + dk += tl.trans(tl.dot(qT, dsT_transposed)) * descale_q + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + + # increment pointers + curr_m += step_m + qT_ptrs += step_m * stride_q_m + do_ptrs += step_m * stride_do_m + + return dk, dv + + +@triton.jit +def _bwd_dkdvdq_inner_atomic( + dk, + dv, + Q, + k, + v, + DO, + DQ, + M, + D, + sm_scale, + stride_q_m, + stride_q_k, + stride_dq_m, + stride_dq_k, + stride_do_m, + stride_do_k, + stride_dropout_m, + stride_dropout_n, + stride_deltam, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + seqlen_q, + seqlen_k, + start_n, + start_m, + num_steps, + descale_q, + descale_k, + descale_v, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + workgroup_id: tl.int32, +): + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + + qT_ptrs_start = ( + Q + offs_m[None, :] * stride_q_m + offs_k[:, None] * stride_q_k + ) # [BLOCK_D_MODEL_POW2, BLOCK_M] + dq_ptrs_start = ( + DQ + offs_m[:, None] * stride_dq_m + offs_k[None, :] * stride_dq_k + ) # [BLOCK_M, BLOCK_D_MODEL_POW2] + + do_ptrs_start = DO + offs_m[:, None] * stride_do_m + offs_k[None, :] * stride_do_k + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 + + # Iterate over blocks(BLOCK_M size) of Q while calculating + # a fixed block(BLOCK_N) of dk and dv. Note, during backward + # pass P has to be recomputed. However, this kernel computes + # dV and dK, so we compute we need P^T and S^T. See backward pass + # equations + # + # From Flash Attention Paper: + # ForwardPass: S = QkT, P=softmax(S), O=PV + # + # BackwardPass equations + # dV = P^TdO + # dP = dOV^T + # dS = dsoftmax(dP) + # dQ = dSK + # dK = QdS^T + + # Compute a starting index and step based on workgroup_id + # Use a simple hash-like function to spread out the starting points + start_idx = ( + workgroup_id * 17 + ) % num_steps # 17 is an arbitrary prime to spread indices + # Ensure step is coprime with num_steps to visit all indices exactly once + step = 1 # 3 if num_steps > 1 or num_steps==3 else 1 # coprime with num_steps + + for iter in range(num_steps): + # Compute the permuted block index + blk_idx = (start_idx + iter * step) % num_steps + + curr_m = start_m + blk_idx * step_m + qT_ptrs = qT_ptrs_start + blk_idx * step_m * stride_q_m + dq_ptrs = dq_ptrs_start + blk_idx * step_m * stride_dq_m + do_ptrs = do_ptrs_start + blk_idx * step_m * stride_do_m + + offs_m = curr_m + tl.arange(0, BLOCK_M) + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + + if PADDED_HEAD: + mask_qT &= offs_k[:, None] < BLOCK_D_MODEL + mask_do &= offs_k[None, :] < BLOCK_D_MODEL + + # load qT + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + + # dropout + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = ( + curr_philox_offset + + offs_m[None, :] * stride_dropout_m + + offs_n[:, None] * stride_dropout_n + ) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + + # Load M + m = tl.load(M + offs_m * stride_deltam, mask=mask_m, other=0.0) + + # Compute qkT + if IS_FP8: + qkT = tl.dot(k, qT) * descale_q * descale_k + else: + qkT = tl.dot(k, qT) + + # Compute pT(use m and also apply sm_scale) + pT = tl.math.exp(qkT * sm_scale - m[None, :]) + + if MASK: + causal_mask = (offs_m[None, :] - delta_qk) >= (offs_n[:, None]) + mask = causal_mask & mask_nm + pT = tl.where(mask, pT, 0.0) + + # load DO + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + + # dV + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + # Load delta + Di = tl.load(D + offs_m * stride_deltam, mask=mask_m) + + # Compute dP and dS + if IS_FP8: + dpT = tl.dot(v, tl.trans(do.to(v.type.element_ty))) * descale_v + else: + dpT = tl.dot(v, tl.trans(do)) + + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + + # compute dk + if IS_FP8: + # Rewrite dk += dsT @ qT.T as dk += (qT @ dsT.T).T + # This puts FP8 tensor (qT) on LHS of dot product + # Cast the transposed dsT to FP8 to match qT's dtype + dsT_transposed = tl.trans(dsT).to(qT.type.element_ty) + dk += tl.trans(tl.dot(qT, dsT_transposed)) * descale_q + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + + # We can compute the dq_partial here and do a atomic add to the correct memory location + # NOTE: Possible problems with the atomic add: contention, is inside a loop which has achieved bad perf before + # (BLOCK_M, BLOCK_N) x (BLOCK_N, D) + if IS_FP8: + dq_partial = tl.dot(dsT.to(k.type.element_ty).T, k) * descale_k + else: + dq_partial = tl.dot(dsT.to(k.type.element_ty).T, k) + tl.atomic_add( + dq_ptrs, + dq_partial * sm_scale, + mask=mask_m[:, None], + sem="relaxed", + ) + + return dk, dv + + +@triton.jit +def _bwd_kernel_fused_atomic_causal( + q_ptr, + k_ptr, + v_ptr, + sm_scale, + do_ptr, + dk_ptr, + dv_ptr, + dq_ptr, + m_ptr, + delta_ptr, + stride_q_b, + stride_q_h, + stride_q_m, + stride_q_k, + stride_k_b, + stride_k_h, + stride_k_n, + stride_k_k, + stride_v_b, + stride_v_h, + stride_v_n, + stride_v_k, + stride_dk_b, + stride_dk_h, + stride_dk_n, + stride_dk_k, + stride_dq_b, + stride_dq_h, + stride_dq_m, + stride_dq_k, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_do_b, + stride_do_h, + stride_do_m, + stride_do_k, + stride_dropout_b, + stride_dropout_h, + stride_dropout_m, + stride_dropout_n, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BATCH, + NUM_K_PIDS, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + wid = tl.program_id(0) # workgoup id: 0, ..., NUM_K_PIDS * BATCH * NUM_K_HEADS - 1 + + # workgroups get launched first along batch dim, then in head_k dim, and then in seq k block dim + batch_idx = wid % BATCH + head_k_idx = wid // BATCH % NUM_K_HEADS + seq_k_blk_idx = wid // (BATCH * NUM_K_HEADS) % NUM_K_PIDS + + # Determine q and k start along with seqlen_q and seqlen_k + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + batch_idx) + q_end = tl.load(cu_seqlens_q + batch_idx + 1) + k_start = tl.load(cu_seqlens_k + batch_idx) + k_end = tl.load(cu_seqlens_k + batch_idx + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + # Figure out causal starting block since we have seqlen_q >=< seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + delta_qk = seqlen_q - seqlen_k + + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N + delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M + if delta_qk >= 0: + start_delta = delta_qk + else: + start_delta = start_delta_q_lt_k + + start_n = seq_k_blk_idx * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_kv &= mask_k[None, :] + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = ( + batch_idx * stride_k_b + + head_k_idx * stride_k_h + + k_start * stride_k_n + + offs_n[:, None] * stride_k_n + + offs_k[None, :] * stride_k_k + ) + adj_v = ( + batch_idx * stride_v_b + + head_k_idx * stride_v_h + + k_start * stride_v_n + + offs_n[:, None] * stride_v_n + + offs_k[None, :] * stride_v_k + ) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(k_ptr + adj_k, mask=mask_kv, other=0.0) + v = tl.load(v_ptr + adj_v, mask=mask_kv, other=0.0) + + # If MQA / GQA, set the K and V head offsets appropriately. + for head_q_idx in range( + head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE + ): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N + else: + start_m = max(start_n + delta_qk, 0) + start_m = (start_m // BLOCK_M) * BLOCK_M + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N + residue_m + + # offset input and output tensor by batch and Q/K heads + adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m + adj_dq = ( + batch_idx * stride_dq_b + head_q_idx * stride_dq_h + q_start * stride_dq_m + ) + + q_ptr_adj = q_ptr + adj_q + dq_ptr_adj = dq_ptr + adj_dq + + adj_do = ( + batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m + ) + do_ptr_adj = do_ptr + adj_do + adj_delta = ( + batch_idx * stride_delta_b + + head_q_idx * stride_delta_h + + q_start * stride_delta_m + ) + m_ptr_adj = m_ptr + adj_delta + delta_ptr_adj = delta_ptr + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + dropout_offset = ( + dropout_mask + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + + MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M) + + # when q < k, we may skip the initial masked op + # if seq_k_blk_idx < num_blocks_skip: + # num_steps = 0 + + if IS_FP8: + # For MQA/GQA, q_descale uses the same indexing as k/v (head_k_idx) + descale_q = tl.load( + descale_q_ptr + batch_idx * stride_descale_q_z + head_k_idx + ) + descale_k = tl.load( + descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx + ) + descale_v = tl.load( + descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx + ) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + # if unaligned start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + dk, dv = _bwd_dkdvdq_inner_atomic( + dk, + dv, # output tensors + q_ptr_adj, + k, + v, + do_ptr_adj, + dq_ptr_adj, + m_ptr_adj, + delta_ptr_adj, + sm_scale, # input tensors + stride_q_m, + stride_q_k, # strides for q + stride_dq_m, + stride_dq_k, # strides for q + stride_do_m, + stride_do_k, # strides for o + stride_dropout_m, + stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + MASK_BLOCK_M, + BLOCK_N, # block dim + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, # head dim + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=seq_k_blk_idx, + ) + + start_m += num_steps * MASK_BLOCK_M + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) + end_m = start_m + num_steps * BLOCK_M + + dk, dv = _bwd_dkdvdq_inner_atomic( + dk, + dv, # output tensors + q_ptr_adj, + k, + v, + do_ptr_adj, + dq_ptr_adj, + m_ptr_adj, + delta_ptr_adj, + sm_scale, # input tensors + stride_q_m, + stride_q_k, # strides for q + stride_dq_m, + stride_dq_k, # strides for dq + stride_do_m, + stride_do_k, # strides for o + stride_dropout_m, + stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + BLOCK_M, + BLOCK_N, # block dim + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, # head dim + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=seq_k_blk_idx, + ) + + # Write back dV and dK. + offs_dkdv = ( + batch_idx * stride_dk_b + + head_k_idx * stride_dk_h + + k_start * stride_dk_n + + offs_n[:, None] * stride_dk_n + + offs_k[None, :] * stride_dk_k + ) + tl.store(dv_ptr + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(dk_ptr + offs_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_split_dkdv_causal( + q_ptr, + k_ptr, + v_ptr, + sm_scale, + do_ptr, + dk_ptr, + dv_ptr, + m_ptr, + delta_ptr, + stride_q_b, + stride_q_h, + stride_q_m, + stride_q_k, + stride_k_b, + stride_k_h, + stride_k_n, + stride_k_k, + stride_v_b, + stride_v_h, + stride_v_n, + stride_v_k, + stride_dk_b, + stride_dk_h, + stride_dk_n, + stride_dk_k, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_do_b, + stride_do_h, + stride_do_m, + stride_do_k, + stride_dropout_b, + stride_dropout_h, + stride_dropout_m, + stride_dropout_n, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + # seq block, batch, head_k + seq_k_blk_idx = tl.program_id(0) + batch_idx = tl.program_id(1) + head_k_idx = tl.program_id(2) + + # Determine q and k start along with seqlen_q and seqlen_k + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + batch_idx) + q_end = tl.load(cu_seqlens_q + batch_idx + 1) + k_start = tl.load(cu_seqlens_k + batch_idx) + k_end = tl.load(cu_seqlens_k + batch_idx + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + # Figure out causal starting block since we have seqlen_q >=< seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + delta_qk = seqlen_q - seqlen_k + + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N + delta_aligned = (num_blocks_skip + 1) * BLOCK_N + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M * BLOCK_M + if delta_qk >= 0: + start_delta = delta_qk + else: + start_delta = start_delta_q_lt_k + + start_n = seq_k_blk_idx * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + # Mask for loading K and V + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_kv &= mask_k[None, :] + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = ( + batch_idx * stride_k_b + + head_k_idx * stride_k_h + + k_start * stride_k_n + + offs_n[:, None] * stride_k_n + + offs_k[None, :] * stride_k_k + ) + adj_v = ( + batch_idx * stride_v_b + + head_k_idx * stride_v_h + + k_start * stride_v_n + + offs_n[:, None] * stride_v_n + + offs_k[None, :] * stride_v_k + ) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(k_ptr + adj_k, mask=mask_kv, other=0.0) + v = tl.load(v_ptr + adj_v, mask=mask_kv, other=0.0) + + # If MQA / GQA, set the K and V head offsets appropriately. + for head_q_idx in range( + head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE + ): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N + else: + start_m = max(start_n + delta_qk, 0) + start_m = start_m // BLOCK_M * BLOCK_M + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N + residue_m + + # offset input and output tensor by batch and Q/K heads + adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m + q_ptr_adj = q_ptr + adj_q + adj_do = ( + batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m + ) + do_ptr_adj = do_ptr + adj_do + adj_delta = ( + batch_idx * stride_delta_b + + head_q_idx * stride_delta_h + + q_start * stride_delta_m + ) + m_ptr_adj = m_ptr + adj_delta + delta_ptr_adj = delta_ptr + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + dropout_offset = ( + dropout_mask + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + + MASK_BLOCK_M: tl.constexpr = BLOCK_M // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M) + # when q < k, we may skip the initial masked op + if seq_k_blk_idx < num_blocks_skip: + num_steps = 0 + + if IS_FP8: + # For MQA/GQA, q_descale uses the same indexing as k/v (head_k_idx) + descale_q = tl.load( + descale_q_ptr + batch_idx * stride_descale_q_z + head_k_idx + ) + descale_k = tl.load( + descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx + ) + descale_v = tl.load( + descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx + ) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + dk, dv = _bwd_dkdv_inner_split( + dk, + dv, # output tensors + q_ptr_adj, + k, + v, + do_ptr_adj, + m_ptr_adj, + delta_ptr_adj, + sm_scale, # input tensors + stride_q_m, + stride_q_k, # strides for q + stride_do_m, + stride_do_k, # strides for o + stride_dropout_m, + stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + MASK_BLOCK_M, + BLOCK_N, # block dim + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, # head dim + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + start_m += num_steps * MASK_BLOCK_M + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M) + end_m = start_m + num_steps * BLOCK_M + + dk, dv = _bwd_dkdv_inner_split( + dk, + dv, # output tensors + q_ptr_adj, + k, + v, + do_ptr_adj, + m_ptr_adj, + delta_ptr_adj, + sm_scale, # input tensors + stride_q_m, + stride_q_k, # strides for q + stride_do_m, + stride_do_k, # strides for o + stride_dropout_m, + stride_dropout_n, # strides for dropout + stride_delta_m, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + BLOCK_M, + BLOCK_N, # block dim + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, # head dim + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + + # Write back dV and dK. + offs_dkdv = ( + batch_idx * stride_dk_b + + head_k_idx * stride_dk_h + + k_start * stride_dk_n + + offs_n[:, None] * stride_dk_n + + offs_k[None, :] * stride_dk_k + ) + tl.store(dv_ptr + offs_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(dk_ptr + offs_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_split_dq_causal( + q_ptr, + k_ptr, + v_ptr, + sm_scale, + do_ptr, + dq_ptr, + m_ptr, + delta_ptr, + stride_q_b, + stride_q_h, + stride_q_m, + stride_q_k, + stride_k_b, + stride_k_h, + stride_k_n, + stride_k_k, + stride_v_b, + stride_v_h, + stride_v_n, + stride_v_k, + stride_dq_b, + stride_dq_h, + stride_dq_m, + stride_dq_k, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_do_b, + stride_do_h, + stride_do_m, + stride_do_k, + stride_dropout_b, + stride_dropout_h, + stride_dropout_m, + stride_dropout_n, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + seq_q_blk_idx = tl.program_id(0) + batch_idx = tl.program_id(1) + head_k_idx = tl.program_id(2) + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + batch_idx) + q_end = tl.load(cu_seqlens_q + batch_idx + 1) + k_start = tl.load(cu_seqlens_k + batch_idx) + k_end = tl.load(cu_seqlens_k + batch_idx + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + # Figure out causal starting block since we have seqlen_q <=> seqlen_k. + # Unlike forward pass where we tile on M dim and iterate on N dim, so that + # we can skip some M blocks, in backward pass, we tile on the N dim for kv + # and iterate over the M. In this way, we cannot skip N blocks, but only to + # determine the starting M blocks to skip some initial blocks masked by + # causal. + # DQ tiles on M dim and iterate on N dim, so we there could be some tiles we + # can simply skip and we need to adjust starting position. + start_m = seq_q_blk_idx * BLOCK_M + # seqlen_q > seqlen_k, no need to process these tile for dq + delta_qk = seqlen_q - seqlen_k + if start_m + BLOCK_M < delta_qk: + return + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_m = start_m + tl.arange(0, BLOCK_M) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_q_m + offs_k[None, :] * stride_q_k + offs_do = offs_m[:, None] * stride_do_m + offs_k[None, :] * stride_do_k + adj_k = batch_idx * stride_k_b + head_k_idx * stride_k_h + k_start * stride_k_n + adj_v = batch_idx * stride_v_b + head_k_idx * stride_v_h + k_start * stride_v_n + k_ptr_adj = k_ptr + v_ptr_adj = v_ptr + k_ptr_adj += adj_k + v_ptr_adj += adj_v + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + for head_q_idx in range( + head_k_idx * GROUP_SIZE, head_k_idx * GROUP_SIZE + GROUP_SIZE + ): + # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front + # for every M-tile + end_n = start_m + BLOCK_M - delta_qk + # clamp end_n at [0, seqlen_k] + end_n = max(min(end_n, seqlen_k), 0) + + # offset input and output tensor by batch and Q/K heads + adj_q = batch_idx * stride_q_b + head_q_idx * stride_q_h + q_start * stride_q_m + adj_do = ( + batch_idx * stride_do_b + head_q_idx * stride_do_h + q_start * stride_do_m + ) + adj_delta = ( + batch_idx * stride_delta_b + + head_q_idx * stride_delta_h + + q_start * stride_delta_m + ) + delta_ptr_adj = delta_ptr + adj_delta + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + dropout_offset = ( + dropout_mask + + batch_idx * stride_dropout_b + + head_q_idx * stride_dropout_h + ) + + q = tl.load(q_ptr + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(do_ptr + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(m_ptr + adj_delta + offs_m * stride_delta_m, mask=offs_m < seqlen_q) + m = m[:, None] + + MASK_BLOCK_N: tl.constexpr = BLOCK_N // BLK_SLICE_FACTOR + # start can only be 0 at minimum + start_n = max(end_n - BLOCK_M, 0) + num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N) + + if IS_FP8: + # For MQA/GQA, q_descale uses the same indexing as k/v (head_k_idx) + descale_q = tl.load( + descale_q_ptr + batch_idx * stride_descale_q_z + head_k_idx + ) + descale_k = tl.load( + descale_k_ptr + batch_idx * stride_descale_k_z + head_k_idx + ) + descale_v = tl.load( + descale_v_ptr + batch_idx * stride_descale_v_z + head_k_idx + ) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _bwd_dq_inner, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + dq = _bwd_dq_inner_split( + dq, + q, + k_ptr_adj, + v_ptr_adj, + do, + m, + delta_ptr_adj, + sm_scale, + stride_q_m, + stride_q_k, + stride_k_n, + stride_k_k, + stride_v_n, + stride_v_k, + stride_dropout_m, + stride_dropout_n, + stride_delta_m, + seqlen_q, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + BLOCK_M, + MASK_BLOCK_N, + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, + MASK=True, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + end_n -= num_steps * MASK_BLOCK_N + num_steps = tl.cdiv(end_n, BLOCK_N) + start_n = max(end_n - num_steps * BLOCK_N, 0) + dq = _bwd_dq_inner_split( + dq, + q, + k_ptr_adj, + v_ptr_adj, + do, + m, + delta_ptr_adj, + sm_scale, + stride_q_m, + stride_q_k, + stride_k_n, + stride_k_k, + stride_v_n, + stride_v_k, + stride_dropout_m, + stride_dropout_n, + stride_delta_m, + seqlen_q, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + BLOCK_M, + BLOCK_N, + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + # Write back dQ. + offs_dq = ( + batch_idx * stride_dq_b + + head_q_idx * stride_dq_h + + q_start * stride_dq_m + + offs_m[:, None] * stride_dq_m + + offs_k[None, :] * stride_dq_k + ) + dq *= sm_scale + tl.store(dq_ptr + offs_dq, dq, mask=mask_q) + + +@triton.jit +def _bwd_kernel_fused_atomic_noncausal( + Q, + K, + V, + sm_scale, + DO, + DK, + DV, + DQ, + M, + Delta, + stride_qb, + stride_qh, + stride_qm, + stride_qk, + stride_kb, + stride_kh, + stride_kn, + stride_kk, + stride_vb, + stride_vh, + stride_vn, + stride_vk, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkk, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqk, + stride_deltab, + stride_deltah, + stride_deltam, + stride_dob, + stride_doh, + stride_dom, + stride_dok, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BATCH, + NUM_K_PIDS, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + # workgroup id + wid = tl.program_id(0) # 0, ..., NUM_K_PIDS * BATCH * NUM_K_HEADS - 1 + + # Workgroups get launched first along batch dim, then in head_k dim, and then in seq k block dim + # This is in order to avoid contention for the tl.atomic_add (inside _bwd_dkdvdq_inner) that happens between workgroups that share the same batch and head_k. + bid = wid % BATCH + hkid = wid // BATCH % NUM_K_HEADS + pid = wid // (BATCH * NUM_K_HEADS) % NUM_K_PIDS + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_kv &= offs_k < BLOCK_D_MODEL + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_k[None, :] * stride_kk + ) + adj_v = ( + bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_k[None, :] * stride_vk + ) + + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + + Q_ptr = Q + adj_q + DQ_ptr = DQ + adj_dq + + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + # dropout + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + if IS_FP8: + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hkid) + descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) + descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M) + + dk, dv = _bwd_dkdvdq_inner_atomic( + dk, + dv, + Q_ptr, + k, + v, + DO_ptr, + DQ_ptr, + M_ptr, + Delta_ptr, + sm_scale, + stride_qm, + stride_qk, + stride_dqm, + stride_dqk, + stride_dom, + stride_dok, + stride_dropoutm, + stride_dropoutn, + stride_deltam, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + seqlen_q, + seqlen_k, + start_n, + start_m, + num_steps, + descale_q, + descale_k, + descale_v, + BLOCK_M, + BLOCK_N, + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + workgroup_id=pid, + ) + + adj_dkdv = ( + bid * stride_dkb + + hkid * stride_dkh + + k_start * stride_dkn + + offs_n[:, None] * stride_dkn + + offs_k[None, :] * stride_dkk + ) + tl.store(DV + adj_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_split_dkdv_noncausal( + Q, + K, + V, + sm_scale, + DO, + DK, + DV, + M, + Delta, + stride_qb, + stride_qh, + stride_qm, + stride_qk, + stride_kb, + stride_kh, + stride_kn, + stride_kk, + stride_vb, + stride_vh, + stride_vn, + stride_vk, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkk, + stride_deltab, + stride_deltah, + stride_deltam, + stride_dob, + stride_doh, + stride_dom, + stride_dok, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + pid = tl.program_id(0) + bid = tl.program_id(1) + hkid = tl.program_id(2) + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + dk = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dv = tl.zeros([BLOCK_N, BLOCK_D_MODEL_POW2], dtype=tl.float32) + + start_n = pid * BLOCK_N + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_kv = offs_n[:, None] < seqlen_k + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_kv &= offs_k < BLOCK_D_MODEL + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + adj_k = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_k[None, :] * stride_kk + ) + adj_v = ( + bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_k[None, :] * stride_vk + ) + + k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + v = tl.load(V + adj_v, mask=mask_kv, other=0.0) + + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + M_ptr = M + adj_delta + Delta_ptr = Delta + adj_delta + + # dropout + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + if IS_FP8: + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hkid) + descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) + descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M) + dk, dv = _bwd_dkdv_inner_split( + dk, + dv, + Q_ptr, + k, + v, + DO_ptr, + M_ptr, + Delta_ptr, + sm_scale, + stride_qm, + stride_qk, + stride_dom, + stride_dok, + stride_dropoutm, + stride_dropoutn, + stride_deltam, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + seqlen_q, + seqlen_k, + start_n, + start_m, + num_steps, + descale_q, + descale_k, + descale_v, + BLOCK_M, + BLOCK_N, + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + + adj_dkdv = ( + bid * stride_dkb + + hkid * stride_dkh + + k_start * stride_dkn + + offs_n[:, None] * stride_dkn + + offs_k[None, :] * stride_dkk + ) + tl.store(DV + adj_dkdv, dv, mask=mask_kv) + dk *= sm_scale + tl.store(DK + adj_dkdv, dk, mask=mask_kv) + + +@triton.jit +def _bwd_kernel_split_dq_noncausal( + Q, + K, + V, + sm_scale, + DO, + DQ, + M, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qk, + stride_kb, + stride_kh, + stride_kn, + stride_kk, + stride_vb, + stride_vh, + stride_vn, + stride_vk, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqk, + stride_deltab, + stride_deltah, + stride_deltam, + stride_dob, + stride_doh, + stride_dom, + stride_dok, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + descale_q_ptr, + descale_k_ptr, + descale_v_ptr, + NUM_Q_HEADS: tl.constexpr, + NUM_K_HEADS: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_D_MODEL: tl.constexpr, + BLOCK_D_MODEL_POW2: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, +): + pid = tl.program_id(0) # seqlen + bid = tl.program_id(1) # batch + hkid = tl.program_id(2) # head_k + + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + start_m = pid * BLOCK_M + + offs_k = tl.arange(0, BLOCK_D_MODEL_POW2) + offs_m = start_m + tl.arange(0, BLOCK_M) + + # mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + PADDED_HEAD: tl.constexpr = BLOCK_D_MODEL != BLOCK_D_MODEL_POW2 + if PADDED_HEAD: + mask_k = offs_k < BLOCK_D_MODEL + mask_q &= mask_k[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + offs_do = offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok + adj_k = bid * stride_kb + hkid * stride_kh + k_start * stride_kn + adj_v = bid * stride_vb + hkid * stride_vh + k_start * stride_vn + K += adj_k + V += adj_v + + GROUP_SIZE = NUM_Q_HEADS // NUM_K_HEADS + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = bid * stride_deltab + hqid * stride_deltah + q_start * stride_deltam + delta_ptr = delta + adj_delta + + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) + m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q) + m = m[:, None] + + # FP8 + if IS_FP8: + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(descale_q_ptr + bid * stride_descale_q_z + hkid) + descale_k = tl.load(descale_k_ptr + bid * stride_descale_k_z + hkid) + descale_v = tl.load(descale_v_ptr + bid * stride_descale_v_z + hkid) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + start_n = 0 + end_n = seqlen_k + num_steps = tl.cdiv(seqlen_k, BLOCK_N) + dq = tl.zeros([BLOCK_M, BLOCK_D_MODEL_POW2], dtype=tl.float32) + dq = _bwd_dq_inner_split( + dq, + q, + K, + V, + do, + m, + delta_ptr, + sm_scale, + stride_qm, + stride_qk, + stride_kn, + stride_kk, + stride_vn, + stride_vk, + stride_dropoutm, + stride_dropoutn, + stride_deltam, + seqlen_q, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + BLOCK_M, + BLOCK_N, + BLOCK_D_MODEL, + BLOCK_D_MODEL_POW2, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + ) + + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_k[None, :] * stride_dqk + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + + +# This function computes delta given output Out and gradient DO +# Here is the I/O shape: +# Out: (batch, nhead_q, max_seqlens_q, headDim) +# DO: (batch, nhead_q, max_seqlens_q, headDim) +# Delta: (batch, nheads_q, max_seqlens_q) +@triton.autotune( + configs=preprocess_autotune_configs, + key=preprocess_autotune_keys, + use_cuda_graph=True, +) +@triton.jit +def _bwd_preprocess( + O, + DO, # noqa: E741 + Delta, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_delta_b, + stride_delta_h, + stride_delta_m, + cu_seqlens_q, + max_seqlen_q, + PRE_BLOCK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_FP8: tl.constexpr, +): + pid_m = tl.program_id(0) + bid = tl.program_id(1) + hid = tl.program_id(2) + # Handle varlen + if IS_VARLEN: + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + seqlen_q = q_end - q_start + else: + q_start = 0 + seqlen_q = max_seqlen_q + + # Compute offsets + offs_m = pid_m * PRE_BLOCK + tl.arange(0, PRE_BLOCK) + offs_d = tl.arange(0, HEAD_DIM_V) + # pointer offsets for O & DO + off_o = ( + bid * stride_ob + + hid * stride_oh + + q_start * stride_om + + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_od + ) # noqa: E741 + off_do = ( + bid * stride_dob + + hid * stride_doh + + q_start * stride_dom + + offs_m[:, None] * stride_dom + + offs_d[None, :] * stride_dod + ) + + # create masks + mask_m = offs_m < seqlen_q + mask_md = mask_m[:, None] + PADDED_HEAD_V: tl.constexpr = ACTUAL_HEAD_DIM_V != HEAD_DIM_V + if PADDED_HEAD_V: + mask_md &= offs_d[None, :] < ACTUAL_HEAD_DIM_V + # load + o = tl.load(O + off_o, mask=mask_md, other=0.0) + do = tl.load(DO + off_do, mask=mask_md, other=0.0) + # compute and write-back to delta + # NOTE: Both o and do are FP32 + delta = tl.sum(o.to(tl.float32) * do.to(tl.float32), axis=1) + off_delta = ( + bid * stride_delta_b + + hid * stride_delta_h + + q_start * stride_delta_m + + offs_m * stride_delta_m + ) + tl.store(Delta + off_delta, delta, mask=mask_m) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _bwd_dkdv_inner( + dk, + dv, # output + Q, + k, + v, + DO, + M, + D, + sm_scale, # input tensor + stride_qm, + stride_qk, + stride_dom, + stride_dok, + stride_dropoutm, + stride_dropoutn, + stride_lse_m, + stride_delta_m, + BLOCK_M: tl.constexpr, # 16 + BLOCK_N: tl.constexpr, # 128 + HEAD_DIM_QK: tl.constexpr, # + HEAD_DIM_V: tl.constexpr, # + ACTUAL_HEAD_DIM_QK: tl.constexpr, # + ACTUAL_HEAD_DIM_V: tl.constexpr, # + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + seqlen_q, + seqlen_k, # max sequence length for q and k + # Filled in by the wrapper. + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + MASK: tl.constexpr, # causal masking, only apply to tiles on mask diagonal + ENABLE_DROPOUT: tl.constexpr, # activate dropout + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, # activate exp2 + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_AUTO_DESCALE: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD_QK: tl.constexpr = ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK + PADDED_HEAD_V: tl.constexpr = ACTUAL_HEAD_DIM_V != HEAD_DIM_V + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) + offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) + offs_k_qk = tl.arange(0, HEAD_DIM_QK) + offs_k_v = tl.arange(0, HEAD_DIM_V) + # mask to make sure not OOB of seqlen_q + mask_n = offs_n < seqlen_k + # Q and DO are (seqlen_q, head_dim) + # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM_QK, 1), transpose of q + qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k_qk[:, None] * stride_qk + # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM_V), NOT transposed + do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k_v[None, :] * stride_dok + # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N % BLOCK_M == 0) + curr_m = start_m + step_m = BLOCK_M + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + + for blk_idx in range(num_steps): + if DEBUG_TRITON: + print(f"iter {blk_idx}: curr_m = {curr_m}") # noqa: E701 + offs_m = curr_m + tl.arange(0, BLOCK_M) + # update the mask because offs_m advanced + mask_m = offs_m < seqlen_q + mask_qT = mask_m[None, :] + mask_do = mask_m[:, None] + mask_nm = mask_n[:, None] & (offs_m[None, :] < seqlen_q) + if PADDED_HEAD_QK: + mask_qT &= offs_k_qk[:, None] < ACTUAL_HEAD_DIM_QK + if PADDED_HEAD_V: + mask_do &= offs_k_v[None, :] < ACTUAL_HEAD_DIM_V + qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + # generate dropout mask + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = ( + curr_philox_offset + + offs_m[None, :] * stride_dropoutm + + offs_n[:, None] * stride_dropoutn + ) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1.0 / (1 - dropout_p) + # Load m before computing qk to reduce pipeline stall. + m = tl.load(M + offs_m * stride_lse_m, mask=mask_m, other=0.0) + + # Compute qk + if IS_FP8: + qkT = tl.dot(k, qT) * descale_q * descale_k + else: + qkT = tl.dot(k, qT) + qkT_scaled = qkT * sm_scale + + if USE_ALIBI: + relative_pos_block = offs_n[:, None] + seqlen_q - seqlen_k - offs_m[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + qkT_scaled += alibi_block + + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"qT: {qT.shape}\n", qT) + print(f"k: {k.shape}\n", k) + print(f"qkT scaled: {qkT.shape}\n", qkT_scaled) + + # Compute probabilities - handle invalid rows where m is -inf + # For rows where m is -inf, no keys were valid, so pT should be 0 + # We shift qkT by m to avoid numerical issues + qkT_shifted = tl.where( + m[None, :] == float("-inf"), float("-inf"), qkT_scaled - m[None, :] + ) + + if USE_EXP2: + pT = tl.math.exp2(qkT_shifted * RCP_LN2) + else: + pT = tl.math.exp(qkT_shifted) + + # Autoregressive masking. + if MASK: + # offset offs_m with delta_qk since the causal mask starts at + # bottom right of the (seqlen_q, seqlen_k) matrix + causal_mask = (offs_m[None, :] - delta_qk) >= offs_n[:, None] + mask = causal_mask & mask_nm + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"causal_mask: {causal_mask.shape}\n", causal_mask) + print( + f"qkT after causal: {qkT.shape}\n", + tl.where(causal_mask, qkT * sm_scale, 0.0), + ) + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs, mask=mask_do, other=0.0) + # Compute dV. + # Note: pT and do are both high precision, so no need for auto-descaling here + if ENABLE_DROPOUT: + pT_dropout = tl.where(dropout_mask, pT, 0.0) * dropout_scale + dv += tl.dot(pT_dropout.to(do.type.element_ty), do) + else: + dv += tl.dot(pT.to(do.type.element_ty), do) + + if DEBUG_TRITON_DETAIL: + if start_n == 256: + print(f"pT: {pT.shape}\n", pT) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m * stride_delta_m, mask=mask_m) + + # Compute dP and dS. + # Note: v is fp8, do is fp32, so we need to scale do before casting to fp8 + if IS_FP8: + if FP8_AUTO_DESCALE: + do_scale, do_descale = compute_fp8_scaling_factors(do, FP8_MAX) + dpT = ( + tl.dot(v, tl.trans((do * do_scale).to(v.type.element_ty))) + * descale_v + * do_descale + ) + else: + dpT = tl.dot(v, tl.trans(do.to(v.type.element_ty))) * descale_v + else: + dpT = tl.dot(v, tl.trans(do)) + + if ENABLE_DROPOUT: + dpT = tl.where(dropout_mask, dpT, 0.0) * dropout_scale + delta_i = Di[None, :] + dsT = pT * (dpT - delta_i) + + # Compute dK + if IS_FP8: + if FP8_AUTO_DESCALE: + # Apply dynamic scaling to dsT before casting to FP8 + dsT_scale, dsT_descale = compute_fp8_scaling_factors(dsT, FP8_MAX) + dk += ( + tl.dot((dsT * dsT_scale).to(qT.type.element_ty), tl.trans(qT)) + * descale_q + * dsT_descale + ) + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) * descale_q + else: + dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_qm + do_ptrs += step_m * stride_dom + return dk, dv + + +# the main inner-loop logic for computing dQ +@triton.jit +def _bwd_dq_inner( + dq, # output + q, + K, + V, + do, + m, + Delta, + sm_scale, # input + # shared by Q/K/V. + stride_qm, + stride_qk, + stride_kn, + stride_kk, + stride_vn, + stride_vk, + stride_dropoutm, + stride_dropoutn, # stride for dropout + stride_lse_m, + stride_delta_m, + seqlen_q, + seqlen_k, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, # + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + # Filled in by the wrapper. + start_m, + start_n, + end_n, + num_steps, # + descale_q, + descale_k, + descale_v, + MASK: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_AUTO_DESCALE: tl.constexpr, + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # if HEAD_DIM is padded + PADDED_HEAD_QK: tl.constexpr = ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK + PADDED_HEAD_V: tl.constexpr = ACTUAL_HEAD_DIM_V != HEAD_DIM_V + delta_qk = seqlen_q - seqlen_k + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k_qk = tl.arange(0, HEAD_DIM_QK) + offs_k_v = tl.arange(0, HEAD_DIM_V) + + # mask to make sure not OOB of seqlen_q + mask_m = offs_m < seqlen_q + + kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k_qk[:, None] * stride_kk + vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k_v[:, None] * stride_vk + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(Delta + offs_m * stride_delta_m, mask=mask_m, other=0.0) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + curr_philox_offset = batch_philox_offset + curr_dropout_offset = dropout_offset + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + for blk_idx in range(num_steps): + if DEBUG_TRITON: + print(f"iter {blk_idx}: curr_n = {curr_n}") # noqa: E701 + offs_n = curr_n + tl.arange(0, BLOCK_N2) + # end_n is needed because the end of causal True might not be perfectly + # aligned with the end of the block + mask_n = offs_n < end_n + if DEBUG_TRITON_DETAIL: + print( + f"start_n = {start_n}, end_n = {end_n}, offs_n: {offs_n.shape}\n{offs_n}" + ) # noqa: E701 + if DEBUG_TRITON_DETAIL: + print(f"mask_n: {mask_n.shape}\n{mask_n}") # noqa: E701 + mask_kT = mask_n[None, :] + mask_vT = mask_n[None, :] + mask_mn = mask_m[:, None] & (offs_n[None, :] < end_n) + if PADDED_HEAD_QK: + mask_kT &= offs_k_qk[:, None] < ACTUAL_HEAD_DIM_QK + if PADDED_HEAD_V: + mask_vT &= offs_k_v[:, None] < ACTUAL_HEAD_DIM_V + + kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) + vT = tl.load(vT_ptrs, mask=mask_vT, other=0.0) + + if ENABLE_DROPOUT: + # NOTE: dropout is transposed because it is used to mask pT + philox_offs = ( + curr_philox_offset + + offs_m[:, None] * stride_dropoutm + + offs_n[None, :] * stride_dropoutn + ) + rand_vals = tl.rand(philox_seed, philox_offs) + dropout_mask = rand_vals > dropout_p + dropout_scale = 1 / (1 - dropout_p) + + if IS_FP8: + qk = tl.dot(q, kT) * descale_q * descale_k + else: + qk = tl.dot(q, kT) + qk_scaled = qk * sm_scale + + if USE_ALIBI: + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + qk_scaled += alibi_block + + if DEBUG_TRITON_DETAIL: + print(f"qk scaled: {qk.shape}\n", qk_scaled) # noqa: E701 + + # Compute probabilities - handle invalid rows where m is -inf + # For rows where m is -inf, no keys were valid, so p should be 0 + # We shift qk by m to avoid numerical issues + qk_shifted = tl.where(m == float("-inf"), float("-inf"), qk_scaled - m) + + if USE_EXP2: + p = tl.math.exp2(qk_shifted * RCP_LN2) + else: + p = tl.math.exp(qk_shifted) + + # Autoregressive masking. + if MASK: + causal_mask = (offs_m[:, None] - delta_qk) >= offs_n[None, :] + mask = causal_mask & mask_mn + p = tl.where(mask, p, 0.0) + + # Compute dP and dS. + # Note: do is fp32, vT is fp8, so we need to scale do before casting to fp8 + if IS_FP8: + if FP8_AUTO_DESCALE: + do_scale, do_descale = compute_fp8_scaling_factors(do, FP8_MAX) + dp = ( + tl.dot((do * do_scale).to(vT.type.element_ty), vT) + * descale_v + * do_descale + ) + else: + dp = tl.dot(do.to(vT.type.element_ty), vT) * descale_v + else: + dp = tl.dot(do, vT) + + if ENABLE_DROPOUT: + dp = tl.where(dropout_mask, dp, 0.0) * dropout_scale + delta_i = Di[:, None] + ds = p * (dp - delta_i) + + # Compute dQ + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + if IS_FP8: + if FP8_AUTO_DESCALE: + # Apply dynamic scaling to ds before casting to FP8 + ds_scale, ds_descale = compute_fp8_scaling_factors(ds, FP8_MAX) + dq += ( + tl.dot((ds * ds_scale).to(kT.type.element_ty), tl.trans(kT)) + * descale_k + * ds_descale + ) + else: + dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) * descale_k + else: + dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_kn + vT_ptrs += step_n * stride_vn + return dq + + +@triton.autotune( + configs=causal_autotune_configs, + key=causal_autotune_keys, + use_cuda_graph=True, +) +@triton.jit +def bwd_kernel_fused_causal( # grid = (nheads_k, tl.cdiv(max_seqlen_q // BLOCK_M2), batch) + Q, + K, + V, + sm_scale, + DO, + DQ, + DK, + DV, + M, + Delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dvd, + stride_lse_b, + stride_lse_h, + stride_lse_m, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_az, + stride_ah, + HQ, + HK, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Add seqused parameters + max_seqlen_q, + max_seqlen_k, + Dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + Alibi_slopes, + Descale_q, + Descale_k, + Descale_v, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_AUTO_DESCALE: tl.constexpr, + USE_SEQUSED: tl.constexpr, # Add flag for seqused + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + hkid = tl.program_id(0) + pid = tl.program_id(1) + bid = tl.program_id(2) + if DEBUG_TRITON: + print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + + # If seqused is provided, use it to limit the actual sequence length + if USE_SEQUSED: + actual_seqlen_q = ( + tl.load(seqused_q + bid) if seqused_q is not None else q_end - q_start + ) + seqlen_q = tl.minimum(actual_seqlen_q, q_end - q_start) + actual_seqlen_k = ( + tl.load(seqused_k + bid) if seqused_k is not None else k_end - k_start + ) + seqlen_k = tl.minimum(actual_seqlen_k, k_end - k_start) + else: + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + delta_qk = seqlen_q - seqlen_k + if DEBUG_TRITON: + print(f"delta_qk = {delta_qk}") # noqa: E701 + PADDED_HEAD_QK: tl.constexpr = ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK + PADDED_HEAD_V: tl.constexpr = ACTUAL_HEAD_DIM_V != HEAD_DIM_V + offs_d_qk = tl.arange(0, HEAD_DIM_QK) + offs_d_v = tl.arange(0, HEAD_DIM_V) + GROUP_SIZE: tl.constexpr = HQ // HK + + # align the delta_qk + start_n = pid * BLOCK_N1 + if start_n < seqlen_k: + # This section does dk and dv + dk = tl.zeros([BLOCK_N1, HEAD_DIM_QK], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM_V], dtype=tl.float32) + + # q > k: diretcly skip all the way until the start of causal block + start_delta_q_gt_k = delta_qk + # q < k: some blocks will have no Masked block, other needs to re-calc + # starting position + # delta_qk is negative so flip it, only multiple of BLOCK_N can skip the + # masked op + num_blocks_skip = -delta_qk // BLOCK_N1 + delta_aligned = (num_blocks_skip + 1) * BLOCK_N1 + delta_qk + start_delta_q_lt_k = delta_aligned // BLOCK_M1 * BLOCK_M1 + if delta_qk >= 0: + start_delta = delta_qk + if DEBUG_TRITON: + print( + f"q >= k: start_delta = delta_qk aligned to BLOCK_M = {start_delta_q_gt_k}" + ) # noqa: E701 + else: + start_delta = start_delta_q_lt_k + if DEBUG_TRITON: + print( + f"q < k: start_delta = residue btw multiple BLOCK_N and delta_qk = {delta_aligned} = aligned to BLOCK_M = {start_delta_q_lt_k}" + ) # noqa: E701 + + offs_n = start_n + tl.arange(0, BLOCK_N1) + # Mask for loading K and V + mask_k = offs_n[:, None] < seqlen_k + mask_v = offs_n[:, None] < seqlen_k + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_k &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_v &= mask_d_v[None, :] + + # K/V tensors not changed for the group + adj_k = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_d_qk[None, :] * stride_kd + ) + adj_v = ( + bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_d_v[None, :] * stride_vd + ) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_k, mask=mask_k) + v = tl.load(V + adj_v, mask=mask_v) + # If MQA / GQA, set the K and V head offsets appropriately. + # hqid = hkid + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + if delta_qk >= 0: + start_m = start_n + start_delta + len_m = BLOCK_N1 + else: + start_m = max(start_n + delta_qk, 0) + start_m = start_m // BLOCK_M1 * BLOCK_M1 + # because we might shift the masked blocks up, we are deeper into + # the masked out region, so we would potentially increase the total + # steps with masked operation to get out of it + residue_m = max(start_n + delta_qk - start_m, 0) + len_m = BLOCK_N1 + residue_m + if DEBUG_TRITON: + print(f"residue_m = {residue_m}") # noqa: E701 + + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = ( + bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m + ) + Delta_ptr = Delta + adj_delta + adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m + M_ptr = M + adj_m + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + if IS_FP8: + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + # bound the masked operation to q len so it does not have to wast cycles + len_m = min(len_m, seqlen_q) + num_steps = tl.cdiv(len_m, MASK_BLOCK_M1) + # when q < k, we may skip the initial masked op + if pid < num_blocks_skip: + num_steps = 0 + + # if start_m is negative, the current N-tile has no block on the + # diagonal of causal mask, so everything have no causal mask + if DEBUG_TRITON: + print( + f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}" + ) # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, + dv, # output tensors + Q_ptr, + k, + v, + DO_ptr, + M_ptr, + Delta_ptr, + sm_scale, # input tensors + stride_qm, + stride_qd, # strides for q + stride_dom, + stride_dod, # strides for o + stride_dropoutm, + stride_dropoutn, # strides for dropout + stride_lse_m, + stride_delta_m, + MASK_BLOCK_M1, + BLOCK_N1, # block dim + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, # head dim + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + MASK=True, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_AUTO_DESCALE=FP8_AUTO_DESCALE, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + start_m += num_steps * MASK_BLOCK_M1 + num_steps = tl.cdiv(seqlen_q - start_m, BLOCK_M1) + end_m = start_m + num_steps * BLOCK_M1 + + if DEBUG_TRITON: + print( + f"start_m after Masked step: {start_m}; num_steps: {num_steps}" + ) # noqa: E701 + if DEBUG_TRITON: + print( + f"unMasked: start_n: {start_n}, start_m: {start_m}, end_m: {end_m}, num_steps: {num_steps}" + ) # noqa: E701 + if DEBUG_TRITON: + print("unMasked") # noqa: E701 + dk, dv = _bwd_dkdv_inner( + dk, + dv, # output tensors + Q_ptr, + k, + v, + DO_ptr, + M_ptr, + Delta_ptr, + sm_scale, # input tensors + stride_qm, + stride_qd, # strides for q + stride_dom, + stride_dod, # strides for o + stride_dropoutm, + stride_dropoutn, # strides for dropout + stride_lse_m, + stride_delta_m, + BLOCK_M1, + BLOCK_N1, # block dim + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, # head dim + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_AUTO_DESCALE=FP8_AUTO_DESCALE, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # end of GQA/MQA of dkdv + # Write back dV + adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn + offs_dv = offs_n[:, None] * stride_dvn + offs_d_v[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_v) + # write back dk + adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn + offs_dk = offs_n[:, None] * stride_dkn + offs_d_qk[None, :] * stride_dkd + dk *= sm_scale + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_k) + + # This part does dq + start_m = pid * BLOCK_M2 + if start_m < seqlen_q: + # seqlen_q > seqlen_k, no need to process these tile for dq + if DEBUG_TRITON: + print( + f"end_n = start_m + BLOCK_M = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2}" + ) # noqa: E701 + if start_m + BLOCK_M2 < delta_qk: + if DEBUG_TRITON: + print( + f"start_m + BLOCK_M2 = {start_m} + {BLOCK_M2} = {start_m + BLOCK_M2} < delta_qk of {delta_qk}" + ) # noqa: E701 + return + + offs_m = start_m + tl.arange(0, BLOCK_M2) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + mask_do = offs_m[:, None] < seqlen_q + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_q &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_do &= mask_d_v[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d_v[None, :] * stride_dod + # NOTE: don't assume that the strides for k and v are the same! + K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn + V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn + + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # seqlen_q < seqlen_k: delta_qk more kv tokens are added at the front + # for every M-tile + end_n = start_m + BLOCK_M2 - delta_qk + # clamp end_n at [0, seqlen_k] + end_n = max(min(end_n, seqlen_k), 0) + if DEBUG_TRITON: + print(f"delta_qk: {delta_qk}; end_n: {end_n}") # noqa: E701 + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = ( + bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m + ) + Delta_ptr = Delta + adj_delta + adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m + M_ptr = M + adj_m + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_do, other=0.0) + m = tl.load(M + adj_m + offs_m * stride_lse_m, mask=offs_m < seqlen_q) + m = m[:, None] + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + # start can only be 0 at minimum + start_n = max(end_n - BLOCK_M2, 0) + num_steps = tl.cdiv(end_n - start_n, MASK_BLOCK_N2) + + if IS_FP8: + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + dq = tl.zeros([BLOCK_M2, HEAD_DIM_QK], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, + q, + K, + V, + do, + m, + Delta_ptr, + sm_scale, + stride_qm, + stride_qd, + stride_kn, + stride_kd, + stride_vn, + stride_vd, + stride_dropoutm, + stride_dropoutn, + stride_lse_m, + stride_delta_m, + seqlen_q, + seqlen_k, + BLOCK_M2, + MASK_BLOCK_N2, + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + MASK=True, # + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_AUTO_DESCALE=FP8_AUTO_DESCALE, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + end_n -= num_steps * MASK_BLOCK_N2 + num_steps = tl.cdiv(end_n, BLOCK_N2) + start_n = max(end_n - num_steps * BLOCK_N2, 0) + if DEBUG_TRITON: + print( + f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}" + ) # noqa: E701 + dq = _bwd_dq_inner( + dq, + q, + K, + V, + do, + m, + Delta_ptr, + sm_scale, + stride_qm, + stride_qd, + stride_kn, + stride_kd, + stride_vn, + stride_vd, + stride_dropoutm, + stride_dropoutn, + stride_lse_m, + stride_delta_m, + seqlen_q, + seqlen_k, + BLOCK_M2, + BLOCK_N2, + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_AUTO_DESCALE=FP8_AUTO_DESCALE, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_d_qk[None, :] * stride_dqd + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + # end of GQA/MQA of dq + + +@triton.autotune( + configs=noncausal_autotune_configs, + key=noncausal_autotune_keys, + use_cuda_graph=True, +) +@triton.jit +def bwd_kernel_fused_noncausal( + Q, + K, + V, + sm_scale, + DO, + DQ, + DK, + DV, + M, + Delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dvd, + stride_lse_b, + stride_lse_h, + stride_lse_m, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_az, + stride_ah, + HQ, + HK, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Add seqused parameters + max_seqlen_q, + max_seqlen_k, + Dropout_mask, + dropout_p, + philox_seed, + philox_offset_base, + Alibi_slopes, + Descale_q, + Descale_k, + Descale_v, + BLOCK_M1: tl.constexpr, # 32 + BLOCK_N1: tl.constexpr, # 128 + BLOCK_M2: tl.constexpr, # 128 + BLOCK_N2: tl.constexpr, # 32 + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM_QK: tl.constexpr, + HEAD_DIM_V: tl.constexpr, + ACTUAL_HEAD_DIM_QK: tl.constexpr, + ACTUAL_HEAD_DIM_V: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_AUTO_DESCALE: tl.constexpr, + USE_SEQUSED: tl.constexpr, # Add flag for seqused + DEBUG_TRITON: tl.constexpr, + DEBUG_TRITON_DETAIL: tl.constexpr, +): + # program ids + hkid = tl.program_id(0) + pid = tl.program_id(1) + bid = tl.program_id(2) + if DEBUG_TRITON: + print(f"\npid: {pid}, bid: {bid}, hkid: {hkid}") # noqa: E701 + # figure out varlen start and end + q_start = 0 + k_start = 0 + seqlen_q = max_seqlen_q + seqlen_k = max_seqlen_k + if IS_VARLEN: + # Compute actual sequence lengths + q_start = tl.load(cu_seqlens_q + bid) + q_end = tl.load(cu_seqlens_q + bid + 1) + k_start = tl.load(cu_seqlens_k + bid) + k_end = tl.load(cu_seqlens_k + bid + 1) + + # If seqused is provided, use it to limit the actual sequence length + if USE_SEQUSED: + actual_seqlen_q = ( + tl.load(seqused_q + bid) if seqused_q is not None else q_end - q_start + ) + seqlen_q = tl.minimum(actual_seqlen_q, q_end - q_start) + actual_seqlen_k = ( + tl.load(seqused_k + bid) if seqused_k is not None else k_end - k_start + ) + seqlen_k = tl.minimum(actual_seqlen_k, k_end - k_start) + else: + seqlen_q = q_end - q_start + seqlen_k = k_end - k_start + + PADDED_HEAD_QK: tl.constexpr = ACTUAL_HEAD_DIM_QK != HEAD_DIM_QK + PADDED_HEAD_V: tl.constexpr = ACTUAL_HEAD_DIM_V != HEAD_DIM_V + offs_d_qk = tl.arange(0, HEAD_DIM_QK) + offs_d_v = tl.arange(0, HEAD_DIM_V) + GROUP_SIZE: tl.constexpr = HQ // HK + + start_n = pid * BLOCK_N1 + if start_n < seqlen_k: + dk = tl.zeros([BLOCK_N1, HEAD_DIM_QK], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM_V], dtype=tl.float32) + + offs_n = start_n + tl.arange(0, BLOCK_N1) + # Mask for loading K and V + mask_k = offs_n[:, None] < seqlen_k + mask_v = offs_n[:, None] < seqlen_k + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_k &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_v &= mask_d_v[None, :] + # NOTE: don't assume that the strides for k and v are the same! + # K/V tensors not changed for the group + adj_k = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_d_qk[None, :] * stride_kd + ) + adj_v = ( + bid * stride_vb + + hkid * stride_vh + + k_start * stride_vn + + offs_n[:, None] * stride_vn + + offs_d_v[None, :] * stride_vd + ) + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + adj_k, mask=mask_k) + v = tl.load(V + adj_v, mask=mask_v) + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + Q_ptr = Q + adj_q + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + DO_ptr = DO + adj_do + adj_delta = ( + bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m + ) + Delta_ptr = Delta + adj_delta + adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m + M_ptr = M + adj_m + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + if IS_FP8: + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + # because there is no causal, we always start from the beginning + start_m = 0 + num_steps = tl.cdiv(seqlen_q, BLOCK_M1) + dk, dv = _bwd_dkdv_inner( + dk, + dv, # output tensors + Q_ptr, + k, + v, + DO_ptr, + M_ptr, + Delta_ptr, + sm_scale, # input tensors + stride_qm, + stride_qd, # strides for q + stride_dom, + stride_dod, # strides for o + stride_dropoutm, + stride_dropoutn, # strides for dropout + stride_lse_m, + stride_delta_m, + BLOCK_M1, + BLOCK_N1, # block dim + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, # head dim + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, # + alibi_slope, + seqlen_q, + seqlen_k, # max sequence length for q and k + start_n, + start_m, + num_steps, # iteration numbers + descale_q, + descale_k, + descale_v, + MASK=False, # causal masking + ENABLE_DROPOUT=ENABLE_DROPOUT, # activate dropout + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_AUTO_DESCALE=FP8_AUTO_DESCALE, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + + # Write back dV + adj_dv = bid * stride_dvb + hkid * stride_dvh + k_start * stride_dvn + offs_dv = offs_n[:, None] * stride_dvn + offs_d_v[None, :] * stride_dvd + tl.store(DV + adj_dv + offs_dv, dv, mask=mask_v) + # write back dk + adj_dk = bid * stride_dkb + hkid * stride_dkh + k_start * stride_dkn + offs_dk = offs_n[:, None] * stride_dkn + offs_d_qk[None, :] * stride_dkd + dk *= sm_scale + tl.store(DK + adj_dk + offs_dk, dk, mask=mask_k) + + # THIS PART DOES DQ + start_m = pid * BLOCK_M2 + if start_m < seqlen_q: + offs_m = start_m + tl.arange(0, BLOCK_M2) + # Mask for loading K and V + mask_q = offs_m[:, None] < seqlen_q + mask_do = offs_m[:, None] < seqlen_q + if PADDED_HEAD_QK: + mask_d_qk = offs_d_qk < ACTUAL_HEAD_DIM_QK + mask_q &= mask_d_qk[None, :] + if PADDED_HEAD_V: + mask_d_v = offs_d_v < ACTUAL_HEAD_DIM_V + mask_do &= mask_d_v[None, :] + offs_q = offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qd + offs_do = offs_m[:, None] * stride_dom + offs_d_v[None, :] * stride_dod + K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn + V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn + # If MQA / GQA, set the K and V head offsets appropriately. + for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): + # offset input and output tensor by batch and Q/K heads + adj_q = bid * stride_qb + hqid * stride_qh + q_start * stride_qm + adj_do = bid * stride_dob + hqid * stride_doh + q_start * stride_dom + adj_delta = ( + bid * stride_delta_b + hqid * stride_delta_h + q_start * stride_delta_m + ) + Delta_ptr = Delta + adj_delta + adj_m = bid * stride_lse_b + hqid * stride_lse_h + q_start * stride_lse_m + M_ptr = M + adj_m + + if USE_ALIBI: + alibi_offset = bid * stride_az + hqid * stride_ah + alibi_slope = tl.load(Alibi_slopes + alibi_offset) + else: + alibi_slope = None + + # batch_philox_offset is the ACTUALLY dropout offset + # dropout_offset is for debug purpose and will be removed later + batch_philox_offset = 0 + dropout_offset = 0 + if ENABLE_DROPOUT: + batch_philox_offset = ( + philox_offset_base + bid * stride_dropoutb + hqid * stride_dropouth + ) + dropout_offset = ( + Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth + ) + + q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + do = tl.load(DO + adj_do + offs_do, mask=mask_do, other=0.0) + m = tl.load(M + adj_m + offs_m * stride_lse_m, mask=offs_m < seqlen_q) + m = m[:, None] + + if IS_FP8: + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (hkid) + # For MHA (GROUP_SIZE == 1), hqid == hkid, so it doesn't matter + descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hkid) + descale_k = tl.load(Descale_k + bid * stride_descale_k_z + hkid) + descale_v = tl.load(Descale_v + bid * stride_descale_v_z + hkid) + else: + descale_q, descale_k, descale_v = 1.0, 1.0, 1.0 + + # start can only be 0 at minimum + start_n = 0 + end_n = seqlen_k + num_steps = tl.cdiv(seqlen_k, BLOCK_N2) + + dq = tl.zeros([BLOCK_M2, HEAD_DIM_QK], dtype=tl.float32) + dq = _bwd_dq_inner( + dq, + q, + K, + V, + do, + m, + Delta_ptr, + sm_scale, + stride_qm, + stride_qd, + stride_kn, + stride_kd, + stride_vn, + stride_vd, + stride_dropoutm, + stride_dropoutn, + stride_lse_m, + stride_delta_m, + seqlen_q, + seqlen_k, + BLOCK_M2, + BLOCK_N2, + HEAD_DIM_QK, + HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V, + dropout_p, + philox_seed, + batch_philox_offset, + dropout_offset, + alibi_slope, + start_m, + start_n, + end_n, + num_steps, + descale_q, + descale_k, + descale_v, + MASK=False, + ENABLE_DROPOUT=ENABLE_DROPOUT, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_AUTO_DESCALE=FP8_AUTO_DESCALE, + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + # Write back dQ. + adj_dq = bid * stride_dqb + hqid * stride_dqh + q_start * stride_dqm + offs_dq = offs_m[:, None] * stride_dqm + offs_d_qk[None, :] * stride_dqd + dq *= sm_scale + tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + + +def is_contiguous(x, name): + if x.is_contiguous(): + return x + else: + print(f"{name} is not contiguous") + return x.contiguous() + + +DEBUG_TRITON: bool = False +DEBUG_TRITON_DETAIL: bool = False + + +def attention_backward_triton_impl( + *, + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + delta: torch.Tensor, + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + philox_seed: Optional[int] = None, + philox_offset: Optional[int] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + use_exp2: bool = True, + mode: Literal["fused", "fused_atomic", "split"] = "fused", +): + # get params, strides and shape + IS_VARLEN = layout == "thd" + use_dropout = dropout_p > 0.0 + + # common assertions + assert ( + 0.0 <= dropout_p <= 1.0 + ), f"dropout_p must be between 0 and 1, got {dropout_p}" + assert ( + q.device == k.device == v.device == o.device == do.device == softmax_lse.device + ), f"All tensors must be on the same device. Got: q={q.device}, k={k.device}, v={v.device}, o={o.device}, do={do.device}, softmax_lse={softmax_lse.device}" + assert q.dtype == k.dtype == v.dtype, "q, k, v must have the same dtype" + current_device = torch.cuda.current_device() + assert ( + q.is_cuda and q.device.index == current_device + ), f"Device mismatch: Kernel will launch on cuda:{current_device}, but tensors are on {q.device}" + + # get shapes and strides + if IS_VARLEN: + # shape + total_seqlen_q, nheads_q, head_size_q = q.shape + total_seqlen_k, nheads_k, head_size_k = k.shape + total_seqlen_v, nheads_v, head_size_v = v.shape + nheads_lse, total_seqlen_lse = softmax_lse.shape + + # assert shapes + assert ( + total_seqlen_lse == total_seqlen_q + ), f"softmax_lse seqlen {total_seqlen_lse} != q seqlen {total_seqlen_q}" + assert ( + cu_seqlens_q is not None + ), "cu_seqlens_q must be provided for varlen layout" + assert ( + cu_seqlens_k is not None + ), "cu_seqlens_k must be provided for varlen layout" + assert ( + max_seqlen_q is not None + ), "max_seqlen_q must be provided for varlen layout" + assert ( + max_seqlen_k is not None + ), "max_seqlen_k must be provided for varlen layout" + + # assert head dimensions + assert ( + head_size_q == head_size_k + ), f"head sizes must match: q={head_size_q}, k={head_size_k}" + assert ( + nheads_k == nheads_v + ), f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert ( + nheads_q % nheads_k == 0 + ), f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + assert ( + nheads_lse == nheads_q + ), f"softmax_lse heads {nheads_lse} != q heads {nheads_q}" + + # assert output shapes + assert o.shape == ( + total_seqlen_q, + nheads_q, + head_size_v, + ), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_v)}" + assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" + assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"dv shape {dv.shape} != v shape {v.shape}" + + # assert cu_seqlens + assert ( + cu_seqlens_q.dtype == torch.int32 + ), f"cu_seqlens_q must be int32, got {cu_seqlens_q.dtype}" + assert ( + cu_seqlens_k.dtype == torch.int32 + ), f"cu_seqlens_k must be int32, got {cu_seqlens_k.dtype}" + assert cu_seqlens_q[0] == 0, "cu_seqlens_q must start with 0" + assert cu_seqlens_k[0] == 0, "cu_seqlens_k must start with 0" + assert ( + cu_seqlens_q[-1] == total_seqlen_q + ), f"cu_seqlens_q[-1] {cu_seqlens_q[-1]} != total_seqlen_q {total_seqlen_q}" + assert ( + cu_seqlens_k[-1] == total_seqlen_k + ), f"cu_seqlens_k[-1] {cu_seqlens_k[-1]} != total_seqlen_k {total_seqlen_k}" + + # set vars + batch = len(cu_seqlens_q) - 1 + head_size_qk = head_size_q + + # strides + stride_qb, stride_qm, stride_qh, stride_qd = ( + 0, + q.stride(0), + q.stride(1), + q.stride(2), + ) + stride_kb, stride_kn, stride_kh, stride_kd = ( + 0, + k.stride(0), + k.stride(1), + k.stride(2), + ) + stride_vb, stride_vn, stride_vh, stride_vd = ( + 0, + v.stride(0), + v.stride(1), + v.stride(2), + ) + stride_ob, stride_om, stride_oh, stride_od = ( + 0, + o.stride(0), + o.stride(1), + o.stride(2), + ) + stride_dqb, stride_dqm, stride_dqh, stride_dqd = ( + 0, + dq.stride(0), + dq.stride(1), + dq.stride(2), + ) + stride_dkb, stride_dkn, stride_dkh, stride_dkd = ( + 0, + dk.stride(0), + dk.stride(1), + dk.stride(2), + ) + stride_dvb, stride_dvn, stride_dvh, stride_dvd = ( + 0, + dv.stride(0), + dv.stride(1), + dv.stride(2), + ) + stride_dob, stride_dom, stride_doh, stride_dod = ( + 0, + do.stride(0), + do.stride(1), + do.stride(2), + ) + stride_lse_b, stride_lse_h, stride_lse_m = ( + 0, + softmax_lse.stride(0), + softmax_lse.stride(1), + ) + else: + # shapes + batch_q, seqlen_q, nheads_q, head_size_q = q.shape + batch_k, seqlen_k, nheads_k, head_size_k = k.shape + batch_v, seqlen_v, nheads_v, head_size_v = v.shape + batch_lse, nheads_lse, seqlen_lse = softmax_lse.shape + + # assert batch dimensions + assert ( + batch_q == batch_k == batch_v + ), f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" + + # assert head dimensions + assert ( + head_size_q == head_size_k + ), f"head sizes must match: q={head_size_q}, k={head_size_k}" + assert ( + nheads_k == nheads_v + ), f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert ( + nheads_q % nheads_k == 0 + ), f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + + # assert sequence lengths + assert ( + seqlen_k == seqlen_v + ), f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" + + # assert output shapes + assert o.shape == ( + batch_q, + seqlen_q, + nheads_q, + head_size_v, + ), f"o shape {o.shape} != expected" + assert do.shape == o.shape, f"do shape {do.shape} != o shape {o.shape}" + assert dq.shape == q.shape, f"dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"dv shape {dv.shape} != v shape {v.shape}" + + # assert softmax_lse shape + assert softmax_lse.shape == ( + batch_q, + nheads_q, + seqlen_q, + ), f"softmax_lse shape {softmax_lse.shape} != expected" + + # set vars + batch = batch_q + head_size_qk = head_size_q + max_seqlen_q = seqlen_q + max_seqlen_k = seqlen_k + + # strides + stride_qb, stride_qm, stride_qh, stride_qd = q.stride() + stride_kb, stride_kn, stride_kh, stride_kd = k.stride() + stride_vb, stride_vn, stride_vh, stride_vd = v.stride() + stride_ob, stride_om, stride_oh, stride_od = o.stride() + stride_dqb, stride_dqm, stride_dqh, stride_dqd = dq.stride() + stride_dkb, stride_dkn, stride_dkh, stride_dkd = dk.stride() + stride_dvb, stride_dvn, stride_dvh, stride_dvd = dv.stride() + stride_dob, stride_dom, stride_doh, stride_dod = do.stride() + stride_lse_b, stride_lse_h, stride_lse_m = softmax_lse.stride() + + # fp8 + IS_FP8 = is_fp8([q, k, v]) + if IS_FP8: + FP8_MAX = torch.finfo(q.dtype).max + + # For GQA/MQA, q_descale should be shaped (batch, nheads_k) to match forward pass + if q_descale is not None: + assert ( + q_descale.shape[0] == batch and q_descale.shape[1] == nheads_k + ), f"q_descale shape {q_descale.shape} != expected {(batch, nheads_k)}" + if q_descale.dtype != torch.float32: + warnings.warn( + f"q_descale is {q_descale.dtype}, but float32 is recommended for better precision." + ) + assert ( + q_descale.device == q.device + ), f"q_descale must be on same device as q" + else: + q_descale = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) + + if k_descale is not None: + assert ( + k_descale.shape[0] == batch and k_descale.shape[1] == nheads_k + ), f"k_descale shape {k_descale.shape} != expected {(batch, nheads_k)}" + if k_descale.dtype != torch.float32: + warnings.warn( + f"k_descale is {k_descale.dtype}, but float32 is recommended for better precision." + ) + assert ( + k_descale.device == q.device + ), f"k_descale must be on same device as q" + else: + k_descale = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) + + if v_descale is not None: + assert ( + v_descale.shape[0] == batch and v_descale.shape[1] == nheads_k + ), f"v_descale shape {v_descale.shape} != expected {(batch, nheads_k)}" + if v_descale.dtype != torch.float32: + warnings.warn( + f"v_descale is {v_descale.dtype}, but float32 is recommended for better precision." + ) + assert ( + v_descale.device == q.device + ), f"v_descale must be on same device as q" + else: + v_descale = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) + + assert ( + q_descale is not None and k_descale is not None and v_descale is not None + ), "q_descale, k_descale, and v_descale must be provided for fp8 training" + + stride_descale_q_z = q_descale.stride(0) + stride_descale_k_z = k_descale.stride(0) + stride_descale_v_z = v_descale.stride(0) + + if DEBUG: + print(f"FP8 path triggered in bwd.py") + else: + FP8_MAX = None + q_descale = k_descale = v_descale = None + stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = None + + # alibi setup + use_alibi, (stride_az, stride_ah) = ( + (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) + ) + + # get closest power of 2 over or equal to 32. + padded_d_model_qk = 1 << (head_size_qk - 1).bit_length() + padded_d_model_qk = max(padded_d_model_qk, 32) + padded_d_model_v = 1 << (head_size_v - 1).bit_length() + padded_d_model_v = max(padded_d_model_v, 32) + HEAD_DIM_QK = padded_d_model_qk + HEAD_DIM_V = padded_d_model_v + ACTUAL_HEAD_DIM_QK = head_size_qk + ACTUAL_HEAD_DIM_V = head_size_v + + # Validate pre-allocated delta tensor + if IS_VARLEN: + # Shape expected by interface varlen backward: (Hq, Total_Q) + total_q, _, _ = q.shape + assert ( + delta.shape[0] == nheads_q + ), f"delta.shape[0] ({delta.shape[0]}) must equal nheads_q ({nheads_q})" + assert ( + delta.shape[1] >= total_q + ), f"delta.shape[1] ({delta.shape[1]}) must be >= total_q ({total_q})" + assert delta.dtype == torch.float32, f"delta must be float32, got {delta.dtype}" + assert delta.device == q.device, f"delta must be on same device as q" + stride_delta_b, stride_delta_h, stride_delta_m = ( + 0, + delta.stride(0), + delta.stride(1), + ) + else: + # Shape expected by dense backward: (B, Hq, Sq) + seqlen_q = q.shape[1] + assert ( + delta.shape[0] == batch + ), f"delta.shape[0] ({delta.shape[0]}) must equal batch ({batch})" + assert ( + delta.shape[1] == nheads_q + ), f"delta.shape[1] ({delta.shape[1]}) must equal nheads_q ({nheads_q})" + assert ( + delta.shape[2] >= seqlen_q + ), f"delta.shape[2] ({delta.shape[2]}) must be >= seqlen_q ({seqlen_q})" + assert delta.dtype == torch.float32, f"delta must be float32, got {delta.dtype}" + assert delta.device == q.device, f"delta must be on same device as q" + stride_delta_b, stride_delta_h, stride_delta_m = delta.stride() + + pre_grid = lambda META: ( + triton.cdiv(max_seqlen_q, META["PRE_BLOCK"]), + batch, + nheads_q, + ) + _bwd_preprocess[pre_grid]( + o, + do, + delta, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_delta_b, + stride_delta_h, + stride_delta_m, + cu_seqlens_q, + max_seqlen_q, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + ) + + if False: + print("delta:", delta, delta.shape) + + # dropout mask tensor for debugging. We dump the dropout mask created in + # the kernel for testing + dropout_mask = None + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = (0, 0, 0, 0) + if use_dropout: + dropout_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=torch.float32, + ) + + stride_dropoutb, stride_dropouth, stride_dropoutm, stride_dropoutn = ( + dropout_mask.stride() + ) + + # Choose which kernels to call based on mode + if mode == "fused": + seqlen = max(max_seqlen_q, max_seqlen_k) + grid = lambda META: ( + nheads_k, + (seqlen + META["BLOCK_N1"] - 1) // META["BLOCK_N1"], + batch, + ) + if causal: + if DEBUG_TRITON: + print(f"bwd_kernel: grid = {grid}") # noqa: E701 + bwd_kernel_fused_causal[grid]( + q, + k, + v, + sm_scale, + do, + dq, + dk, + dv, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dvd, + stride_lse_b, + stride_lse_h, + stride_lse_m, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_az, + stride_ah, + nheads_q, + nheads_k, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Pass seqused tensors + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, + q_descale, + k_descale, + v_descale, + HEAD_DIM_QK=HEAD_DIM_QK, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_AUTO_DESCALE=FP8_AUTO_DESCALE, + USE_SEQUSED=( + seqused_q is not None or seqused_k is not None + ), # Add flag for seqused + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + else: + bwd_kernel_fused_noncausal[grid]( + q, + k, + v, + sm_scale, + do, + dq, + dk, + dv, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_dvb, + stride_dvh, + stride_dvn, + stride_dvd, + stride_lse_b, + stride_lse_h, + stride_lse_m, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + stride_az, + stride_ah, + nheads_q, + nheads_k, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Pass seqused tensors + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + alibi_slopes, + q_descale, + k_descale, + v_descale, + HEAD_DIM_QK=HEAD_DIM_QK, + HEAD_DIM_V=HEAD_DIM_V, + ACTUAL_HEAD_DIM_QK=ACTUAL_HEAD_DIM_QK, + ACTUAL_HEAD_DIM_V=ACTUAL_HEAD_DIM_V, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + USE_ALIBI=use_alibi, + USE_EXP2=use_exp2, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_AUTO_DESCALE=FP8_AUTO_DESCALE, + USE_SEQUSED=( + seqused_q is not None or seqused_k is not None + ), # Add flag for seqused + DEBUG_TRITON=DEBUG_TRITON, + DEBUG_TRITON_DETAIL=DEBUG_TRITON_DETAIL, + ) + elif mode == "fused_atomic": + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 64, 64, 16 + BLK_SLICE_FACTOR = 2 + BLOCK_D_MODEL_POW2 = max(triton.next_power_of_2(HEAD_DIM_QK), 16) + + grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, nheads_k) + grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, nheads_k) + + # fuses dk, dv, dq computations into one kernel by computing the dq using atomic adds between workgroups + BLOCK_N = ( + 128 if BLOCK_D_MODEL_POW2 < 160 else 64 + ) # larger head sizes lead to oom + config = { + "BLOCK_M": 32, + "BLOCK_N": BLOCK_N, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 1, + "BLK_SLICE_FACTOR": 2, + } + + num_k_pids = (max_seqlen_k + BLOCK_N - 1) // BLOCK_N + grid_dkdvdq = (batch * nheads_k * num_k_pids,) + + if causal: + _bwd_kernel_fused_atomic_causal[grid_dkdvdq]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + dq, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + q_descale, + k_descale, + v_descale, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BATCH=batch, + NUM_K_PIDS=num_k_pids, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + **config, + ) + else: + _bwd_kernel_fused_atomic_noncausal[grid_dkdvdq]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + dq, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + q_descale, + k_descale, + v_descale, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BATCH=batch, + NUM_K_PIDS=num_k_pids, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + **config, + ) + elif mode == "split": + NUM_WARPS, NUM_STAGES = 4, 1 + WAVES_PER_EU = 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 64, 64, 64, 16 + BLK_SLICE_FACTOR = 2 + BLOCK_D_MODEL_POW2 = max(triton.next_power_of_2(HEAD_DIM_QK), 16) + + grid_dkdv = ((max_seqlen_k + BLOCK_N1 - 1) // BLOCK_N1, batch, nheads_k) + grid_dq = ((max_seqlen_q + BLOCK_M2 - 1) // BLOCK_M2, batch, nheads_k) + + if causal: + _bwd_kernel_split_dkdv_causal[grid_dkdv]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + q_descale, + k_descale, + v_descale, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BLOCK_M=BLOCK_M1, + BLOCK_N=BLOCK_N1, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=HEAD_DIM_QK, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + _bwd_kernel_split_dq_causal[grid_dq]( + q, + k, + v, + sm_scale, + do, + dq, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + q_descale, + k_descale, + v_descale, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BLOCK_M=BLOCK_M2, + BLOCK_N=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=HEAD_DIM_QK, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + else: + _bwd_kernel_split_dkdv_noncausal[grid_dkdv]( + q, + k, + v, + sm_scale, + do, + dk, + dv, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dkd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + q_descale, + k_descale, + v_descale, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BLOCK_M=BLOCK_M1, + BLOCK_N=BLOCK_N1, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=HEAD_DIM_QK, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + + _bwd_kernel_split_dq_noncausal[grid_dq]( + q, + k, + v, + sm_scale, + do, + dq, + softmax_lse, + delta, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dqd, + stride_delta_b, + stride_delta_h, + stride_delta_m, + stride_dob, + stride_doh, + stride_dom, + stride_dod, + stride_dropoutb, + stride_dropouth, + stride_dropoutm, + stride_dropoutn, + stride_descale_q_z, + stride_descale_k_z, + stride_descale_v_z, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_mask, + dropout_p, + philox_seed, + philox_offset, + q_descale, + k_descale, + v_descale, + NUM_Q_HEADS=nheads_q, + NUM_K_HEADS=nheads_k, + BLOCK_M=BLOCK_M2, + BLOCK_N=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_D_MODEL=HEAD_DIM_QK, + BLOCK_D_MODEL_POW2=HEAD_DIM_QK, + ENABLE_DROPOUT=use_dropout, + IS_VARLEN=IS_VARLEN, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + waves_per_eu=WAVES_PER_EU, + ) + else: + raise ValueError( + f"Unknown backward mode '{mode}'. Expected 'split', 'fused_atomic' or 'fused'." + ) diff --git a/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/fwd_decode.py b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/fwd_decode.py new file mode 100755 index 0000000000..4645dcc97f --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/fwd_decode.py @@ -0,0 +1,1404 @@ +import os +import warnings +import torch +import triton +import triton.language as tl +from typing import Literal, Optional +from .utils import ( + DEBUG, + AUTOTUNE, + get_arch, + get_padded_headsize, + get_shape_and_strides_from_layout, + apply_rotary, + is_cdna, + is_fp8, + get_recommended_fp8_dtype, +) + + +def get_cdna_autotune_configs(): + return [ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 3, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "waves_per_eu": 2, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + # Fall-back config. + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 16, "waves_per_eu": 1, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + ], [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_BLOCK_DMODEL", + "VARLEN", + "HQ", + "HK", + ] + + +def get_autotune_configs(): + if AUTOTUNE: + if is_cdna(): + autotune_configs, autotune_keys = get_cdna_autotune_configs() + fwd_auto_tune_configs, fwd_autotune_keys = autotune_configs, autotune_keys + reduce_auto_tune_configs, reduce_autotune_keys = ( + autotune_configs, + autotune_keys, + ) + return (fwd_auto_tune_configs, fwd_autotune_keys), ( + reduce_auto_tune_configs, + reduce_autotune_keys, + ) + else: + raise ValueError("Unknown Device Type") + else: + autotune_configs, autotune_keys = [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "waves_per_eu": 1, "PRE_LOAD_V": False}, + num_stages=1, + num_warps=4, + ), + ], [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_BLOCK_DMODEL", + "VARLEN", + "HQ", + "HK", + ] + + fwd_auto_tune_configs, fwd_autotune_keys = autotune_configs, autotune_keys + reduce_auto_tune_configs, reduce_autotune_keys = autotune_configs, autotune_keys + return (fwd_auto_tune_configs, fwd_autotune_keys), ( + reduce_auto_tune_configs, + reduce_autotune_keys, + ) + + +(fwd_auto_tune_configs, fwd_autotune_keys), ( + reduce_auto_tune_configs, + reduce_autotune_keys, +) = get_autotune_configs() + + +@triton.jit +def _attn_fwd_inner( + q, + kT, + v, + pos, + col_mask, + m_i, + l_i, + acc, + pid_m, + q_descale, + k_descale, + v_descale, # FP8 scaling factors + IS_FP8: tl.constexpr, # FP8 flag + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + N_CTX_Q: tl.constexpr, + N_CTX_K_FINAL: tl.constexpr, + USE_ALIBI: tl.constexpr, + alibi_slope, + USE_SLIDING_WINDOW: tl.constexpr, + IS_CAUSAL: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + APPLY_COL_MASK: tl.constexpr, # apply provided col_mask when True +): + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_FP8: + qk += tl.dot(q, kT) * q_descale * k_descale # Apply FP8 scaling + else: + qk += tl.dot(q, kT) # noqa: F821 + + if USE_ALIBI: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = pos + tl.arange(0, BLOCK_N) + + # Compute relative positions + relative_pos = row_idx[:, None] + N_CTX_K_FINAL - (N_CTX_Q + col_idx[None, :]) + relative_pos = tl.abs(relative_pos) + + # Compute ALiBi bias + alibi_bias = -1 * alibi_slope * relative_pos + qk += alibi_bias * 1.44269504 + + # ------------------------------------------------------------------ + # masking + # ------------------------------------------------------------------ + if USE_SLIDING_WINDOW: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # q positions + col_idx = pos + tl.arange(0, BLOCK_N) # k positions + row = row_idx[:, None] # [M,1] + col = col_idx[None, :] # [1,N] + + if IS_CAUSAL: + # -------- causal + window -------- + diag = N_CTX_K_FINAL - N_CTX_Q # sk-sq + causal_ok = col <= row + diag + if WINDOW_SIZE_LEFT < 0: # only right window + win_ok = col <= row + diag + WINDOW_SIZE_RIGHT + else: # both sides + win_ok = (col >= row + diag - WINDOW_SIZE_LEFT) & ( + col <= row + diag + WINDOW_SIZE_RIGHT + ) + mask = ~(causal_ok & win_ok) # True ⇒ -inf + else: + # -------- non-causal window -------- + sk, sq = N_CTX_K_FINAL, N_CTX_Q + if WINDOW_SIZE_LEFT < 0: + mask = col > row + (sk - sq) + WINDOW_SIZE_RIGHT + else: + right = tl.minimum(row + (sk - sq) + WINDOW_SIZE_RIGHT, sk) + left = row + (sk - sq) - WINDOW_SIZE_LEFT + mask = (col > right) | (col < left) + qk = tl.where(mask, float("-inf"), qk) + else: + if IS_CAUSAL: + row_idx = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + col_idx = pos + tl.arange(0, BLOCK_N) + + # create a N_CTX_Q x kv_len causal mask + col_offset = N_CTX_K_FINAL - N_CTX_Q + causal_mask = row_idx[:, None] >= (col_idx[None, :] - col_offset) + + # Apply the mask + qk = tl.where(causal_mask, qk, float("-inf")) + + # Column mask (tail / variable-length). Instead of recomputing an arange each time, + # we accept a precomputed mask from the caller (col_valid_mask). + if APPLY_COL_MASK: + # Expect col_mask shape: [BLOCK_N]. True where column is within sequence. + qk = tl.where(col_mask[None, :], qk, float("-inf")) + + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) # per-row max so far + + # rows that are *all* -inf after masking + valid = m_i_new > float("-inf") + + # scale previous partial sums safely + alpha = tl.where(valid, tl.math.exp2(m_i - m_i_new), 0.0) + + # subtract the row max only on valid rows + qk = tl.where(valid[:, None], qk - m_i_new[:, None], float("-inf")) + p = tl.math.exp2(qk) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(q.dtype) + + # -- scale and update acc -- + acc *= alpha[:, None] + if IS_FP8: + acc += tl.dot(p.to(v.dtype), v) * v_descale # Apply FP8 scaling for V + else: + acc += tl.dot(p.to(v.dtype), v) + + return m_i, l_i, acc + + +# @triton.autotune( +# configs=fwd_auto_tune_configs, +# key=fwd_autotune_keys, +# use_cuda_graph=True, +# ) +@triton.jit +def _fwd_kernel_splitK( + Q, + K, + V, + Q_Descale, # FP8 descale factors for Q + K_Descale, # FP8 descale factors for K + V_Descale, # FP8 descale factors for V + sm_scale, + Out_splitK, # [B*H*G, split_k, Mq, K] + Metadata, # [B*H*G, 2, split_k, M_ceil] contains [mi, li] + K_new, + V_new, + Cache_seqlens, + Cache_batch_idx, + Block_table, + Alibi_slopes, + stride_qz, + stride_qm, + stride_qg, + stride_qh, + stride_qd, + stride_kz, + stride_kn, + stride_kg, + stride_kh, + stride_kd, + stride_vz, + stride_vn, + stride_vg, + stride_vh, + stride_vd, + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_d, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_kn_z, + stride_kn_n, + stride_kn_g, + stride_kn_h, + stride_kn_d, + stride_vn_z, + stride_vn_n, + stride_vn_g, + stride_vn_h, + stride_vn_d, + stride_bt_b, + stride_bt_s, + stride_az, + stride_ah, + stride_q_descale_z, # FP8 descale strides + stride_q_descale_h, + stride_k_descale_z, + stride_k_descale_h, + stride_v_descale_z, + stride_v_descale_h, + Z, + N_CTX_Q, + N_CTX_K, + N_CTX_NEW, + BLOCK_N_PER_SPLIT, + BLOCK_SIZE_K: tl.constexpr, + H_q: tl.constexpr, + H_kv: tl.constexpr, + G_q: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, + USE_CACHE_SEQLENs: tl.constexpr, + USE_CACHE_BATCH_IDX: tl.constexpr, + NEW_KV: tl.constexpr, + IS_GQA: tl.constexpr, + IS_CAUSAL: tl.constexpr, + USE_ALIBI: tl.constexpr, + PADDED_HEAD: tl.constexpr, + GROUP_SIZE: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + USE_BLOCK_TABLE: tl.constexpr, + IS_FP8: tl.constexpr, # FP8 flag +): + # get program ids + pid_m = tl.program_id(0) + pid_zhg = tl.program_id(1) + pid_splitk = tl.program_id(2) + + # compute z, h and g ids + z_id = pid_zhg // (H_q * G_q) + hq_id = (pid_zhg // G_q) % H_q + g_id = pid_zhg % G_q + + # is gqa + if IS_GQA: + hk_id = hq_id // GROUP_SIZE + hv_id = hk_id + else: + hk_id = hq_id + hv_id = hq_id + + # Load FP8 descale factors if needed + if IS_FP8: + if IS_GQA: + # For MQA/GQA, q_descale uses the same indexing as k/v (hk_id) + q_descale = tl.load( + Q_Descale + z_id * stride_q_descale_z + hk_id * stride_q_descale_h + ) + else: + # For MHA, q_descale uses hq_id + q_descale = tl.load( + Q_Descale + z_id * stride_q_descale_z + hq_id * stride_q_descale_h + ) + k_descale = tl.load( + K_Descale + z_id * stride_k_descale_z + hk_id * stride_k_descale_h + ) + v_descale = tl.load( + V_Descale + z_id * stride_v_descale_z + hv_id * stride_v_descale_h + ) + else: + q_descale, k_descale, v_descale = 1.0, 1.0, 1.0 + + # figure out seqlens + lo = pid_splitk * BLOCK_N_PER_SPLIT + if USE_CACHE_SEQLENs: + cache_seqlen_last_idx = tl.load(Cache_seqlens + z_id) + N_CTX_K_FINAL = cache_seqlen_last_idx + else: + N_CTX_K_FINAL = N_CTX_K + hi = tl.minimum((pid_splitk + 1) * BLOCK_N_PER_SPLIT, N_CTX_K_FINAL) + + # pick batch index + if USE_CACHE_BATCH_IDX: + cache_batch_idx = tl.load(Cache_batch_idx + z_id) + else: + cache_batch_idx = z_id + + # compute offsets + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + # compute ptrs + q_offset = Q + hq_id * stride_qh + z_id * stride_qz + g_id * stride_qg + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + + # Handle block table for paged attention + if USE_BLOCK_TABLE: + # K and V now point to paged cache + # Each batch has its own block table row + block_table_ptr = Block_table + z_id * stride_bt_b + else: + k_offset = ( + K + hk_id * stride_kh + cache_batch_idx * stride_kz + g_id * stride_kg + ) + v_offset = ( + V + hv_id * stride_vh + cache_batch_idx * stride_vz + g_id * stride_vg + ) + + # compute masks + if PADDED_HEAD: + q_mask = (offs_m < N_CTX_Q)[:, None] & (offs_d < ACTUAL_BLOCK_DMODEL)[None, :] + kT_mask = (offs_d < ACTUAL_BLOCK_DMODEL)[:, None] & (offs_n < N_CTX_K_FINAL)[ + None, : + ] + v_mask = (offs_n < N_CTX_K_FINAL)[:, None] & (offs_d < ACTUAL_BLOCK_DMODEL)[ + None, : + ] + osk_mask = (offs_m < N_CTX_Q)[:, None] & (offs_d < ACTUAL_BLOCK_DMODEL)[None, :] + else: + q_mask = (offs_m < N_CTX_Q)[:, None] + kT_mask = (offs_n < N_CTX_K_FINAL)[None, :] + v_mask = (offs_n < N_CTX_K_FINAL)[:, None] + osk_mask = (offs_m < N_CTX_Q)[:, None] + + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs, mask=q_mask, other=0.0) + q = (q * qk_scale).to(q.dtype) + + # load ALiBi slope if enabled + if USE_ALIBI: + a_offset = z_id * stride_az + hq_id * stride_ah + alibi_slope = tl.load(Alibi_slopes + a_offset) + else: + alibi_slope = None + + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # noqa: F821 + + # loop over k, v and update accumulator + if USE_BLOCK_TABLE: + # Paged attention: process all KV blocks from cache + # Note: Cache should be updated externally before calling this kernel + num_kv_blocks = (N_CTX_K_FINAL + BLOCK_SIZE_K - 1) // BLOCK_SIZE_K + + for block_idx in range(num_kv_blocks): + # Calculate sequence range for this block + block_start = block_idx * BLOCK_SIZE_K + block_end = tl.minimum(block_start + BLOCK_SIZE_K, N_CTX_K_FINAL) + + # Check if block overlaps with our split-k range [lo, hi) + if block_end > lo and block_start < hi: + # Load physical block number + physical_block = tl.load(block_table_ptr + block_idx * stride_bt_s) + + # Calculate the range within this block that overlaps with [lo, hi) + process_start = tl.maximum(lo - block_start, 0) + process_end = tl.minimum(hi - block_start, BLOCK_SIZE_K) + process_end = tl.minimum(process_end, block_end - block_start) + + # Instead of forcing a floor alignment to BLOCK_N (which can still skip + # part of the intended range if start falls mid-tile for small splits), + # start from the raw (possibly unaligned) process_start rounded *down* but + # allow the loop to begin earlier (at most BLOCK_N before) so that any + # partial tile overlapping [lo, hi) is covered. Masking below will remove + # columns < lo or >= hi ensuring numerically identical coverage without + # duplication. + aligned_start = (process_start // BLOCK_N) * BLOCK_N + if aligned_start > 0 and aligned_start + BLOCK_N > process_start: + # ensure we include the tile that contains process_start + process_start = aligned_start + else: + process_start = aligned_start + + for offset in range(process_start, process_end, BLOCK_N): + # Current position (may begin slightly before logical split range; masking fixes it) + pos = block_start + offset + # Proceed unconditionally; masking below enforces [lo, hi) + # Calculate base addresses for K and V in this physical block + k_base = ( + K + + physical_block * BLOCK_SIZE_K * stride_kn + + hk_id * stride_kh + + g_id * stride_kg + ) + v_base = ( + V + + physical_block * BLOCK_SIZE_K * stride_vn + + hv_id * stride_vh + + g_id * stride_vg + ) + + # Offsets within the current block + block_offs = offset + offs_n + + # Masks for valid data respecting: + # (1) global key length (seq_mask) + # (2) block bounds (block_mask) + # (3) current split range [lo, hi) + seq_mask = (pos + offs_n) < N_CTX_K_FINAL + block_mask = block_offs < BLOCK_SIZE_K + end_mask = block_offs < process_end + split_mask = ((pos + offs_n) >= lo) & ((pos + offs_n) < hi) + col_mask = seq_mask & block_mask & end_mask & split_mask + + # Apply masks + kT_mask_final = kT_mask & col_mask[None, :] + v_mask_final = v_mask & col_mask[:, None] + + # Load K and V + kT_ptrs = ( + k_base + + offs_d[:, None] * stride_kd + + block_offs[None, :] * stride_kn + ) + v_ptrs = ( + v_base + + block_offs[:, None] * stride_vn + + offs_d[None, :] * stride_vd + ) + + kT = tl.load(kT_ptrs, mask=kT_mask_final, other=0.0) + v = tl.load(v_ptrs, mask=v_mask_final, other=0.0) + + # Unified inner function handles both paged and contiguous + m_i, l_i, acc = _attn_fwd_inner( + q, + kT, + v, + pos, + col_mask, + m_i, + l_i, + acc, + pid_m, + q_descale, + k_descale, + v_descale, + IS_FP8, + BLOCK_M, + BLOCK_N, + N_CTX_Q, + N_CTX_K_FINAL, + USE_ALIBI, + alibi_slope, + USE_SLIDING_WINDOW, + IS_CAUSAL, + WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT, + True, + ) + else: + # Non-paged attention: process KV from cache + # Note: Cache should be updated externally before calling this kernel + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + kT_ptrs = ( + k_offset + + offs_d[:, None] * stride_kd + + (start_n + offs_n)[None, :] * stride_kn + ) + V_ptrs = ( + v_offset + + (start_n + offs_n)[:, None] * stride_vn + + offs_d[None, :] * stride_vd + ) + + # load k + kT = tl.load(kT_ptrs, mask=kT_mask, other=0.0) + v = tl.load(V_ptrs, mask=v_mask, other=0.0) + + # Use the same inner loop logic + # Precompute column validity mask for this tile (all True for full tiles). + # hi is the upper bound of the overall split range; start_n marks this tile's base. + col_valid_mask = offs_n < (hi - start_n) + + m_i, l_i, acc = _attn_fwd_inner( + q, + kT, + v, + start_n, + col_valid_mask, + m_i, + l_i, + acc, + pid_m, + q_descale, + k_descale, + v_descale, + IS_FP8, + BLOCK_M, + BLOCK_N, + N_CTX_Q, + N_CTX_K_FINAL, + USE_ALIBI, + alibi_slope, + USE_SLIDING_WINDOW, + IS_CAUSAL, + WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT, + BOUNDS_CHECKS_N, + ) + + # write back O + osk_offset = Out_splitK + pid_zhg * stride_osk_zhg + pid_splitk * stride_osk_s + osk_ptrs = ( + osk_offset + offs_m[:, None] * stride_osk_m + offs_d[None, :] * stride_osk_d + ) + tl.store( + osk_ptrs, + acc, + mask=osk_mask, + ) + + # write metadata for split-K reduction + metadata_offset = Metadata + pid_zhg * stride_mzhg + pid_splitk * stride_ms + metadata_ptr = metadata_offset + offs_m + tl.store(metadata_ptr, m_i) + tl.store(metadata_ptr + stride_m2, l_i) + + +# @triton.autotune( +# configs=reduce_auto_tune_configs, +# key=reduce_autotune_keys, +# use_cuda_graph=True, +# ) +@triton.jit +def _splitK_reduce( + Out_splitK, # [B*H*G, split_k, Mq, K] + Metadata, # [B*H*G, 2, split_k, M_ceil] contains [mi, li] + Out, # [B, H, G, M, K] + LSE, # [B*H*G, M] + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_oz, + stride_oh, + stride_og, + stride_om, + stride_ok, + stride_lse_zhg, + stride_lse_m, + K_BLOCK_SIZE: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + split_k: tl.constexpr, + splitK_pow2: tl.constexpr, + MASK_SPLITK: tl.constexpr, + PADDED_HEAD: tl.constexpr, +): + # get pids + pid_zhg = tl.program_id(0) + pid_m = tl.program_id(1) + pid_k = tl.program_id(2) + + # compute offsets + offs_splitK = tl.arange(0, splitK_pow2) + offs_k = pid_k * K_BLOCK_SIZE + tl.arange(0, K_BLOCK_SIZE) + + # compute masks + if PADDED_HEAD: + o_mask = offs_k < ACTUAL_BLOCK_DMODEL + else: + o_mask = None + + # compute ptrs + metadata_offset = Metadata + pid_zhg * stride_mzhg + metadata_ptr = metadata_offset + offs_splitK * stride_ms + pid_m * stride_mm + + osk_offset = Out_splitK + pid_zhg * stride_osk_zhg + pid_m * stride_osk_m + osk_ptr = ( + osk_offset + + offs_splitK[:, None] * stride_osk_s + + offs_k[None, :] * stride_osk_k + ) + + # read max values of each splitK + if MASK_SPLITK: + splitK_mask = offs_splitK < split_k + l_m = tl.load(metadata_ptr, mask=splitK_mask, other=float("-inf")) + l_sum = tl.load(metadata_ptr + stride_m2, mask=splitK_mask, other=0.0) + acc = tl.load(osk_ptr, mask=splitK_mask[:, None], other=0.0) + else: + l_m = tl.load(metadata_ptr) + l_sum = tl.load(metadata_ptr + stride_m2) + acc = tl.load(osk_ptr) + + g_m = tl.max(l_m, axis=0) + + alpha = tl.where(l_m > float("-inf"), tl.math.exp2(l_m - g_m), 0.0) + + # read sum + l_sum *= alpha + g_sum = tl.sum(l_sum, axis=0) + acc = acc * alpha[:, None] + + g_sum_safe = tl.where(g_sum > 0, g_sum, 1.0) + acc_out = tl.sum(acc, axis=0) / g_sum_safe + + # Store output + z_id = pid_zhg // (H * G) + h_id = (pid_zhg // G) % H + g_id = pid_zhg % G + out_offset = Out + z_id * stride_oz + h_id * stride_oh + g_id * stride_og + out_ptr = out_offset + pid_m * stride_om + offs_k + tl.store(out_ptr, acc_out, mask=o_mask) + + # Store lse + l_ptrs = LSE + pid_zhg * stride_lse_zhg + pid_m + lse_val = tl.where(g_sum > 0, (g_m + tl.math.log2(g_sum)) / 1.44269504, g_m) + tl.store(l_ptrs, lse_val) + + +@triton.jit +def cast_uint32_to_half2(scale_shift): + # Extract two float16 packed into one int32 + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + + +@triton.jit +def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, +): + # PACKED_PER_VAL is the number of values packed into + # each element x_. For example, for int4 quantization + # and x_ of type int32, PACKED_PER_VAL is 8. + + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = ( + x_[:, None, :] >> offsets[None, :, None] + ) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + + quant_offset = tl.view( + quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL) + ) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + + +def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + # Scale and shift are such that quantization linearly maps + # int4 values range [0..15] to input values range min(k)..max(k) + # individually for every row + k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups) + max_vals = torch.max(k, dim=-1, keepdim=True).values + min_vals = torch.min(k, dim=-1, keepdim=True).values + scale_k: torch.Tensor = (max_vals - min_vals) / 15 + + shift_k = torch.min(k, dim=-1, keepdim=True).values + scale_k = scale_k.to(torch.float16) + shift_k = shift_k.to(torch.float16) + + in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5 + in_bytes = in_bytes.to(torch.uint8) + in_int4 = in_bytes & 0xF + in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4) + scale_shift = torch.concat( + [scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1 + ) + k_quant = torch.concat( + [ + scale_shift.flatten(start_dim=-2), + in_int4_packed.flatten(start_dim=-2), + ], + dim=-1, + ).view(torch.int16) + return k_quant + + +def dequantize_kv_fp16(quant_k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + k_i16 = quant_k.view(torch.int16) + k_ui8 = k_i16.view(torch.uint8) + + ss_size = num_groups * 4 + scale_shift_ui8 = k_ui8[..., 0:ss_size] + scale_shift_ui8 = scale_shift_ui8.reshape( + *scale_shift_ui8.shape[:-1], num_groups, 4 + ) + scale = scale_shift_ui8[..., 0:2].view(torch.float16) + shift = scale_shift_ui8[..., 2:4].view(torch.float16) + + kv_ui8 = k_ui8[..., ss_size:] + k_ui8 = kv_ui8.reshape(*kv_ui8.shape[:-1], num_groups, -1) + k1_i4 = k_ui8 & 0xF + k2_i4 = (k_ui8 & 0xF0) >> 4 + k_shape = k1_i4.shape + k1_f16 = k1_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + k2_f16 = k2_i4.to(torch.float16) * scale.expand(k_shape) + shift.expand(k_shape) + + out = torch.empty( + (*k1_f16.shape[:-1], k1_f16.shape[-1] * 2), + dtype=torch.float16, + device=quant_k.device, + ) + out[..., ::2] = k1_f16 + out[..., 1::2] = k2_f16 + out = out.reshape(*k_shape[:-2], -1) + + return out + + +def get_split_k(B: int, G: int, H: int, Mk: int) -> int: + """Heuristic for the number of splits""" + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 + while split_k > 0 and Mk / split_k < max_chunk_size: + split_k = split_k // 2 + while B * H * G * split_k >= 1024: + split_k = split_k // 2 + split_k = min(split_k, 512) + split_k = max(split_k, 1) + return split_k + + +def attention_forward_decode_triton_impl( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_new: Optional[torch.Tensor], + v_new: Optional[torch.Tensor], + out: torch.Tensor, + softmax_lse: torch.Tensor, + sm_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + alibi_slopes: Optional[torch.Tensor], + layout: Literal["bshd"], + cache_seqlens: Optional[torch.Tensor], + cache_batch_idx: Optional[torch.Tensor], + block_table: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + # rotary (optional) + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + rotary_interleaved: bool = False, + seqlens_rotary: Optional[torch.Tensor] = None, +): + # apply rotary embedding + if rotary_cos is not None and rotary_sin is not None: + # Prefer explicitly provided rotary sequence start offsets if given; fall back to cache_seqlens. + seqlen_offsets = ( + seqlens_rotary + if seqlens_rotary is not None + else (cache_seqlens if cache_seqlens is not None else 0) + ) + local = (window_size_left != -1) or (window_size_right != -1) + q, k_new = apply_rotary( + q, + k_new, + rotary_cos, + rotary_sin, + causal=causal, + local=local, + interleaved=rotary_interleaved, + seqlen_offsets=seqlen_offsets, + ) + + # handle cache updates + if k_new is not None and v_new is not None: + # Update cache with new KV values + if block_table is None: + # Non-paged attention: update cache directly + batch_size = k_new.shape[0] + seqlen_new = k_new.shape[1] + + if cache_seqlens is not None: + # Use cache_seqlens to determine where to insert new KV + for b in range(batch_size): + start_idx = int(cache_seqlens[b].item()) + end_idx = start_idx + seqlen_new + k_cache[b, start_idx:end_idx] = k_new[b] + v_cache[b, start_idx:end_idx] = v_new[b] + cache_seqlens[b] = end_idx + else: + # Append at the end of existing cache + seqlen_cache = k_cache.shape[1] + k_cache[:, seqlen_cache - seqlen_new :] = k_new + v_cache[:, seqlen_cache - seqlen_new :] = v_new + else: + # Paged attention: update cache using block table + batch_size = k_new.shape[0] + seqlen_new = k_new.shape[1] + block_size = k_cache.shape[ + 1 + ] # k_cache shape: [num_blocks, block_size, nheads, head_dim] + + # Update cache for each batch element + for b in range(batch_size): + if cache_seqlens is not None: + start_idx = int(cache_seqlens[b].item()) + else: + # If no cache_seqlens, assume we're appending at the end + # Find the last used position from block table + start_idx = 0 + for block_idx in range(block_table.shape[1]): + if block_table[b, block_idx] >= 0: + start_idx = (block_idx + 1) * block_size + else: + start_idx = block_idx * block_size + break + + # Copy new KV values into the paged cache + for i in range(seqlen_new): + pos = start_idx + i + block_idx = pos // block_size + within_block_idx = pos % block_size + + # Get the physical block number from block table + if block_idx < block_table.shape[1]: + physical_block = int(block_table[b, block_idx].item()) + + # Update k_cache and v_cache at the physical block location + k_cache[physical_block, within_block_idx] = k_new[b, i] + v_cache[physical_block, within_block_idx] = v_new[b, i] + + # Update cache_seqlens if provided + if cache_seqlens is not None: + cache_seqlens[b] = start_idx + seqlen_new + + # triton configs + BLOCK_M = 16 + BLOCK_N = 64 + num_stages = 1 + num_warps_fwd = 1 + num_warps_reduce = 4 + + # kernel_configs + is_new_kv = False # Cache has been updated, so no new KV in kernel + use_alibi, (stride_az, stride_ah) = True if alibi_slopes is not None else False, ( + alibi_slopes.stride() if alibi_slopes is not None else (None, None) + ) + use_cache_seqlens = cache_seqlens is not None + use_sliding_window = window_size_left != -1 or window_size_right != -1 + use_block_table = block_table is not None + SPLIT_K = None + NUM_QUANT_GROUPS = 1 + + # get shapes and strides + (batch_size, seqlen_q, nheads_q, dim_q), ( + stride_qz, + stride_qh, + stride_qm, + stride_qd, + ) = get_shape_and_strides_from_layout(q, layout) + + # Handle paged KV cache layout + if use_block_table: + # For paged attention, k_cache and v_cache have shape [num_blocks, block_size, nheads, head_dim] + num_blocks_kc, block_size_k, nheads_kc, dim_kc = k_cache.shape + num_blocks_vc, block_size_v, nheads_vc, dim_vc = v_cache.shape + # Get the actual sequence length from cache_seqlens or block_table + if cache_seqlens is not None: + seqlen_kc = int(cache_seqlens.max().item()) + else: + # Infer from block_table shape [batch_size, num_blocks_per_seq] + num_blocks_per_seq = block_table.shape[1] + seqlen_kc = num_blocks_per_seq * block_size_k + seqlen_vc = seqlen_kc + + # Strides for paged layout + stride_kc_z = 0 # No batch dimension in paged cache + stride_kc_n = k_cache.stride(1) # Sequence stride + stride_kc_h = k_cache.stride(2) # Head stride + stride_kc_d = k_cache.stride(3) # Dim stride + + stride_vc_z = 0 + stride_vc_n = v_cache.stride(1) + stride_vc_h = v_cache.stride(2) + stride_vc_d = v_cache.stride(3) + else: + (_, seqlen_kc, nheads_kc, dim_kc), ( + stride_kc_z, + stride_kc_h, + stride_kc_n, + stride_kc_d, + ) = get_shape_and_strides_from_layout(k_cache, layout) + (_, seqlen_vc, nheads_vc, dim_vc), ( + stride_vc_z, + stride_vc_h, + stride_vc_n, + stride_vc_d, + ) = get_shape_and_strides_from_layout(v_cache, layout) + block_size_k = 0 # Not used + if is_new_kv: + (_, seqlen_kn, nheads_kn, dim_kn), ( + stride_kn_z, + stride_kn_h, + stride_kn_n, + stride_kn_d, + ) = get_shape_and_strides_from_layout(k_new, layout) + (_, seqlen_vn, nheads_vn, dim_vn), ( + stride_vn_z, + stride_vn_h, + stride_vn_n, + stride_vn_d, + ) = get_shape_and_strides_from_layout(v_new, layout) + else: + (_, seqlen_kn, nheads_kn, dim_kn), ( + stride_kn_z, + stride_kn_h, + stride_kn_n, + stride_kn_d, + ) = (None, None, None, None,), (None, None, None, None) + (_, seqlen_vn, nheads_vn, dim_vn), ( + stride_vn_z, + stride_vn_h, + stride_vn_n, + stride_vn_d, + ) = (None, None, None, None,), (None, None, None, None) + (_, seqlen_o, nheads_o, dim_o), (stride_oz, stride_oh, stride_om, stride_od) = ( + get_shape_and_strides_from_layout(out, layout) + ) + assert ( + dim_q == dim_kc == dim_vc + ), f"Dimensions must match: {dim_q}, {dim_kc}, {dim_vc}" + + # add extra information needed by the kernels + if layout == "bshd": + (n_group_q, heads_per_group_q), stride_qg = (1, nheads_q), stride_qm + (n_group_k, heads_per_group_k), stride_kc_g = (1, nheads_kc), stride_kc_n + (n_group_v, heads_per_group_v), stride_vc_g = (1, nheads_vc), stride_vc_n + if is_new_kv: + (n_group_kn, heads_per_group_kn), stride_kn_g = (1, nheads_kn), stride_kn_n + (n_group_vn, heads_per_group_vn), stride_vn_g = (1, nheads_vn), stride_vn_n + else: + (n_group_kn, heads_per_group_kn), stride_kn_g = (None, None), None + (n_group_vn, heads_per_group_vn), stride_vn_g = (None, None), None + (n_group_o, heads_per_group_o), stride_og = (1, nheads_o), stride_om + else: + raise ValueError(f"{layout} layout is not supported") + + # get padded size + dim_padded = get_padded_headsize(dim_kc) + is_padded_head = dim_padded != dim_kc + + # Handle MQA/GQA case + group_size = nheads_q // nheads_kc + if group_size > 1: + is_gqa = True + else: + is_gqa = False + + if SPLIT_K is not None: + split_k = SPLIT_K + else: + # Use heuristics + if use_block_table: + # For paged attention, use the actual sequence length from cache_seqlens + max_seqlen = ( + int(cache_seqlens.max().item()) + if cache_seqlens is not None + else block_size_k + ) + split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, max_seqlen) + else: + split_k = get_split_k(batch_size, n_group_q, heads_per_group_q, seqlen_kc) + split_size = (seqlen_kc + split_k - 1) // split_k + + # setup grid + seqlen_q_ceil = (seqlen_q + BLOCK_M - 1) // BLOCK_M * BLOCK_M + grid = lambda META: ( + triton.cdiv(seqlen_q, META["BLOCK_M"]), + batch_size * n_group_q * heads_per_group_q, + split_k, + ) + + # create intermediate tensors + out_splitk = torch.empty( + [batch_size * n_group_q * heads_per_group_q, split_k, seqlen_q_ceil, dim_kc], + dtype=torch.float32, + device=q.device, + ) + metadata = torch.empty( + [batch_size * n_group_q * heads_per_group_q, 2, split_k, seqlen_q_ceil], + dtype=torch.float32, + device=q.device, + ) + + # Validate pre-allocated softmax_lse tensor + # Expected shape after view: (batch_size, n_group_q * heads_per_group_q, seqlen_q) + # Internal shape: (batch_size * n_group_q * heads_per_group_q, seqlen_q) + expected_h_total = batch_size * n_group_q * heads_per_group_q + assert ( + softmax_lse.shape[0] == batch_size + ), f"softmax_lse.shape[0] ({softmax_lse.shape[0]}) must equal batch_size ({batch_size})" + assert ( + softmax_lse.shape[1] == n_group_q * heads_per_group_q + ), f"softmax_lse.shape[1] ({softmax_lse.shape[1]}) must equal n_group_q * heads_per_group_q ({n_group_q * heads_per_group_q})" + assert ( + softmax_lse.shape[2] >= seqlen_q + ), f"softmax_lse.shape[2] ({softmax_lse.shape[2]}) must be >= seqlen_q ({seqlen_q})" + assert ( + softmax_lse.dtype == torch.float32 + ), f"softmax_lse must be float32, got {softmax_lse.dtype}" + assert softmax_lse.device == q.device, f"softmax_lse must be on same device as q" + + # Create internal lse view for kernel use + lse = softmax_lse.view(expected_h_total, -1)[:, :seqlen_q].contiguous() + + # get intermediate tensor strides + stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d = out_splitk.stride() + stride_mzhg, stride_m2, stride_ms, stride_mm = metadata.stride() + stride_lse_zhg, stride_lse_m = lse.stride() + + # Block table strides + if use_block_table: + stride_bt_b, stride_bt_s = block_table.stride() + else: + stride_bt_b, stride_bt_s = 0, 0 + + # FP8 support + IS_FP8 = is_fp8([q, k_cache, v_cache]) + if IS_FP8: + rec_dtype = get_recommended_fp8_dtype(q) + if ( + q.dtype != rec_dtype + or k_cache.dtype != rec_dtype + or v_cache.dtype != rec_dtype + ): + arch = get_arch() + warnings.warn( + f"Use {rec_dtype} data type on {arch}. Got q: {q.dtype}, k: {k_cache.dtype}, v: {v_cache.dtype}", + UserWarning, + ) + if (q_descale is None) or (k_descale is None) or (v_descale is None): + warnings.warn( + "FP8 tensors detected but descale factors not provided. Using default scale of 1.0", + UserWarning, + ) + # Create default descale tensors if not provided + if q_descale is None: + q_descale = torch.ones( + batch_size, nheads_q, dtype=torch.float32, device=q.device + ) + if k_descale is None: + k_descale = torch.ones( + batch_size, nheads_kc, dtype=torch.float32, device=q.device + ) + if v_descale is None: + v_descale = torch.ones( + batch_size, nheads_vc, dtype=torch.float32, device=q.device + ) + else: + # Enforce exact expected shapes; no reshaping or normalization. + assert ( + q_descale.dim() == 2 + and q_descale.shape[0] == batch_size + and q_descale.shape[1] == nheads_kc + ), f"q_descale expected shape ({batch_size}, {nheads_kc}) got {tuple(q_descale.shape)}" + assert ( + k_descale.dim() == 2 + and k_descale.shape[0] == batch_size + and k_descale.shape[1] == nheads_kc + ), f"k_descale expected shape ({batch_size}, {nheads_kc}) got {tuple(k_descale.shape)}" + assert ( + v_descale.dim() == 2 + and v_descale.shape[0] == batch_size + and v_descale.shape[1] == nheads_kc + ), f"v_descale expected shape ({batch_size}, {nheads_kc}) got {tuple(v_descale.shape)}" + stride_q_descale_z, stride_q_descale_h = q_descale.stride() + stride_k_descale_z, stride_k_descale_h = k_descale.stride() + stride_v_descale_z, stride_v_descale_h = v_descale.stride() + else: + q_descale = None + k_descale = None + v_descale = None + stride_q_descale_z = 0 + stride_q_descale_h = 0 + stride_k_descale_z = 0 + stride_k_descale_h = 0 + stride_v_descale_z = 0 + stride_v_descale_h = 0 + + if DEBUG: + print( + "batch_size, seqlen_q, nheads_q, dim_q", + (batch_size, seqlen_q, nheads_q, dim_q), + ) + print("_, seqlen_kc, nheads_kc, dim_kc", (_, seqlen_kc, nheads_kc, dim_kc)) + print("dim_padded:", dim_padded) + print( + "stride_qz, stride_qm, stride_qg, stride_qh, stride_qd", + (stride_qz, stride_qm, stride_qg, stride_qh, stride_qd), + ) + print( + "stride_kc_z, stride_kc_n, stride_kc_g, stride_kc_h, stride_kc_d", + (stride_kc_z, stride_kc_n, stride_kc_g, stride_kc_h, stride_kc_d), + ) + print( + "stride_vc_z, stride_vc_n, stride_vc_g, stride_vc_h, stride_vc_d", + (stride_vc_z, stride_vc_n, stride_vc_g, stride_vc_h, stride_vc_d), + ) + if is_new_kv: + print( + "stride_kn_z, stride_kn_n, stride_kn_g, stride_kn_h, stride_kn_d", + (stride_kn_z, stride_kn_n, stride_kn_g, stride_kn_h, stride_kn_d), + ) + print( + "stride_vn_z, stride_vn_n, stride_vn_g, stride_vn_h, stride_vn_d", + (stride_vn_z, stride_vn_n, stride_vn_g, stride_vn_h, stride_vn_d), + ) + print( + "stride_oz, stride_om, stride_og, stride_oh, stride_od", + (stride_oz, stride_om, stride_og, stride_oh, stride_od), + ) + print( + "stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d", + (stride_osk_zhg, stride_osk_s, stride_osk_m, stride_osk_d), + ) + print( + "stride_mzhg, stride_m2, stride_ms, stride_mm", + (stride_mzhg, stride_m2, stride_ms, stride_mm), + ) + print("stride_lse_zhg, stride_lse_m", (stride_lse_zhg, stride_lse_m)) + + _fwd_kernel_splitK[grid]( + Q=q, + K=k_cache, + V=v_cache, + Q_Descale=q_descale, + K_Descale=k_descale, + V_Descale=v_descale, + sm_scale=sm_scale, + Out_splitK=out_splitk, + Metadata=metadata, + K_new=None, + V_new=None, + Cache_seqlens=cache_seqlens, + Cache_batch_idx=cache_batch_idx, + Block_table=block_table, + Alibi_slopes=alibi_slopes, + # q strides + stride_qz=stride_qz, + stride_qm=stride_qm, + stride_qg=stride_qg, + stride_qh=stride_qh, + stride_qd=stride_qd, + # k strides + stride_kz=stride_kc_z, + stride_kn=stride_kc_n, + stride_kg=stride_kc_g, + stride_kh=stride_kc_h, + stride_kd=stride_kc_d, + # v strides + stride_vz=stride_vc_z, + stride_vn=stride_vc_n, + stride_vg=stride_vc_g, + stride_vh=stride_vc_h, + stride_vd=stride_vc_d, + # out_splitk strides + stride_osk_zhg=stride_osk_zhg, + stride_osk_s=stride_osk_s, + stride_osk_m=stride_osk_m, + stride_osk_d=stride_osk_d, + # metadata strides + stride_mzhg=stride_mzhg, + stride_m2=stride_m2, + stride_ms=stride_ms, + stride_mm=stride_mm, + # k_new strides + stride_kn_z=stride_kn_z, + stride_kn_n=stride_kn_n, + stride_kn_g=stride_kn_g, + stride_kn_h=stride_kn_h, + stride_kn_d=stride_kn_d, + # v_new strides + stride_vn_z=stride_vn_z, + stride_vn_n=stride_vn_n, + stride_vn_g=stride_vn_g, + stride_vn_h=stride_vn_h, + stride_vn_d=stride_vn_d, + # block table strides + stride_bt_b=stride_bt_b, + stride_bt_s=stride_bt_s, + # alibi strides + stride_az=stride_az, + stride_ah=stride_ah, + # FP8 descale strides + stride_q_descale_z=stride_q_descale_z, + stride_q_descale_h=stride_q_descale_h, + stride_k_descale_z=stride_k_descale_z, + stride_k_descale_h=stride_k_descale_h, + stride_v_descale_z=stride_v_descale_z, + stride_v_descale_h=stride_v_descale_h, + Z=batch_size, + H_q=heads_per_group_q, + H_kv=heads_per_group_k, + G_q=n_group_q, + N_CTX_Q=seqlen_q, + N_CTX_K=seqlen_kc, + N_CTX_NEW=0, # No new KV, cache already updated + BLOCK_N_PER_SPLIT=split_size, + BLOCK_SIZE_K=block_size_k if use_block_table else 256, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=dim_padded, + ACTUAL_BLOCK_DMODEL=dim_kc, + BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_cache_seqlens, + USE_CACHE_SEQLENs=use_cache_seqlens, + USE_CACHE_BATCH_IDX=cache_batch_idx is not None, + NEW_KV=False, # Cache already updated + IS_GQA=is_gqa, + IS_CAUSAL=causal, + USE_ALIBI=use_alibi, + PADDED_HEAD=is_padded_head, + GROUP_SIZE=group_size, + USE_SLIDING_WINDOW=use_sliding_window, + WINDOW_SIZE_LEFT=window_size_left, + WINDOW_SIZE_RIGHT=window_size_right, + USE_BLOCK_TABLE=use_block_table, + IS_FP8=IS_FP8, + num_warps=num_warps_fwd, + num_stages=num_stages, + ) + + if DEBUG: + print("Out_splitK:", out_splitk, out_splitk.shape) + print("metadata:", metadata, metadata.shape) + print("lse:", lse, lse.shape) + print("Out:", out, out.shape) + + # Merge together + splitK_pow2 = triton.next_power_of_2(split_k) + mask_split_k = splitK_pow2 > split_k + if batch_size * n_group_q * heads_per_group_q * seqlen_q >= 512: + k_block_num = 1 + else: + k_block_num = 2 + assert dim_padded % k_block_num == 0 + k_block_size = dim_padded // k_block_num + grid = (batch_size * n_group_q * heads_per_group_q, seqlen_q, k_block_num) + + if DEBUG: + print("splitK_pow2:", splitK_pow2) + print("k_block_num:", k_block_num) + print("k_block_size:", k_block_size) + print("grid:", grid) + + _splitK_reduce[grid]( + out_splitk, + metadata, + out, + lse, + # Split-K output strides + stride_osk_zhg=stride_osk_zhg, + stride_osk_s=stride_osk_s, + stride_osk_m=stride_osk_m, + stride_osk_k=stride_osk_d, + # Metadata strides + stride_mzhg=stride_mzhg, + stride_m2=stride_m2, + stride_ms=stride_ms, + stride_mm=stride_mm, + # Output tensor strides + stride_oz=stride_oz, + stride_oh=stride_oh, + stride_og=stride_og, + stride_om=stride_om, + stride_ok=stride_od, + # LSE strides + stride_lse_zhg=stride_lse_zhg, + stride_lse_m=stride_lse_m, + K_BLOCK_SIZE=k_block_size, + BLOCK_DMODEL=dim_padded, + ACTUAL_BLOCK_DMODEL=dim_kc, + G=n_group_q, + H=heads_per_group_q, + # TODO: Tune num_warps + split_k=split_k, + splitK_pow2=splitK_pow2, + MASK_SPLITK=mask_split_k, + PADDED_HEAD=is_padded_head, + num_warps=num_warps_reduce, + ) diff --git a/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/fwd_prefill.py b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/fwd_prefill.py new file mode 100755 index 0000000000..3cc427382d --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/fwd_prefill.py @@ -0,0 +1,2090 @@ +import os +import warnings +import torch +import triton +import triton.language as tl +from typing import Literal, Optional +from .utils import ( + DEBUG, + AUTOTUNE, + FP8_AUTO_DESCALE, + compute_alibi_block, + compute_fp8_scaling_factors, + get_arch, + get_cu_count, + is_cdna, + is_fp8, + is_rdna, + apply_rotary, + get_recommended_fp8_dtype, +) + + +def get_fwd_configs(autotune: bool): + configs = [] + keys = [ + "IS_CAUSAL", + "dropout_p", + "MAX_SEQLENS_Q", + "MAX_SEQLENS_K", + "ACTUAL_BLOCK_DMODEL_QK", + "ACTUAL_BLOCK_DMODEL_V", + "IS_VARLEN", + "HQ", + "HK", + ] + + # get best config for the architecture + if not autotune: + arch = get_arch() + if arch == "gfx950": + configs.append( + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ) + ) + elif arch == "gfx942": + if get_cu_count() < 304: + configs.extend( + [ + # best fp8 config + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + # best f16 config + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 32, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=2, + num_warps=4, + ), + ] + ) + else: + configs.append( + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ) + ) + elif arch in ( + "gfx1030", + "gfx1100", + "gfx1101", + "gfx1102", + "gfx1200", + "gfx1201", + ): # RDNA architectures + configs.append( + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 32, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=2, + ) + ) + else: + configs.append( + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ) + ) + + return configs, keys + + # ===================== Autotune Sweep ===================== + BLOCK_M_OPTIONS = [128, 64, 32] + BLOCK_N_OPTIONS = [128, 64, 32] + NUM_WARPS_OPTIONS = [2, 4, 8] + NUM_STAGES_OPTIONS = [1, 2] + WAVES_PER_EU_OPTIONS = [4, 2, 1] + PRE_LOAD_V_OPTIONS = [False] + for bm in BLOCK_M_OPTIONS: + for bn in BLOCK_N_OPTIONS: + for waves in WAVES_PER_EU_OPTIONS: + for nw in NUM_WARPS_OPTIONS: + for ns in NUM_STAGES_OPTIONS: + for preload_v in PRE_LOAD_V_OPTIONS: + configs.append( + triton.Config( + { + "BLOCK_M": bm, + "BLOCK_N": bn, + "waves_per_eu": waves, + "PRE_LOAD_V": preload_v, + }, + num_stages=ns, + num_warps=nw, + ) + ) + + return configs, keys + + +fwd_prefill_autotune_configs, fwd_prefill_autotune_keys = get_fwd_configs(AUTOTUNE) + + +@triton.jit +def _attn_fwd_no_mask( + acc, + l_i, + m_i, + q, + k_base_ptrs, + v_base_ptrs, + bias_base_ptrs, + stride_kn, + stride_vk, + stride_bn, + stride_sn, + stride_sm, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + philox_offset_base, + sd_mask, + stride_sz, + stride_sh, + off_z, + off_h_q, + offs_m, + offs_n, + offs_d_qk, + offs_d_v, + block_min, + block_max, + alibi_slope, + q_descale, + k_descale, + v_descale, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_AUTO_DESCALE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL_QK: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + PADDED_HEAD_QK: tl.constexpr, + PADDED_HEAD_V: tl.constexpr, + ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, + ACTUAL_BLOCK_DMODEL_V: tl.constexpr, + SM_SCALE: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + RETURN_SCORES: tl.constexpr, + ACCUMULATOR_TYPE, +): + if USE_EXP2: + RCP_LN2: tl.constexpr = 1.4426950408889634 + + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # get ptrs + k_ptrs = k_base_ptrs + start_n * stride_kn + v_ptrs = v_base_ptrs + start_n * stride_vk + + kv_offs_n = start_n + tl.arange(0, BLOCK_N) + # Load K + if PADDED_HEAD_QK: + k_mask = offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK + k = tl.load(k_ptrs, mask=k_mask, other=0.0) + else: + k = tl.load(k_ptrs) + + # Optionally preload V + if PRE_LOAD_V: + if PADDED_HEAD_V: + v_mask = offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V + v = tl.load(v_ptrs, mask=v_mask, other=0.0) + else: + v = tl.load(v_ptrs) + + # setup qk accumlator + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE) + + # -- compute qk ---- + if IS_FP8: + qk += tl.dot(q, k) * q_descale * k_descale + else: + qk += tl.dot(q, k) + qk_scaled = qk * SM_SCALE + + if USE_ALIBI: + # compute the global position of each token within the sequence + q_offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + alibi_block = compute_alibi_block( + alibi_slope, seqlen_q, seqlen_k, q_offs_m, kv_offs_n + ) + qk_scaled += alibi_block + + # compute qk mask + qk_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) + + # compute bias + if bias_base_ptrs is not None: + bias_ptrs = bias_base_ptrs + start_n * stride_bn + bias = tl.load(bias_ptrs, mask=qk_mask, other=0.0) + qk_scaled += bias + + # get max scores so far + m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1)) + + # scale and subtract max + q_shifted = tl.where( + m_ij[:, None] == float("-inf"), float("-inf"), qk_scaled - m_ij[:, None] + ) + + # Compute scaled QK and softmax probabilities + if USE_EXP2: + p = tl.math.exp2(q_shifted * RCP_LN2) + else: + p = tl.math.exp(q_shifted) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + # Compute pointers for this block + philox_base = philox_offset_base + off_z * stride_sz + off_h_q * stride_sh + philox_ptrs = ( + philox_base + + offs_m[:, None] * stride_sm + + kv_offs_n[None, :] * stride_sn + ) + + # compute dropout mask + rng_output = tl.rand(philox_seed, philox_ptrs) + dropout_mask = rng_output > dropout_p + + # return scores with negative values for dropped vals (only if RETURN_SCORES is True) + if RETURN_SCORES: + sd_mask_value = tl.where(dropout_mask, p, -p) + sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh + sd_mask_ptrs = ( + sd_mask_base + + offs_m[:, None] * stride_sm + + kv_offs_n[None, :] * stride_sn + ) + + # Compute mask for sd_mask storage + sd_store_mask = (offs_m[:, None] < seqlen_q) & ( + kv_offs_n[None, :] < seqlen_k + ) + tl.store(sd_mask_ptrs, sd_mask_value, mask=sd_store_mask) + + # apply dropout mask in place + p = tl.where(dropout_mask, p, 0.0) + elif RETURN_SCORES: + # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that + sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh + sd_mask_ptrs = ( + sd_mask_base + + offs_m[:, None] * stride_sm + + kv_offs_n[None, :] * stride_sn + ) + + # Compute mask for sd_mask storage + sd_store_mask = (offs_m[:, None] < seqlen_q) & ( + kv_offs_n[None, :] < seqlen_k + ) + tl.store(sd_mask_ptrs, p, mask=sd_store_mask) + + # -- update output accumulator -- + # alpha is an adjustment factor for acc and li as we loop and find new maxes + # store the diff in maxes to adjust acc and li as we discover new maxes + m_diff = tl.where(m_ij == float("-inf"), float("-inf"), m_i - m_ij) + if USE_EXP2: + alpha = tl.math.exp2(m_diff * RCP_LN2) + else: + alpha = tl.math.exp(m_diff) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + if PADDED_HEAD_V: + v_mask = offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V + v = tl.load(v_ptrs, mask=v_mask, other=0.0) + else: + v = tl.load(v_ptrs) + + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + if IS_FP8: + if FP8_AUTO_DESCALE: + scale_p, descale_p = compute_fp8_scaling_factors(p, FP8_MAX) + acc += ( + tl.dot((p * scale_p).to(v.type.element_ty), v) + * descale_p + * v_descale + ) + else: + acc += tl.dot(p.to(v.type.element_ty), v) * v_descale + else: + acc += tl.dot(p.to(v.type.element_ty), v) + + return acc, l_i, m_i + + +@triton.jit +def _attn_fwd_mask( + acc, + l_i, + m_i, + q, + k_base_ptrs, + v_base_ptrs, + bias_base_ptrs, + stride_kn, + stride_vk, + stride_bn, + stride_sn, + stride_sm, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + philox_offset_base, + sd_mask, + stride_sz, + stride_sh, + off_z, + off_h_q, + offs_m, + offs_n, + offs_d_qk, + offs_d_v, + block_min, + block_max, + n_extra_tokens, + alibi_slope, + q_descale, + k_descale, + v_descale, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_AUTO_DESCALE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL_QK: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + PADDED_HEAD_QK: tl.constexpr, + PADDED_HEAD_V: tl.constexpr, + ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, + ACTUAL_BLOCK_DMODEL_V: tl.constexpr, + SM_SCALE: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + RETURN_SCORES: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + ACCUMULATOR_TYPE, +): + if USE_EXP2: + RCP_LN2: tl.constexpr = 1.4426950408889634 + + # seqlen diff + seqlen_delta_qk = seqlen_k - seqlen_q + + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # get ptrs + k_ptrs = k_base_ptrs + start_n * stride_kn + v_ptrs = v_base_ptrs + start_n * stride_vk + + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + kv_offs_n = start_n + tl.arange(0, BLOCK_N) + k_mask = kv_offs_n[None, :] < seqlen_k + v_mask = kv_offs_n[:, None] < seqlen_k + if PADDED_HEAD_QK: + k_mask = k_mask & (offs_d_qk[:, None] < ACTUAL_BLOCK_DMODEL_QK) + if PADDED_HEAD_V: + v_mask = v_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) + + # load k and if preload_v then v + k = tl.load(k_ptrs, mask=k_mask, other=0.0) + if PRE_LOAD_V: + v = tl.load(v_ptrs, mask=v_mask, other=0.0) + + # setup qk accumlator + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=ACCUMULATOR_TYPE) + + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. + if (n_extra_tokens != 0) and (start_n + BLOCK_N == block_max): + boundary_m = tl.full([BLOCK_M], seqlen_k, dtype=tl.int32) + size_n = start_n + offs_n[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + + # -- compute qk ---- + if IS_FP8: + qk += tl.dot(q, k) * q_descale * k_descale + else: + qk += tl.dot(q, k) + qk_scaled = qk * SM_SCALE + + if USE_ALIBI: + # compute the global position of each token within the sequence + q_offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + alibi_block = compute_alibi_block( + alibi_slope, seqlen_q, seqlen_k, q_offs_m, kv_offs_n + ) + qk_scaled += alibi_block + + if USE_SLIDING_WINDOW: + if IS_CAUSAL: + # ========== CAUSAL SLIDING WINDOW MASKING ========== + # For causal sliding window, we need to apply both constraints: + # 1. Causal: col_idx <= row_idx + (seqlen_k - seqlen_q) + # 2. Sliding window: row_idx - window_left <= col_idx <= row_idx + window_right + + # Get positions + row_idx = offs_m # Query positions + col_idx = kv_offs_n # Key positions + + # Expand for broadcasting + row_idx_expanded = row_idx[:, None] # [BLOCK_M, 1] + col_idx_expanded = col_idx[None, :] # [1, BLOCK_N] + + # Apply causal constraint: can only attend to positions before or at the diagonal + causal_offset = seqlen_k - seqlen_q + causal_mask = col_idx_expanded > (row_idx_expanded + causal_offset) + + # Apply sliding window constraint + if WINDOW_SIZE_LEFT < 0: + # Only right window constraint + window_mask = col_idx_expanded > ( + row_idx_expanded + causal_offset + WINDOW_SIZE_RIGHT + ) + else: + # Both left and right window constraints + # Adjust window bounds by causal offset + left_bound = row_idx_expanded + causal_offset - WINDOW_SIZE_LEFT + right_bound = row_idx_expanded + causal_offset + WINDOW_SIZE_RIGHT + + # Can't attend to positions outside the window + window_mask = (col_idx_expanded < left_bound) | ( + col_idx_expanded > right_bound + ) + + # Final mask is the union of both constraints (True = cannot attend) + mask = causal_mask | window_mask + + # Apply mask + qk_scaled = tl.where(mask, float("-inf"), qk_scaled) + else: + # ========== NON-CAUSAL SLIDING WINDOW MASKING ========== + # Exactly matching reference construct_local_mask: + # row_idx = query positions, col_idx = key positions + # sk = seqlen_k, sq = seqlen_q + + # Get positions + row_idx = offs_m # Query positions + col_idx = kv_offs_n # Key positions + + # sk and sq from reference (no padding masks in this test) + sk = seqlen_k + sq = seqlen_q + + # Expand for broadcasting + row_idx_expanded = row_idx[:, None] # [BLOCK_M, 1] + col_idx_expanded = col_idx[None, :] # [1, BLOCK_N] + + # Reference logic for mask computation + if WINDOW_SIZE_LEFT < 0: + # Reference: return col_idx > row_idx + sk - sq + window_size[1] + mask = col_idx_expanded > ( + row_idx_expanded + sk - sq + WINDOW_SIZE_RIGHT + ) + else: + # Reference: + # sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + # return torch.logical_or( + # col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + # col_idx < row_idx + sk - sq - window_size[0], + # ) + # Create sk tensor with proper shape for broadcasting + # sk represents the key sequence length, which should be compared per column + sk_full = tl.full((1, BLOCK_N), sk, dtype=tl.int32) + + # Compute boundaries + right_bound_val = row_idx_expanded + sk - sq + WINDOW_SIZE_RIGHT + right_bound = tl.minimum(right_bound_val, sk_full) + left_bound = row_idx_expanded + sk - sq - WINDOW_SIZE_LEFT + + # Mask where True = cannot attend (matching reference) + mask = (col_idx_expanded > right_bound) | ( + col_idx_expanded < left_bound + ) + + # Apply mask (set to -inf where mask is True) + qk_scaled = tl.where(mask, float("-inf"), qk_scaled) + else: + if IS_CAUSAL: + causal_boundary = start_n + offs_n - seqlen_delta_qk + causal_mask = offs_m[:, None] >= causal_boundary[None, :] + qk_scaled = tl.where(causal_mask, qk_scaled, float("-inf")) + + # compute qk mask + qk_mask = (offs_m[:, None] < seqlen_q) & (kv_offs_n[None, :] < seqlen_k) + + # compute bias + if bias_base_ptrs is not None: + bias_ptrs = bias_base_ptrs + start_n * stride_bn + bias = tl.load(bias_ptrs, mask=qk_mask, other=0.0) + qk_scaled += bias + + # get max scores so far + m_ij = tl.maximum(m_i, tl.max(qk_scaled, 1)) + + # scale and subtract max + # IMPORTANT: Handle the case where all values are -inf + # When m_ij = -inf and qk_scaled = -inf, subtraction gives NaN + # We need to handle this explicitly + if USE_SLIDING_WINDOW: + # Check if this block has any valid values (m_ij != -inf) + # For rows where everything is -inf, set q_shifted to -inf (not NaN) + q_shifted = tl.where( + m_ij[:, None] == float("-inf"), float("-inf"), qk_scaled - m_ij[:, None] + ) + else: + q_shifted = qk_scaled - m_ij[:, None] + + # Compute scaled QK and softmax probabilities + if USE_EXP2: + p = tl.math.exp2(q_shifted * RCP_LN2) + else: + p = tl.math.exp(q_shifted) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + # Compute pointers for this block + philox_base = philox_offset_base + off_z * stride_sz + off_h_q * stride_sh + philox_ptrs = ( + philox_base + + offs_m[:, None] * stride_sm + + kv_offs_n[None, :] * stride_sn + ) + + # compute dropout mask + rng_output = tl.rand(philox_seed, philox_ptrs) + dropout_mask = rng_output > dropout_p + + # return scores with negative values for dropped vals (only if RETURN_SCORES is True) + if RETURN_SCORES: + sd_mask_value = tl.where(dropout_mask, p, -p) + sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh + sd_mask_ptrs = ( + sd_mask_base + + offs_m[:, None] * stride_sm + + kv_offs_n[None, :] * stride_sn + ) + + # Compute mask for sd_mask storage - include bounds check + sd_store_mask = (offs_m[:, None] < seqlen_q) & ( + kv_offs_n[None, :] < seqlen_k + ) + + # Add causal mask if applicable to prevent writing to invalid positions + if IS_CAUSAL: + seqlen_delta_qk = seqlen_k - seqlen_q + causal_constraint = kv_offs_n[None, :] <= ( + offs_m[:, None] + seqlen_delta_qk + ) + sd_store_mask = sd_store_mask & causal_constraint + + # Add sliding window mask if applicable + if USE_SLIDING_WINDOW: + seqlen_delta_qk = seqlen_k - seqlen_q + if WINDOW_SIZE_LEFT < 0: + # Only right window constraint + window_constraint = kv_offs_n[None, :] <= ( + offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT + ) + else: + # Both left and right window constraints + left_bound = ( + offs_m[:, None] + seqlen_delta_qk - WINDOW_SIZE_LEFT + ) + right_bound = ( + offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT + ) + window_constraint = (kv_offs_n[None, :] >= left_bound) & ( + kv_offs_n[None, :] <= right_bound + ) + sd_store_mask = sd_store_mask & window_constraint + + tl.store(sd_mask_ptrs, sd_mask_value, mask=sd_store_mask) + + # apply dropout mask in place + p = tl.where(dropout_mask, p, 0.0) + elif RETURN_SCORES: + # NOTE: the returned score is not the same as the reference because we need to adjust as we find new maxes per block. We are not doing that + sd_mask_base = sd_mask + off_z * stride_sz + off_h_q * stride_sh + sd_mask_ptrs = ( + sd_mask_base + + offs_m[:, None] * stride_sm + + kv_offs_n[None, :] * stride_sn + ) + + # Compute mask for sd_mask storage - include bounds check + sd_store_mask = (offs_m[:, None] < seqlen_q) & ( + kv_offs_n[None, :] < seqlen_k + ) + + # Add causal mask if applicable + if IS_CAUSAL: + seqlen_delta_qk = seqlen_k - seqlen_q + causal_constraint = kv_offs_n[None, :] <= ( + offs_m[:, None] + seqlen_delta_qk + ) + sd_store_mask = sd_store_mask & causal_constraint + + # Add sliding window mask if applicable + if USE_SLIDING_WINDOW: + seqlen_delta_qk = seqlen_k - seqlen_q + if WINDOW_SIZE_LEFT < 0: + # Only right window constraint + window_constraint = kv_offs_n[None, :] <= ( + offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT + ) + else: + # Both left and right window constraints + left_bound = offs_m[:, None] + seqlen_delta_qk - WINDOW_SIZE_LEFT + right_bound = offs_m[:, None] + seqlen_delta_qk + WINDOW_SIZE_RIGHT + window_constraint = (kv_offs_n[None, :] >= left_bound) & ( + kv_offs_n[None, :] <= right_bound + ) + sd_store_mask = sd_store_mask & window_constraint + + tl.store(sd_mask_ptrs, p, mask=sd_store_mask) + + # -- update output accumulator -- + # alpha is an adjustment factor for acc and li as we loop and find new maxes + # store the diff in maxes to adjust acc and li as we discover new maxes + m_diff = tl.where(m_ij == float("-inf"), float("-inf"), m_i - m_ij) + if USE_EXP2: + alpha = tl.math.exp2(m_diff * RCP_LN2) + else: + alpha = tl.math.exp(m_diff) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = tl.load(v_ptrs, mask=v_mask, other=0.0) + + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + m_i = m_ij + + if IS_FP8: + if FP8_AUTO_DESCALE: + p_scale, p_descale = compute_fp8_scaling_factors(p, FP8_MAX) + acc += ( + tl.dot((p * p_scale).to(v.type.element_ty), v) + * p_descale + * v_descale + ) + else: + acc += tl.dot(p.to(v.type.element_ty), v) * v_descale + else: + acc += tl.dot(p.to(v.type.element_ty), v) + + return acc, l_i, m_i + + +@triton.jit +def compute_window_bounds( + q_start, + q_end, + diag, + seqlen_k, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + IS_CAUSAL: tl.constexpr, +): + """Calculate the window boundaries for a query block.""" + # Left boundary + if WINDOW_SIZE_LEFT < 0: + left_min = 0 + left_max = 0 + else: + left_min = tl.maximum(0, q_start + diag - WINDOW_SIZE_LEFT) + left_max = tl.maximum(0, q_end + diag - WINDOW_SIZE_LEFT) + + # Right boundary + if IS_CAUSAL: + # Causal cap: col ≤ row + diag + right_min = tl.minimum(seqlen_k - 1, q_start + diag) + right_max = tl.minimum(seqlen_k - 1, q_end + diag) + else: + if WINDOW_SIZE_RIGHT < 0: + right_min = tl.minimum(seqlen_k - 1, q_start + diag + WINDOW_SIZE_RIGHT) + right_max = tl.minimum(seqlen_k - 1, q_end + diag + WINDOW_SIZE_RIGHT) + else: + # Non-causal doesn't have the diagonal constraint + right_min = tl.minimum(seqlen_k - 1, q_start + diag + WINDOW_SIZE_RIGHT) + right_max = tl.minimum(seqlen_k - 1, q_end + diag + WINDOW_SIZE_RIGHT) + + return left_min, left_max, right_min, right_max + + +@triton.jit +def classify_window_blocks( + left_min, left_max, right_min, right_max, BLOCK_N: tl.constexpr +): + """Classify blocks based on window boundaries.""" + # First and last blocks that have ANY overlap with window + first_block = left_min // BLOCK_N + last_block = right_max // BLOCK_N + + # First block that is FULLY visible for all rows in Q block + full_left_block = left_max // BLOCK_N + (left_max % BLOCK_N != 0) + clipped_left = tl.minimum(full_left_block, last_block + 1) + + # Last block that is FULLY visible for all rows in Q block + last_full_block_candidate = right_min // BLOCK_N + if (last_full_block_candidate + 1) * BLOCK_N - 1 > right_min: + last_full_block_candidate -= 1 + full_right_block = tl.maximum(last_full_block_candidate, clipped_left - 1) + + # Calculate counts + n_front_skip_blocks = first_block + n_front_masked_blocks = tl.maximum(0, clipped_left - first_block) + n_full_blocks = tl.maximum(0, full_right_block - clipped_left + 1) + n_back_masked_blocks = tl.maximum(0, last_block - full_right_block) + + return ( + n_front_skip_blocks, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + clipped_left, + ) # Return clipped_left for padded block handling + + +@triton.jit +def handle_padded_last_block( + n_extra_tokens, + last_block, + total_k_blocks, + clipped_left, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, +): + """Ensure a padded last K-block is never classified as 'full'. + + We move the padded last block (if visible) into the back-masked bucket. + If it's already back-masked, we do nothing. If it was counted in the + front-masked range, we decrement front-masked; if it was counted as full, + we decrement full. Then we increment back-masked. + """ + padded_last_k = (n_extra_tokens != 0) & (last_block == total_k_blocks - 1) + + if padded_last_k: + # current 'full' range right edge + full_right_block = clipped_left + n_full_blocks - 1 + + # If last_block is already beyond full_right_block, it's already in back-masked → nothing to do + last_already_back_masked = last_block > full_right_block + if not last_already_back_masked: + # If the window starts past last_block, it was counted in front-masked + if clipped_left > last_block: + n_front_masked_blocks = tl.maximum(0, n_front_masked_blocks - 1) + else: + # Otherwise it was counted 'full' → move it out of full + n_full_blocks = tl.maximum(0, n_full_blocks - 1) + # In both cases we need one more back-masked block + n_back_masked_blocks = n_back_masked_blocks + 1 + + return n_front_masked_blocks, n_full_blocks, n_back_masked_blocks + + +@triton.jit +def compute_padding_info(seqlen_k, BLOCK_N: tl.constexpr): + """Calculate padding information for the last K block.""" + # check if we will need to do masking due either BLOCK_N being bigger than seqlen_k or seqlen_k not being a factor of BLOCK_N + # n_extra_tokens = 10 % 4 = 2 + # This means the last K block has 2 valid tokens and 2 padding positions + # K blocks visualization: + # Block 0 Block 1 Block 2 (last) + # K0 K1 K2 K3 K4 K5 K6 K7 K8 K9 ?? ?? + # ↑---------↑ ↑---------↑ ↑---↑ ↑---↑ + # full block full block valid pad + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + else: + n_extra_tokens = 0 + return n_extra_tokens + + +@triton.jit +def compute_block_masking( + seqlen_k, + seqlen_q, + start_m, + IS_CAUSAL: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + Classify K blocks for attention computation with sliding window support. + + Returns: + - n_front_skip_blocks: Blocks completely before the window + - n_front_masked_blocks: Blocks partially overlapping window front + - n_full_blocks: Blocks completely inside the window + - n_back_masked_blocks: Blocks partially overlapping window back + - n_extra_tokens: Padding tokens in last K block + """ + + # common + q_start = start_m * BLOCK_M + q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1) + diag = seqlen_k - seqlen_q + total_k_blocks = tl.cdiv(seqlen_k, BLOCK_N) + n_extra_tokens = compute_padding_info(seqlen_k, BLOCK_N) + + if USE_SLIDING_WINDOW: + # get window bounds + left_min, left_max, right_min, right_max = compute_window_bounds( + q_start, + q_end, + diag, + seqlen_k, + WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT, + IS_CAUSAL, + ) + + # window vanishes → early exit + if right_max < left_min: + return 0, 0, 0, 0, n_extra_tokens + + # classify blocks + ( + n_front_skip_blocks, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + clipped_left, + ) = classify_window_blocks(left_min, left_max, right_min, right_max, BLOCK_N) + + # handle padded last block if needed + if n_extra_tokens != 0: + last_block = right_max // BLOCK_N + n_front_masked_blocks, n_full_blocks, n_back_masked_blocks = ( + handle_padded_last_block( + n_extra_tokens, + last_block, + total_k_blocks, + clipped_left, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + ) + ) + return ( + n_front_skip_blocks, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + n_extra_tokens, + ) + else: + if IS_CAUSAL: + # ========== CAUSAL MODE: Classify K Blocks ========== + # Calculate causal boundary for this Q block + # [K0 K1 K2 K3] [K4 K5 K6 K7] [K8 K9 ?? ??] + # Q0-Q3: [ 1 0 0 0] [ 0 0 0 0] [ 0 0 -- --] ← Q0 + # [ 1 1 0 0] [ 0 0 0 0] [ 0 0 -- --] ← Q1 + # [ 1 1 1 0] [ 0 0 0 0] [ 0 0 -- --] ← Q2 + # [ 1 1 1 1] [ 1 1 0 0] [ 0 0 -- --] ← Q3 + # ↑ can see up to K5 + # + # Q4-Q7: [ 1 1 1 1] [ 1 1 1 0] [ 0 0 -- --] ← Q4 + # [ 1 1 1 1] [ 1 1 1 1] [ 0 0 -- --] ← Q5 + # [ 1 1 1 1] [ 1 1 1 1] [ 1 0 -- --] ← Q6 + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -- --] ← Q7 + + # ------------------------------------------------------------ + # 1. figure out, in tokens, the right-most K position + # this Q-block may attend to + # ------------------------------------------------------------ + k_max_token = q_end + diag # last visible K index + + # this Q-block is entirely above the diagonal ⇒ nothing to do + if k_max_token < 0: + return 0, 0, 0, 0, n_extra_tokens + + k_max_token = tl.minimum(k_max_token, seqlen_k - 1) + + # ------------------------------------------------------------ + # 2. translate token indices into K-block indices + # ------------------------------------------------------------ + last_visible_k_block = k_max_token // BLOCK_N + n_visible_k_blocks = tl.minimum(last_visible_k_block + 1, total_k_blocks) + + # ------------------------------------------------------------ + # 3. classify those visible blocks + # – we *never* skip or mask blocks in front, because causal + # attention always starts at K0 + # – the back side can require several masked blocks: + # • intersection of the causal diagonal with K-grid + # (at most ⌈BLOCK_M / BLOCK_N⌉ blocks) + # • plus one extra block if this Q-block stops in the + # middle of a K-block or the last K-block is padded + # ------------------------------------------------------------ + padded_last_k = n_extra_tokens != 0 + is_modulo_mn = (not padded_last_k) & (seqlen_q % BLOCK_M == 0) + + n_back_masked_blocks = BLOCK_M // BLOCK_N + tl.where(is_modulo_mn, 0, 1) + n_back_masked_blocks = tl.minimum(n_back_masked_blocks, n_visible_k_blocks) + + n_front_skip_blocks = 0 # causal never skips the left side + n_front_masked_blocks = 0 # ditto + n_full_blocks = n_visible_k_blocks - n_back_masked_blocks + else: + # ========== NON-CAUSAL MODE ========== + # Without causal mask, all positions can attend to all positions + # Only need to handle the padding in the last block + # [K0 K1 K2 K3] [K4 K5 K6 K7] [K8 K9 ?? ??] + # Q0-Q3: [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # + # Q4-Q7: [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + # [ 1 1 1 1] [ 1 1 1 1] [ 1 1 -∞ -∞] + + n_front_skip_blocks = 0 # never skips the left side + n_front_masked_blocks = 0 # ditto + if n_extra_tokens != 0: + n_back_masked_blocks = 1 # Last block needs padding mask + n_full_blocks = total_k_blocks - 1 + else: + n_back_masked_blocks = 0 # All blocks are aligned + n_full_blocks = total_k_blocks + + return ( + n_front_skip_blocks, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + n_extra_tokens, + ) + + +@triton.autotune( + configs=fwd_prefill_autotune_configs, + key=fwd_prefill_autotune_keys, + use_cuda_graph=True, +) +@triton.jit +def attn_fwd( + Q, + K, + V, + bias, + Q_Descale, + K_Descale, + V_Descale, + stride_q_descale_z, + stride_k_descale_z, + stride_v_descale_z, + LSE, + Out, + SD_MASK, + ALIBI_SLOPES, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + stride_az, + stride_ah, + stride_sz, + stride_sh, + stride_sm, + stride_sn, + stride_lse_z, + stride_lse_h, + stride_lse_m, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Add seqused parameters + dropout_p, + philox_seed, + philox_offset_base, + HQ: tl.constexpr, + HK: tl.constexpr, + ACTUAL_BLOCK_DMODEL_QK: tl.constexpr, + ACTUAL_BLOCK_DMODEL_V: tl.constexpr, + MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, + IS_VARLEN: tl.constexpr, + SM_SCALE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL_QK: tl.constexpr, + BLOCK_DMODEL_V: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + USE_BIAS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_SCORES: tl.constexpr, + USE_ALIBI: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_FP8: tl.constexpr, + FP8_MAX: tl.constexpr, + FP8_AUTO_DESCALE: tl.constexpr, + USE_SEQUSED: tl.constexpr, +): + # set params + ACCUMULATOR_TYPE = tl.float32 + + # compute offsets + off_z = tl.program_id(0) + off_h_q = tl.program_id(1) + start_m = tl.program_id(2) + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + if GROUP_SIZE != 1: + off_h_k = off_h_q // GROUP_SIZE + else: + off_h_k = off_h_q + # Determine if we need to mask the heads + PADDED_HEAD_QK: tl.constexpr = ACTUAL_BLOCK_DMODEL_QK != BLOCK_DMODEL_QK + PADDED_HEAD_V: tl.constexpr = ACTUAL_BLOCK_DMODEL_V != BLOCK_DMODEL_V + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d_qk = tl.arange(0, BLOCK_DMODEL_QK) + offs_d_v = tl.arange(0, BLOCK_DMODEL_V) + + # handle seqlen + if IS_VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + + # If seqused is provided, use it to limit the actual sequence length + if USE_SEQUSED: + actual_seqlen_q = ( + tl.load(seqused_q + off_z) + if seqused_q is not None + else cu_seqlens_q_end - cu_seqlens_q_start + ) + seqlen_q = tl.minimum( + actual_seqlen_q, cu_seqlens_q_end - cu_seqlens_q_start + ) + else: + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + + # we have a one-size-fits-all grid in id(0). Some seqlens might be too small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + + # If seqused is provided, use it to limit the actual sequence length for keys + if USE_SEQUSED: + actual_seqlen_k = ( + tl.load(seqused_k + off_z) + if seqused_k is not None + else cu_seqlens_k_end - cu_seqlens_k_start + ) + seqlen_k = tl.minimum( + actual_seqlen_k, cu_seqlens_k_end - cu_seqlens_k_start + ) + else: + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Load scale factors if IS_FP8. + if IS_FP8: + # For MQA/GQA (GROUP_SIZE != 1), q_descale uses the same indexing as k/v (off_h_k) + # For MHA (GROUP_SIZE == 1), q_descale uses off_h_q (same as off_h_k) + if GROUP_SIZE != 1: + q_descale = tl.load( + Q_Descale + off_z * stride_q_descale_z + off_h_k + ) # MQA/GQA: broadcast using k/v head index + else: + q_descale = tl.load( + Q_Descale + off_z * stride_q_descale_z + off_h_q + ) # MHA: use q head index + k_descale = tl.load(K_Descale + off_z * stride_k_descale_z + off_h_k) + v_descale = tl.load(V_Descale + off_z * stride_v_descale_z + off_h_k) + else: + q_descale, k_descale, v_descale = 1.0, 1.0, 1.0 + + # figure out masking pattern + ( + n_front_skip_blocks, + n_front_masked_blocks, + n_full_blocks, + n_back_masked_blocks, + n_extra_tokens, + ) = compute_block_masking( + seqlen_k, + seqlen_q, + start_m, + IS_CAUSAL, + USE_SLIDING_WINDOW, + WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT, + BLOCK_M, + BLOCK_N, + ) + + # ============================================================ + # PROGRAM EARLY EXIT (All K Blocks Skipped) + # ============================================================ + total_visible_blocks = n_front_masked_blocks + n_full_blocks + n_back_masked_blocks + if total_visible_blocks == 0: + """ + No K blocks visible - write zeros and exit. + """ + # Write zeros to output + o_offset = ( + Out + + off_z * stride_oz + + off_h_q * stride_oh + + cu_seqlens_q_start * stride_om + ) + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d_v[None, :] * stride_on + o_mask = offs_m[:, None] < seqlen_q + if PADDED_HEAD_V: + o_mask = o_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) + tl.store( + o_ptrs, + tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=Out.type.element_ty), + mask=o_mask, + ) + + # Write zeros to LSE + l_ptrs = ( + LSE + + off_z * stride_lse_z + + off_h_q * stride_lse_h + + cu_seqlens_q_start * stride_lse_m + + offs_m * stride_lse_m + ) + tl.store(l_ptrs, tl.zeros([BLOCK_M], dtype=tl.float32), mask=offs_m < seqlen_q) + return + + # ============================================================ + # NORMAL PROCESSING (Some K Blocks Visible) + # ============================================================ + """ + This program has visible K blocks to process. + We'll use two calls to handle different block types efficiently. + """ + + # Initialize for processing + # Compute pointers for all the tensors used in this kernel. + q_offset = ( + Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + ) + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d_qk[None, :] * stride_qk + k_offset = ( + K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + ) + k_ptrs = k_offset + offs_d_qk[:, None] * stride_kk + offs_n[None, :] * stride_kn + v_offset = ( + V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + ) + v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d_v[None, :] * stride_vn + if USE_BIAS: + # Note: this might get large enough to overflow on some configs + bias_offset = off_h_q * stride_bh + bias_ptrs = ( + bias + + bias_offset + + offs_m[:, None] * stride_bm + + offs_n[None, :] * stride_bn + ) + else: + bias_ptrs = None + + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(ALIBI_SLOPES + a_offset) + else: + alibi_slope = None + + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=ACCUMULATOR_TYPE) + l_i = tl.full([BLOCK_M], 1.0, dtype=ACCUMULATOR_TYPE) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_V], dtype=ACCUMULATOR_TYPE) + + # Q is loaded once at the beginning and shared by all N blocks. + q_ptrs_mask = offs_m[:, None] < seqlen_q + if PADDED_HEAD_QK: + q_ptrs_mask = q_ptrs_mask & (offs_d_qk[None, :] < ACTUAL_BLOCK_DMODEL_QK) + q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + + # ========== Process MASKED K Blocks in the front ========== + # NOTE: we use USE_SLIDING_WINDOW as guard because the compiler will crash other wise. front masking is only for sliding window so that is fine. + if n_front_masked_blocks > 0 and USE_SLIDING_WINDOW: + block_min = n_front_skip_blocks * BLOCK_N + block_max = (n_front_skip_blocks + n_front_masked_blocks) * BLOCK_N + + acc, l_i, m_i = _attn_fwd_mask( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + stride_sn, + stride_sm, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + philox_offset_base, + SD_MASK, + stride_sz, + stride_sh, + off_z, + off_h_q, + offs_m, + offs_n, + offs_d_qk, + offs_d_v, + block_min, # Start of front masked blocks + block_max, # End of front masked blocks + 0, # n_extra_tokens (0 for front blocks, only relevant for last block) + alibi_slope, + q_descale, + k_descale, + v_descale, + IS_FP8, + FP8_MAX, + FP8_AUTO_DESCALE, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL_QK, + BLOCK_DMODEL_V, + BLOCK_N, + PRE_LOAD_V, + ENABLE_DROPOUT, + PADDED_HEAD_QK, + PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK, + ACTUAL_BLOCK_DMODEL_V, + SM_SCALE, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + RETURN_SCORES=RETURN_SCORES, + USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, + WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT, + ACCUMULATOR_TYPE=ACCUMULATOR_TYPE, + ) + + # ========== Process FULL K Blocks (Fast Path) ========== + if n_full_blocks > 0: + block_min = (n_front_skip_blocks + n_front_masked_blocks) * BLOCK_N + block_max = ( + n_front_skip_blocks + n_front_masked_blocks + n_full_blocks + ) * BLOCK_N + + acc, l_i, m_i = _attn_fwd_no_mask( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + stride_sn, + stride_sm, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + philox_offset_base, + SD_MASK, + stride_sz, + stride_sh, + off_z, + off_h_q, + offs_m, + offs_n, + offs_d_qk, + offs_d_v, + block_min, # Start of range: 0 + block_max, # End of range: n_full_blocks * BLOCK_N + alibi_slope, + q_descale, + k_descale, + v_descale, + IS_FP8, + FP8_MAX, + FP8_AUTO_DESCALE, + BLOCK_M, + BLOCK_DMODEL_QK, + BLOCK_DMODEL_V, + BLOCK_N, + PRE_LOAD_V, + ENABLE_DROPOUT, + PADDED_HEAD_QK, + PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK, + ACTUAL_BLOCK_DMODEL_V, + SM_SCALE, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + RETURN_SCORES=RETURN_SCORES, + ACCUMULATOR_TYPE=ACCUMULATOR_TYPE, + ) + + # ========== Process MASKED K Blocks in the back ========== + if n_back_masked_blocks > 0: + block_min = ( + n_front_skip_blocks + n_front_masked_blocks + n_full_blocks + ) * BLOCK_N + block_max = ( + n_front_skip_blocks + + n_front_masked_blocks + + n_full_blocks + + n_back_masked_blocks + ) * BLOCK_N + + acc, l_i, m_i = _attn_fwd_mask( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + stride_sn, + stride_sm, + start_m, + seqlen_k, + seqlen_q, + dropout_p, + philox_seed, + philox_offset_base, + SD_MASK, + stride_sz, + stride_sh, + off_z, + off_h_q, + offs_m, + offs_n, + offs_d_qk, + offs_d_v, + block_min, # Start of range: n_full_blocks * BLOCK_N + block_max, # End of range: n_visible_k_blocks * BLOCK_N + n_extra_tokens, # Padding tokens in last block + alibi_slope, + q_descale, + k_descale, + v_descale, + IS_FP8, + FP8_MAX, + FP8_AUTO_DESCALE, + IS_CAUSAL, # Use actual causal flag + BLOCK_M, + BLOCK_DMODEL_QK, + BLOCK_DMODEL_V, + BLOCK_N, + PRE_LOAD_V, + ENABLE_DROPOUT, + PADDED_HEAD_QK, + PADDED_HEAD_V, + ACTUAL_BLOCK_DMODEL_QK, + ACTUAL_BLOCK_DMODEL_V, + SM_SCALE, + USE_ALIBI=USE_ALIBI, + USE_EXP2=USE_EXP2, + RETURN_SCORES=RETURN_SCORES, + USE_SLIDING_WINDOW=USE_SLIDING_WINDOW, + WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT, + ACCUMULATOR_TYPE=ACCUMULATOR_TYPE, + ) + + # ============================================================ + # EPILOGUE + # ============================================================ + # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. + # Instead of directly computing 1/l_i which can be inf, + # we check for the invalid case first + if USE_SLIDING_WINDOW: + # For rows where m_i is still -inf, no keys were valid + # Set l_i to 1.0 to avoid division by zero (acc is already 0) + invalid_mask = m_i == float("-inf") + l_i_safe = tl.where(invalid_mask, 1.0, l_i) + l_recip = 1 / l_i_safe[:, None] + else: + invalid_mask = None + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + if ENABLE_DROPOUT: + dropout_scale = 1 / (1 - dropout_p) + acc = acc * dropout_scale + + # compute log-sum-exp + if USE_EXP2: + RCP_LN2: tl.constexpr = 1.4426950408889634 + LN2: tl.constexpr = 0.6931471824645996 + # compute log-sum-exp in base 2 units + mi_base2 = m_i * RCP_LN2 + # For invalid rows, log(l_i) would be -inf, but we want LSE to be -inf + # So we handle this case explicitly + if USE_SLIDING_WINDOW: + log_l_i = tl.where(invalid_mask, 0.0, tl.math.log2(l_i)) + softmax_lse = mi_base2 + log_l_i + # Ensure invalid rows have LSE = -inf + softmax_lse = tl.where(invalid_mask, float("-inf"), softmax_lse) + else: + softmax_lse = mi_base2 + tl.math.log2(l_i) + # convert back to natural units + softmax_lse *= LN2 + else: + if USE_SLIDING_WINDOW: + log_l_i = tl.where(invalid_mask, 0.0, tl.math.log(l_i)) + softmax_lse = m_i + log_l_i + softmax_lse = tl.where(invalid_mask, float("-inf"), softmax_lse) + else: + softmax_lse = m_i + tl.math.log(l_i) + + # handle masking edge cases + if USE_SLIDING_WINDOW: + if IS_CAUSAL: + pass + else: + pass + else: + if IS_CAUSAL: + # When seqlen_q > seqlen_k, some rows are completely above the causal diagonal + # These rows have all -inf attention scores, resulting in NaN after softmax + # e.g. + # Q length: 6, K length: 4 + # Causal mask (X = can attend, . = cannot): + # K0 K1 K2 K3 + # Q0 . . . . <- All masked, would give NaN + # Q1 . . . . <- All masked, would give NaN + # Q2 X . . . <- First valid row + # Q3 X X . . + # Q4 X X X . + # Q5 X X X X + causal_start_idx = seqlen_q - seqlen_k + start_m_idx = start_m * BLOCK_M + + # Create mask for rows that need zeroing + row_indices = start_m_idx + tl.arange(0, BLOCK_M) + causal_mask = row_indices < causal_start_idx + + # Zero out both acc and LSE for these rows + if causal_start_idx > start_m_idx: + end_m_idx = (start_m + 1) * BLOCK_M + if causal_start_idx < end_m_idx: + # This block contains the boundary - need to mask acc + out_mask_boundary = tl.full( + (BLOCK_DMODEL_V,), causal_start_idx, dtype=tl.int32 + ) + out_ptrs_mask = row_indices[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + + # Zero out LSE for rows above diagonal + softmax_lse = tl.where(causal_mask, 0.0, softmax_lse) + + # write back LSE(Log Sum Exponents), the log of the normalization constant + l_offset = ( + LSE + + off_z * stride_lse_z + + off_h_q * stride_lse_h + + cu_seqlens_q_start * stride_lse_m + ) + l_ptrs = l_offset + offs_m * stride_lse_m + + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last Q block. For others, overflow_size will be -ve + end_m_idx = (start_m + 1) * BLOCK_M + overflow_size = end_m_idx - seqlen_q + if overflow_size > 0: + boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary + tl.store(l_ptrs, softmax_lse, mask=l_ptrs_mask) + else: + tl.store(l_ptrs, softmax_lse) + + # write back O + o_offset = ( + Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + ) + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d_v[None, :] * stride_on + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL_V], 1, dtype=tl.int1) + if overflow_size > 0: + o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) + if PADDED_HEAD_V: + o_ptrs_mask = o_ptrs_mask & (offs_d_v[None, :] < ACTUAL_BLOCK_DMODEL_V) + + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) + + +def attention_forward_prefill_triton_impl( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + softmax_lse: torch.Tensor, + sd_mask: Optional[torch.Tensor], + sm_scale: float, + alibi_slopes: Optional[torch.Tensor], + causal: bool, + window_size_left: int, + window_size_right: int, + bias: Optional[torch.Tensor], + layout: Literal["bshd", "bhsd", "thd"], + # varlen + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + max_seqlens_q: int, + max_seqlens_k: int, + # dropout + dropout_p: float, + philox_seed: Optional[int], + philox_offset: Optional[int], + # misc + return_scores: bool, + use_exp2: bool, + # fp8 + q_descale: Optional[torch.Tensor], + k_descale: Optional[torch.Tensor], + v_descale: Optional[torch.Tensor], + # seqused for FA v3 + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + # rotary (optional) + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + rotary_interleaved: bool = False, + seqlens_rotary: Optional[torch.Tensor] = None, +): + # get params, strides and shape + IS_VARLEN = layout == "thd" + + # common assertions + assert ( + 0.0 <= dropout_p <= 1.0 + ), f"dropout_p must be between 0 and 1, got {dropout_p}" + assert ( + q.device == k.device == v.device == o.device + ), f"All tensors must be on the same device. Got: q={q.device}, k={k.device}, v={v.device}, o={o.device}" + assert q.dtype == k.dtype == v.dtype, "q, k, v must have the same dtype" + current_device = torch.cuda.current_device() + assert ( + q.is_cuda and q.device.index == current_device + ), f"Device mismatch: Kernel will launch on cuda:{current_device}, but tensors are on {q.device}" + + # get shapes and strides + if IS_VARLEN: + # shape + total_seqlen_q, nheads_q, head_size_q = q.shape + total_seqlen_k, nheads_k, head_size_k = k.shape + total_seqlen_v, nheads_v, head_size_v = v.shape + + # assert shapes + assert ( + cu_seqlens_q is not None + ), "cu_seqlens_q must be provided for varlen layout" + assert ( + cu_seqlens_k is not None + ), "cu_seqlens_k must be provided for varlen layout" + assert ( + max_seqlens_q is not None and max_seqlens_q > 0 + ), "max_seqlens_q must be provided and positive for varlen layout" + assert ( + max_seqlens_k is not None and max_seqlens_k > 0 + ), "max_seqlens_k must be provided and positive for varlen layout" + + # assert head dimensions + assert ( + head_size_q == head_size_k + ), f"head sizes must match: q={head_size_q}, k={head_size_k}" + assert ( + nheads_k == nheads_v + ), f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert ( + nheads_q % nheads_k == 0 + ), f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + + # assert output shapes + assert o.shape == ( + total_seqlen_q, + nheads_q, + head_size_v, + ), f"o shape {o.shape} != expected {(total_seqlen_q, nheads_q, head_size_v)}" + + # assert cu_seqlens + assert ( + cu_seqlens_q.dtype == torch.int32 + ), f"cu_seqlens_q must be int32, got {cu_seqlens_q.dtype}" + assert ( + cu_seqlens_k.dtype == torch.int32 + ), f"cu_seqlens_k must be int32, got {cu_seqlens_k.dtype}" + assert cu_seqlens_q[0] == 0, "cu_seqlens_q must start with 0" + assert cu_seqlens_k[0] == 0, "cu_seqlens_k must start with 0" + assert ( + cu_seqlens_q[-1] == total_seqlen_q + ), f"cu_seqlens_q[-1] {cu_seqlens_q[-1]} != total_seqlen_q {total_seqlen_q}" + assert ( + cu_seqlens_k[-1] == total_seqlen_k + ), f"cu_seqlens_k[-1] {cu_seqlens_k[-1]} != total_seqlen_k {total_seqlen_k}" + + # set vars + batch = len(cu_seqlens_q) - 1 + head_size_qk = head_size_q + + # Assert softmax_lse tensor is large enough + assert ( + softmax_lse.shape[0] >= nheads_q + ), f"softmax_lse.shape[0]={softmax_lse.shape[0]} must be >= nheads_q={nheads_q}" + assert ( + softmax_lse.shape[1] >= total_seqlen_q + ), f"softmax_lse.shape[1]={softmax_lse.shape[1]} must be >= total_seqlen_q={total_seqlen_q}" + assert ( + softmax_lse.dtype == torch.float32 + ), f"softmax_lse must be float32, got {softmax_lse.dtype}" + assert ( + softmax_lse.device == q.device + ), f"softmax_lse must be on same device as q" + + # strides + stride_qb, stride_qh, stride_qm, stride_qd = ( + 0, + q.stride(1), + q.stride(0), + q.stride(2), + ) + stride_kb, stride_kh, stride_kn, stride_kd = ( + 0, + k.stride(1), + k.stride(0), + k.stride(2), + ) + stride_vb, stride_vh, stride_vn, stride_vd = ( + 0, + v.stride(1), + v.stride(0), + v.stride(2), + ) + stride_ob, stride_oh, stride_om, stride_od = ( + 0, + o.stride(1), + o.stride(0), + o.stride(2), + ) + stride_lse_z, stride_lse_h, stride_lse_m = ( + 0, + softmax_lse.stride(0), + softmax_lse.stride(1), + ) + else: + # shapes + batch_q, seqlen_q, nheads_q, head_size_q = q.shape + batch_k, seqlen_k, nheads_k, head_size_k = k.shape + batch_v, seqlen_v, nheads_v, head_size_v = v.shape + + # assert batch dimensions + assert ( + batch_q == batch_k == batch_v + ), f"batch sizes must match: q={batch_q}, k={batch_k}, v={batch_v}" + + # assert head dimensions + assert ( + head_size_q == head_size_k + ), f"head sizes must match: q={head_size_q}, k={head_size_k}" + assert ( + nheads_k == nheads_v + ), f"k and v must have same number of heads: k={nheads_k}, v={nheads_v}" + assert ( + nheads_q % nheads_k == 0 + ), f"nheads_q {nheads_q} must be divisible by nheads_k {nheads_k} for GQA/MQA" + + # assert sequence lengths + assert ( + seqlen_k == seqlen_v + ), f"k and v sequence lengths must match: k={seqlen_k}, v={seqlen_v}" + + # assert output shapes + assert o.shape == ( + batch_q, + seqlen_q, + nheads_q, + head_size_v, + ), f"o shape {o.shape} != expected {(batch_q, seqlen_q, nheads_q, head_size_v)}" + + # set vars + batch = batch_q + head_size_qk = head_size_q + max_seqlens_q = seqlen_q + max_seqlens_k = seqlen_k + + # Assert softmax_lse tensor is large enough + assert ( + softmax_lse.shape[0] >= batch + ), f"softmax_lse.shape[0]={softmax_lse.shape[0]} must be >= batch={batch}" + assert ( + softmax_lse.shape[1] >= nheads_q + ), f"softmax_lse.shape[1]={softmax_lse.shape[1]} must be >= nheads_q={nheads_q}" + assert ( + softmax_lse.shape[2] >= seqlen_q + ), f"softmax_lse.shape[2]={softmax_lse.shape[2]} must be >= seqlen_q={seqlen_q}" + assert ( + softmax_lse.dtype == torch.float32 + ), f"softmax_lse must be float32, got {softmax_lse.dtype}" + assert ( + softmax_lse.device == q.device + ), f"softmax_lse must be on same device as q" + + # strides + stride_qb, stride_qh, stride_qm, stride_qd = ( + q.stride(0), + q.stride(2), + q.stride(1), + q.stride(3), + ) + stride_kb, stride_kh, stride_kn, stride_kd = ( + k.stride(0), + k.stride(2), + k.stride(1), + k.stride(3), + ) + stride_vb, stride_vh, stride_vn, stride_vd = ( + v.stride(0), + v.stride(2), + v.stride(1), + v.stride(3), + ) + stride_ob, stride_oh, stride_om, stride_od = ( + o.stride(0), + o.stride(2), + o.stride(1), + o.stride(3), + ) + stride_lse_z, stride_lse_h, stride_lse_m = softmax_lse.stride() + + # apply rotary embeddings + if rotary_cos is not None and rotary_sin is not None: + if IS_VARLEN: + raise NotImplementedError( + "Rotary embeddings with varlen (thd layout) prefill are not implemented yet." + ) + seqlen_offsets = seqlens_rotary if seqlens_rotary is not None else 0 + local = (window_size_left != -1) or (window_size_right != -1) + q, _ = apply_rotary( + q, + None, + rotary_cos, + rotary_sin, + causal=causal, + local=local, + interleaved=rotary_interleaved, + seqlen_offsets=seqlen_offsets, + ) + + # fp8 setup and assertions + IS_FP8 = is_fp8([q, k, v]) + if IS_FP8: + FP8_MAX = torch.finfo(q.dtype).max + rec_dtype = get_recommended_fp8_dtype(q) + if q.dtype != rec_dtype or k.dtype != rec_dtype or v.dtype != rec_dtype: + arch = get_arch() + warnings.warn( + f"Use {rec_dtype} data type on {arch}. Got q: {q.dtype}, k: {k.dtype}, v: {v.dtype}", + UserWarning, + ) + + if (q_descale is None) or (k_descale is None) or (v_descale is None): + warnings.warn( + "FP8 tensors detected but descale factors not provided. Using default scale of 1.0", + UserWarning, + ) + # Create default descale tensors if not provided + if q_descale is None: + q_descale = torch.ones( + batch, nheads_q, dtype=torch.float32, device=q.device + ) + if k_descale is None: + k_descale = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) + if v_descale is None: + v_descale = torch.ones( + batch, nheads_k, dtype=torch.float32, device=q.device + ) + else: + # Enforce exact expected shapes; no reshaping or normalization. + assert ( + q_descale.dim() == 2 + and q_descale.shape[0] == batch + and q_descale.shape[1] == nheads_k + ), f"q_descale expected shape ({batch}, {nheads_k}) got {tuple(q_descale.shape)}" + assert ( + k_descale.dim() == 2 + and k_descale.shape[0] == batch + and k_descale.shape[1] == nheads_k + ), f"k_descale expected shape ({batch}, {nheads_k}) got {tuple(k_descale.shape)}" + assert ( + v_descale.dim() == 2 + and v_descale.shape[0] == batch + and v_descale.shape[1] == nheads_k + ), f"v_descale expected shape ({batch}, {nheads_k}) got {tuple(v_descale.shape)}" + + # o should be fp32 or fp16/bf16 + assert o.dtype in [ + torch.float16, + torch.bfloat16, + torch.float32, + ], f"Output tensor o must be fp16, bf16, or fp32 when using fp8, got {o.dtype}" + + stride_q_descale_z = q_descale.stride(0) if q_descale is not None else 0 + stride_k_descale_z = k_descale.stride(0) if k_descale is not None else 0 + stride_v_descale_z = v_descale.stride(0) if v_descale is not None else 0 + + if DEBUG: + print(f"FP8 path triggered in fwd_prefill.py") + else: + FP8_MAX = None + q_descale = k_descale = v_descale = None + stride_q_descale_z = stride_k_descale_z = stride_v_descale_z = None + + # check output dtype matches input dtype when not using fp8 + assert ( + o.dtype == q.dtype + ), f"Output dtype {o.dtype} must match input dtype {q.dtype} when not using fp8" + + # check features + use_sliding_window = window_size_left != -1 or window_size_right != -1 + use_alibi, (stride_az, stride_ah) = ( + (True, alibi_slopes.stride()) if alibi_slopes is not None else (False, (0, 0)) + ) + # NOTE: a large bias tensor leads to overflow during pointer arithmetic + if bias is not None: + assert bias.numel() < 2**31 + + # Get closest power of 2 over or equal to 32 for both QK and V dimensions + padded_d_model_qk = 1 << (head_size_qk - 1).bit_length() + padded_d_model_v = 1 << (head_size_v - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. + padded_d_model_qk = max(padded_d_model_qk, 16) + padded_d_model_v = max(padded_d_model_v, 16) + + # sd_mask assertions and strides + if sd_mask is not None: + assert dropout_p > 0.0 or return_scores, "sd_mask provided but not used" + assert ( + sd_mask is not None + ), "sd_mask must be provided when return_scores=True or dropout_p > 0" + # Assert sd_mask tensor is large enough + assert ( + sd_mask.shape[0] >= batch + ), f"sd_mask.shape[0]={sd_mask.shape[0]} must be >= batch={batch}" + assert ( + sd_mask.shape[1] >= nheads_q + ), f"sd_mask.shape[1]={sd_mask.shape[1]} must be >= nheads_q={nheads_q}" + assert ( + sd_mask.shape[2] >= max_seqlens_q + ), f"sd_mask.shape[2]={sd_mask.shape[2]} must be >= max_seqlens_q={max_seqlens_q}" + assert ( + sd_mask.shape[3] >= max_seqlens_k + ), f"sd_mask.shape[3]={sd_mask.shape[3]} must be >= max_seqlens_k={max_seqlens_k}" + assert sd_mask.device == q.device, f"sd_mask must be on same device as q" + + stride_sz, stride_sh, stride_sm, stride_sn = ( + sd_mask.stride(0), + sd_mask.stride(1), + sd_mask.stride(2), + sd_mask.stride(3), + ) + else: + stride_sz, stride_sh, stride_sm, stride_sn = (0, 0, 0, 0) + + if bias is not None: + stride_bz, stride_bh, stride_bm, stride_bn = ( + bias.stride(0), + bias.stride(1), + bias.stride(2), + bias.stride(3), + ) + else: + stride_bz, stride_bh, stride_bm, stride_bn = (0, 0, 0, 0) + + # launch kernel + grid = lambda META: (batch, nheads_q, triton.cdiv(max_seqlens_q, META["BLOCK_M"])) + attn_fwd[grid]( + q, + k, + v, + bias, + q_descale, + k_descale, + v_descale, + stride_q_descale_z, + stride_k_descale_z, + stride_v_descale_z, + softmax_lse, + o, + sd_mask, + alibi_slopes, + stride_qb, + stride_qh, + stride_qm, + stride_qd, + stride_kb, + stride_kh, + stride_kn, + stride_kd, + stride_vb, + stride_vh, + stride_vn, + stride_vd, + stride_ob, + stride_oh, + stride_om, + stride_od, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + stride_az, + stride_ah, + stride_sz, + stride_sh, + stride_sm, + stride_sn, + stride_lse_z, + stride_lse_h, + stride_lse_m, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, # Pass seqused tensors + dropout_p=dropout_p, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + HQ=nheads_q, + HK=nheads_k, + ACTUAL_BLOCK_DMODEL_QK=head_size_qk, + ACTUAL_BLOCK_DMODEL_V=head_size_v, + MAX_SEQLENS_Q=max_seqlens_q, + MAX_SEQLENS_K=max_seqlens_k, + SM_SCALE=sm_scale, + IS_CAUSAL=causal, + USE_SLIDING_WINDOW=use_sliding_window, + WINDOW_SIZE_LEFT=window_size_left, + WINDOW_SIZE_RIGHT=window_size_right, + IS_VARLEN=IS_VARLEN, + BLOCK_DMODEL_QK=padded_d_model_qk, + BLOCK_DMODEL_V=padded_d_model_v, + USE_BIAS=False if bias is None else True, + USE_ALIBI=use_alibi, + ENABLE_DROPOUT=dropout_p > 0.0, + USE_EXP2=use_exp2, + RETURN_SCORES=return_scores, + IS_FP8=IS_FP8, + FP8_MAX=FP8_MAX, + FP8_AUTO_DESCALE=FP8_AUTO_DESCALE, + USE_SEQUSED=(seqused_q is not None or seqused_k is not None), + ) diff --git a/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/interface_v2.py b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/interface_v2.py new file mode 100644 index 0000000000..5c83fc42c8 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/interface_v2.py @@ -0,0 +1,817 @@ +import torch +import os +from typing import Optional, Union +from .fwd_prefill import attention_forward_prefill_triton_impl +from .fwd_decode import attention_forward_decode_triton_impl +from .bwd import attention_backward_triton_impl +from .utils import ( + DEBUG, + USE_EXP2, + BWD_MODE, + PHILOX_SEED, + PHILOX_OFFSET, + SHAPE_EXPECTATIONS, + round_multiple, +) + + +def fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, +): + + # Reject FP8 tensors (FA2 AMD path does not support FP8) + if str(q.dtype).startswith("torch.float8"): + raise NotImplementedError( + "FP8 tensors are not supported in the AMD Triton FA2 interface. Use the FA3 path instead." + ) + + # Unsupported features assertions (keep behavior explicit like v3 shim) + if softcap != 0.0: + raise NotImplementedError( + "softcap is not supported in the AMD Triton FA2 interface (expected 0.0)." + ) + + if DEBUG: + print() + print("flash_attn_triton_amd.py::fwd inputs") + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("out:", out.shape if out is not None else None) + print("alibi_slopes:", alibi_slopes) + print("dropout_p:", dropout_p) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("softcap:", softcap) + print("return_softmax:", return_softmax) + + if out is None: + out = torch.zeros_like(q) + else: + out.zero_() + + # Layout / shapes + layout = "bshd" + max_seqlen_q = q.shape[1] + max_seqlen_k = k.shape[1] + batch, _, nheads_q, _ = q.shape + + # Normalize / validate alibi + if alibi_slopes is not None: + if alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + assert alibi_slopes.is_cuda and alibi_slopes.dim() == 2 + assert alibi_slopes.shape == (batch, nheads_q) + + # Dropout + RNG seed + philox_seed, philox_offset = PHILOX_SEED, PHILOX_OFFSET + rng_state = torch.as_tensor([philox_seed, philox_offset]) + + # argument checks + assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4 + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert q.dtype == k.dtype == v.dtype + assert out.shape[:-1] == q.shape[:-1] and out.shape[-1] == v.shape[-1] + nheads_k = k.shape[2] + assert (nheads_q % nheads_k) == 0 + + # Create output tensors based on shape expectations + if SHAPE_EXPECTATIONS == "rounded": + softmax_lse = torch.zeros( + (batch, nheads_q, round_multiple(max_seqlen_q, 128)), + device=q.device, + dtype=torch.float32, + ) + if dropout_p > 0.0 or return_softmax: + sd_mask = torch.zeros( + ( + batch, + nheads_q, + round_multiple(max_seqlen_q, 128), + round_multiple(max_seqlen_k, 128), + ), + device=q.device, + dtype=torch.float32, + ) + else: + sd_mask = None + else: + softmax_lse = torch.zeros( + (batch, nheads_q, max_seqlen_q), + device=q.device, + dtype=torch.float32, + ) + if dropout_p > 0.0 or return_softmax: + sd_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=torch.float32, + ) + else: + sd_mask = None + + # call implementation + if DEBUG: + print("Using Triton implementation") + attention_forward_prefill_triton_impl( + q, + k, + v, + out, + softmax_lse, + sd_mask, + softmax_scale, + alibi_slopes, + causal, + window_size_left, + window_size_right, + None, + layout, + None, + None, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + return_softmax, + USE_EXP2, + None, + None, + None, + None, + None, + None, + None, + ) + + if DEBUG: + print("flash_attn_triton_amd.py::fwd outputs") + print("o:", out.shape if out is not None else None) + print("softmax_lse:", softmax_lse.shape if softmax_lse is not None else None) + print("sd_mask:", sd_mask.shape if sd_mask is not None else None) + print("rng_state:", rng_state) + + # --- Assertions (shape + dtype contracts) --- + # out: (B, Sq, Hq, D) + assert out.shape == q.shape, f"[fwd] out shape {out.shape} != q shape {q.shape}" + # softmax_lse dtype + assert ( + softmax_lse.dtype == torch.float32 + ), f"[fwd] softmax_lse dtype {softmax_lse.dtype} != torch.float32" + # softmax_lse shape depends on SHAPE_EXPECTATIONS + if SHAPE_EXPECTATIONS == "rounded": + expected_lse_shape = (q.shape[0], q.shape[2], round_multiple(q.shape[1], 128)) + else: + expected_lse_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert ( + softmax_lse.shape == expected_lse_shape + ), f"[fwd] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" + if return_softmax: + # sd_mask: (B, Hq, Sq, Sk) + assert sd_mask is not None, "[fwd] return_softmax=True but sd_mask is None" + assert sd_mask.dim() == 4, f"[fwd] sd_mask dim {sd_mask.dim()} != 4" + if SHAPE_EXPECTATIONS == "rounded": + expected_sq = round_multiple(q.shape[1], 128) + expected_sk = round_multiple(k.shape[1], 128) + assert ( + sd_mask.shape[0] == q.shape[0] + and sd_mask.shape[1] == q.shape[2] + and sd_mask.shape[2] == expected_sq + and sd_mask.shape[3] == expected_sk + ), f"[fwd] sd_mask shape {sd_mask.shape} != (B={q.shape[0]}, Hq={q.shape[2]}, Sq={expected_sq}, Sk={expected_sk})" + else: + assert ( + sd_mask.shape[0] == q.shape[0] + and sd_mask.shape[1] == q.shape[2] + and sd_mask.shape[2] == q.shape[1] + ), f"[fwd] sd_mask leading dims {sd_mask.shape[:3]} mismatch (B,Hq,Sq) {(q.shape[0], q.shape[2], q.shape[1])}" + else: + assert sd_mask is None, "[fwd] return_softmax=False but sd_mask is not None" + + return out, softmax_lse, sd_mask, rng_state + + +def bwd( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + dropout_p: float, + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + deterministic: bool, + gen_: Optional[torch.Tensor] = None, + rng_state: Optional[torch.Tensor] = None, +): + if softcap != 0.0: + raise NotImplementedError( + "softcap is not supported in the AMD Triton FA2 interface (expected 0.0)." + ) + + if DEBUG: + print() + print("flash_attn_triton_amd.py::bwd inputs") + print("dout:", dout, dout.shape) + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("out:", out.shape) + print("softmax_lse:", softmax_lse.shape) + print("dq:", dq.shape if dq is not None else None) + print("dk:", dk.shape if dk is not None else None) + print("dv:", dv.shape if dv is not None else None) + print("alibi_slopes:", alibi_slopes) + print("dropout_p:", dropout_p) + print("out:", out) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("deterministic:", deterministic) + print("gen_:", gen_) + print("rng_state:", rng_state) + + dq = torch.zeros_like(q) if dq is None else dq.zero_() + dk = torch.zeros_like(k) if dk is None else dk.zero_() + dv = torch.zeros_like(v) if dv is None else dv.zero_() + + # get shape + batch, seqlen_q, nheads_q, _ = q.shape + + # Create delta tensor with shape based on expectations + # delta (softmax_d) : (B, Hq, Sq) or (B, Hq, round_multiple(Sq, 128)) + if SHAPE_EXPECTATIONS == "rounded": + delta = torch.zeros( + (batch, nheads_q, round_multiple(seqlen_q, 128)), + device=q.device, + dtype=torch.float32, + ) + else: + delta = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) + + # Upstream change: base seeding logic on provided rng_state instead of dropout probability. + if rng_state is not None: + philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() + else: + philox_seed, philox_offset = None, None + + if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") + + # call implementation + if DEBUG: + print(f"Using Triton implementation in {BWD_MODE} mode") + attention_backward_triton_impl( + do=dout, + q=q, + k=k, + v=v, + o=out, + softmax_lse=softmax_lse, + dq=dq, + dk=dk, + dv=dv, + delta=delta, + sm_scale=softmax_scale, + alibi_slopes=alibi_slopes, + causal=causal, + layout="bshd", + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=seqlen_q, + max_seqlen_k=k.shape[1], + seqused_q=None, + seqused_k=None, + dropout_p=dropout_p, + philox_seed=philox_seed, + philox_offset=philox_offset, + use_exp2=USE_EXP2, + mode=BWD_MODE, + ) + + if DEBUG: + print("flash_attn_triton_amd.py::bwd outputs") + print("dv:", dv, dv.shape) + print("dk:", dk, dk.shape) + print("dq:", dq, dq.shape) + # --- Assertions --- + assert dq.shape == q.shape, f"[bwd] dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"[bwd] dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"[bwd] dv shape {dv.shape} != v shape {v.shape}" + # delta (softmax_d) : (B, Hq, Sq) + if SHAPE_EXPECTATIONS == "rounded": + expected_delta_shape = (q.shape[0], q.shape[2], round_multiple(q.shape[1], 128)) + else: + expected_delta_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert ( + delta.shape == expected_delta_shape + ), f"[bwd] delta shape {delta.shape} != {expected_delta_shape}" + return dq, dk, dv, delta + + +def varlen_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + seqused_k: Optional[torch.Tensor], + leftpad_k: Optional[torch.Tensor], + block_table_: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + softmax_scale: float, + zero_tensors: bool, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + return_softmax: bool, + gen_: Optional[torch.Tensor] = None, +): + + if str(q.dtype).startswith("torch.float8"): + raise NotImplementedError( + "FP8 tensors are not supported in the AMD Triton FA2 interface (varlen_fwd). Use the FA3 path instead." + ) + + if softcap != 0.0: + raise NotImplementedError( + "softcap is not supported in varlen_fwd (expected 0.0)." + ) + if leftpad_k is not None: + raise NotImplementedError( + "leftpad_k is not supported in AMD Triton FA2 varlen_fwd." + ) + if block_table_ is not None: + raise NotImplementedError( + "block_table / paged attention is not supported in AMD Triton FA2 varlen_fwd." + ) + if seqused_k is not None: + raise NotImplementedError( + "seqused_k is not supported in AMD Triton FA2 varlen_fwd." + ) + + if DEBUG: + print() + print("flash_attn_triton_amd.py::varlen_fwd") + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) + print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) + print("alibi_slopes:", alibi_slopes) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("dropout_p:", dropout_p) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("gen_:", gen_) + out = torch.zeros_like(q) if out is None else out.zero_() + + # Layout and basic info for varlen + layout = "thd" + batch = len(cu_seqlens_q) - 1 + total_q, nheads_q, _ = q.shape + + # Create softmax_lse tensor - varlen always uses exact shape (Hq, Total_Q) + softmax_lse = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) + + # Create sd_mask tensor if needed + if return_softmax: + # sd_mask: (B, Hq, Sq, Sk) - shape based on expectations + if SHAPE_EXPECTATIONS == "rounded": + sd_mask = torch.zeros( + ( + batch, + nheads_q, + round_multiple(max_seqlen_q, 128), + round_multiple(max_seqlen_k, 128), + ), + device=q.device, + dtype=q.dtype, + ) + else: + sd_mask = torch.zeros( + (batch, nheads_q, max_seqlen_q, max_seqlen_k), + device=q.device, + dtype=q.dtype, + ) + else: + sd_mask = None + + if alibi_slopes is not None: + if alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + assert alibi_slopes.is_cuda and alibi_slopes.dim() == 2 + assert alibi_slopes.shape == (batch, nheads_q) + + philox_seed, philox_offset = PHILOX_SEED, PHILOX_OFFSET + rng_state = torch.as_tensor([philox_seed, philox_offset]) + + # Inline checks (subset appropriate for varlen) + assert q.dim() == 3 and k.dim() == 3 and v.dim() == 3 + assert q.shape[-1] == k.shape[-1] == v.shape[-1] + assert q.dtype == k.dtype == v.dtype + assert out.shape == q.shape + nheads_k = k.shape[1] + assert (nheads_q % nheads_k) == 0 + + # call implementation + if DEBUG: + print("Using Triton implementation") + attention_forward_prefill_triton_impl( + q, + k, + v, + out, + softmax_lse, + sd_mask, + softmax_scale, + alibi_slopes, + causal, + window_size_left, + window_size_right, + None, + layout, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + philox_seed, + philox_offset, + return_softmax, + USE_EXP2, + None, + None, + None, + ) + + if DEBUG: + print("varlen_fwd outputs") + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + print("sd_mask:", sd_mask, sd_mask.shape if sd_mask is not None else None) + # --- Assertions --- + # out: (Total_Q, Hq, D) + assert ( + out.shape == q.shape + ), f"[varlen_fwd] out shape {out.shape} != q shape {q.shape}" + # softmax_lse: (Hq, Total_Q) + expected_lse_shape = (q.shape[1], q.shape[0]) + assert ( + softmax_lse.shape == expected_lse_shape + ), f"[varlen_fwd] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" + assert ( + softmax_lse.dtype == torch.float32 + ), f"[varlen_fwd] softmax_lse dtype {softmax_lse.dtype} != torch.float32" + if return_softmax: + # sd_mask expected: (B, Hq, max_seqlen_q, max_seqlen_k) + assert ( + sd_mask is not None + ), "[varlen_fwd] return_softmax=True but sd_mask is None" + assert sd_mask.dim() == 4, f"[varlen_fwd] sd_mask dim {sd_mask.dim()} != 4" + batch = len(cu_seqlens_q) - 1 + assert ( + sd_mask.shape[0] == batch + ), f"[varlen_fwd] sd_mask batch {sd_mask.shape[0]} != {batch}" + assert ( + sd_mask.shape[1] == q.shape[1] + ), f"[varlen_fwd] sd_mask nheads {sd_mask.shape[1]} != {q.shape[1]}" + if SHAPE_EXPECTATIONS == "rounded": + expected_sq = round_multiple(max_seqlen_q, 128) + expected_sk = round_multiple(max_seqlen_k, 128) + assert ( + sd_mask.shape[2] == expected_sq and sd_mask.shape[3] == expected_sk + ), f"[varlen_fwd] sd_mask shape {sd_mask.shape} != (B={batch}, Hq={q.shape[1]}, Sq={expected_sq}, Sk={expected_sk})" + else: + assert ( + sd_mask.shape[2] == max_seqlen_q and sd_mask.shape[3] == max_seqlen_k + ), f"[varlen_fwd] sd_mask shape {sd_mask.shape} != (B={batch}, Hq={q.shape[1]}, Sq={max_seqlen_q}, Sk={max_seqlen_k})" + else: + assert ( + sd_mask is None + ), "[varlen_fwd] return_softmax=False but sd_mask is not None" + return out, softmax_lse, sd_mask, rng_state + + +def varlen_bwd( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + alibi_slopes: Optional[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float, + softmax_scale: float, + zero_tensors: bool, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + deterministic: bool, + gen_: Optional[torch.Tensor] = None, + rng_state: Optional[torch.Tensor] = None, +): + if str(q.dtype).startswith("torch.float8"): + raise NotImplementedError( + "FP8 tensors are not supported in the AMD Triton FA2 interface (varlen_bwd). Use the FA3 path instead." + ) + if softcap != 0.0: + raise NotImplementedError( + "softcap is not supported in varlen_bwd (expected 0.0)." + ) + + if DEBUG: + print() + print("varlen_bwd") + print("dout:", dout.shape) + print("q:", q.shape) + print("k:", k.shape) + print("v:", v.shape) + print("out:", out) + print("softmax_lse:", softmax_lse.shape) + print("dq:", dq.shape if dq is not None else None) + print("dk:", dk.shape if dk is not None else None) + print("dv:", dv.shape if dv is not None else None) + print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape) + print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape) + print("alibi_slopes:", alibi_slopes) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("dropout_p:", dropout_p) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("deterministic:", deterministic) + print("gen_:", gen_) + print("rng_state:", rng_state) + + dq = torch.zeros_like(q) if dq is None else dq.zero_() + dk = torch.zeros_like(k) if dk is None else dk.zero_() + dv = torch.zeros_like(v) if dv is None else dv.zero_() + + # get shape + batch = len(cu_seqlens_q) - 1 + total_q, nheads_q, _ = q.shape + + # Create delta tensor with shape based on expectations + # delta (softmax_d) : (Hq, Total_Q) or (Hq, Total_Q + 128*batch) + if SHAPE_EXPECTATIONS == "rounded": + delta = torch.zeros( + (nheads_q, total_q + 128 * batch), device=q.device, dtype=torch.float32 + ) + else: + delta = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) + + # Upstream change: base seeding logic on provided rng_state instead of dropout probability. + if rng_state is not None: + philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item() + else: + philox_seed, philox_offset = None, None + + if alibi_slopes is not None: + if alibi_slopes.dim() == 2: + pass + elif alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + else: + raise ValueError("Alibi can be (nheads,) or (batch_size, nheads).") + + # call implementation + if DEBUG: + print(f"Using Triton implementation in {BWD_MODE} mode") + attention_backward_triton_impl( + do=dout, + q=q, + k=k, + v=v, + o=out, + softmax_lse=softmax_lse, + dq=dq, + dk=dk, + dv=dv, + delta=delta, + sm_scale=softmax_scale, + alibi_slopes=alibi_slopes, + causal=causal, + layout="thd", + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=None, + seqused_k=None, + dropout_p=dropout_p, + philox_seed=philox_seed, + philox_offset=philox_offset, + use_exp2=USE_EXP2, + mode=BWD_MODE, + ) + + if DEBUG: + print("varlen_bwd outputs") + print("delta:", delta, delta.shape) + print("dv:", dv, dv.shape) + print("dk:", dk, dk.shape) + print("dq:", dq, dq.shape) + # --- Assertions --- + assert dq.shape == q.shape, f"[varlen_bwd] dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"[varlen_bwd] dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"[varlen_bwd] dv shape {dv.shape} != v shape {v.shape}" + if SHAPE_EXPECTATIONS == "rounded": + batch = len(cu_seqlens_q) - 1 + expected_delta_shape = (q.shape[1], q.shape[0] + 128 * batch) + else: + expected_delta_shape = (q.shape[1], q.shape[0]) # (Hq, Total_Q) + assert ( + delta.shape == expected_delta_shape + ), f"[varlen_bwd] delta shape {delta.shape} != {expected_delta_shape}" + return dq, dk, dv, delta + + +def fwd_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + cache_seqlens: Optional[Union[(int, torch.Tensor)]], + rotary_cos: Optional[torch.Tensor], + rotary_sin: Optional[torch.Tensor], + cache_batch_idx: Optional[torch.Tensor], + cache_leftpad: Optional[torch.Tensor], + block_table: Optional[torch.Tensor], + alibi_slopes: Optional[torch.Tensor], + out: Optional[torch.Tensor], + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + rotary_interleaved: bool, + num_splits: int, +): + + if softcap != 0.0: + raise NotImplementedError( + "softcap is not supported in fwd_kvcache (expected 0.0)." + ) + if num_splits not in (0, 1): + raise NotImplementedError( + "num_splits > 1 not supported in AMD Triton FA2 fwd_kvcache." + ) + + if DEBUG: + print() + print("flash_attn_triton_amd.py::fwd_kvcache inputs") + print("q:", q, q.shape) + print("k_cache:", k_cache, k_cache.shape) + print("v_cache:", v_cache, v_cache.shape) + print("k:", k, k.shape if k is not None else None) + print("v:", v, v.shape if v is not None else None) + print("cache_seqlens:", cache_seqlens) + print("rotary_cos:", rotary_cos) + print("rotary_sin:", rotary_sin) + print("cache_batch_idx:", cache_batch_idx) + print("cache_leftpad:", cache_leftpad) + print("block_table:", block_table) + print("alibi_slopes:", alibi_slopes) + print("out:", out) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("softcap:", softcap) + print("rotary_interleaved:", rotary_interleaved) + print("num_splits:", num_splits) + + # output + out = torch.zeros_like(q) if out is None else out.zero_() + + # Basic layout info for decode path + layout = "bshd" + max_seqlen_q = q.shape[1] + max_seqlen_k = k_cache.shape[1] + cache_seqlens_tensor = ( + torch.tensor(cache_seqlens, device=q.device) + if isinstance(cache_seqlens, int) + else cache_seqlens + ) + window_left = ( + int(window_size_left.item()) + if isinstance(window_size_left, torch.Tensor) + else window_size_left + ) + window_right = ( + int(window_size_right.item()) + if isinstance(window_size_right, torch.Tensor) + else window_size_right + ) + + k_new = k + v_new = v + + # get shape + batch, seqlen_q, nheads_q, _ = q.shape + + # Create softmax_lse tensor - decode always uses exact shape (B, Hq, Sq) + softmax_lse = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) + + if alibi_slopes is not None: + if alibi_slopes.dim() == 1: + alibi_slopes = alibi_slopes.unsqueeze(0).expand(batch, -1) + assert alibi_slopes.is_cuda and alibi_slopes.dim() == 2 + assert alibi_slopes.shape == (batch, nheads_q) + + # launch kernel + if DEBUG: + print("Using Triton implementation") + attention_forward_decode_triton_impl( + q, + k_cache, + v_cache, + k_new, + v_new, + out, + softmax_lse, + softmax_scale, + causal, + window_left, + window_right, + alibi_slopes, + layout, + cache_seqlens_tensor, + cache_batch_idx, + block_table, + None, + None, + None, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + rotary_interleaved=rotary_interleaved, + ) + + if DEBUG: + print("out:", out, out.shape) + print("softmax_lse:", softmax_lse, softmax_lse.shape) + # --- Assertions --- + assert ( + out.shape == q.shape + ), f"[fwd_kvcache] out shape {out.shape} != q shape {q.shape}" + expected_lse_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert ( + softmax_lse.shape == expected_lse_shape + ), f"[fwd_kvcache] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" + assert ( + softmax_lse.dtype == torch.float32 + ), f"[fwd_kvcache] softmax_lse dtype {softmax_lse.dtype} != torch.float32" + return out, softmax_lse diff --git a/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/interface_v3.py b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/interface_v3.py new file mode 100755 index 0000000000..2cca2c861e --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/interface_v3.py @@ -0,0 +1,756 @@ +import os +import warnings +import torch +from typing import Optional, Union, Tuple +from .fwd_prefill import attention_forward_prefill_triton_impl +from .fwd_decode import attention_forward_decode_triton_impl +from .bwd import attention_backward_triton_impl +from .utils import ( + DEBUG, + USE_EXP2, + BWD_MODE, + PHILOX_SEED, + PHILOX_OFFSET, + is_fp8, + get_recommended_fp8_dtype, +) + + +def fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_new: Optional[torch.Tensor], + v_new: Optional[torch.Tensor], + qv: Optional[torch.Tensor], + out: Optional[torch.Tensor], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + cu_seqlens_k_new: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + page_table: Optional[torch.Tensor], + kv_batch_idx: Optional[torch.Tensor], + leftpad_k: Optional[torch.Tensor], + rotary_cos: Optional[torch.Tensor], + rotary_sin: Optional[torch.Tensor], + seqlens_rotary: Optional[torch.Tensor], + q_descale: Optional[torch.Tensor], + k_descale: Optional[torch.Tensor], + v_descale: Optional[torch.Tensor], + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + attention_chunk: int, + softcap: float, + rotary_interleaved: bool, + scheduler_metadata=None, + num_splits: int = 1, + pack_gqa=None, + sm_margin: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Flash Attention v3 forward pass compatible interface for AMD Triton implementation. + + This function maps v3 parameters to the existing AMD Triton implementation. + """ + + if DEBUG: + print() + print("interface_fa_v3.py::fwd inputs") + print("q:", q.dtype if q is not None else None, q.shape) + print("k:", k.dtype if k is not None else None, k.shape) + print("v:", v.dtype if v is not None else None, v.shape) + print( + "k_new:", + k_new.dtype if k_new is not None else None, + k_new.shape if k_new is not None else None, + ) + print( + "v_new:", + v_new.dtype if v_new is not None else None, + v_new.shape if v_new is not None else None, + ) + print( + "qv:", + qv.dtype if qv is not None else None, + qv.shape if qv is not None else None, + ) + print( + "out:", + out.dtype if out is not None else None, + out.shape if out is not None else None, + ) + print( + "cu_seqlens_q:", + cu_seqlens_q, + cu_seqlens_q.shape if cu_seqlens_q is not None else None, + ) + print( + "cu_seqlens_k:", + cu_seqlens_k, + cu_seqlens_k.shape if cu_seqlens_k is not None else None, + ) + print( + "cu_seqlens_k_new:", + cu_seqlens_k_new, + cu_seqlens_k_new.shape if cu_seqlens_k_new is not None else None, + ) + print( + "seqused_q:", seqused_q, seqused_q.shape if seqused_q is not None else None + ) + print( + "seqused_k:", seqused_k, seqused_k.shape if seqused_k is not None else None + ) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print( + "page_table:", + page_table, + page_table.shape if page_table is not None else None, + ) + print( + "kv_batch_idx:", + kv_batch_idx, + kv_batch_idx.shape if kv_batch_idx is not None else None, + ) + print( + "leftpad_k:", leftpad_k, leftpad_k.shape if leftpad_k is not None else None + ) + print( + "rotary_cos:", + rotary_cos, + rotary_cos.shape if rotary_cos is not None else None, + ) + print( + "rotary_sin:", + rotary_sin, + rotary_sin.shape if rotary_sin is not None else None, + ) + print( + "seqlens_rotary:", + seqlens_rotary, + seqlens_rotary.shape if seqlens_rotary is not None else None, + ) + print( + "q_descale:", + q_descale.dtype if q_descale is not None else None, + q_descale.shape if q_descale is not None else None, + ) + print( + "k_descale:", + k_descale.dtype if k_descale is not None else None, + k_descale.shape if k_descale is not None else None, + ) + print( + "v_descale:", + v_descale.dtype if v_descale is not None else None, + v_descale.shape if v_descale is not None else None, + ) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("attention_chunk:", attention_chunk) + print("softcap:", softcap) + print("rotary_interleaved:", rotary_interleaved) + print("scheduler_metadata:", scheduler_metadata) + print("num_splits:", num_splits) + print("pack_gqa:", pack_gqa) + print("sm_margin:", sm_margin) + + # Handle qv packed input + if qv is not None: + raise NotImplementedError( + "QV packed input is not yet supported in the AMD Triton backend" + ) + + # Handle softcap + if softcap != 0.0: + raise NotImplementedError( + f"Softcap is not yet supported in the AMD Triton backend (got softcap={softcap}, expected 0.0)" + ) + + # Handle attention_chunk + if attention_chunk != 0 and attention_chunk != 1: + raise NotImplementedError( + f"attention_chunk is not yet supported in the AMD Triton backend (got attention_chunk={attention_chunk})" + ) + + # Handle scheduler metadata + if scheduler_metadata is not None: + raise NotImplementedError( + "Scheduler metadata is not yet supported in the AMD Triton backend" + ) + + # Handle pack_gqa + if pack_gqa is not None and pack_gqa is not False: + raise NotImplementedError( + f"pack_gqa is not yet supported in the AMD Triton backend (got pack_gqa={pack_gqa})" + ) + + # Handle num_splits + if num_splits != 1: + raise NotImplementedError( + f"Split attention (num_splits > 1) is not yet supported in the AMD Triton backend (got num_splits={num_splits})" + ) + + # Handle sm_margin + if sm_margin != 0: + raise NotImplementedError( + f"sm_margin is not yet supported in the AMD Triton backend (got sm_margin={sm_margin}, expected 0)" + ) + + # Handle leftpad_k + if leftpad_k is not None: + raise NotImplementedError( + "Left padding (leftpad_k) is not yet supported in the AMD Triton backend" + ) + + # Handle cu_seqlens_k_new + if cu_seqlens_k_new is not None: + raise NotImplementedError( + "cu_seqlens_k_new is not yet supported in the AMD Triton backend" + ) + + # establish layout / varlen & max seq lens + if cu_seqlens_q is not None: + if len(q.shape) != 3: + raise ValueError( + f"cu_seqlens_q provided but q has shape {q.shape}, expected 3D tensor for varlen" + ) + layout = "thd" + cu_seqlens_q_local = cu_seqlens_q + max_seqlens_q_local = max_seqlen_q + if cu_seqlens_k is not None: + cu_seqlens_k_local = cu_seqlens_k + max_seqlens_k_local = max_seqlen_k + else: + cu_seqlens_k_local = None + max_seqlens_k_local = k.shape[1] if len(k.shape) == 4 else max_seqlen_k + else: + layout = "bshd" + cu_seqlens_q_local = None + cu_seqlens_k_local = None + max_seqlens_q_local = q.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlens_k_local = k.shape[1] if max_seqlen_k is None else max_seqlen_k + + # Now determine if we should use decode or prefill kernel + # Decode kernel should be used for KV cache scenarios where: + # 1. k_new/v_new are provided - incremental KV cache update (primary KV cache indicator) + # 2. kv_batch_idx is provided - KV cache batch indexing (primary KV cache indicator) + # 3. seqused_k without seqused_q - indicates KV cache fill levels (not varlen masking) + # Note: In varlen, both seqused_q and seqused_k are used for sequence masking + # In KV cache, only seqused_k is used to track cache fill levels + # Detect KV cache scenarios: + # - Clear KV cache indicators (k_new, v_new, kv_batch_idx) + # - OR seqused_k without seqused_q (KV cache fill tracking, not varlen masking) + use_decode = ( + k_new is not None # Have new KV to append (KV cache indicator) + or v_new is not None # Have new KV to append (KV cache indicator) + or kv_batch_idx is not None # Have KV cache batch indexing (KV cache indicator) + or ( + seqused_k is not None and seqused_q is None + ) # KV cache fill levels (not varlen) + ) + + # Check for unsupported features with decode kernel + if use_decode: + if layout == "thd": + raise NotImplementedError( + "Varlen is not yet supported with the decode kernel in the AMD Triton backend" + ) + if kv_batch_idx is not None: + raise NotImplementedError( + "kv_batch_idx is not yet supported with the decode kernel in the AMD Triton backend" + ) + + if out is None: + # NOTE: Using types that are lower precision than float32 such as bfloat16 for fp8 causes mismatches on a small set of tests. + out_dtype = torch.float32 if is_fp8([q, k, v]) else q.dtype + if layout == "bshd": + out = torch.zeros( + q.shape[0], + q.shape[1], + q.shape[2], + v.shape[-1], + dtype=out_dtype, + device=q.device, + ) + elif layout == "thd": + out = torch.zeros( + q.shape[0], q.shape[1], v.shape[-1], dtype=out_dtype, device=q.device + ) + else: + raise ValueError( + f"Unsupported layout: {layout}. Only 'bshd' and 'thd' layouts are supported." + ) + else: + out = out.zero_() + + # Handle causal mask + causal_flag = bool(causal) + + # Handle alibi slopes + alibi_slopes = None + + # Handle dropout + dropout_p = 0.0 + return_softmax = False + philox_seed = PHILOX_SEED + philox_offset = PHILOX_OFFSET + + # Call implementation + if DEBUG: + print("Using Triton implementation") + + if use_decode: + if DEBUG: + print( + f"Using Decode Triton implementation (cache_seqlens={seqused_k is not None}, k_new={k_new is not None}, v_new={v_new is not None}, kv_batch_idx={kv_batch_idx is not None})" + ) + + # Create softmax_lse tensor for decode - always exact shape (B, Hq, Sq) + batch, seqlen_q, nheads_q, _ = q.shape + softmax_lse = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) + + attention_forward_decode_triton_impl( + q, + k, + v, + k_new, + v_new, + out, + softmax_lse, + softmax_scale, + causal_flag, + window_size_left, + window_size_right, + alibi_slopes, + layout, + seqused_k, + kv_batch_idx, + page_table, + q_descale, + k_descale, + v_descale, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + rotary_interleaved=rotary_interleaved, + seqlens_rotary=seqlens_rotary, + ) + else: + if DEBUG: + print("Using Prefill Triton implementation") + + # Create softmax_lse tensor - FA3 always uses exact shapes + if layout == "thd": + # varlen: (Hq, Total_Q) + total_q, nheads_q, _ = q.shape + softmax_lse = torch.zeros( + (nheads_q, total_q), device=q.device, dtype=torch.float32 + ) + else: + # bshd: (B, Hq, Sq) + batch, seqlen_q, nheads_q, _ = q.shape + softmax_lse = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) + + # sd_mask is not returned in v3 interface + sd_mask = None + + attention_forward_prefill_triton_impl( + q, + k, + v, + out, + softmax_lse, + sd_mask, + softmax_scale, + alibi_slopes, + causal_flag, + window_size_left, + window_size_right, + None, + layout, + cu_seqlens_q_local, + cu_seqlens_k_local, + max_seqlens_q_local, + max_seqlens_k_local, + dropout_p, + philox_seed, + philox_offset, + return_softmax, + USE_EXP2, + q_descale, + k_descale, + v_descale, + seqused_q, + seqused_k, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + rotary_interleaved=rotary_interleaved, + seqlens_rotary=seqlens_rotary, + ) + + if DEBUG: + print("interface_fa_v3.py::fwd outputs") + print( + "out:", + out.dtype if out is not None else None, + out.shape if out is not None else None, + ) + print( + "softmax_lse:", + softmax_lse.dtype if softmax_lse is not None else None, + softmax_lse.shape if softmax_lse is not None else None, + ) + + # --- Assertions (FA3 always expects exact shapes) --- + # out: same shape as q except last dim is v's head_dim + if layout == "thd": + # varlen: (Total_Q, Hq, Dv) + assert ( + out.shape[0] == q.shape[0] + ), f"[fwd_v3] out.shape[0] {out.shape[0]} != q.shape[0] {q.shape[0]}" + assert ( + out.shape[1] == q.shape[1] + ), f"[fwd_v3] out.shape[1] {out.shape[1]} != q.shape[1] {q.shape[1]}" + assert ( + out.shape[2] == v.shape[-1] + ), f"[fwd_v3] out.shape[2] {out.shape[2]} != v.shape[-1] {v.shape[-1]}" + else: + # bshd: (B, Sq, Hq, Dv) + assert ( + out.shape[0] == q.shape[0] + ), f"[fwd_v3] out.shape[0] {out.shape[0]} != q.shape[0] {q.shape[0]}" + assert ( + out.shape[1] == q.shape[1] + ), f"[fwd_v3] out.shape[1] {out.shape[1]} != q.shape[1] {q.shape[1]}" + assert ( + out.shape[2] == q.shape[2] + ), f"[fwd_v3] out.shape[2] {out.shape[2]} != q.shape[2] {q.shape[2]}" + assert ( + out.shape[3] == v.shape[-1] + ), f"[fwd_v3] out.shape[3] {out.shape[3]} != v.shape[-1] {v.shape[-1]}" + + # softmax_lse dtype + assert ( + softmax_lse.dtype == torch.float32 + ), f"[fwd_v3] softmax_lse dtype {softmax_lse.dtype} != torch.float32" + # softmax_lse shape depends on layout + if layout == "thd": + # varlen: (Hq, Total_Q) + expected_lse_shape = (q.shape[1], q.shape[0]) + else: + # bshd: (B, Hq, Sq) + expected_lse_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert ( + softmax_lse.shape == expected_lse_shape + ), f"[fwd_v3] softmax_lse shape {softmax_lse.shape} != {expected_lse_shape}" + + # Return format compatible with v3 + # V3 returns (out, softmax_lse, *rest) where rest can be empty or contain additional outputs + return out, softmax_lse + + +def bwd( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor], + seqused_k: Optional[torch.Tensor], + max_seqlen_q: Optional[int], + max_seqlen_k: Optional[int], + softmax_scale: float, + causal: bool, + window_size_left: int, + window_size_right: int, + softcap: float, + deterministic: bool, + sm_margin: int = 0, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Flash Attention v3 backward pass compatible interface for AMD Triton implementation. + + This function maps v3 parameters to the existing AMD Triton implementation. + """ + + if DEBUG: + print() + print("interface_fa_v3.py::bwd inputs") + print( + "dout:", + dout.dtype if dout is not None else None, + dout.shape if dout is not None else None, + ) + print( + "q:", q.dtype if q is not None else None, q.shape if q is not None else None + ) + print( + "k:", k.dtype if k is not None else None, k.shape if k is not None else None + ) + print( + "v:", v.dtype if v is not None else None, v.shape if v is not None else None + ) + print( + "out:", + out.dtype if out is not None else None, + out.shape if out is not None else None, + ) + print( + "softmax_lse:", + softmax_lse.dtype if softmax_lse is not None else None, + softmax_lse.shape if softmax_lse is not None else None, + ) + print( + "dq:", + dq.dtype if dq is not None else None, + dq.shape if dq is not None else None, + ) + print( + "dk:", + dk.dtype if dk is not None else None, + dk.shape if dk is not None else None, + ) + print( + "dv:", + dv.dtype if dv is not None else None, + dv.shape if dv is not None else None, + ) + print( + "cu_seqlens_q:", + cu_seqlens_q, + cu_seqlens_q.shape if cu_seqlens_q is not None else None, + ) + print( + "cu_seqlens_k:", + cu_seqlens_k, + cu_seqlens_k.shape if cu_seqlens_k is not None else None, + ) + print( + "seqused_q:", seqused_q, seqused_q.shape if seqused_q is not None else None + ) + print( + "seqused_k:", seqused_k, seqused_k.shape if seqused_k is not None else None + ) + print("max_seqlen_q:", max_seqlen_q) + print("max_seqlen_k:", max_seqlen_k) + print("softmax_scale:", softmax_scale) + print("causal:", causal) + print("window_size_left:", window_size_left) + print("window_size_right:", window_size_right) + print("softcap:", softcap) + print("deterministic:", deterministic) + print("sm_margin:", sm_margin) + + # Check for unsupported features in backward pass + + # Handle softcap + if softcap != 0.0: + raise NotImplementedError( + f"Softcap is not yet supported in the AMD Triton backend backward pass (got softcap={softcap}, expected 0.0)" + ) + + # Handle sm_margin + if sm_margin != 0: + raise NotImplementedError( + f"sm_margin is not yet supported in the AMD Triton backend backward pass (got sm_margin={sm_margin}, expected 0)" + ) + + # Initialize gradient tensors if not provided + # NOTE: Using types that are lower precision than float32 such as bfloat16 for fp8 causes mismatches on a small set of tests. + grad_dtype = torch.float32 if is_fp8([q, k, v]) else q.dtype + dq = torch.zeros_like(q, dtype=grad_dtype) if dq is None else dq.zero_() + dk = torch.zeros_like(k, dtype=grad_dtype) if dk is None else dk.zero_() + dv = torch.zeros_like(v, dtype=grad_dtype) if dv is None else dv.zero_() + + # Determine layout based on cu_seqlens + if cu_seqlens_q is not None and cu_seqlens_k is not None: + # Variable length sequence mode + layout = "thd" + batch = len(cu_seqlens_q) - 1 + total_q, nheads_q, _ = q.shape + # Create delta tensor - varlen: (Hq, Total_Q) + delta = torch.zeros((nheads_q, total_q), device=q.device, dtype=torch.float32) + else: + # Regular batch mode + layout = "bshd" + batch, seqlen_q, nheads_q, _ = q.shape + max_seqlen_q = q.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlen_k = k.shape[1] if max_seqlen_k is None else max_seqlen_k + # Create delta tensor - bshd: (B, Hq, Sq) + delta = torch.zeros( + (batch, nheads_q, seqlen_q), device=q.device, dtype=torch.float32 + ) + + # V3 backward doesn't have dropout or alibi slopes + dropout_p = 0.0 + philox_seed, philox_offset = None, None + alibi_slopes = None + + # Call implementation + if DEBUG: + print(f"Using Triton implementation in {BWD_MODE} mode") + attention_backward_triton_impl( + do=dout, + q=q, + k=k, + v=v, + o=out, + softmax_lse=softmax_lse, + dq=dq, + dk=dk, + dv=dv, + delta=delta, + sm_scale=softmax_scale, + alibi_slopes=alibi_slopes, + causal=causal, + layout=layout, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + dropout_p=dropout_p, + philox_seed=philox_seed, + philox_offset=philox_offset, + use_exp2=USE_EXP2, + mode=BWD_MODE, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + + if DEBUG: + print("interface_fa_v3.py::bwd outputs") + print( + "dq:", + dq.dtype if dq is not None else None, + dq.shape if dq is not None else None, + ) + print( + "dk:", + dk.dtype if dk is not None else None, + dk.shape if dk is not None else None, + ) + print( + "dv:", + dv.dtype if dv is not None else None, + dv.shape if dv is not None else None, + ) + print( + "delta:", + delta.dtype if delta is not None else None, + delta.shape if delta is not None else None, + ) + + # --- Assertions (FA3 always expects exact shapes) --- + # Gradients should match input shapes + assert dq.shape == q.shape, f"[bwd_v3] dq shape {dq.shape} != q shape {q.shape}" + assert dk.shape == k.shape, f"[bwd_v3] dk shape {dk.shape} != k shape {k.shape}" + assert dv.shape == v.shape, f"[bwd_v3] dv shape {dv.shape} != v shape {v.shape}" + # delta (softmax_d) should match softmax_lse shape + assert ( + delta.dtype == torch.float32 + ), f"[bwd_v3] delta dtype {delta.dtype} != torch.float32" + if layout == "thd": + # varlen: (Hq, Total_Q) + expected_delta_shape = (q.shape[1], q.shape[0]) + else: + # bshd: (B, Hq, Sq) + expected_delta_shape = (q.shape[0], q.shape[2], q.shape[1]) + assert ( + delta.shape == expected_delta_shape + ), f"[bwd_v3] delta shape {delta.shape} != {expected_delta_shape}" + + # V3 expects (dq, dk, dv, softmax_d, *rest) + # delta is the softmax_d in this case + return dq, dk, dv, delta + + +def fwd_combine( + out_partial: torch.Tensor, + lse_partial: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """ + Combine partial outputs from split attention computation. + + This is used when num_splits > 1 to combine the partial results. + + Args: + out_partial: Partial output tensor from split computation + lse_partial: Partial log-sum-exp tensor + out: Optional output tensor to write to + out_dtype: Optional dtype for output + + Returns: + Combined output tensor + """ + raise NotImplementedError( + "fwd_combine is not yet implemented in the AMD Triton backend" + ) + + +def get_scheduler_metadata( + batch_size: int, + max_seqlen_q: int, + max_seqlen_k: int, + num_heads_q: int, + num_heads_kv: int, + headdim: int, + headdim_v: int, + qkv_dtype: torch.dtype, + cache_seqlens: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_size: Optional[int] = None, + max_seqlen_k_new: int = 0, + causal: bool = False, + window_size_left: int = -1, + window_size_right: int = -1, + attention_chunk: int = 0, + has_softcap: bool = False, + num_splits: int = 0, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, +): + """ + Get scheduler metadata for optimized kernel selection. + + This function is used to precompute metadata for kernel scheduling in FA3. + The AMD Triton backend currently doesn't use scheduler metadata, so this + raises an error. + + Args: + Various attention parameters used for scheduling decisions + + Returns: + None - scheduler metadata is not used in AMD Triton backend + """ + raise NotImplementedError( + "get_scheduler_metadata is not supported in the AMD Triton backend yet." + ) diff --git a/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/utils.py b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/utils.py new file mode 100644 index 0000000000..44c8a53541 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/utils.py @@ -0,0 +1,1512 @@ +import csv +import math +import torch +import os +import random +import functools +import triton +import triton.language as tl +import numpy as np +from typing import Literal, Optional, Union, Tuple + +# ------------------------------- +# Gloabl Variables +# ------------------------------- +AUTOTUNE = os.environ.get("FLASH_ATTENTION_TRITON_AMD_AUTOTUNE", "0").lower() in ( + "1", + "true", + "yes", +) +DEBUG = os.environ.get("FLASH_ATTENTION_TRITON_AMD_DEBUG", "0").lower() in ( + "1", + "true", + "yes", +) +if AUTOTUNE or DEBUG: + os.environ["TRITON_PRINT_AUTOTUNING"] = "1" +USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE" +USE_TRITON_INTERPRET = os.environ.get("TRITON_INTERPRET", "0").lower() in ( + "1", + "true", + "yes", +) +DEBUG_TRITON = ( + os.environ.get("DEBUG_TRITON", "0").lower() in ("1", "true", "yes") + and USE_TRITON_INTERPRET +) +DEBUG_TRITON_DETAIL = ( + os.environ.get("DEBUG_TRITON_DETAIL", "0").lower() in ("1", "true", "yes") + and USE_TRITON_INTERPRET +) +if USE_TRITON_ROCM: # TODO remove this + random.seed(42) +BWD_MODE: Literal["fused", "fused_atomic", "split"] = "fused" +USE_EXP2 = True +PHILOX_SEED = 0x1BF58 +PHILOX_OFFSET = 0x1D4B49 +SHAPE_EXPECTATIONS: Literal["exact", "rounded"] = "exact" +FP8_AUTO_DESCALE = False + + +# ------------------------------- +# Input Helper +# ------------------------------- +def random_seqlens_composition(SEQ_LEN, BATCH): + # generate a random composition of N into Z positive parts. + idx = torch.randperm(SEQ_LEN - 1)[: BATCH - 1] + 1 + idx, _ = torch.sort(idx) + breakpoints = torch.cat( + [ + torch.tensor([0], dtype=torch.long), + idx, + torch.tensor([SEQ_LEN], dtype=torch.long), + ] + ) + seqlens = (breakpoints[1:] - breakpoints[:-1]).to(torch.int32) + return seqlens + + +def generate_varlen_tensor( + total_seqlen: int, + num_heads: int, + head_size: int, + batch_size: Optional[int] = None, + equal_seqlens: bool = False, + device: str = "cuda", + dtype: torch.dtype = torch.float16, + mode: Literal["random", "ones", "incremental", "identity"] = "random", +): + if DEBUG: + print("total_seqlen", total_seqlen) + print("num_heads", num_heads) + print("head_size", head_size) + + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # get valid batch_size + if batch_size is None: + valid_batch_sizes = [ + bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen + ] + batch_size = random.choice(valid_batch_sizes) + + # get seqlens + if equal_seqlens: + seqlens = torch.full( + (batch_size,), total_seqlen // batch_size, dtype=torch.int32, device=device + ) + seqlens[-1] += total_seqlen % batch_size + else: + seqlens = random_seqlens_composition(total_seqlen, batch_size).to(device=device) + + # create cumulative sequence lengths + cu_seqlens = ( + torch.cat( + [torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)] + ) + .to(torch.int32) + .to(device=device) + ) + max_seqlen = torch.max(seqlens).to(torch.int32).item() + + # create varlen tensor based on mode + if mode == "incremental": + x = torch.zeros(total_seqlen, num_heads, head_size, dtype=dtype, device=device) + for i in range(batch_size): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + length = end - start + + x[start:end, :, :] = ( + torch.arange(length, dtype=dtype, device=device) + .view(length, 1, 1) + .expand(length, num_heads, head_size) + ) + elif mode == "identity": + x = torch.zeros(total_seqlen, num_heads, head_size, dtype=dtype, device=device) + # for each batch, create identity pattern within that batch's sequence + for i in range(batch_size): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + length = end - start + + # create identity pattern for positions within this batch + for pos in range(min(length, head_size)): + x[start + pos, :, pos] = 1.0 + elif mode == "random": + x = torch.randn( + (total_seqlen, num_heads, head_size), dtype=dtype, device=device + ) + elif mode == "ones": + x = torch.ones((total_seqlen, num_heads, head_size), dtype=dtype, device=device) + else: + raise ValueError(f"Unkown mode {mode}") + + if is_fp8_dtype: + # cast to fp8 + x, descale_x = cast_to_fp8( + x, og_fp8_dtype, "thd", cu_seqlens=cu_seqlens, max_seqlen=max_seqlen + ) + x.requires_grad_() + return x, cu_seqlens, max_seqlen, descale_x + else: + x.requires_grad_() + return x, cu_seqlens, max_seqlen + + +def generate_bshd_tensor( + BATCH, + SEQ_LEN, + NUM_HEADS, + D_HEAD, + dtype: torch.dtype = torch.float16, + device="cuda", + mode: Literal["random", "ones", "incremental", "identity"] = "random", +): + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # gen tensor based on mode + tensor_shape = (BATCH, SEQ_LEN, NUM_HEADS, D_HEAD) + if mode == "incremental": + x = ( + torch.arange(SEQ_LEN, dtype=dtype, device=device) + .view(1, SEQ_LEN, 1, 1) + .expand(*tensor_shape) + .contiguous() + ) + elif mode == "identity": + x = torch.zeros(tensor_shape, dtype=dtype, device=device) + # create identity pattern: position i has value 1 at dimension i + for i in range(min(SEQ_LEN, D_HEAD)): + x[:, i, :, i] = 1.0 + elif mode == "random": + x = torch.randn(tensor_shape, dtype=dtype, device=device) + elif mode == "ones": + x = torch.ones(tensor_shape, dtype=dtype, device=device) + else: + raise ValueError(f"Unkown mode {mode}") + + if is_fp8_dtype: + # cast to fp8 + x, descale_x = cast_to_fp8(x, og_fp8_dtype, "bshd") + x.requires_grad_() + return x, descale_x + else: + x.requires_grad_() + return x + + +def generate_bhsd_tensor( + BATCH, + NUM_HEADS, + SEQ_LEN, + D_HEAD, + dtype: torch.dtype = torch.float16, + device="cuda", + mode: Literal["random", "ones", "incremental", "identity"] = "random", +): + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # gen tensor based on mode + tensor_shape = (BATCH, NUM_HEADS, SEQ_LEN, D_HEAD) + if mode == "incremental": + x = ( + torch.arange(SEQ_LEN, dtype=dtype, device=device) + .view(1, 1, SEQ_LEN, 1) + .expand(*tensor_shape) + .contiguous() + ) + elif mode == "identity": + x = torch.zeros(tensor_shape, dtype=dtype, device=device) + # create identity pattern: position i has value 1 at dimension i + for i in range(min(SEQ_LEN, D_HEAD)): + x[:, :, i, i] = 1.0 + elif mode == "random": + x = torch.randn(tensor_shape, dtype=dtype, device=device) + elif mode == "ones": + x = torch.ones(tensor_shape, dtype=dtype, device=device) + else: + raise ValueError(f"Unkown mode {mode}") + + if is_fp8_dtype: + raise ValueError("fp8 not supported for bhsd yet") + else: + x.requires_grad_() + return x + + +def generate_bshd_qkv_packed( + BATCH, + SEQ_LEN, + NUM_HEADS, + D_HEAD, + dtype: torch.dtype = torch.float16, + device="cuda", + DEBUG_INPUT=False, +): + """Generate QKV packed tensor with shape (BATCH, SEQ_LEN, 3, NUM_HEADS, D_HEAD)""" + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # gen tensor + tensor_shape = (BATCH, SEQ_LEN, 3, NUM_HEADS, D_HEAD) + if DEBUG_INPUT: + x = ( + torch.arange(SEQ_LEN, dtype=dtype, device=device) + .view(1, SEQ_LEN, 1, 1, 1) + .expand(*tensor_shape) + .contiguous() + ) + else: + x = torch.randn(tensor_shape, dtype=dtype, device=device) + + if is_fp8_dtype: + # cast to fp8 - need to handle the packed dimension + raise NotImplementedError("FP8 not supported for QKV packing yet") + else: + x.requires_grad_() + return x + + +def generate_bshd_kv_packed( + BATCH, + SEQ_LEN, + NUM_HEADS, + D_HEAD, + dtype: torch.dtype = torch.float16, + device="cuda", + DEBUG_INPUT=False, +): + """Generate KV packed tensor with shape (BATCH, SEQ_LEN, 2, NUM_HEADS, D_HEAD)""" + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # gen tensor + tensor_shape = (BATCH, SEQ_LEN, 2, NUM_HEADS, D_HEAD) + if DEBUG_INPUT: + x = ( + torch.arange(SEQ_LEN, dtype=dtype, device=device) + .view(1, SEQ_LEN, 1, 1, 1) + .expand(*tensor_shape) + .contiguous() + ) + else: + x = torch.randn(tensor_shape, dtype=dtype, device=device) + + if is_fp8_dtype: + # cast to fp8 - need to handle the packed dimension + raise NotImplementedError("FP8 not supported for KV packing yet") + else: + x.requires_grad_() + return x + + +def generate_bhsd_qkv_packed( + BATCH, + NUM_HEADS, + SEQ_LEN, + D_HEAD, + dtype: torch.dtype = torch.float16, + device="cuda", + DEBUG_INPUT=False, +): + """Generate QKV packed tensor with shape (BATCH, 3, NUM_HEADS, SEQ_LEN, D_HEAD)""" + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # gen tensor + tensor_shape = (BATCH, 3, NUM_HEADS, SEQ_LEN, D_HEAD) + if DEBUG_INPUT: + x = ( + torch.arange(SEQ_LEN, dtype=dtype, device=device) + .view(1, 1, 1, SEQ_LEN, 1) + .expand(*tensor_shape) + .contiguous() + ) + else: + x = torch.randn(tensor_shape, dtype=dtype, device=device) + + if is_fp8_dtype: + # cast to fp8 - need to handle the packed dimension + raise NotImplementedError("FP8 not supported for QKV packing yet") + else: + x.requires_grad_() + return x + + +def generate_bhsd_kv_packed( + BATCH, + NUM_HEADS, + SEQ_LEN, + D_HEAD, + dtype: torch.dtype = torch.float16, + device="cuda", + DEBUG_INPUT=False, +): + """Generate KV packed tensor with shape (BATCH, 2, NUM_HEADS, SEQ_LEN, D_HEAD)""" + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # gen tensor + tensor_shape = (BATCH, 2, NUM_HEADS, SEQ_LEN, D_HEAD) + if DEBUG_INPUT: + x = ( + torch.arange(SEQ_LEN, dtype=dtype, device=device) + .view(1, 1, 1, SEQ_LEN, 1) + .expand(*tensor_shape) + .contiguous() + ) + else: + x = torch.randn(tensor_shape, dtype=dtype, device=device) + + if is_fp8_dtype: + # cast to fp8 - need to handle the packed dimension + raise NotImplementedError("FP8 not supported for KV packing yet") + else: + x.requires_grad_() + return x + + +def generate_varlen_qkv_packed( + total_seqlen: int, + num_heads: int, + head_size: int, + batch_size: Optional[int] = None, + equal_seqlens: bool = False, + device: str = "cuda", + dtype: torch.dtype = torch.float16, + DEBUG_INPUT: bool = False, +): + """Generate varlen QKV packed tensor with shape (total_seqlen, 3, num_heads, head_size)""" + if DEBUG: + print("generate_varlen_qkv_packed") + print("total_seqlen", total_seqlen) + print("num_heads", num_heads) + print("head_size", head_size) + + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # get valid batch_size + if batch_size is None: + valid_batch_sizes = [ + bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen + ] + batch_size = random.choice(valid_batch_sizes) + + # get seqlens + if equal_seqlens: + seqlens = torch.full( + (batch_size,), total_seqlen // batch_size, dtype=torch.int32, device=device + ) + seqlens[-1] += total_seqlen % batch_size + else: + seqlens = random_seqlens_composition(total_seqlen, batch_size).to(device=device) + + # create cumulative sequence lengths + cu_seqlens = ( + torch.cat( + [torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)] + ) + .to(torch.int32) + .to(device=device) + ) + max_seqlen = torch.max(seqlens).to(torch.int32).item() + + # create varlen qkv packed tensor + if DEBUG_INPUT: + x = torch.zeros( + total_seqlen, 3, num_heads, head_size, dtype=dtype, device=device + ) + for i in range(batch_size): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + length = end - start + + x[start:end, :, :, :] = ( + torch.arange(length, dtype=dtype, device=device) + .view(length, 1, 1, 1) + .expand(length, 3, num_heads, head_size) + ) + else: + x = torch.randn( + (total_seqlen, 3, num_heads, head_size), dtype=dtype, device=device + ) + + if is_fp8_dtype: + # cast to fp8 - need to handle the packed dimension + raise NotImplementedError("FP8 not supported for QKV packing yet") + else: + x.requires_grad_() + return x, cu_seqlens, max_seqlen + + +def generate_varlen_kv_packed( + total_seqlen: int, + num_heads: int, + head_size: int, + batch_size: Optional[int] = None, + equal_seqlens: bool = False, + device: str = "cuda", + dtype: torch.dtype = torch.float16, + DEBUG_INPUT: bool = False, +): + """Generate varlen KV packed tensor with shape (total_seqlen, 2, num_heads, head_size)""" + if DEBUG: + print("generate_varlen_kv_packed") + print("total_seqlen", total_seqlen) + print("num_heads", num_heads) + print("head_size", head_size) + + # save fp8 type + is_fp8_dtype = is_dtype_fp8(dtype) + if is_fp8_dtype: + og_fp8_dtype = dtype + dtype = torch.float32 + + # get valid batch_size + if batch_size is None: + valid_batch_sizes = [ + bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen + ] + batch_size = random.choice(valid_batch_sizes) + + # get seqlens + if equal_seqlens: + seqlens = torch.full( + (batch_size,), total_seqlen // batch_size, dtype=torch.int32, device=device + ) + seqlens[-1] += total_seqlen % batch_size + else: + seqlens = random_seqlens_composition(total_seqlen, batch_size).to(device=device) + + # create cumulative sequence lengths + cu_seqlens = ( + torch.cat( + [torch.tensor([0], dtype=torch.int32, device=device), seqlens.cumsum(dim=0)] + ) + .to(torch.int32) + .to(device=device) + ) + max_seqlen = torch.max(seqlens).to(torch.int32).item() + + # create varlen kv packed tensor + if DEBUG_INPUT: + x = torch.zeros( + total_seqlen, 2, num_heads, head_size, dtype=dtype, device=device + ) + for i in range(batch_size): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + length = end - start + + x[start:end, :, :, :] = ( + torch.arange(length, dtype=dtype, device=device) + .view(length, 1, 1, 1) + .expand(length, 2, num_heads, head_size) + ) + else: + x = torch.randn( + (total_seqlen, 2, num_heads, head_size), dtype=dtype, device=device + ) + + if is_fp8_dtype: + # cast to fp8 - need to handle the packed dimension + raise NotImplementedError("FP8 not supported for KV packing yet") + else: + x.requires_grad_() + return x, cu_seqlens, max_seqlen + + +# ------------------------------- +# Alibi +# ------------------------------- +@triton.jit +def compute_alibi_block( + alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False +): + # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix + # for casual mask we want something like this where (1 is kept and 0 is masked) + # seqlen_q = 2 and seqlen_k = 5 + # 1 1 1 1 0 + # 1 1 1 1 1 + # seqlen_q = 5 and seqlen_k = 2 + # 0 0 + # 0 0 + # 0 0 + # 1 0 + # 1 1 + # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False + # 1. offs_m[:,None] = [[0], + # [1], + # 2. offs_m[:,None] + seqlen_k = [[5], + # [6], + # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], + # [4], + # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], + # [4], [ 4, 3, 2, 1, 0]] + # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], + # [ -4, -3, -2, -1, 0]], + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block + + +# ------------------------------- +# FP8 +# ------------------------------- +def is_dtype_fp8(dtype) -> bool: + supported = { + torch.float8_e4m3fnuz, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + } + if dtype not in supported: + return False + return True + + +_RECOMMENDED_FP8_REPLACEMENTS = { + "gfx942": { + torch.float8_e4m3fn: torch.float8_e4m3fnuz, + torch.float8_e5m2: torch.float8_e5m2fnuz, + }, +} + + +def get_recommended_fp8_dtype(x): + dtype = x.dtype if isinstance(x, torch.Tensor) else x + if not is_dtype_fp8(dtype): + return dtype + arch = get_arch() + return _RECOMMENDED_FP8_REPLACEMENTS.get(arch, {}).get(dtype, dtype) + + +def is_fp8(x) -> bool: + """Return whether tensor(s) use FP8. + + Accepts either a single tensor or a list/tuple of tensors. + + Rules: + * Single tensor: return True if FP8 (after arch validation), else False. + * Multiple tensors: + - If all tensors are FP8 -> return True. + - If none are FP8 -> return False. + - If a mix of FP8 and non-FP8 -> raise ValueError. + + Empty list/tuple returns False. + """ + + def _is_fp8_single(t: torch.Tensor) -> bool: + if is_dtype_fp8(t.dtype): + arch = get_arch() + if arch not in ("gfx942", "gfx950"): + raise RuntimeError( + f"{arch} is not in the list of supported architectures for FP8" + ) + return True + return False + + if isinstance(x, (list, tuple)): + if len(x) == 0: + return False + flags = [_is_fp8_single(t) for t in x] + if all(flags): + return True + if not any(flags): + return False + raise ValueError( + "Mixed FP8 and non-FP8 tensors provided; either all or none must be FP8." + ) + else: + return _is_fp8_single(x) + + +@triton.jit +def compute_fp8_scaling_factors(x, fp8_max: tl.constexpr): + # compute fp8 scaling and descaling factor for a block + x_amax = tl.max(tl.abs(x)) # NOTE: abs deals with negative values + x_amax = tl.where(x_amax <= 1e-9, 1e-9, x_amax) + scale_x = fp8_max / x_amax + descale_x = x_amax / fp8_max + return scale_x, descale_x + + +@triton.jit +def _cast_varlen_to_fp8_kernel_2d( + X, + X_fp8, + Descale, + cu_seqlens, + H, + MAX_SEQLEN, + stride_batch, + stride_seq, + stride_head, + stride_dim, + stride_out_batch, + stride_out_seq, + stride_out_head, + stride_out_dim, + stride_desc_batch, + stride_desc_head, + FP8_CLAMP_VAL, + FP8_MAX, + BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + ACTUAL_HEAD_DIM: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + # Process one (batch, head) pair per kernel + b_id = tl.program_id(0) + h_id = tl.program_id(1) + + # Get sequence bounds for this batch + if IS_VARLEN: + seq_start = tl.load(cu_seqlens + b_id) + seq_end = tl.load(cu_seqlens + b_id + 1) + seqlen = seq_end - seq_start + else: + seq_start = 0 + seqlen = MAX_SEQLEN + + # initialize max value tracker + x_max_val = 0.0 + + # STEP 1: Find max absolute value across the entire sequence + num_of_blocks = tl.cdiv(seqlen, BLOCK_SIZE) + for blk_idx in range(0, num_of_blocks): + # print("blk_idx:", blk_idx) + # offsets + offs_seq = blk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_dim = tl.arange(0, HEAD_DIM) + + # Create mask for valid elements + mask_seq = offs_seq[:, None] < seqlen + if ACTUAL_HEAD_DIM != HEAD_DIM: + mask_dim = offs_dim[None, :] < ACTUAL_HEAD_DIM + mask_seq = mask_seq & mask_dim + + # Load block + adj_x = ( + b_id * stride_batch + + h_id * stride_head + + seq_start * stride_seq + + offs_seq[:, None] * stride_seq + + offs_dim[None, :] * stride_dim + ) + x_block = tl.load(X + adj_x, mask=mask_seq, other=0.0) + # print("x_block:", x_block) + + # Find max absolute value in this block + block_max = tl.max(tl.abs(x_block)) + # print("block_max:", block_max) + + # Update overall max + x_max_val = tl.maximum(x_max_val, block_max) + # print("x_max_val:", x_max_val) + + # clamp to avoid division by zero issues + x_max_val = tl.maximum(x_max_val, FP8_CLAMP_VAL) + + # compute scale and descale factors for the entire sequence + scale = FP8_MAX / x_max_val + descale = x_max_val / FP8_MAX + + # store descale factor for this (batch, head) pair + desc_ptr = Descale + b_id * stride_desc_batch + h_id # * stride_desc_head + tl.store(desc_ptr, descale) + + # STEP 2: Apply scaling to the entire sequence and convert to FP8 + for blk_idx in range(0, num_of_blocks): + # offsets + offs_seq = blk_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_dim = tl.arange(0, HEAD_DIM) + + # Create mask for valid elements + mask_seq = offs_seq[:, None] < seqlen + if ACTUAL_HEAD_DIM != HEAD_DIM: + mask_dim = offs_dim[None, :] < ACTUAL_HEAD_DIM + mask_seq = mask_seq & mask_dim + + # Load block - Using the fixed addressing + addr = ( + b_id * stride_batch + + h_id * stride_head + + seq_start * stride_seq + + offs_seq[:, None] * stride_seq + + offs_dim[None, :] * stride_dim + ) + x_block = tl.load(X + addr, mask=mask_seq, other=0.0) + + # Apply scale and convert to FP8 + x_fp8_block = (x_block * scale).to(X_fp8.type.element_ty) + + # Store results + addr_out = ( + b_id * stride_out_batch + + h_id * stride_out_head + + seq_start * stride_out_seq + + offs_seq[:, None] * stride_out_seq + + offs_dim[None, :] * stride_out_dim + ) + tl.store(X_fp8 + addr_out, x_fp8_block, mask=mask_seq) + + +def cast_to_fp8( + x: torch.Tensor, + fp8_dtype: torch.dtype, + layout: Literal["bshd", "thd"], + clamp_val: float = 1e-9, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + if False: + print() + print("cast_to_fp8") + print("x:", x, x.shape) + print("fp8_dtype:", fp8_dtype) + print("cu_seqlens:", cu_seqlens) + print("max_seqlen:", max_seqlen) + print("clamp_val:", clamp_val) + + # check types are valid + assert x.dtype in { + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + } and is_dtype_fp8(fp8_dtype), f"Cannot cast {x.dtype} to {fp8_dtype}" + + # extract dimensions + batch, max_seqlen_final, num_heads, head_dim = get_shape_from_layout( + x, layout, cu_seqlens, max_seqlen + ) + is_varlen = layout == "thd" + fp8_max = torch.finfo(fp8_dtype).max + if False: + print("batch:", batch) + print("max_seqlen_final:", max_seqlen_final) + print("num_heads:", num_heads) + print("head_dim:", head_dim) + + # get closest power of 2 for head_dim + padded_head_dim = 1 << (head_dim - 1).bit_length() + padded_head_dim = max(padded_head_dim, 32) + + # kernel params + x_fp8 = torch.zeros_like(x, dtype=fp8_dtype) + descale_factors = torch.zeros( + (batch, num_heads), device=x.device, dtype=torch.float32 + ) + BLOCK_SIZE = 128 + + # calculate strides + stride_batch, stride_head, stride_seq, stride_dim = get_stride_from_layout( + x, layout + ) + stride_out_batch, stride_out_head, stride_out_seq, stride_out_dim = ( + get_stride_from_layout(x_fp8, layout) + ) + stride_desc_batch, stride_desc_head = descale_factors.stride() + + if False: + print("stride_batch", stride_batch) + print("stride_head", stride_head) + print("stride_seq", stride_seq) + print("stride_dim", stride_dim) + print("stride_out_batch", stride_out_batch) + print("stride_out_head", stride_out_head) + print("stride_out_seq", stride_out_seq) + print("stride_out_dim", stride_out_dim) + print("stride_desc_batch", stride_desc_batch) + print("stride_desc_head", stride_desc_head) + + grid = (batch, num_heads) + _cast_varlen_to_fp8_kernel_2d[grid]( + x, + x_fp8, + descale_factors, + cu_seqlens, + num_heads, + max_seqlen_final, + stride_batch, + stride_seq, + stride_head, + stride_dim, + stride_out_batch, + stride_out_seq, + stride_out_head, + stride_out_dim, + stride_desc_batch, + stride_desc_head, + clamp_val, + fp8_max, + BLOCK_SIZE=BLOCK_SIZE, + HEAD_DIM=padded_head_dim, + ACTUAL_HEAD_DIM=head_dim, + IS_VARLEN=is_varlen, + ) + + if False: + print("x_fp8:", x_fp8, x_fp8.shape) + print("descale_factors:", descale_factors, descale_factors.shape) + return x_fp8, descale_factors + + +# ------------------------------- +# Misc +# ------------------------------- +def get_shape_from_layout( + x: torch.Tensor, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +) -> tuple[int, int, int, int]: + if layout == "bhsd": + batch, num_heads, max_seqlen_final, head_dim = x.shape + elif layout == "bshd": + batch, max_seqlen_final, num_heads, head_dim = x.shape + elif layout == "thd": + total_seqlen, num_heads, head_dim = x.shape + if cu_seqlens is None: + raise ValueError("cu_seqlens must be provided for varlen (thd) layout") + if max_seqlen is None: + raise ValueError("max_seqlen must be provided for varlen (thd) layout") + + batch, max_seqlen_final, num_heads, head_dim = ( + len(cu_seqlens) - 1, + max_seqlen, + num_heads, + head_dim, + ) + else: + assert False, "Got unsupported layout." + + return batch, max_seqlen_final, num_heads, head_dim + + +def get_shapes_from_layout( + q, + k, + layout, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, +): + batch_q, seqlen_q, nheads_q, head_size_q = get_shape_from_layout( + q, layout, cu_seqlens_q, max_seqlen_q + ) + batch_k, seqlen_k, nheads_k, head_size_k = get_shape_from_layout( + k, layout, cu_seqlens_k, max_seqlen_k + ) + + # assert + assert batch_q == batch_k + assert head_size_q == head_size_k + + return batch_q, nheads_q, nheads_k, head_size_q, seqlen_q, seqlen_k + + +def get_stride_from_layout(x: torch.Tensor, layout: Literal["bshd", "bhsd", "thd"]): + if layout == "thd": + strides = (0, x.stride(1), x.stride(0), x.stride(2)) + elif layout == "bhsd": + strides = (x.stride(0), x.stride(1), x.stride(2), x.stride(3)) + elif layout == "bshd": + strides = (x.stride(0), x.stride(2), x.stride(1), x.stride(3)) + else: + assert False, "Got unsupported layout." + return strides + + +def get_shape_and_strides_from_layout( + x: torch.Tensor, + layout: Literal["bshd", "bhsd", "thd"], + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + return get_shape_from_layout( + x, layout, cu_seqlens, max_seqlen + ), get_stride_from_layout(x, layout) + + +def get_strides_from_layout(q, k, v, o, layout): + q_strides = get_stride_from_layout(q, layout) + k_strides = get_stride_from_layout(k, layout) + v_strides = get_stride_from_layout(v, layout) + o_strides = get_stride_from_layout(o, layout) + return q_strides, k_strides, v_strides, o_strides + + +def get_padded_headsize(size): + # Get closest power of 2 over or equal to 32. + padded_d_model = 1 << (size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. + padded_d_model = max(padded_d_model, 16) + return padded_d_model + + +def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k): + q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze( + -1 + ) # (N_CTX_Q, 1) + k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze( + 0 + ) # (1, N_CTX_K) + relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) + return ( + -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos + ) # (Z, H, N_CTX_Q, N_CTX_K) + + +def round_multiple(x, m): + return (x + m - 1) // m * m + + +def save_tensor_to_csv(tensor, filename, decimal_places=2): + """ + save a 2d tensor to csv file + + args: + tensor: torch tensor of shape [rows, cols] + filename: output csv filename + decimal_places: number of decimal places (default: 2) + """ + # ensure tensor is 2d + if tensor.ndim != 2: + raise ValueError(f"tensor must be 2d, got shape {tensor.shape}") + + # ensure filename ends with .csv + if not filename.endswith(".csv"): + filename = filename + ".csv" + + # save to csv using numpy + np.savetxt( + filename, + tensor.detach().cpu().numpy(), + delimiter=",", + fmt=f"%.{decimal_places}f", + ) + + +# ------------------------------- +# Dropouts +# ------------------------------- +def create_dropout_mask(dropout_p, shape, seed): + device = "cuda" + rand_vals = torch.rand( + shape, + generator=torch.Generator(device=device).manual_seed(seed), + device=device, + dtype=torch.float32, + ) + return rand_vals > dropout_p + + +def create_dropout_mask_varlen( + dropout_p, batch, nheads_q, cu_seqlens_q, cu_seqlens_k, philox_seed +): + device = "cuda" + qlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + klens = cu_seqlens_k[1:] - cu_seqlens_k[:-1] + max_qlen = qlens.max() + max_klen = klens.max() + dropout_mask = torch.zeros((batch, nheads_q, max_qlen, max_klen), device=device) + for b in range(batch): + qlen = qlens[b] + klen = klens[b] + rand_vals = torch.rand( + (nheads_q, qlen, klen), + generator=torch.Generator(device=device).manual_seed(philox_seed), + device=device, + dtype=torch.float32, + ) + submask = rand_vals > dropout_p + dropout_mask[b, :, :qlen, :klen] = submask + + return dropout_mask + + +def write_dropout_mask(x, tensor_name="tensor"): + batch, head, seqlen_m, seqlen_n = x.shape + x = x.tolist() + + with open(f"{tensor_name}.csv", "w") as f: + writer = csv.writer(f) + for b in range(batch): + for h in range(head): + dropout_mask = x[b][h] + if True: + BLOCK_M = 64 + BLOCK_N = 64 + + # Calculate number of blocks in each dimension + m_blocks = math.ceil(seqlen_m / BLOCK_M) + n_blocks = math.ceil(seqlen_n / BLOCK_N) + + # Process each block + for m_block in range(m_blocks): + # Calculate row range for current block + row_start = m_block * BLOCK_M + row_end = min(row_start + BLOCK_M, seqlen_m) + + for n_block in range(n_blocks): + # Calculate column range for current block + col_start = n_block * BLOCK_N + col_end = min(col_start + BLOCK_N, seqlen_n) + + # Extract and write the current block + for row_idx in range(row_start, row_end): + row_data = dropout_mask[row_idx][col_start:col_end] + writer.writerow(row_data) + else: + writer.writerows(dropout_mask) + + +# ------------------------------- +# Rotary +# ------------------------------- +@triton.jit +def _rotary_kernel( + OUT, + X, + COS, + SIN, + CU_SEQLENS, + SEQLEN_OFFSETS, + seqlen, + nheads, + seqlen_ro, + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + ROTARY_DIM: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_M: tl.constexpr, +): + BLOCK_K: tl.constexpr = triton.next_power_of_2(ROTARY_DIM) + ROTARY_DIM_HALF = ROTARY_DIM // 2 + pid_head = tl.program_id(axis=0) + pid_m = tl.program_id(axis=1) + pid_batch = tl.program_id(axis=2) + + if not IS_VARLEN: + X = X + pid_batch * stride_x_batch + OUT = OUT + pid_batch * stride_out_batch + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + X = X + start_idx * stride_x_seqlen + OUT = OUT + start_idx * stride_out_seqlen + + if pid_m * BLOCK_M >= seqlen: + return + + rh = pid_head * BLOCK_H + tl.arange(0, BLOCK_H) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + + rk_half = tl.arange(0, BLOCK_K // 2) + COS = COS + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) + SIN = SIN + (rm_cs[:, None] * ROTARY_DIM_HALF + rk_half[None, :]) + mask_cs = (rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < ROTARY_DIM_HALF) + cos = tl.load(COS, mask=mask_cs, other=1.0).to(tl.float32) + sin = tl.load(SIN, mask=mask_cs, other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + + if not INTERLEAVED: + X = X + ( + rh[:, None, None] * stride_x_nheads + + rm[None, :, None] * stride_x_seqlen + + rk_half[None, None, :] * stride_x_headdim + ) + OUT = OUT + ( + rh[:, None, None] * stride_out_nheads + + rm[None, :, None] * stride_out_seqlen + + rk_half[None, None, :] * stride_out_headdim + ) + mask = ( + (rh[:, None, None] < nheads) + & (rm[None, :, None] < seqlen) + & (rk_half[None, None, :] < ROTARY_DIM_HALF) + ) + x0 = tl.load(X, mask=mask, other=0.0).to(tl.float32) + x1 = tl.load(X + ROTARY_DIM_HALF * stride_x_headdim, mask=mask, other=0.0).to( + tl.float32 + ) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + tl.store(OUT, o0, mask=mask) + tl.store(OUT + ROTARY_DIM_HALF * stride_out_headdim, o1, mask=mask) + else: + rk = tl.arange(0, BLOCK_K) + X = X + ( + rh[:, None, None] * stride_x_nheads + + rm[None, :, None] * stride_x_seqlen + + rk[None, None, :] * stride_x_headdim + ) + OUT = OUT + ( + rh[:, None, None] * stride_out_nheads + + rm[None, :, None] * stride_out_seqlen + + rk[None, None, :] * stride_out_headdim + ) + mask = ( + (rh[:, None, None] < nheads) + & (rm[None, :, None] < seqlen) + & (rk[None, None, :] < ROTARY_DIM) + ) + x = tl.load(X, mask=mask, other=0.0).to(tl.float32) + x0, x1 = tl.split(tl.reshape(x, [BLOCK_H, BLOCK_M, BLOCK_K // 2, 2])) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + o = tl.reshape(tl.join(o0, o1), [BLOCK_H, BLOCK_M, BLOCK_K]) + tl.store(OUT, o, mask=mask) + + +def _apply_rotary_kernel( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved: bool = False, + inplace: bool = False, + conjugate: bool = False, +) -> torch.Tensor: + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert ( + max_seqlen is not None + ), "If cu_seqlens is passed, max_seqlen must also be provided" + total_seqlen, nheads, headdim = x.shape + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim_half = cos.shape + assert sin.shape == cos.shape + rotary_dim = 2 * rotary_dim_half + assert rotary_dim <= headdim + assert headdim <= 256 + assert seqlen_ro >= seqlen + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in (torch.int32, torch.int64) + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + out = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + # Block heuristics + BLOCK_M = 8 if rotary_dim <= 128 else 4 + grid = ( + triton.cdiv(nheads, 2), + triton.cdiv(seqlen, BLOCK_M), + batch, + ) + + with torch.cuda.device(x.device.index): + torch.library.wrap_triton(_rotary_kernel)[grid]( + out, + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, + nheads, + seqlen_ro, + out.stride(0) if not is_varlen else 0, + out.stride(-3), + out.stride(-2), + out.stride(-1), + x.stride(0) if not is_varlen else 0, + x.stride(-3), + x.stride(-2), + x.stride(-1), + rotary_dim, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + BLOCK_M=BLOCK_M, + BLOCK_H=2, + ) + return out + + +class _ApplyRotary(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool, + inplace: bool, + seqlen_offsets: Union[int, torch.Tensor], + cu_seqlens: Optional[torch.Tensor], + max_seqlen: Optional[int], + ): + out = _apply_rotary_kernel( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, + conjugate=False, + ) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cu_seqlens) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.max_seqlen = max_seqlen + return out if not inplace else x + + @staticmethod + def backward(ctx, do: torch.Tensor): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cu_seqlens = ctx.saved_tensors + dx = _apply_rotary_kernel( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, + ) + return dx, None, None, None, None, None, None, None + + +def apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False, + inplace: bool = False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +) -> torch.Tensor: + """Public API: apply rotary embeddings to tensor x. + + Args: + x: (B, S, H, D) if `cu_seqlens` is None else (total_S, H, D). + cos, sin: (S_rotary, rotary_dim/2) + interleaved: GPT-J style if True. + inplace: modify x in place (saves memory if rotary_dim == D). + seqlen_offsets: int or (B,) tensor of starting offsets per sequence (KV cache decode). + cu_seqlens: (B+1,) tensor enabling varlen mode. + max_seqlen: required when `cu_seqlens` is provided. + """ + # FP8 path: upcast to bfloat16 (preferred) or float16 for rotary math to avoid excessive error + original_dtype = x.dtype + is_fp8_input = original_dtype == getattr(torch, "float8_e4m3fn", None) + if is_fp8_input: + # Choose bf16 if available in cos.dtype path; otherwise fallback to float16 + target_dtype = ( + torch.bfloat16 + if cos.dtype == torch.bfloat16 or torch.cuda.is_bf16_supported() + else torch.float16 + ) + # Upcast x, cos, sin for computation (without modifying originals in-place) + x_up = x.to(target_dtype) + cos_up = cos.to(target_dtype) if cos.dtype != target_dtype else cos + sin_up = sin.to(target_dtype) if sin.dtype != target_dtype else sin + out_up = _ApplyRotary.apply( + x_up, + cos_up, + sin_up, + interleaved, + False, + seqlen_offsets, + cu_seqlens, + max_seqlen, + ) + # Cast result back to original fp8 dtype + if inplace: + x.copy_(out_up.to(original_dtype)) + return x + return out_up.to(original_dtype) + else: + return _ApplyRotary.apply( + x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen + ) + + +def apply_rotary( + q: torch.Tensor, + k_new: Optional[torch.Tensor], + cos: torch.Tensor, + sin: torch.Tensor, + *, + causal: bool, + local: bool, + interleaved: bool = False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """High-level rotary application used by AMD prefill & decode paths. + + Policy (matches test reference & legacy semantics): + - If causal OR local attention ⇒ apply rotary directly on (B, S, H, D). + - Else (non-causal global) ⇒ flatten heads into sequence: (B, 1, S*H, D), + apply rotary once, then unflatten back. + - k_new (incremental KV slice) is always rotated directly when provided. + + Args: + q: (B, S, H, D) + k_new: Optional (B, S_k, H_k, D) + cos, sin: rotary caches (S_rotary, rotary_dim/2) + causal: causal attention flag + local: sliding-window / local attention flag (pre-computed outside) + interleaved: GPT-J style rotary layout + seqlen_offsets: int or (B,) tensor of per-sequence start offsets + Returns: + (q_rot, k_new_rot) + """ + assert q.ndim == 4, f"Expected q shape (B,S,H,D), got {q.shape}" + B, S, H, D = q.shape + use_flatten = (not causal) and (not local) + + if use_flatten: + # Flatten (S,H) -> (S*H) with an added singleton dim to preserve expected 4D shape. + q_flat = q.reshape(B, S * H, D).unsqueeze(1) # (B, 1, S*H, D) + q_flat = apply_rotary_emb( + q_flat, + cos, + sin, + interleaved=interleaved, + seqlen_offsets=seqlen_offsets, + ) + # Restore shape back to (B, S, H, D) + q = q_flat.view(B, 1, S * H, D).reshape(B, S, H, D) + else: + q = apply_rotary_emb( + q, + cos, + sin, + interleaved=interleaved, + seqlen_offsets=seqlen_offsets, + ) + + if k_new is not None: + k_new = apply_rotary_emb( + k_new, + cos, + sin, + interleaved=interleaved, + seqlen_offsets=seqlen_offsets, + ) + return q, k_new + + +# ------------------------------- +# Runtime info +# ------------------------------- +@functools.cache +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +@functools.cache +def get_arch(): + return triton.runtime.driver.active.get_current_target().arch + + +@functools.cache +def get_cu_count(): + return torch.cuda.get_device_properties( + torch.cuda.current_device() + ).multi_processor_count + + +@functools.cache +def is_cdna(): + return is_hip() and get_arch() in ( + "gfx908", + "gfx90a", + "gfx940", + "gfx941", + "gfx942", + "gfx950", + ) + + +@functools.cache +def is_rdna(): + return is_hip() and get_arch() in ( + "gfx1030", + "gfx1100", + "gfx1101", + "gfx1102", + "gfx1200", + "gfx1201", + ) diff --git a/aiter/ops/triton/mha.py b/aiter/ops/triton/mha.py index 43248c0ed2..46007f28f3 100644 --- a/aiter/ops/triton/mha.py +++ b/aiter/ops/triton/mha.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch import triton import triton.language as tl @@ -12,6 +12,7 @@ from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton.utils.device_info import get_num_xcds from aiter.ops.triton._triton_kernels.mha import _attn_fwd, _get_config +from aiter.ops.triton._triton_kernels.flash_attn_triton_amd import flash_attn_2 _LOGGER = AiterTritonLogger() @@ -33,103 +34,6 @@ def mha_set_use_int64_strides(value: bool): _USE_INT64_STRIDES = value -def _cast_to_fp8( - x: torch.Tensor, - fp8_dtype, - layout, - clamp_val=1e-9, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Convert a tensor to FP8 format, returning an FP8 tensor and a descale factor. - Args: - - x (torch.Tensor): shape [batch, seq_len, heads, dim] - Returns: - - x_fp8 (torch.Tensor): FP8 tensor with the same shape as x - - descale_factor (torch.Tensor): tensor of shape [batch, 1, heads, 1] - """ - if len(x.shape) != 4: - raise ValueError( - f"'bshd' tensor should have shape [batch, seqlen, heads, dim], got {x.shape}" - ) - reduce_dims = (1, 3) # seq_len and dim dimensions - - # Compute the absolute max along reduce_dims, clamped to avoid 0-scale - x_abs_max = x.abs().amax(dim=reduce_dims) - x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) - - # Unsqueeze back to a shape suitable for broadcast - unsqueeze_dims = sorted(reduce_dims) - for d in unsqueeze_dims: - x_abs_max = x_abs_max.unsqueeze(d) - - # compute scale and descale - fp8_max = torch.finfo(fp8_dtype).max - scale = fp8_max / x_abs_max - descale_factor = x_abs_max / fp8_max - - # cast to FP8, optionally setting requires_grad - x_fp8 = (x * scale).to(fp8_dtype) - - return x_fp8, descale_factor - - -def _cast_varlen_to_fp8( - x: torch.Tensor, - fp8_dtype: torch.dtype, - cu_seqlens, - clamp_val: float = 1e-9, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Convert a tensor of sequences with variable seq_len into fp8. - Args: - - x (torch.Tensor): shape [total_seq_len, heads, dim] - Returns: - - x_fp8 (torch.Tensor): shape [total_seq_len, heads, dim] - - descale_factors (torch.Tensor): shape [batch, heads] - """ - # validate tensor shape - if len(x.shape) != 3: - raise ValueError( - f"tensor should have shape [total_seqlen, heads, dim], got {x.shape}" - ) - num_heads = x.shape[1] - - # Get batch size from cu_seqlens - batch = cu_seqlens.shape[0] - 1 - fp8_max = torch.finfo(fp8_dtype).max - - # Compute scale and descale factors per sequence - x_fp8 = torch.zeros_like(x, dtype=fp8_dtype) - descale_factors = torch.zeros( - (batch, num_heads), device=x.device, dtype=torch.float32 - ) - - for i in range(batch): - start = cu_seqlens[i] - end = cu_seqlens[i + 1] - x_slice = x[start:end] # Slice for current sequence - - # Standard tensor (0: seq_len, 2: head_dim) - x_abs_max = x_slice.abs().amax(dim=(0, 2)) # [heads] - - # apply minimum clamping - x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) - - # compute scale and descale factors - scale_i = fp8_max / x_abs_max - descale_i = x_abs_max / fp8_max - - # store descale factors - descale_factors[i, :] = descale_i - - scale_reshape = scale_i.reshape(1, num_heads, 1) - - # scale and cast to FP8 - x_fp8[start:end] = (x_slice * scale_reshape).to(fp8_dtype) - - return x_fp8, descale_factors - - def _flash_attn_forward( q: torch.Tensor, k: torch.Tensor, @@ -151,7 +55,7 @@ def _flash_attn_forward( descale_k: Optional[torch.Tensor] = None, descale_v: Optional[torch.Tensor] = None, config: Optional[dict[str, any]] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int, int]: if bias is not None: raise ValueError("Bias is not supported yet in the Triton Backend") @@ -571,221 +475,6 @@ def flash_attn_func( ) -class _FlashAttnFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_softmax, - is_grad_enabled, - config=None, - ): - is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v]) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(3) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - - # cast input to fp8 - fp8_dtype = types.get_fp8_e4m3_dtype() - q_fp8, descale_q = _cast_to_fp8(q, fp8_dtype, "bshd") - k_fp8, descale_k = _cast_to_fp8(k, fp8_dtype, "bshd") - v_fp8, descale_v = _cast_to_fp8(v, fp8_dtype, "bshd") - - out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = ( - _flash_attn_forward( - q_fp8, - k_fp8, - v_fp8, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=int(window_size[0]), - window_size_right=int(window_size[1]), - bias=None, - alibi_slopes=alibi_slopes, - return_lse=return_lse, - return_softmax=return_softmax and dropout_p > 0, - max_seqlen_q=q.shape[1], - max_seqlen_k=k.shape[1], - cu_seqlens_q=None, - cu_seqlens_k=None, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - config=config, - ) - ) - - if is_grad: - ctx.save_for_backward( - q_fp8, - k_fp8, - v_fp8, - out_padded, - softmax_lse, - descale_q, - descale_k, - descale_v, - ) - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - - out = out_padded[..., :head_size_og] - result = [out] - if return_lse: - result.append(softmax_lse) - if return_softmax: - result.append(S_dmask) - - return result[0] if len(result) == 1 else tuple(result) - - @staticmethod - def backward(ctx, do, *args): - q_fp8, k_fp8, v_fp8, out, softmax_lse, descale_q, descale_k, descale_v = ( - ctx.saved_tensors - ) - dq, dk, dv = ( - torch.zeros_like(q_fp8, dtype=torch.float32), - torch.zeros_like(k_fp8, dtype=torch.float32), - torch.zeros_like(v_fp8, dtype=torch.float32), - ) - head_size_v_og = do.size(3) - do_padded = do - if head_size_v_og % 8 != 0: - do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) - - fp8_dtype = types.get_fp8_e4m3_dtype() - do_padded_fp8, descale_do = _cast_to_fp8(do_padded, fp8_dtype, "bshd") - if _USE_FUSED_BWD_KERNEL: - flash_attn_fused_backward( - do_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out, - softmax_lse, - dq, - dk, - dv, - None, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - None, - None, - max_seqlen_q=q_fp8.shape[1], - max_seqlen_k=k_fp8.shape[1], - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_do=descale_do, - USE_INT64_STRIDES=_USE_INT64_STRIDES, - ) - else: - flash_attn_onekernel_backward( - do_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out, - softmax_lse, - dq, - dk, - dv, - None, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - None, - None, - max_seqlen_q=q_fp8.shape[1], - max_seqlen_k=k_fp8.shape[1], - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_do=descale_do, - USE_INT64_STRIDES=_USE_INT64_STRIDES, - ) - - # dq = dq[..., : q_fp8.shape[-1]] # We could have padded the head dimension - # dk = dk[..., : k_fp8.shape[-1]] - # dv = dv[..., : v_fp8.shape[-1]] - return ( - dq, - dk, - dv, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -def flash_attn_fp8_func( - q, - k, - v, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_lse=False, - return_attn_probs=False, - config: Optional[dict[str, any]] = None, -): - _LOGGER.info( - f"FLASH_ATTN_FP8: q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)}" - ) - return _FlashAttnFP8Func.apply( - q, - k, - v, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_attn_probs, - torch.is_grad_enabled(), - config, - ) - - class _FlashAttnVarlenFunc(torch.autograd.Function): @staticmethod def forward( @@ -1056,229 +745,92 @@ def flash_attn_varlen_func( ) -class _FlashAttnVarlenFP8Func(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, - softmax_scale, - causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_softmax, - block_table, - is_grad_enabled, - config=None, - ): - is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v]) - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - head_size_og = q.size(2) - if head_size_og % 8 != 0: - q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) - k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) - v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) +def flash_attn_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k: Optional[torch.Tensor] = None, + v: Optional[torch.Tensor] = None, + cache_seqlens: Optional[Union[torch.Tensor, int]] = None, + softmax_scale: Optional[float] = None, + causal: bool = True, + window_size: tuple[int, int] = (-1, -1), + softcap: float = 0.0, + num_splits: int = 0, + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, + rotary_interleaved: bool = True, + return_softmax_lse: bool = False, +): + """ + This mirrors the public flash_attn v2 interface for KV cache using the AMD Triton backend. - # cast input to fp8 - fp8_dtype = types.get_fp8_e4m3_dtype() - q_fp8, descale_q = _cast_varlen_to_fp8(q, fp8_dtype, cu_seqlens=cu_seqlens_q) - k_fp8, descale_k = _cast_varlen_to_fp8(k, fp8_dtype, cu_seqlens=cu_seqlens_k) - v_fp8, descale_v = _cast_varlen_to_fp8(v, fp8_dtype, cu_seqlens=cu_seqlens_k) + Args: + q: (batch, seqlen_q, nheads_q, headdim) + k_cache / v_cache: Either contiguous (batch, seqlen_cache, nheads_k, headdim) or paged + (num_blocks, page_block_size, nheads_k, headdim) when block_table provided. + k, v: Optional incremental tokens to append in-place (appended logically after existing cache). + cache_seqlens: int or (batch,) current valid lengths per batch entry. + softmax_scale: Optional override; defaults to 1/sqrt(headdim). + causal: Apply causal masking. + window_size: (left, right) local attention window; (-1,-1) = full. + softcap: (float) currently must be 0.0 (backend limitation). + num_splits: 0 or 1 only (backend limitation >1). + rotary_cos/rotary_sin: Optional rotary embeddings (applied if provided) – interleaving flag unused here. + cache_batch_idx/cache_leftpad: Optional indexing / left padding metadata. + block_table: Optional paging table mapping logical blocks for paged KV cache. + alibi_slopes: (nheads,) or (batch,nheads) bias slopes (currently ignored if provided – placeholder). + rotary_interleaved: Flag kept for parity (currently forwarded as True constant to backend which ignores it). + return_softmax_lse: If True returns (out, lse) else out. - out_padded, softmax_lse, S_dmask, philox_seed, philox_offset = ( - _flash_attn_forward( - q_fp8, - k_fp8, - v_fp8, - dropout_p, - softmax_scale, - causal=causal, - window_size_left=int(window_size[0]), - window_size_right=int(window_size[1]), - bias=None, - alibi_slopes=alibi_slopes, - return_lse=return_lse, - return_softmax=return_softmax and dropout_p > 0, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - config=config, - ) + Returns: + out (and optionally softmax_lse): (batch, seqlen_q, nheads_q, headdim) + """ + # Feature guards / normalization + if softcap != 0.0: + raise NotImplementedError( + "softcap != 0 not supported in v2 KV cache backend yet" + ) + if num_splits not in (0, 1): + raise NotImplementedError( + "num_splits > 1 not supported in v2 KV cache backend yet" ) - if is_grad: - ctx.save_for_backward( - q_fp8, - k_fp8, - v_fp8, - out_padded, - softmax_lse, - cu_seqlens_q, - cu_seqlens_k, - descale_q, - descale_k, - descale_v, - ) - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_k = max_seqlen_k - ctx.philox_seed = philox_seed - ctx.philox_offset = philox_offset - ctx.dropout_p = dropout_p - ctx.softmax_scale = softmax_scale - ctx.causal = causal - ctx.window_size = window_size - ctx.alibi_slopes = alibi_slopes - - out = out_padded[..., :head_size_og] - result = [out] - if return_lse: - result.append(softmax_lse) - if return_softmax: - result.append(S_dmask) - return result[0] if len(result) == 1 else tuple(result) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) - @staticmethod - def backward(ctx, do, *args): - ( - q_fp8, - k_fp8, - v_fp8, - out, - softmax_lse, - cu_seqlens_q, - cu_seqlens_k, - descale_q, - descale_k, - descale_v, - ) = ctx.saved_tensors - dq, dk, dv = ( - torch.zeros_like(q_fp8, dtype=torch.float32), - torch.zeros_like(k_fp8, dtype=torch.float32), - torch.zeros_like(v_fp8, dtype=torch.float32), + if cache_seqlens is not None and isinstance(cache_seqlens, int): + cache_seqlens = torch.full( + (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device ) - head_size_v_og = do.size(3) - do_padded = do - if head_size_v_og % 8 != 0: - do_padded = torch.nn.functional.pad(do, [0, 8 - head_size_v_og % 8]) - - fp8_dtype = types.get_fp8_e4m3_dtype() - do_padded_fp8, descale_do = _cast_varlen_to_fp8( - do_padded, fp8_dtype, "thd", cu_seqlens_q - ) - if _USE_FUSED_BWD_KERNEL: - flash_attn_fused_backward( - do_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out, - softmax_lse, - dq, - dk, - dv, - None, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_k=ctx.max_seqlen_k, - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_do=descale_do, - USE_INT64_STRIDES=_USE_INT64_STRIDES, - ) - else: - flash_attn_onekernel_backward( - do_padded_fp8, - q_fp8, - k_fp8, - v_fp8, - out, - softmax_lse, - dq, - dk, - dv, - None, - ctx.softmax_scale, - ctx.alibi_slopes, - ctx.causal, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_k=ctx.max_seqlen_k, - dropout_p=ctx.dropout_p, - philox_seed=ctx.philox_seed, - philox_offset=ctx.philox_offset, - descale_q=descale_q, - descale_k=descale_k, - descale_v=descale_v, - descale_do=descale_do, - USE_INT64_STRIDES=_USE_INT64_STRIDES, - ) - dq = dq[..., : q_fp8.shape[-1]] # We could have padded the head dimension - dk = dk[..., : k_fp8.shape[-1]] - dv = dv[..., : v_fp8.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None + # Contiguity (align last dim contiguous requirement similar to v3 path assumptions) + assert q.stride(-1) == 1 and k_cache.stride(-1) == 1 and v_cache.stride(-1) == 1 -def flash_attn_varlen_fp8_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - alibi_slopes=None, - deterministic=False, - return_lse=False, - return_attn_probs=False, - block_table=None, - config: Optional[dict[str, any]] = None, -): - _LOGGER.info( - f"FLASH_ATTN_VARLEN_FP8: q={tuple(q.shape)} k={tuple(k.shape)} v={tuple(v.shape)}" - ) - return _FlashAttnVarlenFP8Func.apply( + out, softmax_lse = flash_attn_2.fwd_kvcache( q, + k_cache, + v_cache, k, v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p, + cache_seqlens, + rotary_cos, + rotary_sin, + cache_batch_idx, + cache_leftpad, + block_table, + alibi_slopes, + None, # out tensor softmax_scale, causal, - window_size, - alibi_slopes, - deterministic, - return_lse, - return_attn_probs, - block_table, - torch.is_grad_enabled(), - config, + int(window_size[0]), + int(window_size[1]), + 0.0, # softcap (guarded) + rotary_interleaved, + num_splits, ) + return (out, softmax_lse) if return_softmax_lse else out diff --git a/aiter/ops/triton/mha_v3.py b/aiter/ops/triton/mha_v3.py new file mode 100644 index 0000000000..459a28fcdd --- /dev/null +++ b/aiter/ops/triton/mha_v3.py @@ -0,0 +1,1264 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +from __future__ import annotations +from typing import Optional, Tuple, Union +import torch + +from aiter.ops.triton._triton_kernels.flash_attn_triton_amd import flash_attn_3 + + +class _FlashAttnV3Func(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: float | None, + causal: bool, + qv: Optional[torch.Tensor], + q_descale: Optional[torch.Tensor], + k_descale: Optional[torch.Tensor], + v_descale: Optional[torch.Tensor], + window_size: Tuple[int, int], + attention_chunk: int, + softcap: float, + num_splits: int, + pack_gqa: Optional[bool], + deterministic: bool, + sm_margin: int, + ): + # Derive softmax scale if not provided (include qv width like Hopper v3) + if softmax_scale is None: + q_extra = qv.shape[-1] if qv is not None else 0 + softmax_scale = (q.shape[-1] + q_extra) ** (-0.5) + + # Fast validation of unsupported features + if qv is not None: + raise NotImplementedError("qv is not supported in AMD Triton v3 yet") + if attention_chunk not in (0, 1): + raise NotImplementedError("attention_chunk > 1 not supported (0 or 1 only)") + if softcap != 0.0: + raise NotImplementedError("softcap not implemented in AMD Triton v3") + if num_splits != 1: + raise NotImplementedError("num_splits != 1 not supported in AMD Triton v3") + if pack_gqa is not None: + raise NotImplementedError("pack_gqa not implemented in AMD Triton v3") + if sm_margin != 0: + raise NotImplementedError("sm_margin != 0 not supported in AMD Triton v3") + + out, softmax_lse = flash_attn_3.fwd( + q, + k, + v, + None, # k_new + None, # v_new + None, # qv + None, # out tensor (allocate inside) + None, # cu_seqlens_q + None, # cu_seqlens_k + None, # cu_seqlens_k_new + None, # seqused_q + None, # seqused_k + None, # max_seqlen_q + None, # max_seqlen_k + None, # page_table + None, # kv_batch_idx + None, # leftpad_k + None, # rotary_cos + None, # rotary_sin + None, # seqlens_rotary + q_descale, + k_descale, + v_descale, + softmax_scale, + causal, + int(window_size[0]), + int(window_size[1]), + attention_chunk, + softcap, + False, # rotary_interleaved + None, # scheduler_metadata + num_splits, + pack_gqa, + sm_margin, + ) + + ctx.save_for_backward( + q, k, v, out, softmax_lse, q_descale, k_descale, v_descale + ) + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.deterministic = deterministic + ctx.sm_margin = sm_margin + return out + + @staticmethod + def backward(ctx, dout: torch.Tensor): + q, k, v, out, softmax_lse, q_descale, k_descale, v_descale = ctx.saved_tensors + + dq, dk, dv, _delta = flash_attn_3.bwd( + dout, + q, + k, + v, + out, + softmax_lse, + None, # dq + None, # dk + None, # dv + None, # cu_seqlens_q + None, # cu_seqlens_k + None, # seqused_q + None, # seqused_k + None, # max_seqlen_q + None, # max_seqlen_k + ctx.softmax_scale, + ctx.causal, + int(ctx.window_size[0]), + int(ctx.window_size[1]), + ctx.softcap, + ctx.deterministic, + ctx.sm_margin, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + return ( + dq, # q + dk, # k + dv, # v + None, # softmax_scale + None, # causal + None, # qv + None, # q_descale + None, # k_descale + None, # v_descale + None, # window_size + None, # attention_chunk + None, # softcap + None, # num_splits + None, # pack_gqa + None, # deterministic + None, # sm_margin + ) + + +def flash_attn_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, + qv: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + window_size: Tuple[int, int] = (-1, -1), + attention_chunk: int = 0, + softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + deterministic: bool = False, + sm_margin: int = 0, +): + """FlashAttention v3 entry point.""" + return _FlashAttnV3Func.apply( + q, + k, + v, + softmax_scale, + causal, + qv, + q_descale, + k_descale, + v_descale, + window_size, + attention_chunk, + softcap, + num_splits, + pack_gqa, + deterministic, + sm_margin, + ) + + +class _FlashAttnVarlenV3Func(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float | None, + causal: bool, + q_descale: torch.Tensor | None, + k_descale: torch.Tensor | None, + v_descale: torch.Tensor | None, + window_size: tuple[int, int], + attention_chunk: int, + softcap: float, + deterministic: bool, + sm_margin: int, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + if attention_chunk != 0: + raise NotImplementedError( + "attention_chunk != 0 not supported in varlen v3 yet" + ) + if softcap != 0.0: + raise NotImplementedError("softcap not implemented in varlen v3 yet") + if sm_margin != 0: + raise NotImplementedError("sm_margin != 0 not supported in varlen v3 yet") + + out, softmax_lse = flash_attn_3.fwd( + q, + k, + v, + None, # k_new + None, # v_new + None, # qv + None, # out tensor + cu_seqlens_q, + cu_seqlens_k, + None, # cu_seqlens_k_new + None, # seqused_q + None, # seqused_k + max_seqlen_q, + max_seqlen_k, + None, # page_table + None, # kv_batch_idx + None, # leftpad_k + None, # rotary_cos + None, # rotary_sin + None, # seqlens_rotary + q_descale, + k_descale, + v_descale, + softmax_scale, + causal, + int(window_size[0]), + int(window_size[1]), + attention_chunk, + softcap, + False, # rotary_interleaved + None, # scheduler_metadata + 1, # num_splits + None, # pack_gqa + sm_margin, + ) + + ctx.save_for_backward( + q, k, v, out, softmax_lse, q_descale, k_descale, v_descale + ) + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.deterministic = deterministic + ctx.sm_margin = sm_margin + ctx.cu_seqlens_q = cu_seqlens_q + ctx.cu_seqlens_k = cu_seqlens_k + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + return out + + @staticmethod + def backward(ctx, dout: torch.Tensor): + q, k, v, out, softmax_lse, q_descale, k_descale, v_descale = ctx.saved_tensors + + dq, dk, dv, _delta = flash_attn_3.bwd( + dout, + q, + k, + v, + out, + softmax_lse, + None, # dq + None, # dk + None, # dv + ctx.cu_seqlens_q, + ctx.cu_seqlens_k, + None, # seqused_q + None, # seqused_k + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.softmax_scale, + ctx.causal, + int(ctx.window_size[0]), + int(ctx.window_size[1]), + ctx.softcap, + ctx.deterministic, + ctx.sm_margin, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + return ( + dq, + dk, + dv, + None, # cu_seqlens_q + None, # cu_seqlens_k + None, # max_seqlen_q + None, # max_seqlen_k + None, # softmax_scale + None, # causal + None, # q_descale + None, # k_descale + None, # v_descale + None, # window_size + None, # attention_chunk + None, # softcap + None, # deterministic + None, # sm_margin + ) + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: Optional[float] = None, + causal: bool = False, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + window_size: Tuple[int, int] = (-1, -1), + attention_chunk: int = 0, + softcap: float = 0.0, + deterministic: bool = False, + sm_margin: int = 0, +): + """FlashAttention v3 varlen path.""" + return _FlashAttnVarlenV3Func.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, + q_descale, + k_descale, + v_descale, + window_size, + attention_chunk, + softcap, + deterministic, + sm_margin, + ) + + +def flash_attn_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k: Optional[torch.Tensor] = None, + v: Optional[torch.Tensor] = None, + qv: Optional[torch.Tensor] = None, + cache_seqlens: Optional[Union[torch.Tensor, int]] = None, + softmax_scale: Optional[float] = None, + causal: bool = True, + window_size: Tuple[int, int] = (-1, -1), + attention_chunk: int = 0, + softcap: float = 0.0, + num_splits: int = 0, + pack_gqa: Optional[bool] = None, + sm_margin: int = 0, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + return_softmax_lse: bool = False, + page_table: Optional[torch.Tensor] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, +): + """ + Arguments mirror Hopper's `flash_attn_with_kvcache` with current backend limitations. + Unsupported: backward, qv, softcap!=0, pack_gqa, sm_margin!=0, attention_chunk>1, num_splits>1, + simultaneous varlen (cu_seqlens_q) + cache_seqlens tensor, and partial rotary inputs. + """ + # Scale + if softmax_scale is None: + q_extra = qv.shape[-1] if qv is not None else 0 + softmax_scale = (q.shape[-1] + q_extra) ** (-0.5) + + # Feature guards + if qv is not None: + raise NotImplementedError("qv not supported in KV cache path yet") + if softcap != 0.0: + raise NotImplementedError("softcap not implemented in KV cache path") + if pack_gqa is not None: + raise NotImplementedError("pack_gqa not implemented in KV cache path") + if sm_margin != 0: + raise NotImplementedError("sm_margin != 0 not supported in KV cache path") + if attention_chunk not in (0, 1): + raise NotImplementedError("attention_chunk > 1 not supported (0 or 1 only)") + if num_splits not in (0, 1): + raise NotImplementedError("num_splits > 1 not supported in KV cache path") + + if cache_seqlens is not None and isinstance(cache_seqlens, int): + cache_seqlens = torch.full( + (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + ) + + if cu_seqlens_q is not None and cache_seqlens is not None: + raise NotImplementedError( + "Varlen decode with cache_seqlens tensor not supported yet" + ) + if (rotary_cos is None) ^ (rotary_sin is None): + raise ValueError( + "Both rotary_cos and rotary_sin must be provided together or neither" + ) + if ( + (rotary_cos is not None) + and rotary_seqlens is not None + and cu_seqlens_q is None + and cache_seqlens is None + ): + raise ValueError( + "rotary_seqlens provided without cu_seqlens_q or cache_seqlens context" + ) + + kv_batch_idx = cache_batch_idx + leftpad_k = cache_leftpad + seqlens_rotary = rotary_seqlens + + out, softmax_lse = flash_attn_3.fwd( + q, + k_cache, + v_cache, + k, + v, + None, # qv + None, # out allocate + cu_seqlens_q, + None, # cu_seqlens_k + cu_seqlens_k_new, + None, # seqused_q + cache_seqlens if isinstance(cache_seqlens, torch.Tensor) else None, # seqused_k + max_seqlen_q, + None, # max_seqlen_k + page_table, + kv_batch_idx, + leftpad_k, + rotary_cos, + rotary_sin, + seqlens_rotary, + q_descale, + k_descale, + v_descale, + softmax_scale, + causal, + int(window_size[0]), + int(window_size[1]), + attention_chunk, + softcap, + False, # rotary_interleaved + None, # scheduler_metadata + num_splits if num_splits != 0 else 1, + pack_gqa, + sm_margin, + ) + return (out, softmax_lse) if return_softmax_lse else out + + +# ------------------------------- +# FP8 Wrappers +# ------------------------------- +# do the quantization to fp8 internally and maintain high-precision inputs/outputs + + +def _quantize_bshd( + x: torch.Tensor, + fp8_dtype: torch.dtype, + clamp_val=1e-9, + group_size: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Convert a tensor to FP8 format, returning an FP8 tensor and a descale factor. + + Args: + x (torch.Tensor): shape [batch, seq_len, heads, dim] + fp8_dtype (torch.dtype): FP8 data type (e.g., torch.float8_e4m3fnuz) + clamp_val (float): minimum value for scaling to avoid division by zero + group_size (int, optional): For GQA/MQA on query tensors, specify the group size (num_heads // num_kv_heads) + to group query heads appropriately. If None, computes scaling per head. + Returns: + x_fp8 (torch.Tensor): FP8 tensor with the same shape as x (leaf tensor if requires_grad=True) + descale_factor (torch.Tensor): tensor of shape [batch, num_heads // group_size] if group_size is specified, + otherwise [batch, heads] + """ + if len(x.shape) != 4: + raise ValueError( + f"'bshd' tensor should have shape [batch, seqlen, heads, dim], got {x.shape}" + ) + + batch, seqlen, num_heads, head_dim = x.shape + + # For GQA/MQA: if group_size is specified and > 1, + # we need to group query heads and compute scaling per group + if group_size is not None and group_size > 1: + assert ( + num_heads % group_size == 0 + ), f"num_heads ({num_heads}) must be divisible by group_size ({group_size})" + + num_groups = num_heads // group_size + + # Reshape to group query heads: [batch, seqlen, num_groups, group_size, head_dim] + x_grouped = x.view(batch, seqlen, num_groups, group_size, head_dim) + + # Compute max over seqlen, group_size (query heads in group), and head_dim + # Result shape: [batch, num_groups] + x_abs_max = x_grouped.abs().amax(dim=(1, 3, 4)) + x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) + + # Unsqueeze to [batch, 1, num_groups, 1, 1] for broadcasting + x_abs_max_broadcast = x_abs_max.unsqueeze(1).unsqueeze(3).unsqueeze(4) + + # Compute scale and descale + fp8_max = torch.finfo(fp8_dtype).max + scale = fp8_max / x_abs_max_broadcast + descale_factor = (x_abs_max / fp8_max).to(torch.float32) + + # Quantize to FP8 and reshape back to original shape + x_fp8 = ( + (x_grouped * scale).view(batch, seqlen, num_heads, head_dim).to(fp8_dtype) + ) + else: + # Standard case: compute scaling per head + reduce_dims = (1, 3) # seq_len and dim dimensions + + # Compute the absolute max along reduce_dims, clamped to avoid 0-scale + # Result shape: [batch, heads] + x_abs_max = x.abs().amax(dim=reduce_dims) + x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) + + # Unsqueeze to [batch, 1, heads, 1] for broadcasting during scaling + x_abs_max_broadcast = x_abs_max.unsqueeze(1).unsqueeze(3) + + # compute scale and descale + fp8_max = torch.finfo(fp8_dtype).max + scale = fp8_max / x_abs_max_broadcast + descale_factor = (x_abs_max / fp8_max).to(torch.float32) + + # Quantize to FP8 + x_fp8 = (x * scale).to(fp8_dtype) + + # Detach to make a leaf tensor, This is required because PyTorch only populates .grad on leaf tensors + # x_fp8_leaf = x_fp8.detach().requires_grad_(True) + + return x_fp8, descale_factor + + +def _quantize_thd( + x: torch.Tensor, + fp8_dtype: torch.dtype, + cu_seqlens: torch.Tensor, + clamp_val=1e-9, + group_size: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Convert a tensor to FP8 format for varlen inputs, returning an FP8 tensor and a descale factor. + + This function computes descale factors per sequence in the batch, analogous to how _quantize_bshd + computes per-batch descale factors. + + Args: + x (torch.Tensor): shape [total_tokens, heads, dim] + fp8_dtype (torch.dtype): FP8 data type (e.g., torch.float8_e4m3fnuz) + cu_seqlens (torch.Tensor): Cumulative sequence lengths [batch_size + 1] + clamp_val (float): minimum value for scaling to avoid division by zero + group_size (int, optional): For GQA/MQA on query tensors, specify the group size (num_heads // num_kv_heads) + to group query heads appropriately. If None, computes scaling per head. + Returns: + x_fp8 (torch.Tensor): FP8 tensor with the same shape as x + descale_factor (torch.Tensor): tensor of shape [batch_size, num_heads // group_size] if group_size is specified, + otherwise [batch_size, heads] + """ + if len(x.shape) != 3: + raise ValueError( + f"'thd' tensor should have shape [total_tokens, heads, dim], got {x.shape}" + ) + + total_tokens, num_heads, head_dim = x.shape + batch_size = len(cu_seqlens) - 1 + + fp8_max = torch.finfo(fp8_dtype).max + + # For GQA/MQA: if group_size is specified and > 1, + # we need to group query heads and compute scaling per group + if group_size is not None and group_size > 1: + assert ( + num_heads % group_size == 0 + ), f"num_heads ({num_heads}) must be divisible by group_size ({group_size})" + + num_groups = num_heads // group_size + + # Reshape to group query heads: [total_tokens, num_groups, group_size, head_dim] + x_grouped = x.view(total_tokens, num_groups, group_size, head_dim) + + # Compute descale factors per sequence (analogous to per-batch in bshd) + descale_list = [] + x_fp8_list = [] + + for b in range(batch_size): + start = cu_seqlens[b].item() + end = cu_seqlens[b + 1].item() + + # Get tokens for this sequence: [seq_len, num_groups, group_size, head_dim] + x_seq = x_grouped[start:end] + + # Compute max over seq_len, group_size, and head_dim + # Result shape: [num_groups] + x_abs_max = x_seq.abs().amax(dim=(0, 2, 3)) + x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) + + # Compute descale for this sequence: [num_groups] + descale_seq = (x_abs_max / fp8_max).to(torch.float32) + descale_list.append(descale_seq) + + # Quantize this sequence + # Unsqueeze to [1, num_groups, 1, 1] for broadcasting + x_abs_max_broadcast = x_abs_max.unsqueeze(0).unsqueeze(2).unsqueeze(3) + scale = fp8_max / x_abs_max_broadcast + x_seq_fp8 = (x_seq * scale).to(fp8_dtype) + x_fp8_list.append(x_seq_fp8) + + # Stack descale factors: [batch_size, num_groups] + descale_factor = torch.stack(descale_list, dim=0) + + # Concatenate quantized sequences and reshape back to original shape + x_fp8 = torch.cat(x_fp8_list, dim=0).view(total_tokens, num_heads, head_dim) + else: + # Standard case: compute scaling per head for each sequence + descale_list = [] + x_fp8_list = [] + + for b in range(batch_size): + start = cu_seqlens[b].item() + end = cu_seqlens[b + 1].item() + + # Get tokens for this sequence: [seq_len, num_heads, head_dim] + x_seq = x[start:end] + + # Compute max over seq_len and head_dim + # Result shape: [num_heads] + x_abs_max = x_seq.abs().amax(dim=(0, 2)) + x_abs_max = torch.maximum(x_abs_max, x.new_tensor(clamp_val)) + + # Compute descale for this sequence: [num_heads] + descale_seq = (x_abs_max / fp8_max).to(torch.float32) + descale_list.append(descale_seq) + + # Quantize this sequence + # Unsqueeze to [1, num_heads, 1] for broadcasting + x_abs_max_broadcast = x_abs_max.unsqueeze(0).unsqueeze(2) + scale = fp8_max / x_abs_max_broadcast + x_seq_fp8 = (x_seq * scale).to(fp8_dtype) + x_fp8_list.append(x_seq_fp8) + + # Stack descale factors: [batch_size, num_heads] + descale_factor = torch.stack(descale_list, dim=0) + + # Concatenate quantized sequences + x_fp8 = torch.cat(x_fp8_list, dim=0) + + return x_fp8, descale_factor + + +class _FlashAttnFP8Wrapper(torch.autograd.Function): + """ + FP8 Flash Attention wrapper that maintains high-precision inputs/outputs. + + This wrapper allows users to pass BF16/FP32 tensors and automatically handles + the FP8 quantization internally, maintaining backward compatibility with + high-precision training workflows. + + Forward: BF16/FP32 -> FP8 -> flash_attn -> FP32 output + Backward: FP32 grad_out -> flash_attn_bwd -> FP32 grads -> input dtype grads + """ + + @staticmethod + def forward( + ctx, + q: torch.Tensor, # High precision (BF16/FP32) + k: torch.Tensor, # High precision (BF16/FP32) + v: torch.Tensor, # High precision (BF16/FP32) + softmax_scale: Optional[float], + causal: bool, + window_size: Tuple[int, int], + attention_chunk: int, + softcap: float, + deterministic: bool, + sm_margin: int, + ): + batch, seqlen, num_q_heads, head_dim = q.shape + _, _, num_kv_heads, _ = k.shape + + # Quantize inputs to FP8 + fp8_dtype = torch.float8_e4m3fnuz + + # For GQA/MQA: quantize query with grouped scaling + group_size = ( + num_q_heads // num_kv_heads if num_q_heads != num_kv_heads else None + ) + q_fp8, q_descale = _quantize_bshd(q, fp8_dtype, group_size=group_size) + k_fp8, k_descale = _quantize_bshd(k, fp8_dtype) + v_fp8, v_descale = _quantize_bshd(v, fp8_dtype) + + # Verify descale shapes for GQA/MQA + assert q_descale.shape == ( + batch, + num_kv_heads, + ), f"q_descale shape {q_descale.shape} != expected {(batch, num_kv_heads)}" + assert k_descale.shape == ( + batch, + num_kv_heads, + ), f"k_descale shape {k_descale.shape} != expected {(batch, num_kv_heads)}" + assert v_descale.shape == ( + batch, + num_kv_heads, + ), f"v_descale shape {v_descale.shape} != expected {(batch, num_kv_heads)}" + + # Derive softmax scale if not provided + if softmax_scale is None: + softmax_scale = head_dim ** (-0.5) + + # Validate unsupported features + if attention_chunk not in (0, 1): + raise NotImplementedError("attention_chunk > 1 not supported (0 or 1 only)") + if softcap != 0.0: + raise NotImplementedError( + "softcap not implemented in FP8 high-precision API" + ) + if sm_margin != 0: + raise NotImplementedError( + "sm_margin != 0 not supported in FP8 high-precision API" + ) + + # Call flash attention forward + out, softmax_lse = flash_attn_3.fwd( + q_fp8, + k_fp8, + v_fp8, + None, + None, + None, + None, # k_new, v_new, qv, out + None, + None, + None, # cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new + None, + None, + None, + None, # seqused_q, seqused_k, max_seqlen_q, max_seqlen_k + None, + None, + None, # page_table, kv_batch_idx, leftpad_k + None, + None, + None, # rotary_cos, rotary_sin, seqlens_rotary + q_descale, + k_descale, + v_descale, + softmax_scale, + causal, + int(window_size[0]), + int(window_size[1]), + attention_chunk, + softcap, + False, # rotary_interleaved + None, + 1, + None, + sm_margin, # scheduler_metadata, num_splits, pack_gqa, sm_margin + ) + + # Save tensors needed for backward + ctx.save_for_backward( + q_fp8, k_fp8, v_fp8, out, softmax_lse, q_descale, k_descale, v_descale + ) + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.deterministic = deterministic + ctx.sm_margin = sm_margin + ctx.input_dtype = q.dtype + + return out + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + """ + Compute gradients w.r.t. inputs. + The backward pass returns FP32 gradients, which we convert to the input dtype. + """ + # Retrieve saved tensors + q_fp8, k_fp8, v_fp8, out, softmax_lse, q_descale, k_descale, v_descale = ( + ctx.saved_tensors + ) + + # Call flash attention backward - returns FP32 gradients + dq, dk, dv, _delta = flash_attn_3.bwd( + grad_output, + q_fp8, + k_fp8, + v_fp8, + out, + softmax_lse, + None, + None, + None, # dq, dk, dv (will be allocated) + None, + None, # cu_seqlens_q, cu_seqlens_k + None, + None, + None, + None, # seqused_q, seqused_k, max_seqlen_q, max_seqlen_k + ctx.softmax_scale, + ctx.causal, + int(ctx.window_size[0]), + int(ctx.window_size[1]), + ctx.softcap, + ctx.deterministic, + ctx.sm_margin, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + + # Convert gradients to input dtype (FP32 -> BF16 if needed) + dq = dq.to(ctx.input_dtype) + dk = dk.to(ctx.input_dtype) + dv = dv.to(ctx.input_dtype) + + # Return gradients for all forward inputs (None for non-tensor inputs) + return ( + dq, # q + dk, # k + dv, # v + None, # softmax_scale + None, # causal + None, # window_size + None, # attention_chunk + None, # softcap + None, # deterministic + None, # sm_margin + ) + + +def flash_attn_fp8_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, + qv: Optional[torch.Tensor] = None, + window_size: Tuple[int, int] = (-1, -1), + attention_chunk: int = 0, + softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + deterministic: bool = False, + sm_margin: int = 0, +): + """ + FlashAttention v3 FP8 high-precision entry point. + + This function accepts high-precision (BF16/FP32) tensors and internally + quantizes them to FP8 for computation. The output and gradients remain + in high precision (FP32 for output, input dtype for gradients). + + This API is designed for seamless integration with existing training code + that uses BF16/FP32 tensors, providing FP8 acceleration without requiring + manual quantization. + + Args: + q: Query tensor [batch, seqlen, num_q_heads, head_dim] (BF16/FP32) + k: Key tensor [batch, seqlen, num_kv_heads, head_dim] (BF16/FP32) + v: Value tensor [batch, seqlen, num_kv_heads, head_dim] (BF16/FP32) + softmax_scale: Scaling factor for softmax (default: 1/sqrt(head_dim)) + causal: Whether to apply causal masking + qv: Extra query-value tensor (not yet supported in FP8 mode) + window_size: Sliding window attention size (left, right) + attention_chunk: Chunking parameter (0 or 1 only) + softcap: Softcapping value (not yet supported in FP8 mode) + num_splits: Number of splits for parallel processing (not yet supported in FP8 mode) + pack_gqa: GQA packing flag (not yet supported in FP8 mode) + deterministic: Whether to use deterministic backward + sm_margin: SM margin parameter (not yet supported in FP8 mode) + + Returns: + out: Output tensor [batch, seqlen, num_q_heads, head_dim] (FP32) + + Note: + - Supports GQA/MQA (num_q_heads != num_kv_heads) + - Automatically handles grouped quantization for GQA/MQA queries + - Gradients are computed in FP32 and converted to input dtype + - qv, softcap, num_splits, pack_gqa, and sm_margin are not yet supported in FP8 mode + """ + # Check that inputs are high precision (not already FP8) + assert q.dtype in [torch.float16, torch.bfloat16, torch.float32], ( + f"flash_attn_fp8_func expects high-precision inputs (fp16/bf16/fp32), got q.dtype={q.dtype}. " + f"If you already have FP8 tensors, use flash_attn_func() with q_descale/k_descale/v_descale parameters instead." + ) + assert k.dtype in [torch.float16, torch.bfloat16, torch.float32], ( + f"flash_attn_fp8_func expects high-precision inputs (fp16/bf16/fp32), got k.dtype={k.dtype}. " + f"If you already have FP8 tensors, use flash_attn_func() with q_descale/k_descale/v_descale parameters instead." + ) + assert v.dtype in [torch.float16, torch.bfloat16, torch.float32], ( + f"flash_attn_fp8_func expects high-precision inputs (fp16/bf16/fp32), got v.dtype={v.dtype}. " + f"If you already have FP8 tensors, use flash_attn_func() with q_descale/k_descale/v_descale parameters instead." + ) + + if qv is not None: + raise NotImplementedError("qv not supported in FP8 high-precision API") + if softcap != 0.0: + raise NotImplementedError("softcap not supported in FP8 high-precision API") + if num_splits != 1: + raise NotImplementedError( + "num_splits != 1 not supported in FP8 high-precision API" + ) + if pack_gqa is not None: + raise NotImplementedError("pack_gqa not supported in FP8 high-precision API") + if sm_margin != 0: + raise NotImplementedError( + "sm_margin != 0 not supported in FP8 high-precision API" + ) + + return _FlashAttnFP8Wrapper.apply( + q, + k, + v, + softmax_scale, + causal, + window_size, + attention_chunk, + softcap, + deterministic, + sm_margin, + ) + + +class _FlashAttnVarlenFP8Wrapper(torch.autograd.Function): + """ + FP8 Flash Attention varlen wrapper that maintains high-precision inputs/outputs. + + This wrapper allows users to pass BF16/FP32 tensors and automatically handles + the FP8 quantization internally for variable-length sequences, maintaining + backward compatibility with high-precision training workflows. + + Forward: BF16/FP32 -> FP8 -> flash_attn_varlen -> FP32 output + Backward: FP32 grad_out -> flash_attn_varlen_bwd -> FP32 grads -> input dtype grads + """ + + @staticmethod + def forward( + ctx, + q: torch.Tensor, # High precision (BF16/FP32) + k: torch.Tensor, # High precision (BF16/FP32) + v: torch.Tensor, # High precision (BF16/FP32) + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: Optional[float], + causal: bool, + window_size: Tuple[int, int], + attention_chunk: int, + softcap: float, + deterministic: bool, + sm_margin: int, + ): + # Determine heads and head_dim from input shapes + total_q = q.shape[0] + num_q_heads = q.shape[1] + head_dim = q.shape[2] + + total_k = k.shape[0] + num_kv_heads = k.shape[1] + + # Quantize inputs to FP8 using _quantize_thd for varlen tensors + fp8_dtype = torch.float8_e4m3fnuz + + # For GQA/MQA: quantize query with grouped scaling + group_size = ( + num_q_heads // num_kv_heads if num_q_heads != num_kv_heads else None + ) + q_fp8, q_descale = _quantize_thd( + q, fp8_dtype, cu_seqlens_q, group_size=group_size + ) + k_fp8, k_descale = _quantize_thd(k, fp8_dtype, cu_seqlens_k) + v_fp8, v_descale = _quantize_thd(v, fp8_dtype, cu_seqlens_k) + + # Verify descale shapes - _quantize_thd now returns shape [batch_size, num_heads] or [batch_size, num_groups] + batch_size = len(cu_seqlens_q) - 1 + assert q_descale.shape == ( + batch_size, + num_kv_heads, + ), f"q_descale shape {q_descale.shape} != expected {(batch_size, num_kv_heads)}" + assert k_descale.shape == ( + batch_size, + num_kv_heads, + ), f"k_descale shape {k_descale.shape} != expected {(batch_size, num_kv_heads)}" + assert v_descale.shape == ( + batch_size, + num_kv_heads, + ), f"v_descale shape {v_descale.shape} != expected {(batch_size, num_kv_heads)}" + + # Derive softmax scale if not provided + if softmax_scale is None: + softmax_scale = head_dim ** (-0.5) + + # Validate unsupported features + if attention_chunk != 0: + raise NotImplementedError( + "attention_chunk != 0 not supported in FP8 varlen high-precision API" + ) + if softcap != 0.0: + raise NotImplementedError( + "softcap not implemented in FP8 varlen high-precision API" + ) + if sm_margin != 0: + raise NotImplementedError( + "sm_margin != 0 not supported in FP8 varlen high-precision API" + ) + + # Call flash attention varlen forward + out, softmax_lse = flash_attn_3.fwd( + q_fp8, + k_fp8, + v_fp8, + None, # k_new + None, # v_new + None, # qv + None, # out tensor + cu_seqlens_q, + cu_seqlens_k, + None, # cu_seqlens_k_new + None, # seqused_q + None, # seqused_k + max_seqlen_q, + max_seqlen_k, + None, # page_table + None, # kv_batch_idx + None, # leftpad_k + None, # rotary_cos + None, # rotary_sin + None, # seqlens_rotary + q_descale, + k_descale, + v_descale, + softmax_scale, + causal, + int(window_size[0]), + int(window_size[1]), + attention_chunk, + softcap, + False, # rotary_interleaved + None, # scheduler_metadata + 1, # num_splits + None, # pack_gqa + sm_margin, + ) + + # Save tensors needed for backward + ctx.save_for_backward( + q_fp8, k_fp8, v_fp8, out, softmax_lse, q_descale, k_descale, v_descale + ) + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.deterministic = deterministic + ctx.sm_margin = sm_margin + ctx.input_dtype = q.dtype + ctx.cu_seqlens_q = cu_seqlens_q + ctx.cu_seqlens_k = cu_seqlens_k + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + + return out + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + """ + Compute gradients w.r.t. inputs. + The backward pass returns FP32 gradients, which we convert to the input dtype. + """ + # Retrieve saved tensors + q_fp8, k_fp8, v_fp8, out, softmax_lse, q_descale, k_descale, v_descale = ( + ctx.saved_tensors + ) + + # Call flash attention varlen backward - returns FP32 gradients + dq, dk, dv, _delta = flash_attn_3.bwd( + grad_output, + q_fp8, + k_fp8, + v_fp8, + out, + softmax_lse, + None, # dq + None, # dk + None, # dv + ctx.cu_seqlens_q, + ctx.cu_seqlens_k, + None, # seqused_q + None, # seqused_k + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.softmax_scale, + ctx.causal, + int(ctx.window_size[0]), + int(ctx.window_size[1]), + ctx.softcap, + ctx.deterministic, + ctx.sm_margin, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + + # Convert gradients to input dtype (FP32 -> BF16 if needed) + dq = dq.to(ctx.input_dtype) + dk = dk.to(ctx.input_dtype) + dv = dv.to(ctx.input_dtype) + + # Return gradients for all forward inputs (None for non-tensor inputs) + return ( + dq, # q + dk, # k + dv, # v + None, # cu_seqlens_q + None, # cu_seqlens_k + None, # max_seqlen_q + None, # max_seqlen_k + None, # softmax_scale + None, # causal + None, # window_size + None, # attention_chunk + None, # softcap + None, # deterministic + None, # sm_margin + ) + + +def flash_attn_varlen_fp8_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + attention_chunk: int = 0, + softcap: float = 0.0, + deterministic: bool = False, + sm_margin: int = 0, +): + """ + FlashAttention v3 FP8 varlen high-precision entry point. + + This function accepts high-precision (BF16/FP32) tensors and internally + quantizes them to FP8 for computation. The output and gradients remain + in high precision (FP32 for output, input dtype for gradients). + + This API is designed for seamless integration with existing training code + that uses BF16/FP32 tensors with variable-length sequences, providing + FP8 acceleration without requiring manual quantization. + + Args: + q: Query tensor [total_q, num_q_heads, head_dim] (BF16/FP32) + k: Key tensor [total_k, num_kv_heads, head_dim] (BF16/FP32) + v: Value tensor [total_k, num_kv_heads, head_dim] (BF16/FP32) + cu_seqlens_q: Cumulative sequence lengths for queries [batch_size + 1] + cu_seqlens_k: Cumulative sequence lengths for keys [batch_size + 1] + max_seqlen_q: Maximum query sequence length + max_seqlen_k: Maximum key sequence length + softmax_scale: Scaling factor for softmax (default: 1/sqrt(head_dim)) + causal: Whether to apply causal masking + window_size: Sliding window attention size (left, right) + attention_chunk: Chunking parameter (must be 0 in varlen FP8 mode) + softcap: Softcapping value (not yet supported in FP8 mode) + deterministic: Whether to use deterministic backward + sm_margin: SM margin parameter (not yet supported in FP8 mode) + + Returns: + out: Output tensor [total_q, num_q_heads, head_dim] (FP32) + + Note: + - Supports GQA/MQA (num_q_heads != num_kv_heads) + - Automatically handles grouped quantization for GQA/MQA queries + - Gradients are computed in FP32 and converted to input dtype + - attention_chunk, softcap, and sm_margin are not yet supported in varlen FP8 mode + """ + # Check that inputs are high precision (not already FP8) + assert q.dtype in [torch.float16, torch.bfloat16, torch.float32], ( + f"flash_attn_varlen_fp8_func expects high-precision inputs (fp16/bf16/fp32), got q.dtype={q.dtype}. " + f"If you already have FP8 tensors, use flash_attn_varlen_func() with q_descale/k_descale/v_descale parameters instead." + ) + assert k.dtype in [torch.float16, torch.bfloat16, torch.float32], ( + f"flash_attn_varlen_fp8_func expects high-precision inputs (fp16/bf16/fp32), got k.dtype={k.dtype}. " + f"If you already have FP8 tensors, use flash_attn_varlen_func() with q_descale/k_descale/v_descale parameters instead." + ) + assert v.dtype in [torch.float16, torch.bfloat16, torch.float32], ( + f"flash_attn_varlen_fp8_func expects high-precision inputs (fp16/bf16/fp32), got v.dtype={v.dtype}. " + f"If you already have FP8 tensors, use flash_attn_varlen_func() with q_descale/k_descale/v_descale parameters instead." + ) + + if attention_chunk != 0: + raise NotImplementedError( + "attention_chunk != 0 not supported in FP8 varlen high-precision API" + ) + if softcap != 0.0: + raise NotImplementedError( + "softcap not supported in FP8 varlen high-precision API" + ) + if sm_margin != 0: + raise NotImplementedError( + "sm_margin != 0 not supported in FP8 varlen high-precision API" + ) + + return _FlashAttnVarlenFP8Wrapper.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, + window_size, + attention_chunk, + softcap, + deterministic, + sm_margin, + ) diff --git a/op_tests/op_benchmarks/triton/bench_mha.py b/op_tests/op_benchmarks/triton/bench_mha.py index cfebe1b39a..d5e167ce6a 100644 --- a/op_tests/op_benchmarks/triton/bench_mha.py +++ b/op_tests/op_benchmarks/triton/bench_mha.py @@ -6,11 +6,13 @@ import triton from aiter.ops.triton.mha import ( flash_attn_func, - flash_attn_fp8_func, flash_attn_varlen_func, - flash_attn_varlen_fp8_func, mha_set_use_fused_bwd_kernel, ) +from aiter.ops.triton.mha_v3 import ( + flash_attn_fp8_func, + flash_attn_varlen_fp8_func, +) from aiter.test_mha_common import ( generate_random_padding_mask, generate_qkv, @@ -440,11 +442,8 @@ def fn(): cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p=dropout, softmax_scale=sm_scale, causal=causal, - return_lse=return_lse, - return_attn_probs=return_attn_probs, ) else: @@ -473,11 +472,8 @@ def fn(): q_input, k_input, v_input, - dropout_p=dropout, softmax_scale=sm_scale, causal=causal, - return_lse=return_lse, - return_attn_probs=return_attn_probs, ) else: diff --git a/op_tests/triton_tests/test_mha.py b/op_tests/triton_tests/test_mha.py index 8bae346de0..8d202efda3 100644 --- a/op_tests/triton_tests/test_mha.py +++ b/op_tests/triton_tests/test_mha.py @@ -7,12 +7,14 @@ import numpy as np from aiter.ops.triton.mha import ( flash_attn_func, - flash_attn_fp8_func, flash_attn_varlen_func, - flash_attn_varlen_fp8_func, mha_set_use_fused_bwd_kernel, mha_set_use_int64_strides, ) +from aiter.ops.triton.mha_v3 import ( + flash_attn_fp8_func, + flash_attn_varlen_fp8_func, +) from aiter.test_mha_common import ( attention_ref, generate_random_padding_mask, @@ -22,7 +24,7 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) DEBUG_MODE = False -ATOL_fp8 = 2.5e-1 +ATOL_fp8 = 3.0e-1 RTOL_fp8 = 2.5e-1 @@ -133,14 +135,16 @@ def test_mha( dropout_mask = None if FP8: + if DROPOUT > 0.0 or RETURN_LSE or RETURN_SOFTMAX: + pytest.skip( + "FP8 mode does not support dropout_p, return_lse, or return_attn_probs" + ) + triton_out = flash_attn_fp8_func( q, k, v, - dropout_p=DROPOUT, causal=CAUSAL, - return_lse=RETURN_LSE, - return_attn_probs=RETURN_SOFTMAX, ) else: triton_out = flash_attn_func( @@ -371,6 +375,11 @@ def test_mha_varlen( print(f"cu_seqlens_q={cu_seqlens_q }") print(f"cu_seqlens_k={cu_seqlens_k }") if FP8: + if DROPOUT > 0.0 or RETURN_LSE or RETURN_SOFTMAX: + pytest.skip( + "FP8 varlen mode does not support dropout_p, return_lse, or return_attn_probs" + ) + triton_out = flash_attn_varlen_fp8_func( q_unpad, k_unpad, @@ -379,10 +388,7 @@ def test_mha_varlen( cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p=DROPOUT, causal=CAUSAL, - return_lse=RETURN_LSE, - return_attn_probs=RETURN_SOFTMAX, ) else: triton_out = flash_attn_varlen_func( @@ -456,8 +462,8 @@ def test_mha_varlen( ) if FP8: - fp8_assert_close( - triton_out, torch_out.to(torch_out.dtype), atol=ATOL_fp8, rtol=RTOL_fp8 + torch.testing.assert_close( + triton_out, torch_out.to(triton_out.dtype), atol=ATOL_fp8, rtol=RTOL_fp8 ) else: torch.testing.assert_close( @@ -517,15 +523,15 @@ def test_mha_backward( with torch.enable_grad(): if FP8: + if DROPOUT > 0.0: + pytest.skip("FP8 does not support dropout_p") triton_out = flash_attn_fp8_func( q, k, v, - dropout_p=DROPOUT, causal=CAUSAL, - return_lse=True, - return_attn_probs=True, ) + lse, sd_mask = None, None else: triton_out = flash_attn_func( q, @@ -537,8 +543,8 @@ def test_mha_backward( return_attn_probs=True, ) - assert len(triton_out) == 3 - triton_out, lse, sd_mask = triton_out[0], triton_out[1], triton_out[2] + assert len(triton_out) == 3 + triton_out, lse, sd_mask = triton_out[0], triton_out[1], triton_out[2] if DROPOUT > 0.0: dropout_mask = sd_mask >= 0