From ac0a11e63cb131bd74423acc610d1e2bef57219d Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Mon, 13 Oct 2025 11:10:38 -0300 Subject: [PATCH 01/18] Add PE support to forward MHA kernel --- aiter/ops/triton/_triton_kernels/mha.py | 100 +++++++++++++++++++++++- aiter/ops/triton/mha.py | 42 +++++++--- 2 files changed, 127 insertions(+), 15 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/mha.py b/aiter/ops/triton/_triton_kernels/mha.py index 9fdd6cbdc0..9f8e66a8a1 100644 --- a/aiter/ops/triton/_triton_kernels/mha.py +++ b/aiter/ops/triton/_triton_kernels/mha.py @@ -79,7 +79,9 @@ def _attn_fwd_inner( l_i, m_i, q, + q_pe, k_ptrs, + k_pe_ptrs, v_ptrs, stride_kn, stride_vk, @@ -107,6 +109,7 @@ def _attn_fwd_inner( BLOCK_N: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL_POW2: tl.constexpr, + BLOCK_DMODEL_PE: tl.constexpr, # it's zero or a power of 2 SM_SCALE: tl.constexpr, IS_CAUSAL: tl.constexpr, MASK_STEPS: tl.constexpr, @@ -115,12 +118,17 @@ def _attn_fwd_inner( PADDED_HEAD: tl.constexpr, IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, + ENABLE_PIPELINING: tl.constexpr, ): RCP_LN2: tl.constexpr = 1.4426950408889634 + HAS_PE: tl.constexpr = BLOCK_DMODEL_PE > 0 # loop over k, v, and update accumulator - for start_n in range(block_min, block_max, BLOCK_N): + num_stages: tl.constexpr = ( + None if ENABLE_PIPELINING else 1 + ) # Set num_stages==1 if we want to disable pipelining + for start_n in tl.range(block_min, block_max, BLOCK_N, num_stages=num_stages): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. if MASK_STEPS: @@ -129,6 +137,14 @@ def _attn_fwd_inner( k_offs_n = None k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL_POW2) k = _load_fn(k_ptrs, k_offs_k, k_offs_n, BLOCK_DMODEL, seqlen_k) + if HAS_PE: + k_pe = _load_fn( + k_pe_ptrs, + None, + k_offs_n, + (BLOCK_DMODEL + BLOCK_DMODEL_PE), + seqlen_k, + ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need @@ -163,6 +179,8 @@ def _attn_fwd_inner( qk += tl.dot(q, k) * descale_q * descale_k else: qk += tl.dot(q, k) + if HAS_PE: + qk += tl.dot(q_pe, k_pe) if IS_CAUSAL: causal_boundary = start_n + offs_n_causal @@ -229,6 +247,8 @@ def _attn_fwd_inner( acc += tl.dot(p.to(v.type.element_ty), v) k_ptrs += BLOCK_N * stride_kn + if HAS_PE: + k_pe_ptrs += BLOCK_N * stride_kn v_ptrs += BLOCK_N * stride_vk if RETURN_SCORES: sd_mask_ptrs += BLOCK_N * stride_sn @@ -296,6 +316,7 @@ def _attn_fwd( BLOCK_N: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL_POW2: tl.constexpr, + BLOCK_DMODEL_PE: tl.constexpr, # it's zero or a power of 2 RETURN_SCORES: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, IS_FP8: tl.constexpr, @@ -321,6 +342,9 @@ def _attn_fwd( offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL_POW2) + HAS_PE: tl.constexpr = BLOCK_DMODEL_PE > 0 + if HAS_PE: + offs_pe = BLOCK_DMODEL + tl.arange(0, BLOCK_DMODEL_PE) # NOTE: # Workaround for int64 strides, In the absence of strides being int64, parts of the offset @@ -395,6 +419,38 @@ def _attn_fwd( stride_lse_h = stride_lse_h_in stride_lse_m = stride_lse_m_in + tl.assume(stride_qz_in >= 0) + tl.assume(stride_qh_in >= 0) + tl.assume(stride_qm_in >= 0) + tl.assume(stride_qk_in >= 0) + tl.assume(stride_kz_in >= 0) + tl.assume(stride_kh_in >= 0) + tl.assume(stride_kn_in >= 0) + tl.assume(stride_kk_in >= 0) + tl.assume(stride_vz_in >= 0) + tl.assume(stride_vh_in >= 0) + tl.assume(stride_vn_in >= 0) + tl.assume(stride_vk_in >= 0) + if IS_FP8: + tl.assume(stride_descale_q_z_in >= 0) + tl.assume(stride_descale_k_z_in >= 0) + tl.assume(stride_descale_v_z_in >= 0) + tl.assume(stride_oz_in >= 0) + tl.assume(stride_oh_in >= 0) + tl.assume(stride_om_in >= 0) + tl.assume(stride_on_in >= 0) + tl.assume(stride_alibi_z_in >= 0) + tl.assume(stride_alibi_h_in >= 0) + # NOTE: philox offset is need in dropout pointer calculations + tl.assume(philox_offset_base_in >= 0) + tl.assume(stride_sd_z_in >= 0) + tl.assume(stride_sd_h_in >= 0) + tl.assume(stride_sd_m_in >= 0) + tl.assume(stride_sd_n_in >= 0) + tl.assume(stride_lse_z_in >= 0) + tl.assume(stride_lse_h_in >= 0) + tl.assume(stride_lse_m_in >= 0) + if VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) @@ -479,6 +535,17 @@ def _attn_fwd( + offs_d[None, :] * stride_qk ) q_ptrs = q_ptr + q_offs + if HAS_PE: + q_pe_offs = ( + off_z * stride_qz + + off_q_head * stride_qh + + cu_seqlens_q_start * stride_qm + + offs_m[:, None] * stride_qm + + offs_pe[None, :] * stride_qk + ) + q_pe_ptrs = q_ptr + q_pe_offs + else: + q_pe_ptrs = None k_offs = ( off_z * stride_kz @@ -488,6 +555,17 @@ def _attn_fwd( + offs_n[None, :] * stride_kn ) k_ptrs = k_ptr + k_offs + if HAS_PE: + k_pe_offs = ( + off_z * stride_kz + + off_k_head * stride_kh + + cu_seqlens_k_start * stride_kn + + offs_pe[:, None] * stride_kk + + offs_n[None, :] * stride_kn + ) + k_pe_ptrs = k_ptr + k_pe_offs + else: + k_pe_ptrs = None v_offs = ( off_z * stride_vz @@ -545,6 +623,11 @@ def _attn_fwd( else: q_mask = (offs_m[:, None] < seqlen_q) & (offs_d[None, :] < BLOCK_DMODEL) q = tl.load(q_ptrs, mask=q_mask, other=0.0) + if HAS_PE: + q_pe = tl.load(q_pe_ptrs, mask=q_mask, other=0.0) + else: + q_pe = None + if IS_FP8: descale_q = tl.load(descale_q_ptr + off_z * stride_descale_q_z + off_q_head) descale_k = tl.load(descale_k_ptr + off_z * stride_descale_k_z + off_k_head) @@ -584,7 +667,9 @@ def _attn_fwd( l_i, m_i, q, + q_pe, k_ptrs, + k_pe_ptrs, v_ptrs, stride_kn, stride_vn, @@ -612,6 +697,7 @@ def _attn_fwd( BLOCK_N, BLOCK_DMODEL, BLOCK_DMODEL_POW2, + BLOCK_DMODEL_PE, sm_scale, False, MASK_STEPS=False, @@ -620,6 +706,7 @@ def _attn_fwd( PADDED_HEAD=BLOCK_DMODEL != BLOCK_DMODEL_POW2, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, + ENABLE_PIPELINING=True, ) block_min = block_max block_max = n_blocks * BLOCK_N @@ -631,6 +718,8 @@ def _attn_fwd( else: offs_n_causal = 0 k_ptrs += n_full_blocks * BLOCK_N * stride_kn + if HAS_PE: + k_pe_ptrs += n_full_blocks * BLOCK_N * stride_kn v_ptrs += n_full_blocks * BLOCK_N * stride_vn if RETURN_SCORES: s_dmask_ptrs += n_full_blocks * BLOCK_N * stride_sd_n @@ -641,7 +730,9 @@ def _attn_fwd( l_i, m_i, q, + q_pe, k_ptrs, + k_pe_ptrs, v_ptrs, stride_kn, stride_vn, @@ -669,6 +760,7 @@ def _attn_fwd( BLOCK_N, BLOCK_DMODEL, BLOCK_DMODEL_POW2, + BLOCK_DMODEL_PE, sm_scale, IS_CAUSAL, MASK_STEPS=True, @@ -677,6 +769,7 @@ def _attn_fwd( PADDED_HEAD=BLOCK_DMODEL != BLOCK_DMODEL_POW2, IS_FP8=IS_FP8, FP8_MAX=FP8_MAX, + ENABLE_PIPELINING=False, ) # epilogue # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. @@ -759,6 +852,7 @@ def _attn_fwd( def _get_config( enable_dropout: bool, dtype: torch.dtype, + has_pe: bool = False, ): if not hasattr(_get_config, "_config_dict"): dev = arch_info.get_device() @@ -768,7 +862,9 @@ def _get_config( config = json.load(file) _get_config._config_dict["default"] = config - if enable_dropout or dtype == torch.float32: + if has_pe and "pe" in _get_config._config_dict["default"]["fwd"]: + return _get_config._config_dict["default"]["fwd"]["pe"] + elif enable_dropout or dtype == torch.float32: return _get_config._config_dict["default"]["fwd"]["dropout_or_fp32"] else: return _get_config._config_dict["default"]["fwd"]["default"] diff --git a/aiter/ops/triton/mha.py b/aiter/ops/triton/mha.py index ca2d75d586..1c84a1d742 100644 --- a/aiter/ops/triton/mha.py +++ b/aiter/ops/triton/mha.py @@ -168,35 +168,50 @@ def _flash_attn_forward( is_varlen = True if cu_seqlens_q is not None else False if IS_FP8: - o = torch.zeros_like(q, dtype=torch.float32) + o = torch.zeros( + (q.shape[:-1] + v.shape[-1:]), dtype=torch.float32, device=q.device + ) else: - o = torch.zeros_like(q) + o = torch.zeros((q.shape[:-1] + v.shape[-1:]), dtype=q.dtype, device=q.device) if is_varlen: - # Layout for q,k,v is thd ie [total_tokens, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = ( + # Layout is thd. + # q and k are [total_tokens, num_head, head_dim_qk]. + # v is [total_tokens, num_head, head_dim_v]. + batch, seqlen_q, num_q_heads = ( len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], - q.shape[2], ) - seqlen_k, num_k_heads = max_seqlen_k, k.shape[1] + num_k_heads = k.shape[1] q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) else: - # Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = q.shape - seqlen_k = k.shape[1] + # Layout is bshd. + # q and k are [batch, seq_len, num_head, head_dim_qk]. + # v is [batch, seq_len, num_head, head_dim_v]. + batch, seqlen_q, num_q_heads = q.shape[:-1] num_k_heads = k.shape[2] q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + qk_head_dim = q.shape[-1] + v_head_dim = v.shape[-1] + pe_head_dim = qk_head_dim - v_head_dim # padding for head_dim. Power of 2 or 16 - BLOCK_DMODEL_POW2 = triton.next_power_of_2(head_sz) - BLOCK_DMODEL_POW2 = max(BLOCK_DMODEL_POW2, 16) + BLOCK_DMODEL_POW2 = max(triton.next_power_of_2(v_head_dim), 16) + BLOCK_DMODEL_PE_POW2 = ( + 0 if pe_head_dim == 0 else max(triton.next_power_of_2(pe_head_dim), 16) + ) + assert (pe_head_dim == 0 and BLOCK_DMODEL_PE_POW2 == 0) or ( + v_head_dim == BLOCK_DMODEL_POW2 and pe_head_dim == BLOCK_DMODEL_PE_POW2 + ), "Positional encoding support requires NOPE and PE head sizes to be unpadded powers of 2." + assert (not IS_FP8) or ( + IS_FP8 and pe_head_dim == 0 + ), "Positional encoding doesn't support FP8." # softmax_lse [batch, num_q_heads, seqlen_q] if is_varlen: @@ -242,7 +257,7 @@ def _flash_attn_forward( dropout_mask = None if config is None: - config = _get_config(enable_dropout, q.dtype) + config = _get_config(enable_dropout, q.dtype, has_pe=pe_head_dim > 0) """ # Tuned for MI300x @@ -309,8 +324,9 @@ def _flash_attn_forward( IS_CAUSAL=causal, NUM_Q_HEADS=num_q_heads, NUM_K_HEADS=num_k_heads, - BLOCK_DMODEL=head_sz, + BLOCK_DMODEL=v_head_dim, BLOCK_DMODEL_POW2=BLOCK_DMODEL_POW2, + BLOCK_DMODEL_PE=pe_head_dim, RETURN_SCORES=return_softmax, ENABLE_DROPOUT=enable_dropout, IS_FP8=IS_FP8, From e5fc9554155fb8579e1b30bb29d61aa8aa47e1e3 Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Fri, 10 Oct 2025 14:58:44 -0300 Subject: [PATCH 02/18] Add MHA PE forward unit tests --- op_tests/triton_tests/test_mha.py | 180 ++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) diff --git a/op_tests/triton_tests/test_mha.py b/op_tests/triton_tests/test_mha.py index 0b100d681c..376619fa7c 100644 --- a/op_tests/triton_tests/test_mha.py +++ b/op_tests/triton_tests/test_mha.py @@ -781,3 +781,183 @@ def test_mha_backward_varlen( torch.testing.assert_close( triton_dv, torch_dv.to(triton_out.dtype), atol=1e-2, rtol=1e-2 ) + + +# Run PE tests with: +# pytest op_tests/triton_tests/test_mha.py -k with_pe + + +@pytest.mark.parametrize("BATCH", [1, 3]) +@pytest.mark.parametrize( + "SEQLEN_Q, SEQLEN_K", + [(1, 1), (4, 4), (128, 128), (2, 1), (1, 2), (32, 16), (16, 48), (4096, 4096)], +) +@pytest.mark.parametrize( + "NUM_Q_HEADS, NUM_K_HEADS", [(1, 1), (4, 4), (2, 1), (128, 128)] +) +@pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(48, 32), (128, 64), (192, 128)]) +@pytest.mark.parametrize("DROPOUT", [0.0, 0.25]) +@pytest.mark.parametrize("CAUSAL", [True, False]) +def test_mha_with_pe( + BATCH: int, + SEQLEN_Q: int, + SEQLEN_K: int, + NUM_Q_HEADS: int, + NUM_K_HEADS: int, + HEAD_SZ_QK: int, + HEAD_SZ_V: int, + DROPOUT: float, + CAUSAL: bool, +): + HAS_DROPOUT: bool = DROPOUT > 0.0 + device: str = "cuda" + dtype: torch.dtype = torch.bfloat16 + + # Generate tensors + torch.cuda.empty_cache() + torch.manual_seed(20) + q = torch.randn( + (BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ_QK), device=device, dtype=dtype + ) + k = torch.randn( + (BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ_QK), device=device, dtype=dtype + ) + v = torch.randn( + (BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ_V), device=device, dtype=dtype + ) + + # Triton + triton_out = flash_attn_func( + q, + k, + v, + dropout_p=DROPOUT, + causal=CAUSAL, + return_lse=HAS_DROPOUT, + return_attn_probs=HAS_DROPOUT, + ) + if HAS_DROPOUT: + assert len(triton_out) == 3 + dropout_mask = triton_out[2] > 0 + triton_out = triton_out[0] + else: + dropout_mask = None + + # Torch + torch_out, _ = attention_ref( + q, + k, + v, + dropout_p=DROPOUT, + dropout_mask=dropout_mask, + causal=CAUSAL, + ) + + # Assertion + torch.testing.assert_close(triton_out, torch_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("BATCH", [1, 3]) +@pytest.mark.parametrize( + "SEQLEN_Q, SEQLEN_K", + [(1, 1), (2, 2), (4, 1), (1, 4), (16, 16), (32, 16), (64, 128), (4096, 4096)], +) +@pytest.mark.parametrize( + "NUM_Q_HEADS, NUM_K_HEADS", [(1, 1), (4, 4), (16, 4), (128, 128)] +) +@pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(32, 16), (96, 64), (192, 128)]) +@pytest.mark.parametrize("DROPOUT", [0.0, 0.17]) +@pytest.mark.parametrize("CAUSAL", [True, False]) +def test_mha_varlen_with_pe( + BATCH: int, + SEQLEN_Q: int, + SEQLEN_K: int, + NUM_Q_HEADS: int, + NUM_K_HEADS: int, + HEAD_SZ_QK: int, + HEAD_SZ_V: int, + DROPOUT: float, + CAUSAL: bool, +): + HAS_DROPOUT: bool = DROPOUT > 0.0 + device: str = "cuda" + dtype: torch.dtype = torch.bfloat16 + + # Generate tensors + torch.cuda.empty_cache() + torch.manual_seed(77) + q = torch.randn( + (BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ_QK), device=device, dtype=dtype + ) + k = torch.randn( + (BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ_QK), device=device, dtype=dtype + ) + v = torch.randn( + (BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ_V), device=device, dtype=dtype + ) + query_padding_mask = generate_random_padding_mask(SEQLEN_Q, BATCH, device) + key_padding_mask = generate_random_padding_mask(SEQLEN_K, BATCH, device) + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + _, + _, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask) + + # Triton + triton_out = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=DROPOUT, + causal=CAUSAL, + return_lse=HAS_DROPOUT, + return_attn_probs=HAS_DROPOUT, + ) + if HAS_DROPOUT: + assert len(triton_out) == 3 + dropout_mask = ( + pad_rearrange_dropout_mask( + triton_out[2] > 0, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + SEQLEN_Q, + SEQLEN_K, + NUM_Q_HEADS, + ) + > 0 + ) + triton_out = triton_out[0] + else: + dropout_mask = None + triton_out = output_pad_fn(triton_out) + + # Torch + torch_out, _ = attention_ref( + q, + k, + v, + query_padding_mask=query_padding_mask, + key_padding_mask=key_padding_mask, + dropout_p=DROPOUT, + dropout_mask=dropout_mask, + causal=CAUSAL, + ) + + # Assertion + torch.testing.assert_close(triton_out, torch_out, atol=1e-2, rtol=1e-2) From ddcdce259d3bfc141e5ebd6416192f5b38530eab Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Fri, 10 Oct 2025 14:59:51 -0300 Subject: [PATCH 03/18] Add best config for MHA PE forward --- aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json | 8 ++++++++ aiter/ops/triton/configs/MI350X-MHA-DEFAULT.json | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json b/aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json index de5c09cd1e..cd5c8f3540 100644 --- a/aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json +++ b/aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json @@ -15,6 +15,14 @@ "num_warps": 4, "num_ctas": 1, "num_stages": 1 + }, + "pe": { + "BLOCK_M": 256, + "BLOCK_N": 64, + "waves_per_eu": 1, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 1 } }, "bkwd_fused" : { diff --git a/aiter/ops/triton/configs/MI350X-MHA-DEFAULT.json b/aiter/ops/triton/configs/MI350X-MHA-DEFAULT.json index 9cc497755b..d853fa8159 100644 --- a/aiter/ops/triton/configs/MI350X-MHA-DEFAULT.json +++ b/aiter/ops/triton/configs/MI350X-MHA-DEFAULT.json @@ -15,6 +15,14 @@ "num_warps": 4, "num_ctas": 1, "num_stages": 1 + }, + "pe": { + "BLOCK_M": 256, + "BLOCK_N": 64, + "waves_per_eu": 2, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 4 } }, "bkwd_fused" : { From 383fe84208f5fba7daefb4ae2027bd60f91f8874 Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Fri, 10 Oct 2025 15:00:29 -0300 Subject: [PATCH 04/18] Enable MHA backward unit tests --- op_tests/triton_tests/test_mha.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/op_tests/triton_tests/test_mha.py b/op_tests/triton_tests/test_mha.py index 376619fa7c..209aa99d72 100644 --- a/op_tests/triton_tests/test_mha.py +++ b/op_tests/triton_tests/test_mha.py @@ -493,7 +493,7 @@ def test_mha_backward( torch.cuda.empty_cache() torch.manual_seed(20) - pytest.skip("Backward accuracy issues due to Triton compiler") + # pytest.skip("Backward accuracy issues due to Triton compiler") if FUSED and CAUSAL: pytest.skip("FUSED+CAUSAL results in NaNs") mha_set_use_fused_bwd_kernel(FUSED) @@ -632,7 +632,7 @@ def test_mha_backward_varlen( ): torch.cuda.empty_cache() torch.manual_seed(20) - pytest.skip("Backward accuracy issues due to Triton compiler") + # pytest.skip("Backward accuracy issues due to Triton compiler") if FUSED and CAUSAL: pytest.skip("FUSED+CAUSAL results in NaNs") From 043ca6112726f68dc01a3cd74b1395c411606a9d Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Mon, 13 Oct 2025 11:48:58 -0300 Subject: [PATCH 05/18] Add PE support to backward MHA kernels * Only "one kernel" backward implementation supports PE, "fused" backward implementation lacks PE support. --- .../_triton_kernels/mha_onekernel_bwd.py | 168 ++++++++++++++++-- aiter/ops/triton/mha_fused_bwd.py | 5 + aiter/ops/triton/mha_onekernel_bwd.py | 44 +++-- 3 files changed, 188 insertions(+), 29 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/mha_onekernel_bwd.py b/aiter/ops/triton/_triton_kernels/mha_onekernel_bwd.py index 9808636501..52f5bea82c 100644 --- a/aiter/ops/triton/_triton_kernels/mha_onekernel_bwd.py +++ b/aiter/ops/triton/_triton_kernels/mha_onekernel_bwd.py @@ -108,10 +108,12 @@ def _bwd_preprocess( # The main inner-loop logic for computing dK and dV. @triton.jit def _bwd_dkdv_inner( - dk, + dk, # output + dk_pe, # optional output, pass None for non-PE case dv, # output Q, k, + k_pe, v, DO, M, @@ -128,6 +130,7 @@ def _bwd_dkdv_inner( BLOCK_N: tl.constexpr, # 128 HEAD_DIM: tl.constexpr, # ACTUAL_HEAD_DIM: tl.constexpr, # + PE_HEAD_DIM: tl.constexpr, dropout_p, philox_seed, batch_philox_offset, @@ -154,15 +157,20 @@ def _bwd_dkdv_inner( ): # if HEAD_DIM is padded PADDED_HEAD: tl.constexpr = ACTUAL_HEAD_DIM != HEAD_DIM + HAS_PE: tl.constexpr = PE_HEAD_DIM > 0 delta_qk = seqlen_q - seqlen_k offs_m = start_m + tl.arange(0, BLOCK_M) # start_m + (0, 15) offs_n = start_n + tl.arange(0, BLOCK_N) # start_m + (0, 127) offs_k = tl.arange(0, HEAD_DIM) + if HAS_PE: + offs_k_pe = HEAD_DIM + tl.arange(0, PE_HEAD_DIM) # mask to make sure not OOB of seqlen_q mask_n = offs_n < seqlen_k # Q and DO are (seqlen_q, head_dim) # qT_ptrs = (1, BLOCK_M) + (HEAD_DIM, 1), transpose of q qT_ptrs = Q + offs_m[None, :] * stride_qm + offs_k[:, None] * stride_qk + if HAS_PE: + qT_pe_ptrs = Q + offs_m[None, :] * stride_qm + offs_k_pe[:, None] * stride_qk # do_ptrs = (BLOCK_M, 1) + (1, HEAD_DIM), NOT transposed do_ptrs = DO + offs_m[:, None] * stride_dom + offs_k[None, :] * stride_dok # BLOCK_N must be a multiple of BLOCK_M, otherwise the code wouldn't work. @@ -186,6 +194,8 @@ def _bwd_dkdv_inner( mask_qT &= offs_k[:, None] < ACTUAL_HEAD_DIM mask_do &= offs_k[None, :] < ACTUAL_HEAD_DIM qT = tl.load(qT_ptrs, mask=mask_qT, other=0.0) + if HAS_PE: + qT_pe = tl.load(qT_pe_ptrs, mask=mask_qT, other=0.0) # generate dropout mask if ENABLE_DROPOUT: # NOTE: dropout is transposed because it is used to mask pT @@ -210,6 +220,8 @@ def _bwd_dkdv_inner( qkT = tl.dot(k, qT) * descale_q * descale_k else: qkT = tl.dot(k, qT) + if HAS_PE: + qkT += tl.dot(k_pe, qT_pe) qkT_scaled = qkT * sm_scale if USE_ALIBI: @@ -291,18 +303,24 @@ def _bwd_dkdv_inner( ) else: dk += tl.dot(dsT.to(qT.type.element_ty), tl.trans(qT)) + if HAS_PE: + dk_pe += tl.dot(dsT.to(qT_pe.type.element_ty), tl.trans(qT_pe)) # Increment pointers. curr_m += step_m qT_ptrs += step_m * stride_qm + if HAS_PE: + qT_pe_ptrs += step_m * stride_qm do_ptrs += step_m * stride_dom - return dk, dv + return dk, dk_pe, dv # the main inner-loop logic for computing dQ @triton.jit def _bwd_dq_inner( dq, # output + dq_pe, # optional output, pass None for non-PE case q, + q_pe, K, V, do, @@ -325,6 +343,7 @@ def _bwd_dq_inner( BLOCK_N2: tl.constexpr, # HEAD_DIM: tl.constexpr, ACTUAL_HEAD_DIM: tl.constexpr, # + PE_HEAD_DIM: tl.constexpr, dropout_p, philox_seed, batch_philox_offset, @@ -350,15 +369,20 @@ def _bwd_dq_inner( ): # if HEAD_DIM is padded PADDED_HEAD: tl.constexpr = ACTUAL_HEAD_DIM != HEAD_DIM + HAS_PE: tl.constexpr = PE_HEAD_DIM > 0 delta_qk = seqlen_q - seqlen_k offs_m = start_m + tl.arange(0, BLOCK_M2) offs_n = start_n + tl.arange(0, BLOCK_N2) offs_k = tl.arange(0, HEAD_DIM) + if HAS_PE: + offs_k_pe = HEAD_DIM + tl.arange(0, PE_HEAD_DIM) # mask to make sure not OOB of seqlen_q mask_m = offs_m < seqlen_q kT_ptrs = K + offs_n[None, :] * stride_kn + offs_k[:, None] * stride_kk + if HAS_PE: + kT_pe_ptrs = K + offs_n[None, :] * stride_kn + offs_k_pe[:, None] * stride_kk vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk # D (= delta) is pre-divided by ds_scale. Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0) @@ -388,6 +412,8 @@ def _bwd_dq_inner( mask_kT &= offs_k[:, None] < ACTUAL_HEAD_DIM kT = tl.load(kT_ptrs, mask=mask_kT, other=0.0) + if HAS_PE: + kT_pe = tl.load(kT_pe_ptrs, mask=mask_kT, other=0.0) vT = tl.load(vT_ptrs, mask=mask_kT, other=0.0) if ENABLE_DROPOUT: @@ -412,6 +438,8 @@ def _bwd_dq_inner( qk = tl.dot(q, kT) * descale_q * descale_k else: qk = tl.dot(q, kT) + if HAS_PE: + qk += tl.dot(q_pe, kT_pe) qk_scaled = qk * sm_scale if USE_ALIBI: @@ -451,11 +479,15 @@ def _bwd_dq_inner( ) else: dq += tl.dot(ds.to(kT.type.element_ty), tl.trans(kT)) + if HAS_PE: + dq_pe += tl.dot(ds.to(kT_pe.type.element_ty), tl.trans(kT_pe)) # Increment pointers. curr_n += step_n kT_ptrs += step_n * stride_kn + if HAS_PE: + kT_pe_ptrs += step_n * stride_kn vT_ptrs += step_n * stride_vn - return dq + return dq, dq_pe @triton.jit @@ -533,6 +565,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea BLK_SLICE_FACTOR: tl.constexpr, HEAD_DIM: tl.constexpr, ACTUAL_HEAD_DIM: tl.constexpr, + PE_HEAD_DIM: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, IS_VARLEN: tl.constexpr, USE_ALIBI: tl.constexpr, @@ -656,7 +689,10 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea if DEBUG_TRITON: print(f"delta_qk = {delta_qk}") # noqa: E701 PADDED_HEAD: tl.constexpr = ACTUAL_HEAD_DIM != HEAD_DIM + HAS_PE: tl.constexpr = PE_HEAD_DIM > 0 offs_d = tl.arange(0, HEAD_DIM) + if HAS_PE: + offs_d_pe = HEAD_DIM + tl.arange(0, PE_HEAD_DIM) GROUP_SIZE: tl.constexpr = HQ // HK # align the delta_qk @@ -664,6 +700,11 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea if start_n < seqlen_k: # This section does dk and dv dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + if HAS_PE: + dk_pe = tl.zeros([BLOCK_N1, PE_HEAD_DIM], dtype=tl.float32) + else: + # Couldn't assign None to dk_pe because _bwd_dkdv_inner can't return None. + dk_pe = dk dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) # q > k: diretcly skip all the way until the start of causal block @@ -703,6 +744,14 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd ) + if HAS_PE: + adj_k_pe = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_d_pe[None, :] * stride_kd + ) adj_v = ( bid * stride_vb + hkid * stride_vh @@ -712,6 +761,10 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea ) # load K and V: they stay in SRAM throughout the inner loop. k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + if HAS_PE: + k_pe = tl.load(K + adj_k_pe, mask=mask_kv, other=0.0) + else: + k_pe = None v = tl.load(V + adj_v, mask=mask_kv, other=0.0) # If MQA / GQA, set the K and V head offsets appropriately. # hqid = hkid @@ -781,11 +834,13 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea print( f"Masked: start_n: {start_n}; start_m: {start_m}, num_steps: {num_steps}" ) # noqa: E701 - dk, dv = _bwd_dkdv_inner( - dk, - dv, # output tensors + dk, dk_pe, dv = _bwd_dkdv_inner( + dk, # output tensor + dk_pe, # optional output tensor + dv, # output tensor Q_ptr, k, + k_pe, v, DO_ptr, M_ptr, @@ -802,6 +857,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea BLOCK_N1, # block dim HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + PE_HEAD_DIM, dropout_p, philox_seed, batch_philox_offset, @@ -839,11 +895,13 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea ) # noqa: E701 if DEBUG_TRITON: print("unMasked") # noqa: E701 - dk, dv = _bwd_dkdv_inner( - dk, - dv, # output tensors + dk, dk_pe, dv = _bwd_dkdv_inner( + dk, # output tensor + dk_pe, # optional output tensor + dv, # output tensor Q_ptr, k, + k_pe, v, DO_ptr, M_ptr, @@ -860,6 +918,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea BLOCK_N1, # block dim HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + PE_HEAD_DIM, dropout_p, philox_seed, batch_philox_offset, @@ -893,6 +952,10 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd dk *= sm_scale tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) + if HAS_PE: + offs_dk_pe = offs_n[:, None] * stride_dkn + offs_d_pe[None, :] * stride_dkd + dk_pe *= sm_scale + tl.store(DK + adj_dk + offs_dk_pe, dk_pe, mask=mask_kv) # This part does dq start_m = pid * BLOCK_M2 @@ -916,6 +979,8 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea mask_d = offs_d < ACTUAL_HEAD_DIM mask_q &= mask_d[None, :] offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + if HAS_PE: + offs_q_pe = offs_m[:, None] * stride_qm + offs_d_pe[None, :] * stride_qd offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod # NOTE: don't assume that the strides for k and v are the same! K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn @@ -956,6 +1021,10 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea Dropout_mask + bid * stride_dropoutb + hqid * stride_dropouth ) q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + if HAS_PE: + q_pe = tl.load(Q + adj_q + offs_q_pe, mask=mask_q, other=0.0) + else: + q_pe = None do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q) m = m[:, None] @@ -974,9 +1043,15 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) - dq = _bwd_dq_inner( - dq, + if HAS_PE: + dq_pe = tl.zeros([BLOCK_M2, PE_HEAD_DIM], dtype=tl.float32) + else: + dq_pe = dq # Couldn't assign None to dq_pe because _bwd_dq_inner can't return None. + dq, dq_pe = _bwd_dq_inner( + dq, # output tensor + dq_pe, # optional output tensor q, + q_pe, K, V, do, @@ -998,6 +1073,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea MASK_BLOCK_N2, HEAD_DIM, ACTUAL_HEAD_DIM, + PE_HEAD_DIM, dropout_p, philox_seed, batch_philox_offset, @@ -1027,9 +1103,11 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea print( f"unMasked: start_m: {start_m}, start_n: {start_n}, end_n: {end_n}, num_steps: {num_steps}" ) # noqa: E701 - dq = _bwd_dq_inner( - dq, + dq, dq_pe = _bwd_dq_inner( + dq, # output tensor + dq_pe, # optional output tensor q, + q_pe, K, V, do, @@ -1051,6 +1129,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea BLOCK_N2, HEAD_DIM, ACTUAL_HEAD_DIM, + PE_HEAD_DIM, dropout_p, philox_seed, batch_philox_offset, @@ -1078,6 +1157,12 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd dq *= sm_scale tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + if HAS_PE: + offs_dq_pe = ( + offs_m[:, None] * stride_dqm + offs_d_pe[None, :] * stride_dqd + ) + dq_pe *= sm_scale + tl.store(DQ + adj_dq + offs_dq_pe, dq_pe, mask=mask_q) # end of GQA/MQA of dq @@ -1156,6 +1241,7 @@ def bwd_kernel_noncausal( BLK_SLICE_FACTOR: tl.constexpr, HEAD_DIM: tl.constexpr, ACTUAL_HEAD_DIM: tl.constexpr, + PE_HEAD_DIM: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, IS_VARLEN: tl.constexpr, USE_ALIBI: tl.constexpr, @@ -1276,12 +1362,20 @@ def bwd_kernel_noncausal( seqlen_k = k_end - k_start PADDED_HEAD: tl.constexpr = ACTUAL_HEAD_DIM != HEAD_DIM + HAS_PE: tl.constexpr = PE_HEAD_DIM > 0 offs_d = tl.arange(0, HEAD_DIM) + if HAS_PE: + offs_d_pe = HEAD_DIM + tl.arange(0, PE_HEAD_DIM) GROUP_SIZE: tl.constexpr = HQ // HK start_n = pid * BLOCK_N1 if start_n < seqlen_k: dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + if HAS_PE: + dk_pe = tl.zeros([BLOCK_N1, PE_HEAD_DIM], dtype=tl.float32) + else: + # Couldn't assign None to dk_pe because _bwd_dkdv_inner can't return None. + dk_pe = dk dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) offs_n = start_n + tl.arange(0, BLOCK_N1) @@ -1299,6 +1393,14 @@ def bwd_kernel_noncausal( + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd ) + if HAS_PE: + adj_k_pe = ( + bid * stride_kb + + hkid * stride_kh + + k_start * stride_kn + + offs_n[:, None] * stride_kn + + offs_d_pe[None, :] * stride_kd + ) adj_v = ( bid * stride_vb + hkid * stride_vh @@ -1308,6 +1410,10 @@ def bwd_kernel_noncausal( ) # load K and V: they stay in SRAM throughout the inner loop. k = tl.load(K + adj_k, mask=mask_kv, other=0.0) + if HAS_PE: + k_pe = tl.load(K + adj_k_pe, mask=mask_kv, other=0.0) + else: + k_pe = None v = tl.load(V + adj_v, mask=mask_kv, other=0.0) # If MQA / GQA, set the K and V head offsets appropriately. for hqid in range(hkid * GROUP_SIZE, hkid * GROUP_SIZE + GROUP_SIZE): @@ -1351,11 +1457,13 @@ def bwd_kernel_noncausal( # because there is no causal, we always start from the beginning start_m = 0 num_steps = tl.cdiv(seqlen_q, BLOCK_M1) - dk, dv = _bwd_dkdv_inner( - dk, - dv, # output tensors + dk, dk_pe, dv = _bwd_dkdv_inner( + dk, # output tensor + dk_pe, # optional output tensor + dv, # output tensor Q_ptr, k, + k_pe, v, DO_ptr, M_ptr, @@ -1372,6 +1480,7 @@ def bwd_kernel_noncausal( BLOCK_N1, # block dim HEAD_DIM, ACTUAL_HEAD_DIM, # head dim + PE_HEAD_DIM, dropout_p, philox_seed, batch_philox_offset, @@ -1405,6 +1514,10 @@ def bwd_kernel_noncausal( offs_dk = offs_n[:, None] * stride_dkn + offs_d[None, :] * stride_dkd dk *= sm_scale tl.store(DK + adj_dk + offs_dk, dk, mask=mask_kv) + if HAS_PE: + offs_dk_pe = offs_n[:, None] * stride_dkn + offs_d_pe[None, :] * stride_dkd + dk_pe *= sm_scale + tl.store(DK + adj_dk + offs_dk_pe, dk_pe, mask=mask_kv) # THIS PART DOES DQ start_m = pid * BLOCK_M2 @@ -1416,6 +1529,8 @@ def bwd_kernel_noncausal( mask_d = offs_d < ACTUAL_HEAD_DIM mask_q &= mask_d[None, :] offs_q = offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd + if HAS_PE: + offs_q_pe = offs_m[:, None] * stride_qm + offs_d_pe[None, :] * stride_qd offs_do = offs_m[:, None] * stride_dom + offs_d[None, :] * stride_dod K += bid * stride_kb + hkid * stride_kh + k_start * stride_kn V += bid * stride_vb + hkid * stride_vh + k_start * stride_vn @@ -1448,6 +1563,10 @@ def bwd_kernel_noncausal( ) q = tl.load(Q + adj_q + offs_q, mask=mask_q, other=0.0) + if HAS_PE: + q_pe = tl.load(Q + adj_q + offs_q_pe, mask=mask_q, other=0.0) + else: + q_pe = None do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0) m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q) m = m[:, None] @@ -1466,9 +1585,15 @@ def bwd_kernel_noncausal( num_steps = tl.cdiv(seqlen_k, BLOCK_N2) dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) - dq = _bwd_dq_inner( - dq, + if HAS_PE: + dq_pe = tl.zeros([BLOCK_M2, PE_HEAD_DIM], dtype=tl.float32) + else: + dq_pe = dq # Couldn't assign None to dq_pe because _bwd_dq_inner can't return None. + dq, dq_pe = _bwd_dq_inner( + dq, # output tensor + dq_pe, # optional output tensor q, + q_pe, K, V, do, @@ -1490,6 +1615,7 @@ def bwd_kernel_noncausal( BLOCK_N2, HEAD_DIM, ACTUAL_HEAD_DIM, + PE_HEAD_DIM, dropout_p, philox_seed, batch_philox_offset, @@ -1517,6 +1643,12 @@ def bwd_kernel_noncausal( offs_dq = offs_m[:, None] * stride_dqm + offs_d[None, :] * stride_dqd dq *= sm_scale tl.store(DQ + adj_dq + offs_dq, dq, mask=mask_q) + if HAS_PE: + offs_dq_pe = ( + offs_m[:, None] * stride_dqm + offs_d_pe[None, :] * stride_dqd + ) + dq_pe *= sm_scale + tl.store(DQ + adj_dq + offs_dq_pe, dq_pe, mask=mask_q) @functools.lru_cache(maxsize=1024) diff --git a/aiter/ops/triton/mha_fused_bwd.py b/aiter/ops/triton/mha_fused_bwd.py index 518b300158..d29483a02d 100644 --- a/aiter/ops/triton/mha_fused_bwd.py +++ b/aiter/ops/triton/mha_fused_bwd.py @@ -14,6 +14,7 @@ _bwd_kernel_dkdvdq_noncausal, _get_config, ) +from aiter.ops.triton.utils.device_info import get_num_xcds _LOGGER = AiterTritonLogger() @@ -53,6 +54,10 @@ def flash_attn_fused_backward( ) if dbias is not None: raise ValueError("Bias is not supported yet in the Triton Backend") + if q.shape[-1] == k.shape[-1] and k.shape[-1] > v.shape[-1]: + raise ValueError( + "'Fused' backward doesn't support Positional Encoding (PE). Please use 'one kernel' backward implementation for PE." + ) IS_FP8 = _is_fp8(q) if IS_FP8: diff --git a/aiter/ops/triton/mha_onekernel_bwd.py b/aiter/ops/triton/mha_onekernel_bwd.py index 59e5e53ec6..0e1c36d6ea 100644 --- a/aiter/ops/triton/mha_onekernel_bwd.py +++ b/aiter/ops/triton/mha_onekernel_bwd.py @@ -90,12 +90,13 @@ def flash_attn_onekernel_backward( # get strides and shape if IS_VARLEN: - # Layout for q,k,v is thd ie [total tokens, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = ( + # Layout is thd. + # q and k are [total_tokens, num_head, head_dim_qk]. + # v is [total_tokens, num_head, head_dim_v]. + batch, seqlen_q, num_q_heads = ( len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], - q.shape[2], ) _, num_k_heads = max_seqlen_k, k.shape[1] q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) @@ -108,8 +109,10 @@ def flash_attn_onekernel_backward( dv_strides = (0, dv.stride(1), dv.stride(0), dv.stride(2)) do_strides = (0, do.stride(1), do.stride(0), do.stride(2)) else: - # Layout for q,k,v is bshd ie [batch, seq_len, num_head, head_dim] - batch, seqlen_q, num_q_heads, head_sz = q.shape + # Layout is bshd. + # q and k are [batch, seq_len, num_head, head_dim_qk]. + # v is [batch, seq_len, num_head, head_dim_v] + batch, seqlen_q, num_q_heads = q.shape[:-1] _, num_k_heads = k.shape[1], k.shape[2] q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) @@ -120,10 +123,21 @@ def flash_attn_onekernel_backward( dv_strides = (dv.stride(0), dv.stride(2), dv.stride(1), dv.stride(3)) do_strides = (do.stride(0), do.stride(2), do.stride(1), do.stride(3)) + qk_head_dim = q.shape[-1] + v_head_dim = v.shape[-1] + pe_head_dim = qk_head_dim - v_head_dim # BLOCK_D_MODEL, BLOCK_D_MODEL_POW2 # padding for head_dim. Power of 2 or 16 - BLOCK_D_MODEL_POW2 = triton.next_power_of_2(head_sz) - BLOCK_D_MODEL_POW2 = max(BLOCK_D_MODEL_POW2, 16) + BLOCK_D_MODEL_POW2 = max(triton.next_power_of_2(v_head_dim), 16) + BLOCK_D_MODEL_PE_POW2 = ( + 0 if pe_head_dim == 0 else max(triton.next_power_of_2(pe_head_dim), 16) + ) + assert (pe_head_dim == 0 and BLOCK_D_MODEL_PE_POW2 == 0) or ( + v_head_dim == BLOCK_D_MODEL_POW2 and pe_head_dim == BLOCK_D_MODEL_PE_POW2 + ), "Positional encoding support requires NOPE and PE head sizes to be unpadded powers of 2." + assert (not IS_FP8) or ( + IS_FP8 and pe_head_dim == 0 + ), "Positional encoding doesn't support FP8." # Configs if config is None: @@ -156,7 +170,7 @@ def flash_attn_onekernel_backward( max_seqlen_q, descale_do, BLOCK_M=config["preprocess_kernel"]["PRE_BLOCK"], - BLOCK_D_MODEL=head_sz, + BLOCK_D_MODEL=v_head_dim, BLOCK_D_MODEL_POW2=BLOCK_D_MODEL_POW2, IS_VARLEN=IS_VARLEN, IS_FP8=IS_FP8, @@ -177,7 +191,13 @@ def flash_attn_onekernel_backward( seqlen = max(max_seqlen_q, max_seqlen_k) - config_onekernel = config["onekernel"] + # "onekernel_pe" is for Positional Encoding (PE) causal case, it's going to be + # used if present. Otherwise, fallback to default "onekernel" config. + config_onekernel = ( + config["onekernel_pe"] + if (pe_head_dim > 0 and causal and "onekernel_pe" in config) + else config["onekernel"] + ) grid = ( num_k_heads, triton.cdiv(seqlen, config_onekernel["BLOCK_N1"]), @@ -223,8 +243,9 @@ def flash_attn_onekernel_backward( descale_k, descale_v, descale_do, - HEAD_DIM=head_sz, + HEAD_DIM=v_head_dim, ACTUAL_HEAD_DIM=BLOCK_D_MODEL_POW2, + PE_HEAD_DIM=pe_head_dim, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, USE_ALIBI=use_alibi, @@ -276,8 +297,9 @@ def flash_attn_onekernel_backward( descale_k, descale_v, descale_do, - HEAD_DIM=head_sz, + HEAD_DIM=v_head_dim, ACTUAL_HEAD_DIM=BLOCK_D_MODEL_POW2, + PE_HEAD_DIM=pe_head_dim, ENABLE_DROPOUT=use_dropout, IS_VARLEN=IS_VARLEN, USE_ALIBI=use_alibi, From 9077d58df916d16e10ce0e2722c0d93d3764a9d1 Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Fri, 10 Oct 2025 15:04:14 -0300 Subject: [PATCH 06/18] Add MHA PE backward unit tests --- op_tests/triton_tests/test_mha.py | 301 ++++++++++++++++++++++++++++++ 1 file changed, 301 insertions(+) diff --git a/op_tests/triton_tests/test_mha.py b/op_tests/triton_tests/test_mha.py index 209aa99d72..4ade0416dd 100644 --- a/op_tests/triton_tests/test_mha.py +++ b/op_tests/triton_tests/test_mha.py @@ -961,3 +961,304 @@ def test_mha_varlen_with_pe( # Assertion torch.testing.assert_close(triton_out, torch_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("BATCH", [1, 4]) +@pytest.mark.parametrize( + "SEQLEN_Q, SEQLEN_K", + [(1, 1), (8, 8), (4, 1), (1, 2), (16, 16), (32, 8), (64, 16), (4096, 4096)], +) +@pytest.mark.parametrize( + "NUM_Q_HEADS, NUM_K_HEADS", [(1, 1), (4, 4), (8, 2), (128, 128)] +) +@pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(32, 16), (128, 64), (192, 128)]) +@pytest.mark.parametrize("DROPOUT", [0.0, 0.13]) +@pytest.mark.parametrize("CAUSAL", [True, False]) +def test_mha_backward_with_pe( + BATCH: int, + SEQLEN_Q: int, + SEQLEN_K: int, + NUM_Q_HEADS: int, + NUM_K_HEADS: int, + HEAD_SZ_QK: int, + HEAD_SZ_V: int, + DROPOUT: float, + CAUSAL: bool, +): + HAS_DROPOUT: bool = DROPOUT > 0.0 + + # Causal + Dropout use case is disabled in `test_mha_backward` and `test_mha_backward_varlen`. + # FIXME: We should fix it in the base implementation before adding PE to the mix. + if CAUSAL and HAS_DROPOUT: + pytest.skip( + "Causal + Dropout use case isn't supported in backward with Positional Encoding." + ) + + device: str = "cuda" + dtype: torch.dtype = torch.bfloat16 + + # Generate tensors + torch.cuda.empty_cache() + torch.manual_seed(63) + q = torch.randn( + (BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ_QK), + device=device, + dtype=dtype, + requires_grad=True, + ) + k = torch.randn( + (BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ_QK), + device=device, + dtype=dtype, + requires_grad=True, + ) + v = torch.randn( + (BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ_V), + device=device, + dtype=dtype, + requires_grad=True, + ) + do = torch.randn((q.shape[:-1] + v.shape[-1:]), dtype=dtype, device=device) + + # Triton forward + with torch.enable_grad(): + triton_out = flash_attn_func( + q, + k, + v, + dropout_p=DROPOUT, + causal=CAUSAL, + return_lse=HAS_DROPOUT, + return_attn_probs=HAS_DROPOUT, + ) + if HAS_DROPOUT: + assert len(triton_out) == 3 + dropout_mask = triton_out[2] > 0 + triton_out = triton_out[0] + else: + dropout_mask = None + + # Torch forward + with torch.enable_grad(): + torch_out, _ = attention_ref( + q, k, v, dropout_p=DROPOUT, dropout_mask=dropout_mask, causal=CAUSAL + ) + + # Forward assertion + torch.testing.assert_close( + triton_out, + torch_out, + atol=1e-2, + rtol=1e-2, + msg=lambda msg: f"fwd mismatch\n\n{msg}\n", + ) + + # Triton backward + # PE support isn't implemented in fused backward. + mha_set_use_fused_bwd_kernel(False) + triton_dq, triton_dk, triton_dv = torch.autograd.grad(triton_out, (q, k, v), do) + + # Torch backward + torch_dq, torch_dk, torch_dv = torch.autograd.grad(torch_out, (q, k, v), do) + + # Backward assertions + # When dropout is active, some cases fail due to less than 1% mismatched elements. + bwd_atol = 1e-1 if HAS_DROPOUT else 1.5e-2 + bwd_rtol = 1e-1 if HAS_DROPOUT else 1.5e-2 + torch.testing.assert_close( + triton_dq, + torch_dq, + atol=bwd_atol, + rtol=bwd_rtol, + msg=lambda msg: f"bwd dq mismatch\n\n{msg}\n", + ) + torch.testing.assert_close( + triton_dk, + torch_dk, + atol=bwd_atol, + rtol=bwd_rtol, + msg=lambda msg: f"bwd dk mismatch\n\n{msg}\n", + ) + torch.testing.assert_close( + triton_dv, + torch_dv, + atol=bwd_atol, + rtol=bwd_rtol, + msg=lambda msg: f"bwd dv mismatch\n\n{msg}\n", + ) + + +@pytest.mark.parametrize("BATCH", [1, 4]) +@pytest.mark.parametrize( + "SEQLEN_Q, SEQLEN_K", + [(1, 1), (8, 8), (4, 1), (1, 2), (16, 16), (32, 8), (64, 16), (4096, 4096)], +) +@pytest.mark.parametrize( + "NUM_Q_HEADS, NUM_K_HEADS", [(1, 1), (4, 4), (8, 2), (128, 128)] +) +@pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(32, 16), (128, 64), (192, 128)]) +@pytest.mark.parametrize("DROPOUT", [0.0, 0.15]) +@pytest.mark.parametrize("CAUSAL", [True, False]) +def test_mha_backward_varlen_with_pe( + BATCH: int, + SEQLEN_Q: int, + SEQLEN_K: int, + NUM_Q_HEADS: int, + NUM_K_HEADS: int, + HEAD_SZ_QK: int, + HEAD_SZ_V: int, + DROPOUT: float, + CAUSAL: bool, +): + HAS_DROPOUT: bool = DROPOUT > 0.0 + + # Causal + Dropout use case is disabled in `test_mha_backward` and `test_mha_backward_varlen`. + # FIXME: We should fix it in the base implementation before adding PE to the mix. + if CAUSAL and HAS_DROPOUT: + pytest.skip( + "Causal + Dropout use case isn't supported in backward with Positional Encoding." + ) + + if (SEQLEN_Q, SEQLEN_K) == (4096, 4096) and HAS_DROPOUT: + pytest.skip("Dropout with large sequence length raises torch.OutOfMemoryError.") + + device: str = "cuda" + dtype: torch.dtype = torch.bfloat16 + + # Generate tensors + torch.cuda.empty_cache() + torch.manual_seed(133) + q = torch.randn( + (BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ_QK), + device=device, + dtype=dtype, + requires_grad=True, + ) + k = torch.randn( + (BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ_QK), + device=device, + dtype=dtype, + requires_grad=True, + ) + v = torch.randn( + (BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ_V), + device=device, + dtype=dtype, + requires_grad=True, + ) + query_padding_mask = generate_random_padding_mask(SEQLEN_Q, BATCH, device) + key_padding_mask = generate_random_padding_mask(SEQLEN_K, BATCH, device) + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask) + q_unpad.requires_grad = True + k_unpad.requires_grad = True + v_unpad.requires_grad = True + do = torch.randn((q.shape[:-1] + v.shape[-1:]), dtype=dtype, device=device) + + # Triton forward + with torch.enable_grad(): + triton_out = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=DROPOUT, + causal=CAUSAL, + return_lse=HAS_DROPOUT, + return_attn_probs=HAS_DROPOUT, + ) + if HAS_DROPOUT: + assert len(triton_out) == 3 + dropout_mask = ( + pad_rearrange_dropout_mask( + triton_out[2] > 0, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + SEQLEN_Q, + SEQLEN_K, + NUM_Q_HEADS, + ) + > 0 + ) + triton_out = triton_out[0] + else: + dropout_mask = None + triton_out = output_pad_fn(triton_out) + + # Torch forward + with torch.enable_grad(): + torch_out, _ = attention_ref( + q, + k, + v, + query_padding_mask=query_padding_mask, + key_padding_mask=key_padding_mask, + dropout_p=DROPOUT, + dropout_mask=dropout_mask, + causal=CAUSAL, + ) + + # Forward assertion + torch.testing.assert_close( + triton_out, + torch_out, + atol=1e-2, + rtol=1e-2, + msg=lambda msg: f"fwd mismatch\n\n{msg}\n", + ) + + # Triton backward + # PE support isn't implemented in fused backward. + mha_set_use_fused_bwd_kernel(False) + triton_dq, triton_dk, triton_dv = torch.autograd.grad( + triton_out, (q_unpad, k_unpad, v_unpad), do + ) + triton_dq = dq_pad_fn(triton_dq) + triton_dk = dk_pad_fn(triton_dk) + triton_dv = dk_pad_fn(triton_dv) + + # Torch backward + torch_dq, torch_dk, torch_dv = torch.autograd.grad(torch_out, (q, k, v), do) + + # Backward assertions + bwd_atol = 1e-1 + bwd_rtol = 1e-1 + torch.testing.assert_close( + triton_dq, + torch_dq, + atol=bwd_atol, + rtol=bwd_rtol, + msg=lambda msg: f"bwd dq mismatch\n\n{msg}\n", + ) + torch.testing.assert_close( + triton_dk, + torch_dk, + atol=bwd_atol, + rtol=bwd_rtol, + msg=lambda msg: f"bwd dk mismatch\n\n{msg}\n", + ) + torch.testing.assert_close( + triton_dv, + torch_dv, + atol=bwd_atol, + rtol=bwd_rtol, + msg=lambda msg: f"bwd dv mismatch\n\n{msg}\n", + ) From dec80ad44008f671bb24ef9ca46ff3c348d61b05 Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Fri, 10 Oct 2025 15:05:23 -0300 Subject: [PATCH 07/18] Add best config for MHA PE backward --- aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json b/aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json index cd5c8f3540..a38732610e 100644 --- a/aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json +++ b/aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json @@ -63,6 +63,18 @@ "num_warps": 4, "num_ctas": 1, "num_stages": 1 + }, + "onekernel_pe" : { + "BLOCK_M1": 32, + "BLOCK_N1": 64, + "BLOCK_M2": 128, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 2 } } } From 378730a88292ce507c7a53821eacfcc0b572cac3 Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Fri, 10 Oct 2025 15:06:06 -0300 Subject: [PATCH 08/18] Add PE support to MHA benchmark script --- op_tests/op_benchmarks/triton/bench_mha.py | 230 ++++++++++++++------- 1 file changed, 155 insertions(+), 75 deletions(-) diff --git a/op_tests/op_benchmarks/triton/bench_mha.py b/op_tests/op_benchmarks/triton/bench_mha.py index e661d0119f..8b7492cddf 100644 --- a/op_tests/op_benchmarks/triton/bench_mha.py +++ b/op_tests/op_benchmarks/triton/bench_mha.py @@ -114,6 +114,7 @@ def create_benchmark_configs(custom, args): hk = args.hq if not args.hk else args.hk sk = args.sq if not args.sk else args.sk head_size = 128 if not args.d else args.d + head_size_v = head_size if not args.dv else args.dv mode = args.mode x_names = ["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K"] causal = args.causal @@ -121,7 +122,13 @@ def create_benchmark_configs(custom, args): configs = [] plot_name = get_caller_name_no_ext() - extra_args = {"D_HEAD": head_size, "dtype": dtype, "causal": causal, "mode": mode} + extra_args = { + "D_HEAD": head_size, + "D_HEAD_V": head_size_v, + "dtype": dtype, + "causal": causal, + "mode": mode, + } if custom: x_vals_list = [(args.b, args.hq, hk, args.sq, sk)] @@ -150,7 +157,7 @@ def create_benchmark_configs(custom, args): if args.fused_bwd: line_vals = [f"fused-bwd({unit})"] else: - line_vals = [f"fused-bwd({unit})", f"bwd({unit})"] + line_vals = [f"onekernel-bwd({unit})"] else: line_vals = [f"fwd({unit})"] @@ -161,7 +168,7 @@ def create_benchmark_configs(custom, args): if args.fused_bwd: line_vals = [f"fused-bwd({unit})"] else: - line_vals = [f"bwd({unit})"] + line_vals = [f"onekernel-bwd({unit})"] configs.append( triton.testing.Benchmark( @@ -190,6 +197,7 @@ def bench_mha( N_CTX_Q, N_CTX_K, D_HEAD, + D_HEAD_V, dtype, causal, mode, @@ -208,11 +216,16 @@ def bench_mha( return_lse = True return_attn_probs = False varlen = args.layout == "thd" + has_pe = D_HEAD > D_HEAD_V + assert not ( + args.fp8 and has_pe + ), "Positional Encoding (PE) doesn't support FP8 data type." + assert not ( + has_pe and "fused-bwd" in provider + ), "'Fused' backward implementation doesn't support Positional Encoding (PE)." global _USE_FUSED_BWD - fused_backward = "fused-bwd" in provider - mha_set_use_fused_bwd_kernel(fused_backward) # Default softmax scale to match standard attention @@ -227,65 +240,117 @@ def bench_mha( f"Testing kernel implementation <{provider}> against Torch with shape:" ) print( - f"BATCH={BATCH}, HQ={HQ}, HK={HK}, N_CTX_Q={N_CTX_Q}, N_CTX_K={N_CTX_K}, D_HEAD={D_HEAD}" + f"BATCH={BATCH}, HQ={HQ}, HK={HK}, N_CTX_Q={N_CTX_Q}, N_CTX_K={N_CTX_K}, D_HEAD={D_HEAD}, D_HEAD_V={D_HEAD_V}" ) - if varlen: - test_mha.test_mha( - BATCH, - N_CTX_Q, - N_CTX_K, - HQ, - HK, - D_HEAD, - dropout, - True, - True, - causal, - args.fp8, - dtype, - ) + if not varlen: + if not has_pe: + test_mha.test_mha( + BATCH, + N_CTX_Q, + N_CTX_K, + HQ, + HK, + D_HEAD, + dropout, + True, + True, + causal, + args.fp8, + dtype, + ) + else: + test_mha.test_mha_with_pe( + BATCH, + N_CTX_Q, + N_CTX_K, + HQ, + HK, + D_HEAD, + D_HEAD_V, + dropout, + causal, + ) print("Forward test passed!") - test_mha.test_mha_backward_varlen( - BATCH, - N_CTX_Q, - N_CTX_K, - HQ, - HK, - D_HEAD, - dropout, - causal, - args.fp8, - dtype, - ) + if not has_pe: + test_mha.test_mha_backward_varlen( + BATCH, + N_CTX_Q, + N_CTX_K, + HQ, + HK, + D_HEAD, + dropout, + causal, + args.fp8, + dtype, + ) + else: + test_mha.test_mha_backward_with_pe( + BATCH, + N_CTX_Q, + N_CTX_K, + HQ, + HK, + D_HEAD, + D_HEAD_V, + dropout, + causal, + ) print("Backward test passed!") else: - test_mha.test_mha_varlen( - BATCH, - N_CTX_Q, - N_CTX_K, - HQ, - HK, - D_HEAD, - dropout, - True, - True, - causal, - args.fp8, - dtype, - ) + if not has_pe: + test_mha.test_mha_varlen( + BATCH, + N_CTX_Q, + N_CTX_K, + HQ, + HK, + D_HEAD, + dropout, + True, + True, + causal, + args.fp8, + dtype, + ) + else: + test_mha.test_mha_varlen_with_pe( + BATCH, + N_CTX_Q, + N_CTX_K, + HQ, + HK, + D_HEAD, + D_HEAD_V, + dropout, + causal, + ) print("Forward test passed!") - test_mha.test_mha_backward( - BATCH, - N_CTX_Q, - N_CTX_K, - HQ, - HK, - D_HEAD, - dropout, - causal, - args.fp8, - dtype, - ) + if not has_pe: + test_mha.test_mha_backward( + BATCH, + N_CTX_Q, + N_CTX_K, + HQ, + HK, + D_HEAD, + dropout, + causal, + args.fp8, + dtype, + ) + else: + test_mha.test_mha_backward_varlen_with_pe( + BATCH, + N_CTX_Q, + N_CTX_K, + HQ, + HK, + D_HEAD, + D_HEAD_V, + dropout, + causal, + ) print("Backward test passed!") return 0 @@ -293,13 +358,13 @@ def bench_mha( # Generate base inputs q = torch.randn((BATCH, N_CTX_Q, HQ, D_HEAD), device=device, dtype=dtype) k = torch.randn((BATCH, N_CTX_K, HK, D_HEAD), device=device, dtype=dtype) - v = torch.randn((BATCH, N_CTX_K, HK, D_HEAD), device=device, dtype=dtype) + v = torch.randn((BATCH, N_CTX_K, HK, D_HEAD_V), device=device, dtype=dtype) q.requires_grad = requires_grad k.requires_grad = requires_grad v.requires_grad = requires_grad # FLOPS calculation variables - flops_per_matmul = 0 + total_flops = 0.0 # Input preparation if varlen: @@ -342,9 +407,9 @@ def bench_mha( if seqlen_q > seqlen_k else (seqlen_q * seqlen_k - ((seqlen_q**2 - seqlen_q) / 2)) ) - flops_per_matmul += valid_out_elements * HQ * D_HEAD * 2 + total_flops += valid_out_elements * HQ * (D_HEAD + D_HEAD_V) * 2.0 else: - flops_per_matmul += seqlen_q * seqlen_k * HQ * D_HEAD * 2 + total_flops += seqlen_q * seqlen_k * HQ * (D_HEAD + D_HEAD_V) * 2.0 else: q_input, k_input, v_input = q, k, v @@ -354,9 +419,13 @@ def bench_mha( if N_CTX_Q > N_CTX_K else (N_CTX_Q * N_CTX_K - ((N_CTX_Q**2 - N_CTX_Q) / 2)) ) - flops_per_matmul = 2.0 * BATCH * HQ * valid_out_elements * D_HEAD + total_flops += ( + 2.0 * BATCH * HQ * valid_out_elements * (D_HEAD + D_HEAD_V) + ) else: - flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD + total_flops += ( + 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * (D_HEAD + D_HEAD_V) + ) # Benchmark mode if varlen: @@ -441,22 +510,25 @@ def fn(): ms = triton.testing.do_bench(fn) - total_flops = 2 * flops_per_matmul if mode == "bwd": total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) - input_bytes = q.element_size() - output_bytes = q.element_size() if varlen: total_num_tokens_q = cu_seqlens_q[-1].item() total_num_tokens_k = cu_seqlens_k[-1].item() else: total_num_tokens_q = BATCH * N_CTX_Q total_num_tokens_k = BATCH * N_CTX_K + # TODO: Is it right for backward? Backward reads do as well, and writes dq, dk and dv... mem = ( - total_num_tokens_q * HQ * D_HEAD * input_bytes - + 2 * total_num_tokens_k * HK * D_HEAD * input_bytes - + total_num_tokens_q * HQ * D_HEAD * output_bytes + # read q + total_num_tokens_q * HQ * D_HEAD * q.element_size() + # read k + + total_num_tokens_k * HK * D_HEAD * k.element_size() + # read v + + total_num_tokens_k * HK * D_HEAD_V * v.element_size() + # write output + + total_num_tokens_q * HQ * D_HEAD_V * q.element_size() ) # return ms if "ms" in provider: @@ -505,7 +577,13 @@ def parse_args(): default=False, help="If specified, uses equal sequence lengths with thd layout, i.e t = b * sq", ) - parser.add_argument("-d", type=int, default=0) + parser.add_argument( + "-d", + type=int, + default=0, + help="Q and K head size, if -dv is absent then -d specifies V head size too", + ) + parser.add_argument("-dv", type=int, default=0, help="optional V head size") parser.add_argument("-causal", type=str2bool, default=None) parser.add_argument("-fp8", action="store_true", default=False) parser.add_argument("-quantize_p", action="store_true", default=False) @@ -573,17 +651,19 @@ def main(): assert ( args.layout == "thd" or not args.equal_seqlens or args.model ), "Equal sequence lengths arg must be used with the thd layout or a model config." - if args.hq or args.hk or args.d: + if args.hq or args.hk or args.d or args.dv: custom_config = True + if not args.dv: + args.dv = args.d assert ( - args.b and args.hq and args.sq and args.d + args.b and args.hq and args.sq and args.d and args.dv ), "If custom config is specified, please provide \ all of batch, number of Q heads, Q sequence length \ and head size." if args.model: assert not ( - args.hq or args.hk or args.d + args.hq or args.hk or args.d or args.dv ), "Specifying model fixes hq, hk and d already. Do not provide them!" assert ( From ffe9f1e8b75fbe4142507ba83057fd2ac568514b Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Fri, 10 Oct 2025 15:08:01 -0300 Subject: [PATCH 09/18] [EXTRA] Omit linter warning in MHA unit tests --- op_tests/triton_tests/test_mha.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/op_tests/triton_tests/test_mha.py b/op_tests/triton_tests/test_mha.py index 4ade0416dd..0154b79209 100644 --- a/op_tests/triton_tests/test_mha.py +++ b/op_tests/triton_tests/test_mha.py @@ -81,7 +81,9 @@ def fp8_assert_close( max_abs_idx = torch.argmax(abs_diff).item() max_rel_idx = torch.argmax(rel_diff).item() - flat_to_idx = lambda flat_idx, shape: np.unravel_index(flat_idx, shape) + flat_to_idx = lambda flat_idx, shape: np.unravel_index( # noqa: E731 + flat_idx, shape + ) max_abs_pos = flat_to_idx(max_abs_idx, tensor_a.shape) max_rel_pos = flat_to_idx(max_rel_idx, tensor_a.shape) From 01d31e33c9a72556cc11d5afc050663fbe5993e8 Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Mon, 13 Oct 2025 15:07:22 +0000 Subject: [PATCH 10/18] [EXTRA] Remove unused imports --- aiter/ops/triton/_triton_kernels/mha.py | 4 +--- aiter/ops/triton/_triton_kernels/mha_onekernel_bwd.py | 3 --- aiter/ops/triton/mha.py | 4 ---- aiter/ops/triton/mha_fused_bwd.py | 1 - 4 files changed, 1 insertion(+), 11 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/mha.py b/aiter/ops/triton/_triton_kernels/mha.py index 9f8e66a8a1..6ce43fde2b 100644 --- a/aiter/ops/triton/_triton_kernels/mha.py +++ b/aiter/ops/triton/_triton_kernels/mha.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -from typing import Optional, Tuple import functools import json import torch @@ -10,9 +9,8 @@ from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH -from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd +from ..utils._triton.pid_preprocessing import remap_xcd from ..utils._triton.mha_kernel_utils import _compute_fp8_scaling_factors -from ..utils.device_info import get_num_xcds @triton.jit diff --git a/aiter/ops/triton/_triton_kernels/mha_onekernel_bwd.py b/aiter/ops/triton/_triton_kernels/mha_onekernel_bwd.py index 52f5bea82c..f6e8870349 100644 --- a/aiter/ops/triton/_triton_kernels/mha_onekernel_bwd.py +++ b/aiter/ops/triton/_triton_kernels/mha_onekernel_bwd.py @@ -1,15 +1,12 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -from typing import Optional, Dict import functools import json -import torch import triton # type: ignore import triton.language as tl # type: ignore from ..utils._triton import arch_info from ..utils.core import AITER_TRITON_CONFIGS_PATH -from ..utils._triton.pid_preprocessing import pid_grid, remap_xcd from ..utils._triton.mha_kernel_utils import _compute_fp8_scaling_factors diff --git a/aiter/ops/triton/mha.py b/aiter/ops/triton/mha.py index 1c84a1d742..43248c0ed2 100644 --- a/aiter/ops/triton/mha.py +++ b/aiter/ops/triton/mha.py @@ -7,12 +7,8 @@ import triton.language as tl import aiter.ops.triton.utils.types as types -import aiter.ops.triton.utils._triton.arch_info as arch_info -from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH -from aiter.ops.triton.utils._triton.pid_preprocessing import pid_grid, remap_xcd from aiter.ops.triton.mha_onekernel_bwd import flash_attn_onekernel_backward from aiter.ops.triton.mha_fused_bwd import flash_attn_fused_backward -from aiter.ops.triton.utils.mha_kernel_utils import _compute_fp8_scaling_factors from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton.utils.device_info import get_num_xcds from aiter.ops.triton._triton_kernels.mha import _attn_fwd, _get_config diff --git a/aiter/ops/triton/mha_fused_bwd.py b/aiter/ops/triton/mha_fused_bwd.py index d29483a02d..275701fdd3 100644 --- a/aiter/ops/triton/mha_fused_bwd.py +++ b/aiter/ops/triton/mha_fused_bwd.py @@ -4,7 +4,6 @@ from typing import Optional, Dict import torch import triton -import triton.language as tl from aiter.ops.triton.utils.types import _is_fp8 from aiter.ops.triton.utils.logger import AiterTritonLogger From c5e0c6ac922ecb2b5d8cd1ea50fd895e6e5470cb Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Fri, 10 Oct 2025 15:08:40 -0300 Subject: [PATCH 11/18] [TEMP] Add dirty tuning scripts TODO: Drop this commit before merging. --- tune_mha_bwd.sh | 143 ++++++++++++++++++++++++++++++++++++++++++++++++ tune_mha_fwd.sh | 139 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 282 insertions(+) create mode 100644 tune_mha_bwd.sh create mode 100644 tune_mha_fwd.sh diff --git a/tune_mha_bwd.sh b/tune_mha_bwd.sh new file mode 100644 index 0000000000..d156ac30eb --- /dev/null +++ b/tune_mha_bwd.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash + +os=$(uname --kernel-name) +script_dir=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) +tune_configs_dir="${script_dir}/tune_mha_bwd_configs" + +# MI300X config: +aiter_config_file="${script_dir}/aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json" +# MI350X config: +# aiter_config_file="${script_dir}/aiter/ops/triton/configs/MI350X-MHA-DEFAULT.json" + + +os_path() { + local path="${1}" + if [ "${os}" != "Linux" ]; then + path=$(cygpath --windows "${path}") + fi + echo "${path}" +} + + +# Stolen from Pure Bash Bible: +# https://github.com/dylanaraps/pure-bash-bible?tab=readme-ov-file#get-the-base-name-of-a-file-path +basename() { + # Usage: basename "path" ["suffix"] + local tmp + + tmp=${1%"${1##*[!/]}"} + tmp=${tmp##*/} + tmp=${tmp%"${2/"$tmp"}"} + + printf '%s\n' "${tmp:-/}" +} + + +gen_tune_configs() { + rm --recursive --force "${tune_configs_dir}" + mkdir --parents "${tune_configs_dir}" + local os_tune_configs_dir + os_tune_configs_dir=$(os_path "${tune_configs_dir}") + python < /dev/null; then + # Test failed. + result='fail,NA' + else + # Test passed, run benchmark. + time=$(python "${script_dir}/op_tests/op_benchmarks/triton/bench_mha.py" \ + --dtype bf16 -mode bwd -b 1 -hq 128 -hk 128 -sq 4096 -sk 4096 -d 192 -dv 128 \ + -causal true --layout bshd -metric time 2> /dev/null \ + | tail -1 | tr --squeeze-repeats ' ' | cut --delimiter=' ' --fields=7) + result="pass,${time}" + fi + + echo "${tune_config_name},${result}" +done diff --git a/tune_mha_fwd.sh b/tune_mha_fwd.sh new file mode 100644 index 0000000000..cadfd1dbe4 --- /dev/null +++ b/tune_mha_fwd.sh @@ -0,0 +1,139 @@ +#!/usr/bin/env bash + +os=$(uname --kernel-name) +script_dir=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) +tune_configs_dir="${script_dir}/tune_mha_fwd_configs" + +# MI300X config: +aiter_config_file="${script_dir}/aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json" +# MI350X config: +# aiter_config_file="${script_dir}/aiter/ops/triton/configs/MI350X-MHA-DEFAULT.json" + + +os_path() { + local path="${1}" + if [ "${os}" != "Linux" ]; then + path=$(cygpath --windows "${path}") + fi + echo "${path}" +} + + +# Stolen from Pure Bash Bible: +# https://github.com/dylanaraps/pure-bash-bible?tab=readme-ov-file#get-the-base-name-of-a-file-path +basename() { + # Usage: basename "path" ["suffix"] + local tmp + + tmp=${1%"${1##*[!/]}"} + tmp=${tmp##*/} + tmp=${tmp%"${2/"$tmp"}"} + + printf '%s\n' "${tmp:-/}" +} + + +gen_tune_configs() { + rm --recursive --force "${tune_configs_dir}" + mkdir --parents "${tune_configs_dir}" + local os_tune_configs_dir + os_tune_configs_dir=$(os_path "${tune_configs_dir}") + python < /dev/null; then + # Test failed. + result='fail,NA' + else + # Test passed, run benchmark. + time=$(python "${script_dir}/op_tests/op_benchmarks/triton/bench_mha.py" \ + --dtype bf16 -mode fwd -b 1 -hq 128 -hk 128 -sq 4096 -sk 4096 -d 192 -dv 128 \ + -causal true --layout bshd -metric time 2> /dev/null \ + | tail -1 | tr --squeeze-repeats ' ' | cut --delimiter=' ' --fields=7) + result="pass,${time}" + fi + + echo "${tune_config_name},${result}" +done From 8d262612c6c6142ffa3b8e6c73280645f378ecf2 Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Mon, 13 Oct 2025 17:38:23 +0000 Subject: [PATCH 12/18] Make all tests pass after rebase on `main` --- aiter/ops/triton/mha_fused_bwd.py | 1 - op_tests/triton_tests/test_mha.py | 11 +++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/aiter/ops/triton/mha_fused_bwd.py b/aiter/ops/triton/mha_fused_bwd.py index 275701fdd3..bc45c3ccbc 100644 --- a/aiter/ops/triton/mha_fused_bwd.py +++ b/aiter/ops/triton/mha_fused_bwd.py @@ -272,7 +272,6 @@ def flash_attn_fused_backward( FP8_MAX=FP8_MAX, NUM_SMS=NUM_SMS, USE_INT64_STRIDES=USE_INT64_STRIDES, - NUM_XCD=get_num_xcds(), **config_dkdvdq, ) diff --git a/op_tests/triton_tests/test_mha.py b/op_tests/triton_tests/test_mha.py index 0154b79209..5d8b47b622 100644 --- a/op_tests/triton_tests/test_mha.py +++ b/op_tests/triton_tests/test_mha.py @@ -846,7 +846,7 @@ def test_mha_with_pe( dropout_mask = None # Torch - torch_out, _ = attention_ref( + torch_out, _, _ = attention_ref( q, k, v, @@ -950,7 +950,7 @@ def test_mha_varlen_with_pe( triton_out = output_pad_fn(triton_out) # Torch - torch_out, _ = attention_ref( + torch_out, _, _ = attention_ref( q, k, v, @@ -996,6 +996,9 @@ def test_mha_backward_with_pe( "Causal + Dropout use case isn't supported in backward with Positional Encoding." ) + if (SEQLEN_Q, SEQLEN_K) == (4096, 4096) and HAS_DROPOUT: + pytest.skip("Dropout with large sequence length raises torch.OutOfMemoryError.") + device: str = "cuda" dtype: torch.dtype = torch.bfloat16 @@ -1042,7 +1045,7 @@ def test_mha_backward_with_pe( # Torch forward with torch.enable_grad(): - torch_out, _ = attention_ref( + torch_out, _, _ = attention_ref( q, k, v, dropout_p=DROPOUT, dropout_mask=dropout_mask, causal=CAUSAL ) @@ -1207,7 +1210,7 @@ def test_mha_backward_varlen_with_pe( # Torch forward with torch.enable_grad(): - torch_out, _ = attention_ref( + torch_out, _, _ = attention_ref( q, k, v, From 64cd5b6d2f4224cfee9b832f7dd9d5f41c1104f1 Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Mon, 13 Oct 2025 15:54:49 -0300 Subject: [PATCH 13/18] Revert "[TEMP] Add dirty tuning scripts" This reverts commit 00357787b5b1549b114d7dc32afd6c2a71b4cd11. --- tune_mha_bwd.sh | 143 ------------------------------------------------ tune_mha_fwd.sh | 139 ---------------------------------------------- 2 files changed, 282 deletions(-) delete mode 100644 tune_mha_bwd.sh delete mode 100644 tune_mha_fwd.sh diff --git a/tune_mha_bwd.sh b/tune_mha_bwd.sh deleted file mode 100644 index d156ac30eb..0000000000 --- a/tune_mha_bwd.sh +++ /dev/null @@ -1,143 +0,0 @@ -#!/usr/bin/env bash - -os=$(uname --kernel-name) -script_dir=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -tune_configs_dir="${script_dir}/tune_mha_bwd_configs" - -# MI300X config: -aiter_config_file="${script_dir}/aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json" -# MI350X config: -# aiter_config_file="${script_dir}/aiter/ops/triton/configs/MI350X-MHA-DEFAULT.json" - - -os_path() { - local path="${1}" - if [ "${os}" != "Linux" ]; then - path=$(cygpath --windows "${path}") - fi - echo "${path}" -} - - -# Stolen from Pure Bash Bible: -# https://github.com/dylanaraps/pure-bash-bible?tab=readme-ov-file#get-the-base-name-of-a-file-path -basename() { - # Usage: basename "path" ["suffix"] - local tmp - - tmp=${1%"${1##*[!/]}"} - tmp=${tmp##*/} - tmp=${tmp%"${2/"$tmp"}"} - - printf '%s\n' "${tmp:-/}" -} - - -gen_tune_configs() { - rm --recursive --force "${tune_configs_dir}" - mkdir --parents "${tune_configs_dir}" - local os_tune_configs_dir - os_tune_configs_dir=$(os_path "${tune_configs_dir}") - python < /dev/null; then - # Test failed. - result='fail,NA' - else - # Test passed, run benchmark. - time=$(python "${script_dir}/op_tests/op_benchmarks/triton/bench_mha.py" \ - --dtype bf16 -mode bwd -b 1 -hq 128 -hk 128 -sq 4096 -sk 4096 -d 192 -dv 128 \ - -causal true --layout bshd -metric time 2> /dev/null \ - | tail -1 | tr --squeeze-repeats ' ' | cut --delimiter=' ' --fields=7) - result="pass,${time}" - fi - - echo "${tune_config_name},${result}" -done diff --git a/tune_mha_fwd.sh b/tune_mha_fwd.sh deleted file mode 100644 index cadfd1dbe4..0000000000 --- a/tune_mha_fwd.sh +++ /dev/null @@ -1,139 +0,0 @@ -#!/usr/bin/env bash - -os=$(uname --kernel-name) -script_dir=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) -tune_configs_dir="${script_dir}/tune_mha_fwd_configs" - -# MI300X config: -aiter_config_file="${script_dir}/aiter/ops/triton/configs/MI300X-MHA-DEFAULT.json" -# MI350X config: -# aiter_config_file="${script_dir}/aiter/ops/triton/configs/MI350X-MHA-DEFAULT.json" - - -os_path() { - local path="${1}" - if [ "${os}" != "Linux" ]; then - path=$(cygpath --windows "${path}") - fi - echo "${path}" -} - - -# Stolen from Pure Bash Bible: -# https://github.com/dylanaraps/pure-bash-bible?tab=readme-ov-file#get-the-base-name-of-a-file-path -basename() { - # Usage: basename "path" ["suffix"] - local tmp - - tmp=${1%"${1##*[!/]}"} - tmp=${tmp##*/} - tmp=${tmp%"${2/"$tmp"}"} - - printf '%s\n' "${tmp:-/}" -} - - -gen_tune_configs() { - rm --recursive --force "${tune_configs_dir}" - mkdir --parents "${tune_configs_dir}" - local os_tune_configs_dir - os_tune_configs_dir=$(os_path "${tune_configs_dir}") - python < /dev/null; then - # Test failed. - result='fail,NA' - else - # Test passed, run benchmark. - time=$(python "${script_dir}/op_tests/op_benchmarks/triton/bench_mha.py" \ - --dtype bf16 -mode fwd -b 1 -hq 128 -hk 128 -sq 4096 -sk 4096 -d 192 -dv 128 \ - -causal true --layout bshd -metric time 2> /dev/null \ - | tail -1 | tr --squeeze-repeats ' ' | cut --delimiter=' ' --fields=7) - result="pass,${time}" - fi - - echo "${tune_config_name},${result}" -done From 7e9b3d330abfa5fd0a165030a631e041101af2c5 Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Tue, 14 Oct 2025 15:20:40 +0000 Subject: [PATCH 14/18] Calculates the amount of bytes differently for forward and backward --- op_tests/op_benchmarks/triton/bench_mha.py | 27 +++++++++++++--------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/op_tests/op_benchmarks/triton/bench_mha.py b/op_tests/op_benchmarks/triton/bench_mha.py index 8b7492cddf..cfebe1b39a 100644 --- a/op_tests/op_benchmarks/triton/bench_mha.py +++ b/op_tests/op_benchmarks/triton/bench_mha.py @@ -519,17 +519,22 @@ def fn(): else: total_num_tokens_q = BATCH * N_CTX_Q total_num_tokens_k = BATCH * N_CTX_K - # TODO: Is it right for backward? Backward reads do as well, and writes dq, dk and dv... - mem = ( - # read q - total_num_tokens_q * HQ * D_HEAD * q.element_size() - # read k - + total_num_tokens_k * HK * D_HEAD * k.element_size() - # read v - + total_num_tokens_k * HK * D_HEAD_V * v.element_size() - # write output - + total_num_tokens_q * HQ * D_HEAD_V * q.element_size() - ) + q_size = total_num_tokens_q * HQ * D_HEAD * q.element_size() + k_size = total_num_tokens_k * HK * D_HEAD * k.element_size() + v_size = total_num_tokens_k * HK * D_HEAD_V * v.element_size() + o_size = total_num_tokens_q * HQ * D_HEAD_V * q.element_size() + if mode == "fwd": + # read q, k, v + mem_read = q_size + k_size + v_size + # write o + mem_write = o_size + else: + # read q, k, v, do + mem_read = q_size + k_size + v_size + o_size + # write dq, dk, dv + mem_write = q_size + k_size + v_size + mem = mem_read + mem_write + # return ms if "ms" in provider: return ms From 307853111f4fdc6675e491f823a015bad541caaf Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Tue, 14 Oct 2025 15:24:43 +0000 Subject: [PATCH 15/18] Remove `tl.constexpr` from `SEQLEN_Q/K` forward kernel arguments --- aiter/ops/triton/_triton_kernels/mha.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aiter/ops/triton/_triton_kernels/mha.py b/aiter/ops/triton/_triton_kernels/mha.py index 6ce43fde2b..28aae55da2 100644 --- a/aiter/ops/triton/_triton_kernels/mha.py +++ b/aiter/ops/triton/_triton_kernels/mha.py @@ -305,8 +305,8 @@ def _attn_fwd( dropout_p, philox_seed, philox_offset_base_in, - SEQLEN_Q: tl.constexpr, - SEQLEN_K: tl.constexpr, + SEQLEN_Q, + SEQLEN_K, IS_CAUSAL: tl.constexpr, NUM_Q_HEADS: tl.constexpr, NUM_K_HEADS: tl.constexpr, From b207ae3c587856a7716ac79b7dba8ae65e21e4d2 Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Thu, 16 Oct 2025 10:43:08 +0000 Subject: [PATCH 16/18] Reduce maximum sequence length on backward PE tests The goal is to avoid out of memory errors while computing the reference gradients with PyTorch autograd. --- op_tests/triton_tests/test_mha.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/op_tests/triton_tests/test_mha.py b/op_tests/triton_tests/test_mha.py index 5d8b47b622..6e6fc78332 100644 --- a/op_tests/triton_tests/test_mha.py +++ b/op_tests/triton_tests/test_mha.py @@ -968,13 +968,13 @@ def test_mha_varlen_with_pe( @pytest.mark.parametrize("BATCH", [1, 4]) @pytest.mark.parametrize( "SEQLEN_Q, SEQLEN_K", - [(1, 1), (8, 8), (4, 1), (1, 2), (16, 16), (32, 8), (64, 16), (4096, 4096)], + [(1, 1), (8, 8), (4, 1), (1, 2), (16, 16), (32, 8), (64, 16), (2048, 2048)], ) @pytest.mark.parametrize( "NUM_Q_HEADS, NUM_K_HEADS", [(1, 1), (4, 4), (8, 2), (128, 128)] ) @pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(32, 16), (128, 64), (192, 128)]) -@pytest.mark.parametrize("DROPOUT", [0.0, 0.13]) +@pytest.mark.parametrize("DROPOUT", [0.0, 0.2]) @pytest.mark.parametrize("CAUSAL", [True, False]) def test_mha_backward_with_pe( BATCH: int, @@ -996,9 +996,6 @@ def test_mha_backward_with_pe( "Causal + Dropout use case isn't supported in backward with Positional Encoding." ) - if (SEQLEN_Q, SEQLEN_K) == (4096, 4096) and HAS_DROPOUT: - pytest.skip("Dropout with large sequence length raises torch.OutOfMemoryError.") - device: str = "cuda" dtype: torch.dtype = torch.bfloat16 @@ -1096,13 +1093,13 @@ def test_mha_backward_with_pe( @pytest.mark.parametrize("BATCH", [1, 4]) @pytest.mark.parametrize( "SEQLEN_Q, SEQLEN_K", - [(1, 1), (8, 8), (4, 1), (1, 2), (16, 16), (32, 8), (64, 16), (4096, 4096)], + [(1, 1), (8, 8), (4, 1), (1, 2), (16, 16), (32, 8), (64, 16), (64, 64)], ) @pytest.mark.parametrize( "NUM_Q_HEADS, NUM_K_HEADS", [(1, 1), (4, 4), (8, 2), (128, 128)] ) @pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(32, 16), (128, 64), (192, 128)]) -@pytest.mark.parametrize("DROPOUT", [0.0, 0.15]) +@pytest.mark.parametrize("DROPOUT", [0.0, 0.2]) @pytest.mark.parametrize("CAUSAL", [True, False]) def test_mha_backward_varlen_with_pe( BATCH: int, @@ -1124,9 +1121,6 @@ def test_mha_backward_varlen_with_pe( "Causal + Dropout use case isn't supported in backward with Positional Encoding." ) - if (SEQLEN_Q, SEQLEN_K) == (4096, 4096) and HAS_DROPOUT: - pytest.skip("Dropout with large sequence length raises torch.OutOfMemoryError.") - device: str = "cuda" dtype: torch.dtype = torch.bfloat16 From 1b7ce77597beefdae266a57e8b7780cdcb291a78 Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Thu, 16 Oct 2025 18:08:33 +0000 Subject: [PATCH 17/18] Fix warning about deprecated logical operator This commit fixes: ``` UserWarning: Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead ``` --- aiter/ops/triton/_triton_kernels/mha.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/ops/triton/_triton_kernels/mha.py b/aiter/ops/triton/_triton_kernels/mha.py index 28aae55da2..45e50ce0e3 100644 --- a/aiter/ops/triton/_triton_kernels/mha.py +++ b/aiter/ops/triton/_triton_kernels/mha.py @@ -183,7 +183,7 @@ def _attn_fwd_inner( if IS_CAUSAL: causal_boundary = start_n + offs_n_causal causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] - mask = mask and causal_mask + mask = mask & causal_mask qk = tl.where(mask, qk, float("-inf")) From 5e9afa9ddaa4982dff5524525729b5d5cc6ff5c4 Mon Sep 17 00:00:00 2001 From: Bruno Mazzotti Date: Fri, 17 Oct 2025 14:28:03 +0000 Subject: [PATCH 18/18] Reduce the number of PE tests to speedup CI Batch: 2 options Sequence length: 4 options Number of heads: 3 options Head size: 2 options Dropout probability: 2 options Causal or non-causal: 2 options Forward or backward: 2 options BSHD or THD layout: 2 options Total of 768 tests --- op_tests/triton_tests/test_mha.py | 32 ++++++++++++------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/op_tests/triton_tests/test_mha.py b/op_tests/triton_tests/test_mha.py index 6e6fc78332..da10db4313 100644 --- a/op_tests/triton_tests/test_mha.py +++ b/op_tests/triton_tests/test_mha.py @@ -792,12 +792,10 @@ def test_mha_backward_varlen( @pytest.mark.parametrize("BATCH", [1, 3]) @pytest.mark.parametrize( "SEQLEN_Q, SEQLEN_K", - [(1, 1), (4, 4), (128, 128), (2, 1), (1, 2), (32, 16), (16, 48), (4096, 4096)], + [(128, 128), (32, 16), (16, 48), (4096, 4096)], ) -@pytest.mark.parametrize( - "NUM_Q_HEADS, NUM_K_HEADS", [(1, 1), (4, 4), (2, 1), (128, 128)] -) -@pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(48, 32), (128, 64), (192, 128)]) +@pytest.mark.parametrize("NUM_Q_HEADS, NUM_K_HEADS", [(1, 1), (2, 1), (128, 128)]) +@pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(128, 64), (192, 128)]) @pytest.mark.parametrize("DROPOUT", [0.0, 0.25]) @pytest.mark.parametrize("CAUSAL", [True, False]) def test_mha_with_pe( @@ -862,12 +860,10 @@ def test_mha_with_pe( @pytest.mark.parametrize("BATCH", [1, 3]) @pytest.mark.parametrize( "SEQLEN_Q, SEQLEN_K", - [(1, 1), (2, 2), (4, 1), (1, 4), (16, 16), (32, 16), (64, 128), (4096, 4096)], -) -@pytest.mark.parametrize( - "NUM_Q_HEADS, NUM_K_HEADS", [(1, 1), (4, 4), (16, 4), (128, 128)] + [(16, 16), (32, 16), (64, 128), (4096, 4096)], ) -@pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(32, 16), (96, 64), (192, 128)]) +@pytest.mark.parametrize("NUM_Q_HEADS, NUM_K_HEADS", [(4, 4), (16, 4), (128, 128)]) +@pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(96, 64), (192, 128)]) @pytest.mark.parametrize("DROPOUT", [0.0, 0.17]) @pytest.mark.parametrize("CAUSAL", [True, False]) def test_mha_varlen_with_pe( @@ -968,12 +964,10 @@ def test_mha_varlen_with_pe( @pytest.mark.parametrize("BATCH", [1, 4]) @pytest.mark.parametrize( "SEQLEN_Q, SEQLEN_K", - [(1, 1), (8, 8), (4, 1), (1, 2), (16, 16), (32, 8), (64, 16), (2048, 2048)], + [(16, 16), (32, 8), (64, 16), (2048, 2048)], ) -@pytest.mark.parametrize( - "NUM_Q_HEADS, NUM_K_HEADS", [(1, 1), (4, 4), (8, 2), (128, 128)] -) -@pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(32, 16), (128, 64), (192, 128)]) +@pytest.mark.parametrize("NUM_Q_HEADS, NUM_K_HEADS", [(4, 4), (8, 2), (128, 128)]) +@pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(32, 16), (192, 128)]) @pytest.mark.parametrize("DROPOUT", [0.0, 0.2]) @pytest.mark.parametrize("CAUSAL", [True, False]) def test_mha_backward_with_pe( @@ -1093,12 +1087,10 @@ def test_mha_backward_with_pe( @pytest.mark.parametrize("BATCH", [1, 4]) @pytest.mark.parametrize( "SEQLEN_Q, SEQLEN_K", - [(1, 1), (8, 8), (4, 1), (1, 2), (16, 16), (32, 8), (64, 16), (64, 64)], -) -@pytest.mark.parametrize( - "NUM_Q_HEADS, NUM_K_HEADS", [(1, 1), (4, 4), (8, 2), (128, 128)] + [(8, 8), (32, 8), (16, 64), (64, 64)], ) -@pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(32, 16), (128, 64), (192, 128)]) +@pytest.mark.parametrize("NUM_Q_HEADS, NUM_K_HEADS", [(4, 4), (8, 2), (128, 128)]) +@pytest.mark.parametrize("HEAD_SZ_QK, HEAD_SZ_V", [(32, 16), (192, 128)]) @pytest.mark.parametrize("DROPOUT", [0.0, 0.2]) @pytest.mark.parametrize("CAUSAL", [True, False]) def test_mha_backward_varlen_with_pe(