diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 701dda997d3..3426d8a31e7 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -2394,8 +2394,6 @@ def correction_epilogue( tOcO = gmem_thr_copy_O.partition_S(cO) t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) tOpO = utils.predicate_k(tOcO, limit=mO_cur.shape[1]) - # TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it - assert not self.pack_gqa pack_gqa = PackGQA( self.m_block_size, self.head_dim_v_padded, @@ -2488,8 +2486,6 @@ def epilogue_s2g( tOcO = gmem_thr_copy_O.partition_S(cO) t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) - # TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it - assert not self.pack_gqa pack_gqa = PackGQA( self.m_block_size, self.head_dim_v_padded, diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 103eb55f5a0..f5c64f597a7 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -276,11 +276,9 @@ def _flash_attn_fwd( n_block_size = 192 if compute_capability == 10: - # TODO: fix the varlen case if ( pack_gqa and (128 % qhead_per_kvhead != 0) - or (cu_seqlens_q is not None or seqused_q is not None) ): pack_gqa = False # TODO: fix GQA + SplitKV + non-varlen diff --git a/flash_attn/cute/utils.py b/flash_attn/cute/utils.py index be703e56caf..70346e9c884 100644 --- a/flash_attn/cute/utils.py +++ b/flash_attn/cute/utils.py @@ -527,7 +527,8 @@ def shuffle_sync( mask = cute.arch.WARP_SIZE - width clamp = cute.arch.WARP_SIZE - 1 mask_and_clamp = mask << 8 | clamp - val = cute.make_fragment(1, type(value)) + # important: need stride 1 and not 0 for recast_tensor to work + val = cute.make_rmem_tensor(cute.make_layout((1, ), stride=(1, )), type(value)) val[0] = value val_i32 = cute.recast_tensor(val, cutlass.Int32) for i in cutlass.range_constexpr(cute.size(val_i32)): diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 83d2b9d3bf5..cd864ff26cc 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -233,9 +233,9 @@ def test_flash_attn_output( print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") # num_splits_vals = [1, 3] - # pack_gqa_vals = [False, True, None] + pack_gqa_vals = [False, True, None] # SplitKV is not supported for hdim >= 192 - pack_gqa_vals = [False] + # pack_gqa_vals = [False] num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): out, lse = flash_attn_func( @@ -600,8 +600,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() rtol = 2 if softcap == 0.0 else 3 - # pack_gqa_vals = [False, True, None] - pack_gqa_vals = [False] + pack_gqa_vals = [False, True, None] + # pack_gqa_vals = [False] # num_splits_vals = [1, 3] # SplitKV is not supported for hdim >= 192 num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1]