From 6433a8054c40c2512bef2d760cb833dd1c0b3f84 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 5 May 2026 00:00:19 +0000 Subject: [PATCH] Remove stale guard stack-info: PR: https://github.com/Dao-AILab/flash-attention/pull/2537, branch: drisspg/stack/39 --- flash_attn/cute/interface.py | 4 ---- tests/cute/test_mask_mod.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 7a332b9f73b..bc074ce0b39 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -511,10 +511,6 @@ def _flash_attn_fwd( if intra_wg_overlap is None: intra_wg_overlap = fwd_cfg.intra_wg_overlap - # TODO: fix GQA + SplitKV + non-varlen - if pack_gqa and num_splits != 1 and cu_seqlens_q is None: - pack_gqa = False - if pack_gqa and qv is not None and 128 % qhead_per_kvhead != 0: pack_gqa = False diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 52f78e8d26d..d38843b7d8f 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -2258,6 +2258,42 @@ def doc_mask(b, h, q_idx, kv_idx): _run_write_order_test(doc_mask, seqlen_q, seqlen_k, block_size=128, B=B, H=H, spt=spt) +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="SM100/SM110 SplitKV forward only") +def test_pack_gqa_splitkv_causal_fixed_len_matches_reference(): + torch.manual_seed(456) + batch_size = 1 + nheads = 8 + nheads_kv = 2 + seqlen_q = 257 + seqlen_k = 513 + headdim = 64 + dtype = torch.bfloat16 + + mask_mod_cute, mask_mod_flex = get_mask_pair("causal", seqlen_q=seqlen_q, seqlen_k=seqlen_k) + tensors = create_tensors( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim, dtype + ) + out_split, _ = _flash_attn_fwd( + q=tensors["q"], + k=tensors["k"], + v=tensors["v"], + out=tensors["out"].clone(), + lse=tensors["lse"].clone(), + softmax_scale=1.0 / math.sqrt(headdim), + causal=True, + pack_gqa=True, + num_splits=16, + return_lse=True, + ) + + out_ref = compute_reference_flex_attn(tensors, mask_mod_flex) + out_ref_fp32 = compute_reference_flex_attn( + {name: tensor.float() for name, tensor in tensors.items()}, + mask_mod_flex, + ) + assert_fwd_matches_reference(out_split, out_ref_fp32, out_ref) + + @pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="SM100/SM110 SplitKV block sparse forward only") def test_block_sparse_splitkv_matches_unsplit(): torch.manual_seed(123)