diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index f117498fd2c..96a5dc2da84 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -5,7 +5,7 @@ These utilities are used by CUTE DSL kernels to produce and consume block-sparse loads. """ -from typing import Callable +from typing import Callable, Optional from functools import partial import math import cutlass @@ -606,6 +606,9 @@ def handle_block_sparse_empty_tile_correction_sm100( o_corr_consumer_phase: Int32, corr_epi_producer_phase: Int32, softmax_scale_log2: Float32, + mO_cur: Optional[cute.Tensor] = None, + gO: Optional[cute.Tensor] = None, + gmem_tiled_copy_O: Optional[cute.TiledCopy] = None, ): """Handle the block-sparse case where a tile is fully masked: * zero staged results @@ -650,18 +653,26 @@ def handle_block_sparse_empty_tile_correction_sm100( ) cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage) - cute.arch.mbarrier_wait( - mbar_ptr + mbar_corr_epi_empty_offset + stage, - corr_epi_producer_phase, - ) + if const_expr(gmem_tiled_copy_O is None): + cute.arch.mbarrier_wait( + mbar_ptr + mbar_corr_epi_empty_offset + stage, + corr_epi_producer_phase, + ) correction_epilogue( thr_mma_pv, tOtOs[stage], tidx, + stage, + m_block, + seqlen.seqlen_q, Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs sO[None, None, stage], + mO_cur, + gO, + gmem_tiled_copy_O, ) - cute.arch.mbarrier_arrive(mbar_ptr + mbar_corr_epi_full_offset + stage) + if const_expr(gmem_tiled_copy_O is None): + cute.arch.mbarrier_arrive(mbar_ptr + mbar_corr_epi_full_offset + stage) cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage) cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage) diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 521e1325a8f..05520fca25d 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -56,8 +56,8 @@ ) -# class NamedBarrierFwd(enum.IntEnum): -# Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() +class NamedBarrierFwd(enum.IntEnum): + Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() # WarpSchedulerWG1 = enum.auto() # WarpSchedulerWG2 = enum.auto() # WarpSchedulerWG3 = enum.auto() @@ -85,6 +85,7 @@ def __init__( mask_mod: cutlass.Constexpr | None = None, has_aux_tensors: cutlass.Constexpr = False, paged_kv_non_tma: bool = False, + is_varlen_q: bool = False, ): self.use_tma_KV = not paged_kv_non_tma # self.dtype = dtype @@ -112,6 +113,8 @@ def __init__( self.is_persistent = is_persistent self.is_causal = is_causal self.is_local = is_local + self.is_varlen_q = is_varlen_q + self.use_correction_warps_for_epi = is_varlen_q self.qhead_per_kvhead = qhead_per_kvhead self.is_split_kv = is_split_kv self.pack_gqa = pack_gqa @@ -146,8 +149,8 @@ def __init__( self.softmax1_warp_ids = (4, 5, 6, 7) self.correction_warp_ids = (8, 9, 10, 11) self.mma_warp_id = 12 - self.load_warp_ids = (13,) - self.epilogue_warp_ids = (14,) + self.epilogue_warp_ids = (13,) + self.load_warp_ids = (14,) self.empty_warp_ids = (15,) SM100_TMEM_CAPACITY_COLUMNS = 512 self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS @@ -164,6 +167,15 @@ def __init__( ) ) + if not self.use_tma_KV: + self.load_warp_ids = (14, 15) + self.empty_warp_ids = () + if self.use_correction_warps_for_epi: + self.empty_warp_ids = self.empty_warp_ids + self.epilogue_warp_ids + self.epilogue_warp_ids = self.correction_warp_ids + elif self.is_varlen_q: # fallback + self.epilogue_warp_ids = (13, 14) + self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 self.tmem_o_offset = [ self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded @@ -506,19 +518,11 @@ def __call__( self.cluster_layout_vmnk.shape, ) else: - assert self.use_tma_O, "Loading O and K/V will contend for the empty warp." - self.epilogue_warp_ids = (13,) - self.load_warp_ids = (14, 15) - self.empty_warp_ids = () tma_atom_K = None tma_atom_V = None o_cta_v_layout = cute.composition(cute.make_identity_layout(mO.shape), self.epi_tile) - # print(sO_layout.outer) - if const_expr(not self.use_tma_O): - self.epilogue_warp_ids = (14, 15) - self.empty_warp_ids = () self.num_epilogue_threads = cute.arch.WARP_SIZE * len(self.epilogue_warp_ids) if const_expr(self.use_tma_O): tma_atom_O, mO = cpasync.make_tiled_tma_atom( @@ -546,7 +550,6 @@ def __call__( assert self.m_block_size % tO_layout.shape[0] == 0 vO_layout = cute.make_layout((1, async_copy_elems)) gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) - print("gmem_tiled_copy_O: ", gmem_tiled_copy_O) if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): TileScheduler = SingleTileVarlenScheduler @@ -799,7 +802,7 @@ def kernel( cute.arch.mbarrier_init( mbar_ptr + self.mbar_s0_s1_sequence_offset + i, cute.arch.WARP_SIZE ) - if warp_idx == 4: + if const_expr(not self.use_correction_warps_for_epi) and warp_idx == 4: for i in cutlass.range_constexpr(self.q_stage): cute.arch.mbarrier_init( mbar_ptr + self.mbar_corr_epi_full_offset + i, @@ -931,6 +934,12 @@ def kernel( if warp_idx == self.empty_warp_ids[0]: cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + if const_expr(len(self.empty_warp_ids) > 1): + if warp_idx == self.empty_warp_ids[1]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + assert len(self.empty_warp_ids) <= 2 + # /////////////////////////////////////////////////////////////////////////////// # LOAD # /////////////////////////////////////////////////////////////////////////////// @@ -1004,19 +1013,20 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// # Epilogue # /////////////////////////////////////////////////////////////////////////////// - if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: - cute.arch.warpgroup_reg_dealloc(self.num_regs_other) - self.epilogue_s2g( - mO, - sO, - gmem_tiled_copy_O, - tma_atom_O, - mbar_ptr, - block_info, - num_splits, - SeqlenInfoCls, - TileSchedulerCls, - ) + if const_expr(not self.use_correction_warps_for_epi): + if warp_idx >= self.epilogue_warp_ids[0] and warp_idx <= self.epilogue_warp_ids[-1]: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + self.epilogue_s2g( + mO, + sO, + gmem_tiled_copy_O, + tma_atom_O, + mbar_ptr, + block_info, + num_splits, + SeqlenInfoCls, + TileSchedulerCls, + ) # /////////////////////////////////////////////////////////////////////////////// # Softmax @@ -1080,6 +1090,7 @@ def kernel( mLSE, sO, learnable_sink, + gmem_tiled_copy_O, tma_atom_O, mbar_ptr, softmax_scale_log2, @@ -1931,6 +1942,7 @@ def correction_loop( mLSE: cute.Tensor, sO: cute.Tensor, learnable_sink: Optional[cute.Tensor], + gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: cute.CopyAtom, mbar_ptr: cute.Pointer, softmax_scale_log2: Float32, @@ -1972,6 +1984,12 @@ def correction_loop( seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + if const_expr(self.is_split_kv): + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] + else: + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx] + gO = cute.local_tile(mO_cur, (self.m_block_size, self.head_dim_v_padded), (None, 0)) + # Default LSE to -inf for invalid split_idx tiles stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage @@ -2070,17 +2088,25 @@ def correction_loop( cute.arch.mbarrier_wait( mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase ) - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase - ) + if const_expr(not self.use_correction_warps_for_epi): + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_corr_epi_empty_offset + stage, corr_epi_producer_phase + ) self.correction_epilogue( thr_mma_pv, tOtOs[stage], tidx, + stage, + m_block, + seqlen.seqlen_q, scale, sO[None, None, stage], + mO_cur, + gO, + gmem_tiled_copy_O, ) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage) + if const_expr(not self.use_correction_warps_for_epi): + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_full_offset + stage) # Signal for the next work tile that O buffers in tmem are already read, so # mma warp can write to them cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) @@ -2090,6 +2116,11 @@ def correction_loop( softmax_corr_consumer_phase ^= 1 corr_epi_producer_phase ^= 1 else: + # WARNING: we need some code before the const_expr, see https://github.com/NVIDIA/cutlass/issues/2781 + if const_expr(self.use_correction_warps_for_epi): + gmem_tiled_copy_O_for_empty_tile = gmem_tiled_copy_O + else: + gmem_tiled_copy_O_for_empty_tile = None if const_expr(self.use_block_sparsity): ( softmax_corr_consumer_phase, @@ -2126,6 +2157,9 @@ def correction_loop( o_corr_consumer_phase, corr_epi_producer_phase, softmax_scale_log2, + mO_cur, + gO, + gmem_tiled_copy_O_for_empty_tile, ) if const_expr(mLSE is not None): @@ -2228,8 +2262,14 @@ def correction_epilogue( thr_mma: cute.core.ThrMma, tOtO: cute.Tensor, tidx: Int32, + stage: Int32, + m_block: Int32, + seqlen_q: Int32, scale: Float32, sO: cute.Tensor, + mO_cur: Optional[cute.Tensor] = None, + gO: Optional[cute.Tensor] = None, + gmem_tiled_copy_O: Optional[cute.TiledCopy] = None, ): """Apply final scaling and transformation to attention output before writing to global memory. @@ -2302,6 +2342,57 @@ def correction_epilogue( space=cute.arch.SharedSpace.shared_cta, ) + if const_expr(self.use_correction_warps_for_epi): + assert(not self.use_tma_O) + assert(gmem_tiled_copy_O is not None) + cute.arch.barrier(barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=len(self.epilogue_warp_ids) * cute.arch.WARP_SIZE) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_v_padded)) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO_cur.shape[1]) + # TODO: the packgqa case isn't correct rn (sometimes IMA), disabling it + assert not self.pack_gqa + pack_gqa = PackGQA( + self.m_block_size, + self.head_dim_v_padded, + self.check_hdim_v_oob, + self.qhead_per_kvhead, + ) + + # load acc O from smem to rmem for wider vectorization + tOrO = cute.make_fragment_like(tOsO, self.o_dtype) + cute.autovec_copy(tOsO, tOrO) + # copy acc O from rmem to gmem + if const_expr(not self.pack_gqa): + for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + if ( + t0OcO[0, rest_m, 0][0] + < seqlen_q + - (self.q_stage * m_block + stage) * self.m_block_size + - tOcO[0][0] + ): + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None, self.q_stage * m_block + stage], + pred=tOpO[None, rest_m, None] + if const_expr(self.check_hdim_v_oob) + else None, + ) + else: + pack_gqa.store_O( + mO_cur, + tOrO, + gmem_tiled_copy_O, + tidx, + self.q_stage * m_block + stage, + seqlen_q, + ) + @cute.jit def epilogue_s2g( self, @@ -2389,7 +2480,7 @@ def epilogue_s2g( tOrO[None, rest_m, None], tOgO[None, rest_m, None, self.q_stage * m_block + stage], pred=tOpO[None, rest_m, None] - if self.check_hdim_v_oob + if const_expr(self.check_hdim_v_oob) else None, ) else: diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index db7930de537..28bcb994ee7 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -464,14 +464,16 @@ def _flash_attn_fwd( m_block_size=m_block_size, n_block_size=n_block_size, is_persistent=not causal - and not local - and cu_seqlens_q is None - and seqused_q is None - and not is_split_kv, + and not local + and cu_seqlens_q is None + and seqused_q is None + and not is_split_kv, score_mod=score_mod, mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, paged_kv_non_tma=page_size not in [None, 128], + is_varlen_q=cu_seqlens_q is not None + or seqused_q is not None, ) else: raise ValueError( diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 14034fa9fd2..4b3398dd479 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -100,8 +100,8 @@ def test_flash_attn_output( mha_type, dtype, ): - if (causal or local) and seqlen_k < seqlen_q: - pytest.skip("Causal attention requires seqlen_k >= seqlen_q") + # if (causal or local) and seqlen_k < seqlen_q: + # pytest.skip("Causal attention requires seqlen_k >= seqlen_q") device = "cuda" # set seed torch.random.manual_seed(0) @@ -228,7 +228,7 @@ def test_flash_attn_output( # pack_gqa_vals = [False, True, None] # SplitKV is not supported for hdim >= 192 pack_gqa_vals = [False] - num_splits_vals = [1] # [1, 3] if d < 192 and not DISABLE_SPLIT else [1] + num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): out, lse = flash_attn_func( q, @@ -267,6 +267,7 @@ def test_flash_attn_output( and learnable_sink is None # and mha_type == "mha" # and False + and not ((causal or local) and seqlen_k < seqlen_q) ): g = torch.randn_like(out) # do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) @@ -388,7 +389,7 @@ def test_flash_attn_varlen_output( ): if ( causal or local - ): # Right now we only support causal attention with seqlen_k == seqlen_q + ): # Right now reference only supports causal attention with seqlen_k == seqlen_q seqlen_k = seqlen_q device = "cuda" # set seed @@ -572,7 +573,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() rtol = 2 if softcap == 0.0 else 3 - pack_gqa_vals = [False, True, None] + # pack_gqa_vals = [False, True, None] + pack_gqa_vals = [False] # num_splits_vals = [1, 3] # SplitKV is not supported for hdim >= 192 num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT else [1] @@ -721,8 +723,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("local", [False, True]) # @pytest.mark.parametrize("local", [False]) -# @pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False]) # @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) @@ -738,14 +740,14 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("has_leftpad", [False]) # @pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("has_batch_idx", [False]) -# @pytest.mark.parametrize("varlen_q", [False, True]) -@pytest.mark.parametrize("varlen_q", [False]) +@pytest.mark.parametrize("varlen_q", [False, True]) +# @pytest.mark.parametrize("varlen_q", [False]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize("d", [128]) -@pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("d", [64, 128]) # @pytest.mark.parametrize("d", [192]) @pytest.mark.parametrize( "seqlen_q,seqlen_k",