Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,10 +511,6 @@ def _flash_attn_fwd(
if intra_wg_overlap is None:
intra_wg_overlap = fwd_cfg.intra_wg_overlap

# TODO: fix GQA + SplitKV + non-varlen
if pack_gqa and num_splits != 1 and cu_seqlens_q is None:
pack_gqa = False

if pack_gqa and qv is not None and 128 % qhead_per_kvhead != 0:
pack_gqa = False

Expand Down
36 changes: 36 additions & 0 deletions tests/cute/test_mask_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -2258,6 +2258,42 @@ def doc_mask(b, h, q_idx, kv_idx):
_run_write_order_test(doc_mask, seqlen_q, seqlen_k, block_size=128, B=B, H=H, spt=spt)


@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="SM100/SM110 SplitKV forward only")
def test_pack_gqa_splitkv_causal_fixed_len_matches_reference():
torch.manual_seed(456)
batch_size = 1
nheads = 8
nheads_kv = 2
seqlen_q = 257
seqlen_k = 513
headdim = 64
dtype = torch.bfloat16

mask_mod_cute, mask_mod_flex = get_mask_pair("causal", seqlen_q=seqlen_q, seqlen_k=seqlen_k)
tensors = create_tensors(
batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim, dtype
)
out_split, _ = _flash_attn_fwd(
q=tensors["q"],
k=tensors["k"],
v=tensors["v"],
out=tensors["out"].clone(),
lse=tensors["lse"].clone(),
softmax_scale=1.0 / math.sqrt(headdim),
causal=True,
pack_gqa=True,
num_splits=16,
return_lse=True,
)

out_ref = compute_reference_flex_attn(tensors, mask_mod_flex)
out_ref_fp32 = compute_reference_flex_attn(
{name: tensor.float() for name, tensor in tensors.items()},
mask_mod_flex,
)
assert_fwd_matches_reference(out_split, out_ref_fp32, out_ref)


@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="SM100/SM110 SplitKV block sparse forward only")
def test_block_sparse_splitkv_matches_unsplit():
torch.manual_seed(123)
Expand Down