diff --git a/aiter/ops/triton/_triton_kernels/mha.py b/aiter/ops/triton/_triton_kernels/mha.py index 9fdd6cbdc0..45e50ce0e3 100644 --- a/aiter/ops/triton/_triton_kernels/mha.py +++ b/aiter/ops/triton/_triton_kernels/mha.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -from typing import Optional, Tuple import functools import json import torch @@ -10,9 +9,8 @@ from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH -from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd +from ..utils._triton.pid_preprocessing import remap_xcd from ..utils._triton.mha_kernel_utils import _compute_fp8_scaling_factors -from ..utils.device_info import get_num_xcds @triton.jit @@ -79,7 +77,9 @@ def _attn_fwd_inner( l_i, m_i, q, + q_pe, k_ptrs, + k_pe_ptrs, v_ptrs, stride_kn, stride_vk, @@ -107,6 +107,7 @@ def _attn_fwd_inner( BLOCK_N: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL_POW2: tl.constexpr, + BLOCK_DMODEL_PE: tl.constexpr, # it's zero or a power of 2 SM_SCALE: tl.constexpr, IS_CAUSAL: tl.constexpr, MASK_STEPS: tl.constexpr, @@ -115,12 +116,17 @@ def _attn_fwd_inner( PADDED_HEAD: tl.constexpr, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, + ENABLE_PIPELINING: tl.constexpr, ): RCP_LN2: tl.constexpr = 1.4426950408889634 + HAS_PE: tl.constexpr = BLOCK_DMODEL_PE > 0 # loop over k, v, and update accumulator - for start_n in range(block_min, block_max, BLOCK_N): + num_stages: tl.constexpr = ( + None if ENABLE_PIPELINING else 1 + ) # Set num_stages==1 if we want to disable pipelining + for start_n in tl.range(block_min, block_max, BLOCK_N, num_stages=num_stages): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. if MASK_STEPS: @@ -129,6 +135,14 @@ def _attn_fwd_inner( k_offs_n = None k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL_POW2) k = _load_fn(k_ptrs, k_offs_k, k_offs_n, BLOCK_DMODEL, seqlen_k) + if HAS_PE: + k_pe = _load_fn( + k_pe_ptrs, + None, + k_offs_n, + (BLOCK_DMODEL + BLOCK_DMODEL_PE), + seqlen_k, + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need @@ -163,11 +177,13 @@ def _attn_fwd_inner( qk += tl.dot(q, k) * descale_q * descale_k else: qk += tl.dot(q, k) + if HAS_PE: + qk += tl.dot(q_pe, k_pe) if IS_CAUSAL: causal_boundary = start_n + offs_n_causal causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] - mask = mask and causal_mask + mask = mask & causal_mask qk = tl.where(mask, qk, float("-inf")) @@ -229,6 +245,8 @@ def _attn_fwd_inner( acc += tl.dot(p.to(v.type.element_ty), v) k_ptrs += BLOCK_N * stride_kn + if HAS_PE: + k_pe_ptrs += BLOCK_N * stride_kn v_ptrs += BLOCK_N * stride_vk if RETURN_SCORES: sd_mask_ptrs += BLOCK_N * stride_sn @@ -287,8 +305,8 @@ def _attn_fwd( dropout_p, philox_seed, philox_offset_base_in, - SEQLEN_Q: tl.constexpr, - SEQLEN_K: tl.constexpr, + SEQLEN_Q, + SEQLEN_K, IS_CAUSAL: tl.constexpr, NUM_Q_HEADS: tl.constexpr, NUM_K_HEADS: tl.constexpr, @@ -296,6 +314,7 @@ def _attn_fwd( BLOCK_N: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL_POW2: tl.constexpr, + BLOCK_DMODEL_PE: tl.constexpr, # it's zero or a power of 2 RETURN_SCORES: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, IS_FP8: tl.constexpr, @@ -321,6 +340,9 @@ def _attn_fwd( offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL_POW2) + HAS_PE: tl.constexpr = BLOCK_DMODEL_PE > 0 + if HAS_PE: + offs_pe = BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL_PE) # NOTE: # Workaround for int64 strides, In the absence of strides being int64, parts of the offset @@ -395,6 +417,38 @@ def _attn_fwd( stride_lse_h = stride_lse_h_in stride_lse_m = stride_lse_m_in + tl.assume(stride_qz_in >= 0) + tl.assume(stride_qh_in >= 0) + tl.assume(stride_qm_in >= 0) + tl.assume(stride_qk_in >= 0) + tl.assume(stride_kz_in >= 0) + tl.assume(stride_kh_in >= 0) + tl.assume(stride_kn_in >= 0) + tl.assume(stride_kk_in >= 0) + tl.assume(stride_vz_in >= 0) + tl.assume(stride_vh_in >= 0) + tl.assume(stride_vn_in >= 0) + tl.assume(stride_vk_in >= 0) + if IS_FP8: + tl.assume(stride_descale_q_z_in >= 0) + tl.assume(stride_descale_k_z_in >= 0) + tl.assume(stride_descale_v_z_in >= 0) + tl.assume(stride_oz_in >= 0) + tl.assume(stride_oh_in >= 0) + tl.assume(stride_om_in >= 0) + tl.assume(stride_on_in >= 0) + tl.assume(stride_alibi_z_in >= 0) + tl.assume(stride_alibi_h_in >= 0) + # NOTE: philox offset is need in dropout pointer calculations + tl.assume(philox_offset_base_in >= 0) + tl.assume(stride_sd_z_in >= 0) + tl.assume(stride_sd_h_in >= 0) + tl.assume(stride_sd_m_in >= 0) + tl.assume(stride_sd_n_in >= 0) + tl.assume(stride_lse_z_in >= 0) + tl.assume(stride_lse_h_in >= 0) + tl.assume(stride_lse_m_in >= 0) + if VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) @@ -479,6 +533,17 @@ def _attn_fwd( + offs_d[None, :] * stride_qk ) q_ptrs = q_ptr + q_offs + if HAS_PE: + q_pe_offs = ( + off_z * stride_qz + + off_q_head * stride_qh + + cu_seqlens_q_start * stride_qm + + offs_m[:, None] * stride_qm + + offs_pe[None, :] * stride_qk + ) + q_pe_ptrs = q_ptr + q_pe_offs + else: + q_pe_ptrs = None k_offs = ( off_z * stride_kz @@ -488,6 +553,17 @@ def _attn_fwd( + offs_n[None, :] * stride_kn ) k_ptrs = k_ptr + k_offs + if HAS_PE: + k_pe_offs = ( + off_z * stride_kz + + off_k_head * stride_kh + + cu_seqlens_k_start * stride_kn + + offs_pe[:, None] * stride_kk + + offs_n[None, :] * stride_kn + ) + k_pe_ptrs = k_ptr + k_pe_offs + else: + k_pe_ptrs = None v_offs = ( off_z * stride_vz @@ -545,6 +621,11 @@ def _attn_fwd( else: q_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < BLOCK_DMODEL) q = tl.load(q_ptrs, mask=q_mask, other=0.0) + if HAS_PE: + q_pe = tl.load(q_pe_ptrs, mask=q_mask, other=0.0) + else: + q_pe = None + if IS_FP8: descale_q = tl.load(descale_q_ptr + off_z * stride_descale_q_z + off_q_head) descale_k = tl.load(descale_k_ptr + off_z * stride_descale_k_z + off_k_head) @@ -584,7 +665,9 @@ def _attn_fwd( l_i, m_i, q, + q_pe, k_ptrs, + k_pe_ptrs, v_ptrs, stride_kn, stride_vn, @@ -612,6 +695,7 @@ def _attn_fwd( BLOCK_N, BLOCK_DMODEL, BLOCK_DMODEL_POW2, + BLOCK_DMODEL_PE, sm_scale, False, MASK_STEPS=False, @@ -620,6 +704,7 @@ def _attn_fwd( PADDED_HEAD=BLOCK_DMODEL != BLOCK_DMODEL_POW2, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, + ENABLE_PIPELINING=True, ) block_min = block_max block_max = n_blocks * BLOCK_N @@ -631,6 +716,8 @@ def _attn_fwd( else: offs_n_causal = 0 k_ptrs += n_full_blocks * BLOCK_N * stride_kn + if HAS_PE: + k_pe_ptrs += n_full_blocks * BLOCK_N * stride_kn v_ptrs += n_full_blocks * BLOCK_N * stride_vn if RETURN_SCORES: s_dmask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n @@ -641,7 +728,9 @@ def _attn_fwd( l_i, m_i, q, + q_pe, k_ptrs, + k_pe_ptrs, v_ptrs, stride_kn, stride_vn, @@ -669,6 +758,7 @@ def _attn_fwd( BLOCK_N, BLOCK_DMODEL, BLOCK_DMODEL_POW2, + BLOCK_DMODEL_PE, sm_scale, IS_CAUSAL, MASK_STEPS=True, @@ -677,6 +767,7 @@ def _attn_fwd( PADDED_HEAD=BLOCK_DMODEL != BLOCK_DMODEL_POW2, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, + ENABLE_PIPELINING=False, ) # epilogue # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. @@ -759,6 +850,7 @@ def _attn_fwd( def _get_config( enable_dropout: bool, dtype: torch.dtype, + has_pe: bool = False, ): if not hasattr(_get_config, "_config_dict"): dev = arch_info.get_device() @@ -768,7 +860,9 @@ def _get_config( config = json.load(file) _get_config._config_dict["default"] = config - if enable_dropout or dtype == torch.float32: + if has_pe and "pe" in _get_config._config_dict["default"]["fwd"]: + return _get_config._config_dict["default"]["fwd"]["pe"] + elif enable_dropout or dtype == torch.float32: return _get_config._config_dict["default"]["fwd"]["dropout_or_fp32"] else: return _get_config._config_dict["default"]["fwd"]["default"] diff --git a/aiter/ops/triton/_triton_kernels/mha_onekernel_bwd.py b/aiter/ops/triton/_triton_kernels/mha_onekernel_bwd.py index 9808636501..f6e8870349 100644 --- a/aiter/ops/triton/_triton_kernels/mha_onekernel_bwd.py +++ b/aiter/ops/triton/_triton_kernels/mha_onekernel_bwd.py @@ -1,15 +1,12 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -from typing import Optional, Dict import functools import json -import torch import triton # type: ignore import triton.language as tl # type: ignore from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH -from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils._triton.mha_kernel_utils import _compute_fp8_scaling_factors @@ -108,10 +105,12 @@ def _bwd_preprocess( # The main inner-loop logic for computing dK and dV. @triton.jit def _bwd_dkdv_inner( - dk, + dk, # output + dk_pe, # optional output, pass None for non-PE case dv, # output Q, k, + k_pe, v, DO, M, @@ -128,6 +127,7 @@ def _bwd_dkdv_inner( BLOCK_N: tl.constexpr, # 128 HEAD_DIM: tl.constexpr, # ACTUAL_HEAD_DIM: tl.constexpr, # + PE_HEAD_DIM: tl.constexpr, dropout_p, philox_seed, batch_philox_offset, @@ -154,15 +154,20 @@ def _bwd_dkdv_inner( ): # if HEAD_DIM is padded PADDED_HEAD: tl.constexpr = ACTUAL_HEAD_DIM != HEAD_DIM + HAS_PE: tl.constexpr = PE_HEAD_DIM > 0 delta_qk = seqlen_q - seqlen_k offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) offs_k = tl.arange(0, HEAD_DIM) + if HAS_PE: + offs_k_pe = HEAD_DIM + tl.arange(0, PE_HEAD_DIM) # mask to make sure not OOB of seqlen_q mask_n = offs_n < seqlen_k # Q and DO are (seqlen_q, head_dim) # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM, 1), transpose of q qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k[:, None] * stride_qk + if HAS_PE: + qT_pe_ptrs = Q + offs_m[None, :] * stride_qm + offs_k_pe[:, None] * stride_qk # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM), NOT transposed do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. @@ -186,6 +191,8 @@ def _bwd_dkdv_inner( mask_qT &= offs_k[:, None] < ACTUAL_HEAD_DIM mask_do &= offs_k[None, :] < ACTUAL_HEAD_DIM qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + if HAS_PE: + qT_pe = tl.load(qT_pe_ptrs, mask=mask_qT, other=0.0) # generate dropout mask if ENABLE_DROPOUT: # NOTE: dropout is transposed because it is used to mask pT @@ -210,6 +217,8 @@ def _bwd_dkdv_inner( qkT = tl.dot(k, qT) * descale_q * descale_k else: qkT = tl.dot(k, qT) + if HAS_PE: + qkT += tl.dot(k_pe, qT_pe) qkT_scaled = qkT * sm_scale if USE_ALIBI: @@ -291,18 +300,24 @@ def _bwd_dkdv_inner( ) else: dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + if HAS_PE: + dk_pe += tl.dot(dsT.to(qT_pe.type.element_ty), tl.trans(qT_pe)) # Increment pointers. curr_m += step_m qT_ptrs += step_m * stride_qm + if HAS_PE: + qT_pe_ptrs += step_m * stride_qm do_ptrs += step_m * stride_dom - return dk, dv + return dk, dk_pe, dv # the main inner-loop logic for computing dQ @triton.jit def _bwd_dq_inner( dq, # output + dq_pe, # optional output, pass None for non-PE case q, + q_pe, K, V, do, @@ -325,6 +340,7 @@ def _bwd_dq_inner( BLOCK_N2: tl.constexpr, # HEAD_DIM: tl.constexpr, ACTUAL_HEAD_DIM: tl.constexpr, # + PE_HEAD_DIM: tl.constexpr, dropout_p, philox_seed, batch_philox_offset, @@ -350,15 +366,20 @@ def _bwd_dq_inner( ): # if HEAD_DIM is padded PADDED_HEAD: tl.constexpr = ACTUAL_HEAD_DIM != HEAD_DIM + HAS_PE: tl.constexpr = PE_HEAD_DIM > 0 delta_qk = seqlen_q - seqlen_k offs_m = start_m + tl.arange(0, BLOCK_M2) offs_n = start_n + tl.arange(0, BLOCK_N2) offs_k = tl.arange(0, HEAD_DIM) + if HAS_PE: + offs_k_pe = HEAD_DIM + tl.arange(0, PE_HEAD_DIM) # mask to make sure not OOB of seqlen_q mask_m = offs_m < seqlen_q kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk + if HAS_PE: + kT_pe_ptrs = K + offs_n[None, :] * stride_kn + offs_k_pe[:, None] * stride_kk vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk # D (= delta) is pre-divided by ds_scale. Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) @@ -388,6 +409,8 @@ def _bwd_dq_inner( mask_kT &= offs_k[:, None] < ACTUAL_HEAD_DIM kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) + if HAS_PE: + kT_pe = tl.load(kT_pe_ptrs, mask=mask_kT, other=0.0) vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) if ENABLE_DROPOUT: @@ -412,6 +435,8 @@ def _bwd_dq_inner( qk = tl.dot(q, kT) * descale_q * descale_k else: qk = tl.dot(q, kT) + if HAS_PE: + qk += tl.dot(q_pe, kT_pe) qk_scaled = qk * sm_scale if USE_ALIBI: @@ -451,11 +476,15 @@ def _bwd_dq_inner( ) else: dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) + if HAS_PE: + dq_pe += tl.dot(ds.to(kT_pe.type.element_ty), tl.trans(kT_pe)) # Increment pointers. curr_n += step_n kT_ptrs += step_n * stride_kn + if HAS_PE: + kT_pe_ptrs += step_n * stride_kn vT_ptrs += step_n * stride_vn - return dq + return dq, dq_pe @triton.jit @@ -533,6 +562,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea BLK_SLICE_FACTOR: tl.constexpr, HEAD_DIM: tl.constexpr, ACTUAL_HEAD_DIM: tl.constexpr, + PE_HEAD_DIM: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, IS_VARLEN: tl.constexpr, USE_ALIBI: tl.constexpr, @@ -656,7 +686,10 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") # noqa: E701 PADDED_HEAD: tl.constexpr = ACTUAL_HEAD_DIM != HEAD_DIM + HAS_PE: tl.constexpr = PE_HEAD_DIM > 0 offs_d = tl.arange(0, HEAD_DIM) + if HAS_PE: + offs_d_pe = HEAD_DIM + tl.arange(0, PE_HEAD_DIM) GROUP_SIZE: tl.constexpr = HQ // HK # align the delta_qk @@ -664,6 +697,11 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea if start_n < seqlen_k: # This section does dk and dv dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + if HAS_PE: + dk_pe = tl.zeros([BLOCK_N1, PE_HEAD_DIM], dtype=tl.float32) + else: + # Couldn't assign None to dk_pe because _bwd_dkdv_inner can't return None. + dk_pe = dk dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) # q > k: diretcly skip all the way until the start of causal block @@ -703,6 +741,14 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd ) + if HAS_PE: + adj_k_pe = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_d_pe[None, :] * stride_kd + ) adj_v = ( bid * stride_vb + hkid * stride_vh @@ -712,6 +758,10 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea ) # load K and V: they stay in SRAM throughout the inner loop. k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + if HAS_PE: + k_pe = tl.load(K + adj_k_pe, mask=mask_kv, other=0.0) + else: + k_pe = None v = tl.load(V + adj_v, mask=mask_kv, other=0.0) # If MQA / GQA, set the K and V head offsets appropriately. # hqid = hkid @@ -781,11 +831,13 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea print( f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}" ) # noqa: E701 - dk, dv = _bwd_dkdv_inner( - dk, - dv, # output tensors + dk, dk_pe, dv = _bwd_dkdv_inner( + dk, # output tensor + dk_pe, # optional output tensor + dv, # output tensor Q_ptr, k, + k_pe, v, DO_ptr, M_ptr, @@ -802,6 +854,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea BLOCK_N1, # block dim HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + PE_HEAD_DIM, dropout_p, philox_seed, batch_philox_offset, @@ -839,11 +892,13 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea ) # noqa: E701 if DEBUG_TRITON: print("unMasked") # noqa: E701 - dk, dv = _bwd_dkdv_inner( - dk, - dv, # output tensors + dk, dk_pe, dv = _bwd_dkdv_inner( + dk, # output tensor + dk_pe, # optional output tensor + dv, # output tensor Q_ptr, k, + k_pe, v, DO_ptr, M_ptr, @@ -860,6 +915,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea BLOCK_N1, # block dim HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + PE_HEAD_DIM, dropout_p, philox_seed, batch_philox_offset, @@ -893,6 +949,10 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd dk *= sm_scale tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) + if HAS_PE: + offs_dk_pe = offs_n[:, None] * stride_dkn + offs_d_pe[None, :] * stride_dkd + dk_pe *= sm_scale + tl.store(DK + adj_dk + offs_dk_pe, dk_pe, mask=mask_kv) # This part does dq start_m = pid * BLOCK_M2 @@ -916,6 +976,8 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea mask_d = offs_d < ACTUAL_HEAD_DIM mask_q &= mask_d[None, :] offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + if HAS_PE: + offs_q_pe = offs_m[:, None] * stride_qm + offs_d_pe[None, :] * stride_qd offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod # NOTE: don't assume that the strides for k and v are the same! K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn @@ -956,6 +1018,10 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth ) q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + if HAS_PE: + q_pe = tl.load(Q + adj_q + offs_q_pe, mask=mask_q, other=0.0) + else: + q_pe = None do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q) m = m[:, None] @@ -974,9 +1040,15 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) - dq = _bwd_dq_inner( - dq, + if HAS_PE: + dq_pe = tl.zeros([BLOCK_M2, PE_HEAD_DIM], dtype=tl.float32) + else: + dq_pe = dq # Couldn't assign None to dq_pe because _bwd_dq_inner can't return None. + dq, dq_pe = _bwd_dq_inner( + dq, # output tensor + dq_pe, # optional output tensor q, + q_pe, K, V, do, @@ -998,6 +1070,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea MASK_BLOCK_N2, HEAD_DIM, ACTUAL_HEAD_DIM, + PE_HEAD_DIM, dropout_p, philox_seed, batch_philox_offset, @@ -1027,9 +1100,11 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea print( f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}" ) # noqa: E701 - dq = _bwd_dq_inner( - dq, + dq, dq_pe = _bwd_dq_inner( + dq, # output tensor + dq_pe, # optional output tensor q, + q_pe, K, V, do, @@ -1051,6 +1126,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea BLOCK_N2, HEAD_DIM, ACTUAL_HEAD_DIM, + PE_HEAD_DIM, dropout_p, philox_seed, batch_philox_offset, @@ -1078,6 +1154,12 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd dq *= sm_scale tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + if HAS_PE: + offs_dq_pe = ( + offs_m[:, None] * stride_dqm + offs_d_pe[None, :] * stride_dqd + ) + dq_pe *= sm_scale + tl.store(DQ + adj_dq + offs_dq_pe, dq_pe, mask=mask_q) # end of GQA/MQA of dq @@ -1156,6 +1238,7 @@ def bwd_kernel_noncausal( BLK_SLICE_FACTOR: tl.constexpr, HEAD_DIM: tl.constexpr, ACTUAL_HEAD_DIM: tl.constexpr, + PE_HEAD_DIM: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, IS_VARLEN: tl.constexpr, USE_ALIBI: tl.constexpr, @@ -1276,12 +1359,20 @@ def bwd_kernel_noncausal( seqlen_k = k_end - k_start PADDED_HEAD: tl.constexpr = ACTUAL_HEAD_DIM != HEAD_DIM + HAS_PE: tl.constexpr = PE_HEAD_DIM > 0 offs_d = tl.arange(0, HEAD_DIM) + if HAS_PE: + offs_d_pe = HEAD_DIM + tl.arange(0, PE_HEAD_DIM) GROUP_SIZE: tl.constexpr = HQ // HK start_n = pid * BLOCK_N1 if start_n < seqlen_k: dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + if HAS_PE: + dk_pe = tl.zeros([BLOCK_N1, PE_HEAD_DIM], dtype=tl.float32) + else: + # Couldn't assign None to dk_pe because _bwd_dkdv_inner can't return None. + dk_pe = dk dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) offs_n = start_n + tl.arange(0, BLOCK_N1) @@ -1299,6 +1390,14 @@ def bwd_kernel_noncausal( + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd ) + if HAS_PE: + adj_k_pe = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_d_pe[None, :] * stride_kd + ) adj_v = ( bid * stride_vb + hkid * stride_vh @@ -1308,6 +1407,10 @@ def bwd_kernel_noncausal( ) # load K and V: they stay in SRAM throughout the inner loop. k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + if HAS_PE: + k_pe = tl.load(K + adj_k_pe, mask=mask_kv, other=0.0) + else: + k_pe = None v = tl.load(V + adj_v, mask=mask_kv, other=0.0) # If MQA / GQA, set the K and V head offsets appropriately. for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): @@ -1351,11 +1454,13 @@ def bwd_kernel_noncausal( # because there is no causal, we always start from the beginning start_m = 0 num_steps = tl.cdiv(seqlen_q, BLOCK_M1) - dk, dv = _bwd_dkdv_inner( - dk, - dv, # output tensors + dk, dk_pe, dv = _bwd_dkdv_inner( + dk, # output tensor + dk_pe, # optional output tensor + dv, # output tensor Q_ptr, k, + k_pe, v, DO_ptr, M_ptr, @@ -1372,6 +1477,7 @@ def bwd_kernel_noncausal( BLOCK_N1, # block dim HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + PE_HEAD_DIM, dropout_p, philox_seed, batch_philox_offset, @@ -1405,6 +1511,10 @@ def bwd_kernel_noncausal( offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd dk *= sm_scale tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) + if HAS_PE: + offs_dk_pe = offs_n[:, None] * stride_dkn + offs_d_pe[None, :] * stride_dkd + dk_pe *= sm_scale + tl.store(DK + adj_dk + offs_dk_pe, dk_pe, mask=mask_kv) # THIS PART DOES DQ start_m = pid * BLOCK_M2 @@ -1416,6 +1526,8 @@ def bwd_kernel_noncausal( mask_d = offs_d < ACTUAL_HEAD_DIM mask_q &= mask_d[None, :] offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + if HAS_PE: + offs_q_pe = offs_m[:, None] * stride_qm + offs_d_pe[None, :] * stride_qd offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn @@ -1448,6 +1560,10 @@ def bwd_kernel_noncausal( ) q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + if HAS_PE: + q_pe = tl.load(Q + adj_q + offs_q_pe, mask=mask_q, other=0.0) + else: + q_pe = None do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q) m = m[:, None] @@ -1466,9 +1582,15 @@ def bwd_kernel_noncausal( num_steps = tl.cdiv(seqlen_k, BLOCK_N2) dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) - dq = _bwd_dq_inner( - dq, + if HAS_PE: + dq_pe = tl.zeros([BLOCK_M2, PE_HEAD_DIM], dtype=tl.float32) + else: + dq_pe = dq # Couldn't assign None to dq_pe because _bwd_dq_inner can't return None. + dq, dq_pe = _bwd_dq_inner( + dq, # output tensor + dq_pe, # optional output tensor q, + q_pe, K, V, do, @@ -1490,6 +1612,7 @@ def bwd_kernel_noncausal( BLOCK_N2, HEAD_DIM, ACTUAL_HEAD_DIM, + PE_HEAD_DIM, dropout_p, philox_seed, batch_philox_offset, @@ -1517,6 +1640,12 @@ def bwd_kernel_noncausal( offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd dq *= sm_scale tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + if HAS_PE: + offs_dq_pe = ( + offs_m[:, None] * stride_dqm + offs_d_pe[None, :] * stride_dqd + ) + dq_pe *= sm_scale + tl.store(DQ + adj_dq + offs_dq_pe, dq_pe, mask=mask_q) @functools.lru_cache(maxsize=1024) diff --git a/aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json b/aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json index de5c09cd1e..a38732610e 100644 --- a/aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json +++ b/aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json @@ -15,6 +15,14 @@ "num_warps": 4, "num_ctas": 1, "num_stages": 1 + }, + "pe": { + "BLOCK_M": 256, + "BLOCK_N": 64, + "waves_per_eu": 1, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 1 } }, "bkwd_fused" : { @@ -55,6 +63,18 @@ "num_warps": 4, "num_ctas": 1, "num_stages": 1 + }, + "onekernel_pe" : { + "BLOCK_M1": 32, + "BLOCK_N1": 64, + "BLOCK_M2": 128, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 2 } } } diff --git a/aiter/ops/triton/configs/MI350X-MHA-DEFAULT.json b/aiter/ops/triton/configs/MI350X-MHA-DEFAULT.json index 9cc497755b..d853fa8159 100644 --- a/aiter/ops/triton/configs/MI350X-MHA-DEFAULT.json +++ b/aiter/ops/triton/configs/MI350X-MHA-DEFAULT.json @@ -15,6 +15,14 @@ "num_warps": 4, "num_ctas": 1, "num_stages": 1 + }, + "pe": { + "BLOCK_M": 256, + "BLOCK_N": 64, + "waves_per_eu": 2, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 4 } }, "bkwd_fused" : { diff --git a/aiter/ops/triton/mha.py b/aiter/ops/triton/mha.py index ca2d75d586..43248c0ed2 100644 --- a/aiter/ops/triton/mha.py +++ b/aiter/ops/triton/mha.py @@ -7,12 +7,8 @@ import triton.language as tl import aiter.ops.triton.utils.types as types -import aiter.ops.triton.utils._triton.arch_info as arch_info -from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH -from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid, remap_xcd from aiter.ops.triton.mha_onekernel_bwd import flash_attn_onekernel_backward from aiter.ops.triton.mha_fused_bwd import flash_attn_fused_backward -from aiter.ops.triton.utils.mha_kernel_utils import _compute_fp8_scaling_factors from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton.utils.device_info import get_num_xcds from aiter.ops.triton._triton_kernels.mha import _attn_fwd, _get_config @@ -168,35 +164,50 @@ def _flash_attn_forward( is_varlen = True if cu_seqlens_q is not None else False if IS_FP8: - o = torch.zeros_like(q, dtype=torch.float32) + o = torch.zeros( + (q.shape[:-1] + v.shape[-1:]), dtype=torch.float32, device=q.device + ) else: - o = torch.zeros_like(q) + o = torch.zeros((q.shape[:-1] + v.shape[-1:]), dtype=q.dtype, device=q.device) if is_varlen: - # Layout for q,k,v is thd ie [total_tokens, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = ( + # Layout is thd. + # q and k are [total_tokens, num_head, head_dim_qk]. + # v is [total_tokens, num_head, head_dim_v]. + batch, seqlen_q, num_q_heads = ( len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], - q.shape[2], ) - seqlen_k, num_k_heads = max_seqlen_k, k.shape[1] + num_k_heads = k.shape[1] q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) else: - # Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = q.shape - seqlen_k = k.shape[1] + # Layout is bshd. + # q and k are [batch, seq_len, num_head, head_dim_qk]. + # v is [batch, seq_len, num_head, head_dim_v]. + batch, seqlen_q, num_q_heads = q.shape[:-1] num_k_heads = k.shape[2] q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + qk_head_dim = q.shape[-1] + v_head_dim = v.shape[-1] + pe_head_dim = qk_head_dim - v_head_dim # padding for head_dim. Power of 2 or 16 - BLOCK_DMODEL_POW2 = triton.next_power_of_2(head_sz) - BLOCK_DMODEL_POW2 = max(BLOCK_DMODEL_POW2, 16) + BLOCK_DMODEL_POW2 = max(triton.next_power_of_2(v_head_dim), 16) + BLOCK_DMODEL_PE_POW2 = ( + 0 if pe_head_dim == 0 else max(triton.next_power_of_2(pe_head_dim), 16) + ) + assert (pe_head_dim == 0 and BLOCK_DMODEL_PE_POW2 == 0) or ( + v_head_dim == BLOCK_DMODEL_POW2 and pe_head_dim == BLOCK_DMODEL_PE_POW2 + ), "Positional encoding support requires NOPE and PE head sizes to be unpadded powers of 2." + assert (not IS_FP8) or ( + IS_FP8 and pe_head_dim == 0 + ), "Positional encoding doesn't support FP8." # softmax_lse [batch, num_q_heads, seqlen_q] if is_varlen: @@ -242,7 +253,7 @@ def _flash_attn_forward( dropout_mask = None if config is None: - config = _get_config(enable_dropout, q.dtype) + config = _get_config(enable_dropout, q.dtype, has_pe=pe_head_dim > 0) """ # Tuned for MI300x @@ -309,8 +320,9 @@ def _flash_attn_forward( IS_CAUSAL=causal, NUM_Q_HEADS=num_q_heads, NUM_K_HEADS=num_k_heads, - BLOCK_DMODEL=head_sz, + BLOCK_DMODEL=v_head_dim, BLOCK_DMODEL_POW2=BLOCK_DMODEL_POW2, + BLOCK_DMODEL_PE=pe_head_dim, RETURN_SCORES=return_softmax, ENABLE_DROPOUT=enable_dropout, IS_FP8=IS_FP8, diff --git a/aiter/ops/triton/mha_fused_bwd.py b/aiter/ops/triton/mha_fused_bwd.py index 518b300158..bc45c3ccbc 100644 --- a/aiter/ops/triton/mha_fused_bwd.py +++ b/aiter/ops/triton/mha_fused_bwd.py @@ -4,7 +4,6 @@ from typing import Optional, Dict import torch import triton -import triton.language as tl from aiter.ops.triton.utils.types import _is_fp8 from aiter.ops.triton.utils.logger import AiterTritonLogger @@ -14,6 +13,7 @@ _bwd_kernel_dkdvdq_noncausal, _get_config, ) +from aiter.ops.triton.utils.device_info import get_num_xcds _LOGGER = AiterTritonLogger() @@ -53,6 +53,10 @@ def flash_attn_fused_backward( ) if dbias is not None: raise ValueError("Bias is not supported yet in the Triton Backend") + if q.shape[-1] == k.shape[-1] and k.shape[-1] > v.shape[-1]: + raise ValueError( + "'Fused' backward doesn't support Positional Encoding (PE). Please use 'one kernel' backward implementation for PE." + ) IS_FP8 = _is_fp8(q) if IS_FP8: @@ -268,7 +272,6 @@ def flash_attn_fused_backward( FP8_MAX=FP8_MAX, NUM_SMS=NUM_SMS, USE_INT64_STRIDES=USE_INT64_STRIDES, - NUM_XCD=get_num_xcds(), **config_dkdvdq, ) diff --git a/aiter/ops/triton/mha_onekernel_bwd.py b/aiter/ops/triton/mha_onekernel_bwd.py index 59e5e53ec6..0e1c36d6ea 100644 --- a/aiter/ops/triton/mha_onekernel_bwd.py +++ b/aiter/ops/triton/mha_onekernel_bwd.py @@ -90,12 +90,13 @@ def flash_attn_onekernel_backward( # get strides and shape if IS_VARLEN: - # Layout for q,k,v is thd ie [total tokens, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = ( + # Layout is thd. + # q and k are [total_tokens, num_head, head_dim_qk]. + # v is [total_tokens, num_head, head_dim_v]. + batch, seqlen_q, num_q_heads = ( len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], - q.shape[2], ) _, num_k_heads = max_seqlen_k, k.shape[1] q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) @@ -108,8 +109,10 @@ def flash_attn_onekernel_backward( dv_strides = (0, dv.stride(1), dv.stride(0), dv.stride(2)) do_strides = (0, do.stride(1), do.stride(0), do.stride(2)) else: - # Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = q.shape + # Layout is bshd. + # q and k are [batch, seq_len, num_head, head_dim_qk]. + # v is [batch, seq_len, num_head, head_dim_v] + batch, seqlen_q, num_q_heads = q.shape[:-1] _, num_k_heads = k.shape[1], k.shape[2] q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) @@ -120,10 +123,21 @@ def flash_attn_onekernel_backward( dv_strides = (dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3)) do_strides = (do.stride(0), do.stride(2), do.stride(1), do.stride(3)) + qk_head_dim = q.shape[-1] + v_head_dim = v.shape[-1] + pe_head_dim = qk_head_dim - v_head_dim # BLOCK_D_MODEL, BLOCK_D_MODEL_POW2 # padding for head_dim. Power of 2 or 16 - BLOCK_D_MODEL_POW2 = triton.next_power_of_2(head_sz) - BLOCK_D_MODEL_POW2 = max(BLOCK_D_MODEL_POW2, 16) + BLOCK_D_MODEL_POW2 = max(triton.next_power_of_2(v_head_dim), 16) + BLOCK_D_MODEL_PE_POW2 = ( + 0 if pe_head_dim == 0 else max(triton.next_power_of_2(pe_head_dim), 16) + ) + assert (pe_head_dim == 0 and BLOCK_D_MODEL_PE_POW2 == 0) or ( + v_head_dim == BLOCK_D_MODEL_POW2 and pe_head_dim == BLOCK_D_MODEL_PE_POW2 + ), "Positional encoding support requires NOPE and PE head sizes to be unpadded powers of 2." + assert (not IS_FP8) or ( + IS_FP8 and pe_head_dim == 0 + ), "Positional encoding doesn't support FP8." # Configs if config is None: @@ -156,7 +170,7 @@ def flash_attn_onekernel_backward( max_seqlen_q, descale_do, BLOCK_M=config["preprocess_kernel"]["PRE_BLOCK"], - BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL=v_head_dim, BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, IS_VARLEN=IS_VARLEN, IS_FP8=IS_FP8, @@ -177,7 +191,13 @@ def flash_attn_onekernel_backward( seqlen = max(max_seqlen_q, max_seqlen_k) - config_onekernel = config["onekernel"] + # "onekernel_pe" is for Positional Encoding (PE) causal case, it's going to be + # used if present. Otherwise, fallback to default "onekernel" config. + config_onekernel = ( + config["onekernel_pe"] + if (pe_head_dim > 0 and causal and "onekernel_pe" in config) + else config["onekernel"] + ) grid = ( num_k_heads, triton.cdiv(seqlen, config_onekernel["BLOCK_N1"]), @@ -223,8 +243,9 @@ def flash_attn_onekernel_backward( descale_k, descale_v, descale_do, - HEAD_DIM=head_sz, + HEAD_DIM=v_head_dim, ACTUAL_HEAD_DIM=BLOCK_D_MODEL_POW2, + PE_HEAD_DIM=pe_head_dim, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, USE_ALIBI=use_alibi, @@ -276,8 +297,9 @@ def flash_attn_onekernel_backward( descale_k, descale_v, descale_do, - HEAD_DIM=head_sz, + HEAD_DIM=v_head_dim, ACTUAL_HEAD_DIM=BLOCK_D_MODEL_POW2, + PE_HEAD_DIM=pe_head_dim, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, USE_ALIBI=use_alibi, diff --git a/op_tests/op_benchmarks/triton/bench_mha.py b/op_tests/op_benchmarks/triton/bench_mha.py index e661d0119f..cfebe1b39a 100644 --- a/op_tests/op_benchmarks/triton/bench_mha.py +++ b/op_tests/op_benchmarks/triton/bench_mha.py @@ -114,6 +114,7 @@ def create_benchmark_configs(custom, args): hk = args.hq if not args.hk else args.hk sk = args.sq if not args.sk else args.sk head_size = 128 if not args.d else args.d + head_size_v = head_size if not args.dv else args.dv mode = args.mode x_names = ["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K"] causal = args.causal @@ -121,7 +122,13 @@ def create_benchmark_configs(custom, args): configs = [] plot_name = get_caller_name_no_ext() - extra_args = {"D_HEAD": head_size, "dtype": dtype, "causal": causal, "mode": mode} + extra_args = { + "D_HEAD": head_size, + "D_HEAD_V": head_size_v, + "dtype": dtype, + "causal": causal, + "mode": mode, + } if custom: x_vals_list = [(args.b, args.hq, hk, args.sq, sk)] @@ -150,7 +157,7 @@ def create_benchmark_configs(custom, args): if args.fused_bwd: line_vals = [f"fused-bwd({unit})"] else: - line_vals = [f"fused-bwd({unit})", f"bwd({unit})"] + line_vals = [f"onekernel-bwd({unit})"] else: line_vals = [f"fwd({unit})"] @@ -161,7 +168,7 @@ def create_benchmark_configs(custom, args): if args.fused_bwd: line_vals = [f"fused-bwd({unit})"] else: - line_vals = [f"bwd({unit})"] + line_vals = [f"onekernel-bwd({unit})"] configs.append( triton.testing.Benchmark( @@ -190,6 +197,7 @@ def bench_mha( N_CTX_Q, N_CTX_K, D_HEAD, + D_HEAD_V, dtype, causal, mode, @@ -208,11 +216,16 @@ def bench_mha( return_lse = True return_attn_probs = False varlen = args.layout == "thd" + has_pe = D_HEAD > D_HEAD_V + assert not ( + args.fp8 and has_pe + ), "Positional Encoding (PE) doesn't support FP8 data type." + assert not ( + has_pe and "fused-bwd" in provider + ), "'Fused' backward implementation doesn't support Positional Encoding (PE)." global _USE_FUSED_BWD - fused_backward = "fused-bwd" in provider - mha_set_use_fused_bwd_kernel(fused_backward) # Default softmax scale to match standard attention @@ -227,65 +240,117 @@ def bench_mha( f"Testing kernel implementation <{provider}> against Torch with shape:" ) print( - f"BATCH={BATCH}, HQ={HQ}, HK={HK}, N_CTX_Q={N_CTX_Q}, N_CTX_K={N_CTX_K}, D_HEAD={D_HEAD}" + f"BATCH={BATCH}, HQ={HQ}, HK={HK}, N_CTX_Q={N_CTX_Q}, N_CTX_K={N_CTX_K}, D_HEAD={D_HEAD}, D_HEAD_V={D_HEAD_V}" ) - if varlen: - test_mha.test_mha( - BATCH, - N_CTX_Q, - N_CTX_K, - HQ, - HK, - D_HEAD, - dropout, - True, - True, - causal, - args.fp8, - dtype, - ) + if not varlen: + if not has_pe: + test_mha.test_mha( + BATCH, + N_CTX_Q, + N_CTX_K, + HQ, + HK, + D_HEAD, + dropout, + True, + True, + causal, + args.fp8, + dtype, + ) + else: + test_mha.test_mha_with_pe( + BATCH, + N_CTX_Q, + N_CTX_K, + HQ, + HK, + D_HEAD, + D_HEAD_V, + dropout, + causal, + ) print("Forward test passed!") - test_mha.test_mha_backward_varlen( - BATCH, - N_CTX_Q, - N_CTX_K, - HQ, - HK, - D_HEAD, - dropout, - causal, - args.fp8, - dtype, - ) + if not has_pe: + test_mha.test_mha_backward_varlen( + BATCH, + N_CTX_Q, + N_CTX_K, + HQ, + HK, + D_HEAD, + dropout, + causal, + args.fp8, + dtype, + ) + else: + test_mha.test_mha_backward_with_pe( + BATCH, + N_CTX_Q, + N_CTX_K, + HQ, + HK, + D_HEAD, + D_HEAD_V, + dropout, + causal, + ) print("Backward test passed!") else: - test_mha.test_mha_varlen( - BATCH, - N_CTX_Q, - N_CTX_K, - HQ, - HK, - D_HEAD, - dropout, - True, - True, - causal, - args.fp8, - dtype, - ) + if not has_pe: + test_mha.test_mha_varlen( + BATCH, + N_CTX_Q, + N_CTX_K, + HQ, + HK, + D_HEAD, + dropout, + True, + True, + causal, + args.fp8, + dtype, + ) + else: + test_mha.test_mha_varlen_with_pe( + BATCH, + N_CTX_Q, + N_CTX_K, + HQ, + HK, + D_HEAD, + D_HEAD_V, + dropout, + causal, + ) print("Forward test passed!") - test_mha.test_mha_backward( - BATCH, - N_CTX_Q, - N_CTX_K, - HQ, - HK, - D_HEAD, - dropout, - causal, - args.fp8, - dtype, - ) + if not has_pe: + test_mha.test_mha_backward( + BATCH, + N_CTX_Q, + N_CTX_K, + HQ, + HK, + D_HEAD, + dropout, + causal, + args.fp8, + dtype, + ) + else: + test_mha.test_mha_backward_varlen_with_pe( + BATCH, + N_CTX_Q, + N_CTX_K, + HQ, + HK, + D_HEAD, + D_HEAD_V, + dropout, + causal, + ) print("Backward test passed!") return 0 @@ -293,13 +358,13 @@ def bench_mha( # Generate base inputs q = torch.randn((BATCH, N_CTX_Q, HQ, D_HEAD), device=device, dtype=dtype) k = torch.randn((BATCH, N_CTX_K, HK, D_HEAD), device=device, dtype=dtype) - v = torch.randn((BATCH, N_CTX_K, HK, D_HEAD), device=device, dtype=dtype) + v = torch.randn((BATCH, N_CTX_K, HK, D_HEAD_V), device=device, dtype=dtype) q.requires_grad = requires_grad k.requires_grad = requires_grad v.requires_grad = requires_grad # FLOPS calculation variables - flops_per_matmul = 0 + total_flops = 0.0 # Input preparation if varlen: @@ -342,9 +407,9 @@ def bench_mha( if seqlen_q > seqlen_k else (seqlen_q * seqlen_k - ((seqlen_q**2 - seqlen_q) / 2)) ) - flops_per_matmul += valid_out_elements * HQ * D_HEAD * 2 + total_flops += valid_out_elements * HQ * (D_HEAD + D_HEAD_V) * 2.0 else: - flops_per_matmul += seqlen_q * seqlen_k * HQ * D_HEAD * 2 + total_flops += seqlen_q * seqlen_k * HQ * (D_HEAD + D_HEAD_V) * 2.0 else: q_input, k_input, v_input = q, k, v @@ -354,9 +419,13 @@ def bench_mha( if N_CTX_Q > N_CTX_K else (N_CTX_Q * N_CTX_K - ((N_CTX_Q**2 - N_CTX_Q) / 2)) ) - flops_per_matmul = 2.0 * BATCH * HQ * valid_out_elements * D_HEAD + total_flops += ( + 2.0 * BATCH * HQ * valid_out_elements * (D_HEAD + D_HEAD_V) + ) else: - flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD + total_flops += ( + 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * (D_HEAD + D_HEAD_V) + ) # Benchmark mode if varlen: @@ -441,23 +510,31 @@ def fn(): ms = triton.testing.do_bench(fn) - total_flops = 2 * flops_per_matmul if mode == "bwd": total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) - input_bytes = q.element_size() - output_bytes = q.element_size() if varlen: total_num_tokens_q = cu_seqlens_q[-1].item() total_num_tokens_k = cu_seqlens_k[-1].item() else: total_num_tokens_q = BATCH * N_CTX_Q total_num_tokens_k = BATCH * N_CTX_K - mem = ( - total_num_tokens_q * HQ * D_HEAD * input_bytes - + 2 * total_num_tokens_k * HK * D_HEAD * input_bytes - + total_num_tokens_q * HQ * D_HEAD * output_bytes - ) + q_size = total_num_tokens_q * HQ * D_HEAD * q.element_size() + k_size = total_num_tokens_k * HK * D_HEAD * k.element_size() + v_size = total_num_tokens_k * HK * D_HEAD_V * v.element_size() + o_size = total_num_tokens_q * HQ * D_HEAD_V * q.element_size() + if mode == "fwd": + # read q, k, v + mem_read = q_size + k_size + v_size + # write o + mem_write = o_size + else: + # read q, k, v, do + mem_read = q_size + k_size + v_size + o_size + # write dq, dk, dv + mem_write = q_size + k_size + v_size + mem = mem_read + mem_write + # return ms if "ms" in provider: return ms @@ -505,7 +582,13 @@ def parse_args(): default=False, help="If specified, uses equal sequence lengths with thd layout, i.e t = b * sq", ) - parser.add_argument("-d", type=int, default=0) + parser.add_argument( + "-d", + type=int, + default=0, + help="Q and K head size, if -dv is absent then -d specifies V head size too", + ) + parser.add_argument("-dv", type=int, default=0, help="optional V head size") parser.add_argument("-causal", type=str2bool, default=None) parser.add_argument("-fp8", action="store_true", default=False) parser.add_argument("-quantize_p", action="store_true", default=False) @@ -573,17 +656,19 @@ def main(): assert ( args.layout == "thd" or not args.equal_seqlens or args.model ), "Equal sequence lengths arg must be used with the thd layout or a model config." - if args.hq or args.hk or args.d: + if args.hq or args.hk or args.d or args.dv: custom_config = True + if not args.dv: + args.dv = args.d assert ( - args.b and args.hq and args.sq and args.d + args.b and args.hq and args.sq and args.d and args.dv ), "If custom config is specified, please provide \ all of batch, number of Q heads, Q sequence length \ and head size." if args.model: assert not ( - args.hq or args.hk or args.d + args.hq or args.hk or args.d or args.dv ), "Specifying model fixes hq, hk and d already. Do not provide them!" assert ( diff --git a/op_tests/triton_tests/test_mha.py b/op_tests/triton_tests/test_mha.py index 0b100d681c..da10db4313 100644 --- a/op_tests/triton_tests/test_mha.py +++ b/op_tests/triton_tests/test_mha.py @@ -81,7 +81,9 @@ def fp8_assert_close( max_abs_idx = torch.argmax(abs_diff).item() max_rel_idx = torch.argmax(rel_diff).item() - flat_to_idx = lambda flat_idx, shape: np.unravel_index(flat_idx, shape) + flat_to_idx = lambda flat_idx, shape: np.unravel_index( # noqa: E731 + flat_idx, shape + ) max_abs_pos = flat_to_idx(max_abs_idx, tensor_a.shape) max_rel_pos = flat_to_idx(max_rel_idx, tensor_a.shape) @@ -493,7 +495,7 @@ def test_mha_backward( torch.cuda.empty_cache() torch.manual_seed(20) - pytest.skip("Backward accuracy issues due to Triton compiler") + # pytest.skip("Backward accuracy issues due to Triton compiler") if FUSED and CAUSAL: pytest.skip("FUSED+CAUSAL results in NaNs") mha_set_use_fused_bwd_kernel(FUSED) @@ -632,7 +634,7 @@ def test_mha_backward_varlen( ): torch.cuda.empty_cache() torch.manual_seed(20) - pytest.skip("Backward accuracy issues due to Triton compiler") + # pytest.skip("Backward accuracy issues due to Triton compiler") if FUSED and CAUSAL: pytest.skip("FUSED+CAUSAL results in NaNs") @@ -781,3 +783,473 @@ def test_mha_backward_varlen( torch.testing.assert_close( triton_dv, torch_dv.to(triton_out.dtype), atol=1e-2, rtol=1e-2 ) + + +# Run PE tests with: +# pytest op_tests/triton_tests/test_mha.py -k with_pe + + +@pytest.mark.parametrize("BATCH", [1, 3]) +@pytest.mark.parametrize( + "SEQLEN_Q, SEQLEN_K", + [(128, 128), (32, 16), (16, 48), (4096, 4096)], +) +@pytest.mark.parametrize("NUM_Q_HEADS, NUM_K_HEADS", [(1, 1), (2, 1), (128, 128)]) +@pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(128, 64), (192, 128)]) +@pytest.mark.parametrize("DROPOUT", [0.0, 0.25]) +@pytest.mark.parametrize("CAUSAL", [True, False]) +def test_mha_with_pe( + BATCH: int, + SEQLEN_Q: int, + SEQLEN_K: int, + NUM_Q_HEADS: int, + NUM_K_HEADS: int, + HEAD_SZ_QK: int, + HEAD_SZ_V: int, + DROPOUT: float, + CAUSAL: bool, +): + HAS_DROPOUT: bool = DROPOUT > 0.0 + device: str = "cuda" + dtype: torch.dtype = torch.bfloat16 + + # Generate tensors + torch.cuda.empty_cache() + torch.manual_seed(20) + q = torch.randn( + (BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ_QK), device=device, dtype=dtype + ) + k = torch.randn( + (BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ_QK), device=device, dtype=dtype + ) + v = torch.randn( + (BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ_V), device=device, dtype=dtype + ) + + # Triton + triton_out = flash_attn_func( + q, + k, + v, + dropout_p=DROPOUT, + causal=CAUSAL, + return_lse=HAS_DROPOUT, + return_attn_probs=HAS_DROPOUT, + ) + if HAS_DROPOUT: + assert len(triton_out) == 3 + dropout_mask = triton_out[2] > 0 + triton_out = triton_out[0] + else: + dropout_mask = None + + # Torch + torch_out, _, _ = attention_ref( + q, + k, + v, + dropout_p=DROPOUT, + dropout_mask=dropout_mask, + causal=CAUSAL, + ) + + # Assertion + torch.testing.assert_close(triton_out, torch_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("BATCH", [1, 3]) +@pytest.mark.parametrize( + "SEQLEN_Q, SEQLEN_K", + [(16, 16), (32, 16), (64, 128), (4096, 4096)], +) +@pytest.mark.parametrize("NUM_Q_HEADS, NUM_K_HEADS", [(4, 4), (16, 4), (128, 128)]) +@pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(96, 64), (192, 128)]) +@pytest.mark.parametrize("DROPOUT", [0.0, 0.17]) +@pytest.mark.parametrize("CAUSAL", [True, False]) +def test_mha_varlen_with_pe( + BATCH: int, + SEQLEN_Q: int, + SEQLEN_K: int, + NUM_Q_HEADS: int, + NUM_K_HEADS: int, + HEAD_SZ_QK: int, + HEAD_SZ_V: int, + DROPOUT: float, + CAUSAL: bool, +): + HAS_DROPOUT: bool = DROPOUT > 0.0 + device: str = "cuda" + dtype: torch.dtype = torch.bfloat16 + + # Generate tensors + torch.cuda.empty_cache() + torch.manual_seed(77) + q = torch.randn( + (BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ_QK), device=device, dtype=dtype + ) + k = torch.randn( + (BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ_QK), device=device, dtype=dtype + ) + v = torch.randn( + (BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ_V), device=device, dtype=dtype + ) + query_padding_mask = generate_random_padding_mask(SEQLEN_Q, BATCH, device) + key_padding_mask = generate_random_padding_mask(SEQLEN_K, BATCH, device) + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + _, + _, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask) + + # Triton + triton_out = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=DROPOUT, + causal=CAUSAL, + return_lse=HAS_DROPOUT, + return_attn_probs=HAS_DROPOUT, + ) + if HAS_DROPOUT: + assert len(triton_out) == 3 + dropout_mask = ( + pad_rearrange_dropout_mask( + triton_out[2] > 0, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + SEQLEN_Q, + SEQLEN_K, + NUM_Q_HEADS, + ) + > 0 + ) + triton_out = triton_out[0] + else: + dropout_mask = None + triton_out = output_pad_fn(triton_out) + + # Torch + torch_out, _, _ = attention_ref( + q, + k, + v, + query_padding_mask=query_padding_mask, + key_padding_mask=key_padding_mask, + dropout_p=DROPOUT, + dropout_mask=dropout_mask, + causal=CAUSAL, + ) + + # Assertion + torch.testing.assert_close(triton_out, torch_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("BATCH", [1, 4]) +@pytest.mark.parametrize( + "SEQLEN_Q, SEQLEN_K", + [(16, 16), (32, 8), (64, 16), (2048, 2048)], +) +@pytest.mark.parametrize("NUM_Q_HEADS, NUM_K_HEADS", [(4, 4), (8, 2), (128, 128)]) +@pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(32, 16), (192, 128)]) +@pytest.mark.parametrize("DROPOUT", [0.0, 0.2]) +@pytest.mark.parametrize("CAUSAL", [True, False]) +def test_mha_backward_with_pe( + BATCH: int, + SEQLEN_Q: int, + SEQLEN_K: int, + NUM_Q_HEADS: int, + NUM_K_HEADS: int, + HEAD_SZ_QK: int, + HEAD_SZ_V: int, + DROPOUT: float, + CAUSAL: bool, +): + HAS_DROPOUT: bool = DROPOUT > 0.0 + + # Causal + Dropout use case is disabled in `test_mha_backward` and `test_mha_backward_varlen`. + # FIXME: We should fix it in the base implementation before adding PE to the mix. + if CAUSAL and HAS_DROPOUT: + pytest.skip( + "Causal + Dropout use case isn't supported in backward with Positional Encoding." + ) + + device: str = "cuda" + dtype: torch.dtype = torch.bfloat16 + + # Generate tensors + torch.cuda.empty_cache() + torch.manual_seed(63) + q = torch.randn( + (BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ_QK), + device=device, + dtype=dtype, + requires_grad=True, + ) + k = torch.randn( + (BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ_QK), + device=device, + dtype=dtype, + requires_grad=True, + ) + v = torch.randn( + (BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ_V), + device=device, + dtype=dtype, + requires_grad=True, + ) + do = torch.randn((q.shape[:-1] + v.shape[-1:]), dtype=dtype, device=device) + + # Triton forward + with torch.enable_grad(): + triton_out = flash_attn_func( + q, + k, + v, + dropout_p=DROPOUT, + causal=CAUSAL, + return_lse=HAS_DROPOUT, + return_attn_probs=HAS_DROPOUT, + ) + if HAS_DROPOUT: + assert len(triton_out) == 3 + dropout_mask = triton_out[2] > 0 + triton_out = triton_out[0] + else: + dropout_mask = None + + # Torch forward + with torch.enable_grad(): + torch_out, _, _ = attention_ref( + q, k, v, dropout_p=DROPOUT, dropout_mask=dropout_mask, causal=CAUSAL + ) + + # Forward assertion + torch.testing.assert_close( + triton_out, + torch_out, + atol=1e-2, + rtol=1e-2, + msg=lambda msg: f"fwd mismatch\n\n{msg}\n", + ) + + # Triton backward + # PE support isn't implemented in fused backward. + mha_set_use_fused_bwd_kernel(False) + triton_dq, triton_dk, triton_dv = torch.autograd.grad(triton_out, (q, k, v), do) + + # Torch backward + torch_dq, torch_dk, torch_dv = torch.autograd.grad(torch_out, (q, k, v), do) + + # Backward assertions + # When dropout is active, some cases fail due to less than 1% mismatched elements. + bwd_atol = 1e-1 if HAS_DROPOUT else 1.5e-2 + bwd_rtol = 1e-1 if HAS_DROPOUT else 1.5e-2 + torch.testing.assert_close( + triton_dq, + torch_dq, + atol=bwd_atol, + rtol=bwd_rtol, + msg=lambda msg: f"bwd dq mismatch\n\n{msg}\n", + ) + torch.testing.assert_close( + triton_dk, + torch_dk, + atol=bwd_atol, + rtol=bwd_rtol, + msg=lambda msg: f"bwd dk mismatch\n\n{msg}\n", + ) + torch.testing.assert_close( + triton_dv, + torch_dv, + atol=bwd_atol, + rtol=bwd_rtol, + msg=lambda msg: f"bwd dv mismatch\n\n{msg}\n", + ) + + +@pytest.mark.parametrize("BATCH", [1, 4]) +@pytest.mark.parametrize( + "SEQLEN_Q, SEQLEN_K", + [(8, 8), (32, 8), (16, 64), (64, 64)], +) +@pytest.mark.parametrize("NUM_Q_HEADS, NUM_K_HEADS", [(4, 4), (8, 2), (128, 128)]) +@pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(32, 16), (192, 128)]) +@pytest.mark.parametrize("DROPOUT", [0.0, 0.2]) +@pytest.mark.parametrize("CAUSAL", [True, False]) +def test_mha_backward_varlen_with_pe( + BATCH: int, + SEQLEN_Q: int, + SEQLEN_K: int, + NUM_Q_HEADS: int, + NUM_K_HEADS: int, + HEAD_SZ_QK: int, + HEAD_SZ_V: int, + DROPOUT: float, + CAUSAL: bool, +): + HAS_DROPOUT: bool = DROPOUT > 0.0 + + # Causal + Dropout use case is disabled in `test_mha_backward` and `test_mha_backward_varlen`. + # FIXME: We should fix it in the base implementation before adding PE to the mix. + if CAUSAL and HAS_DROPOUT: + pytest.skip( + "Causal + Dropout use case isn't supported in backward with Positional Encoding." + ) + + device: str = "cuda" + dtype: torch.dtype = torch.bfloat16 + + # Generate tensors + torch.cuda.empty_cache() + torch.manual_seed(133) + q = torch.randn( + (BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ_QK), + device=device, + dtype=dtype, + requires_grad=True, + ) + k = torch.randn( + (BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ_QK), + device=device, + dtype=dtype, + requires_grad=True, + ) + v = torch.randn( + (BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ_V), + device=device, + dtype=dtype, + requires_grad=True, + ) + query_padding_mask = generate_random_padding_mask(SEQLEN_Q, BATCH, device) + key_padding_mask = generate_random_padding_mask(SEQLEN_K, BATCH, device) + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask) + q_unpad.requires_grad = True + k_unpad.requires_grad = True + v_unpad.requires_grad = True + do = torch.randn((q.shape[:-1] + v.shape[-1:]), dtype=dtype, device=device) + + # Triton forward + with torch.enable_grad(): + triton_out = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=DROPOUT, + causal=CAUSAL, + return_lse=HAS_DROPOUT, + return_attn_probs=HAS_DROPOUT, + ) + if HAS_DROPOUT: + assert len(triton_out) == 3 + dropout_mask = ( + pad_rearrange_dropout_mask( + triton_out[2] > 0, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + SEQLEN_Q, + SEQLEN_K, + NUM_Q_HEADS, + ) + > 0 + ) + triton_out = triton_out[0] + else: + dropout_mask = None + triton_out = output_pad_fn(triton_out) + + # Torch forward + with torch.enable_grad(): + torch_out, _, _ = attention_ref( + q, + k, + v, + query_padding_mask=query_padding_mask, + key_padding_mask=key_padding_mask, + dropout_p=DROPOUT, + dropout_mask=dropout_mask, + causal=CAUSAL, + ) + + # Forward assertion + torch.testing.assert_close( + triton_out, + torch_out, + atol=1e-2, + rtol=1e-2, + msg=lambda msg: f"fwd mismatch\n\n{msg}\n", + ) + + # Triton backward + # PE support isn't implemented in fused backward. + mha_set_use_fused_bwd_kernel(False) + triton_dq, triton_dk, triton_dv = torch.autograd.grad( + triton_out, (q_unpad, k_unpad, v_unpad), do + ) + triton_dq = dq_pad_fn(triton_dq) + triton_dk = dk_pad_fn(triton_dk) + triton_dv = dk_pad_fn(triton_dv) + + # Torch backward + torch_dq, torch_dk, torch_dv = torch.autograd.grad(torch_out, (q, k, v), do) + + # Backward assertions + bwd_atol = 1e-1 + bwd_rtol = 1e-1 + torch.testing.assert_close( + triton_dq, + torch_dq, + atol=bwd_atol, + rtol=bwd_rtol, + msg=lambda msg: f"bwd dq mismatch\n\n{msg}\n", + ) + torch.testing.assert_close( + triton_dk, + torch_dk, + atol=bwd_atol, + rtol=bwd_rtol, + msg=lambda msg: f"bwd dk mismatch\n\n{msg}\n", + ) + torch.testing.assert_close( + triton_dv, + torch_dv, + atol=bwd_atol, + rtol=bwd_rtol, + msg=lambda msg: f"bwd dv mismatch\n\n{msg}\n", + )