Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -2394,8 +2394,6 @@ def correction_epilogue(
tOcO = gmem_thr_copy_O.partition_S(cO)
t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
tOpO = utils.predicate_k(tOcO, limit=mO_cur.shape[1])
# TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it
assert not self.pack_gqa
pack_gqa = PackGQA(
self.m_block_size,
self.head_dim_v_padded,
Expand Down Expand Up @@ -2488,8 +2486,6 @@ def epilogue_s2g(
tOcO = gmem_thr_copy_O.partition_S(cO)
t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO)
tOpO = utils.predicate_k(tOcO, limit=mO.shape[1])
# TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it
assert not self.pack_gqa
pack_gqa = PackGQA(
self.m_block_size,
self.head_dim_v_padded,
Expand Down
2 changes: 0 additions & 2 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,9 @@ def _flash_attn_fwd(
n_block_size = 192

if compute_capability == 10:
# TODO: fix the varlen case
if (
pack_gqa
and (128 % qhead_per_kvhead != 0)
or (cu_seqlens_q is not None or seqused_q is not None)
):
pack_gqa = False
# TODO: fix GQA + SplitKV + non-varlen
Expand Down
3 changes: 2 additions & 1 deletion flash_attn/cute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,8 @@ def shuffle_sync(
mask = cute.arch.WARP_SIZE - width
clamp = cute.arch.WARP_SIZE - 1
mask_and_clamp = mask << 8 | clamp
val = cute.make_fragment(1, type(value))
# important: need stride 1 and not 0 for recast_tensor to work
val = cute.make_rmem_tensor(cute.make_layout((1, ), stride=(1, )), type(value))
val[0] = value
val_i32 = cute.recast_tensor(val, cutlass.Int32)
for i in cutlass.range_constexpr(cute.size(val_i32)):
Expand Down
8 changes: 4 additions & 4 deletions tests/cute/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ def test_flash_attn_output(
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
# num_splits_vals = [1, 3]
# pack_gqa_vals = [False, True, None]
pack_gqa_vals = [False, True, None]
# SplitKV is not supported for hdim >= 192
pack_gqa_vals = [False]
# pack_gqa_vals = [False]
num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1]
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
out, lse = flash_attn_func(
Expand Down Expand Up @@ -600,8 +600,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
rtol = 2 if softcap == 0.0 else 3

# pack_gqa_vals = [False, True, None]
pack_gqa_vals = [False]
pack_gqa_vals = [False, True, None]
# pack_gqa_vals = [False]
# num_splits_vals = [1, 3]
# SplitKV is not supported for hdim >= 192
num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1]
Expand Down
Loading