From fd00014cfacf18f834567c34c81e30f0ab667347 Mon Sep 17 00:00:00 2001 From: imbr92 <40306754+imbr92@users.noreply.github.com> Date: Mon, 30 Mar 2026 16:53:06 +0000 Subject: [PATCH 1/7] Initial split kv impl --- flash_attn/cute/flash_fwd.py | 162 ++++++++++++++++++------------ flash_attn/cute/flash_fwd_sm90.py | 57 ++++++++--- flash_attn/cute/interface.py | 2 +- 3 files changed, 137 insertions(+), 84 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 4d47fab109f..35696486862 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -167,6 +167,7 @@ def can_implement( return False return True + # TODO: fix to allow split kv --> mO type fp32 def _check_type( self, mQ_type: Type[cutlass.Numeric], @@ -336,30 +337,34 @@ def epilogue( m_block: Int32, head_idx: Int32, batch_idx: Int32, + split_idx: Int32 = Int32(0), ): - # store acc_O - rO = cute.make_fragment_like(acc_O, self.dtype) - rO.store(acc_O.load().to(self.dtype)) - # Make sure all threads have finished reading V - cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads - ) - smem_copy_atom_O = utils.get_smem_store_atom(self.arch.major * 10 + self.arch.minor, self.dtype) - smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx) - taccOrO = smem_thr_copy_O.retile(rO) - taccOsO = smem_thr_copy_O.partition_D(sO) - # taccOsO = copy_utils.partition_D_position_independent(smem_thr_copy_O, sO) - # copy acc O from rmem to smem with the smem copy atom - cute.copy(smem_copy_atom_O, taccOrO, taccOsO) - cO = cute.make_identity_tensor((self.tile_m, self.tile_hdimv)) pack_gqa = PackGQA( self.tile_m, self.tile_hdimv, self.check_hdim_v_oob, self.qhead_per_kvhead ) + # Non-split: acc_O -> bf16 -> smem (before LSE write, preserving original order) + if const_expr(not self.is_split_kv): + rO = cute.make_fragment_like(acc_O, self.dtype) + rO.store(acc_O.load().to(self.dtype)) + # Make sure all threads have finished reading V + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + ) + smem_copy_atom_O = utils.get_smem_store_atom(self.arch.major * 10 + self.arch.minor, self.dtype) + smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx) + taccOrO = smem_thr_copy_O.retile(rO) + taccOsO = smem_thr_copy_O.partition_D(sO) + # copy acc O from rmem to smem with the smem copy atom + cute.copy(smem_copy_atom_O, taccOrO, taccOsO) + # Write LSE from rmem -> gmem if const_expr(mLSE is not None): - mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2)[None, head_idx] + if const_expr(self.is_split_kv): + mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2)[None, head_idx, split_idx] + else: + mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2)[None, head_idx] if const_expr(not self.pack_gqa): gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,)) gLSE_expanded_layout = cute.append( @@ -382,64 +387,87 @@ def epilogue( else: pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q) + # Write O to gmem ragged = self.use_tma_O and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q) - mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=ragged)[None, None, head_idx] - # thr_mma = tiled_mma.get_slice(tidx) - # taccOgO = thr_mma.partition_C(gO) - # cute.autovec_copy(rO, taccOgO) - # sync to make sure all smem stores are done - if const_expr(self.use_tma_O): - # ensure smem writes are visible to TMA - cute.arch.fence_view_async_shared() - cute.arch.barrier_arrive( - barrier_id=int(NamedBarrierFwd.Epilogue), - number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, - ) - gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) - store_O, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True - ) - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - if warp_idx == 4: - cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.Epilogue), - number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, - ) - store_O() - cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) + if const_expr(self.is_split_kv): + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] else: + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=ragged)[None, None, head_idx] + + if const_expr(self.is_split_kv): + # SplitKV: fp32 acc_O directly from registers to gmem, bypassing smem cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.Epilogue), - number_of_threads=self.num_epilogue_threads, + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads ) - gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) - tOsO = gmem_thr_copy_O.partition_S(sO) - tOrO = cute.make_fragment_like(tOsO, self.dtype) - # load acc O from smem to rmem for wider vectorization - cute.autovec_copy(tOsO, tOrO) if const_expr(not self.pack_gqa): gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) - tOgO = gmem_thr_copy_O.partition_D(gO) - tOcO = gmem_thr_copy_O.partition_S(cO) - t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) - tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) - # copy acc O from rmem to gmem - for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): - if ( - t0OcO[0, rest_m, 0][0] - < seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0] - ): - cute.copy( - gmem_tiled_copy_O, - tOrO[None, rest_m, None], - tOgO[None, rest_m, None], - pred=tOpO[None, rest_m, None] - if const_expr(self.check_hdim_v_oob) - else None, - ) + thr_mma = tiled_mma.get_slice(tidx) + taccOgO = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(gO)) + taccOcO = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cO)) + taccOrO = layout_utils.reshape_acc_to_mn(acc_O) + seqlen_q_limit = seqlen.seqlen_q - m_block * self.tile_m + for k in cutlass.range_constexpr(cute.size(taccOrO.shape[0])): + if taccOcO[k, 0][0] < seqlen_q_limit: + for m in cutlass.range_constexpr(cute.size(taccOrO.shape[1])): + if const_expr(not self.check_hdim_v_oob) or taccOcO[k, m][1] < mO.shape[1]: + taccOgO[k, m] = taccOrO[k, m] else: - pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q) + # TODO: pack_gqa + split_kv needs store_O_splitkv + pack_gqa.store_O_splitkv(mO_cur, acc_O, tiled_mma, tidx, m_block, seqlen.seqlen_q, head_idx) + else: + # Non-split: smem -> gmem (TMA or non-TMA) + if const_expr(self.use_tma_O): + # ensure smem writes are visible to TMA + cute.arch.fence_view_async_shared() + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, + ) + gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) + store_O, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True + ) + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + if warp_idx == 4: + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, + ) + store_O() + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + else: + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads, + ) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + tOrO = cute.make_fragment_like(tOsO, self.dtype) + # load acc O from smem to rmem for wider vectorization + cute.autovec_copy(tOsO, tOrO) + if const_expr(not self.pack_gqa): + gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + # copy acc O from rmem to gmem + for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + if ( + t0OcO[0, rest_m, 0][0] + < seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0] + ): + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None], + pred=tOpO[None, rest_m, None] + if const_expr(self.check_hdim_v_oob) + else None, + ) + else: + pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q) @cute.jit def advance_pipeline(self, pipeline_index): diff --git a/flash_attn/cute/flash_fwd_sm90.py b/flash_attn/cute/flash_fwd_sm90.py index 4108ce451ff..d01ceecdc9f 100644 --- a/flash_attn/cute/flash_fwd_sm90.py +++ b/flash_attn/cute/flash_fwd_sm90.py @@ -55,6 +55,7 @@ def __init__( intra_wg_overlap: bool = True, mma_pv_is_rs: bool = True, paged_kv_non_tma: bool = False, + is_split_kv: bool = False, **kwargs, ): super().__init__(*args, **kwargs) @@ -62,6 +63,7 @@ def __init__( self.mma_pv_is_rs = mma_pv_is_rs self.buffer_align_bytes = 1024 self.use_tma_KV = not paged_kv_non_tma + self.is_split_kv = is_split_kv assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), ( "Paged KV does not support irregular head dim" ) @@ -181,21 +183,35 @@ def __call__( (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) """ - self._check_type( - *( - t.element_type if t is not None else None - for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK) + # TODO: do something better... + if const_expr(not self.is_split_kv): + self._check_type( + *( + t.element_type if t is not None else None + for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK) + ) ) - ) self.varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)] - QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] - mQ, mO = [layout_utils.select(t, QO_layout_transpose) for t in (mQ, mO)] + Q_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + mQ = layout_utils.select(mQ, Q_layout_transpose) + num_splits = Int32(1) + if const_expr(not self.is_split_kv): + O_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + else: + # Fixed-len O: (num_splits, batch, seqlen_q, heads, d) -> (seqlen_q, d, heads, batch, num_splits) + # Varlen O: (num_splits, total_q, heads, d) -> (total_q, d, heads, num_splits) + O_layout_transpose = [2, 4, 3, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 3, 2, 0] + # Fixed-len LSE: (num_splits, batch, heads, seqlen_q) -> (seqlen_q, heads, batch, num_splits) + # Varlen LSE: (num_splits, heads, total_q) -> (total_q, heads, num_splits) + LSE_layout_transpose = [3, 2, 1, 0] if const_expr(mCuSeqlensQ is None) else [2, 1, 0] + num_splits = mO.shape[0] + mO = layout_utils.select(mO, O_layout_transpose) KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] mK, mV = [layout_utils.select(t, KV_layout_transpose) for t in (mK, mV)] - LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] mLSE = ( layout_utils.select(mLSE, LSE_layout_transpose) if const_expr(mLSE is not None) @@ -224,7 +240,7 @@ def __call__( self.use_tma_Q = self.arch >= Arch.sm_90 and not ( self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0 ) - self.use_tma_O = self.use_tma_Q + self.use_tma_O = self.use_tma_Q and not self.is_split_kv # Producer needs more registers when doing cp.async Q or KV loads if const_expr(self.num_wg_mma == 2 and (not self.use_tma_Q or not self.use_tma_KV)): self.num_mma_regs, self.num_producer_regs = 224, 40 @@ -237,7 +253,8 @@ def __call__( (mQ, (self.tile_m, self.tile_hdim), None), (mK, (self.tile_n, self.tile_hdim), self.num_stages), (mV, (self.tile_n, self.tile_hdimv), self.num_stages), - (mO, (self.tile_m, self.tile_hdimv), None), + # sO layout dtype possibly different from mO dtype when using splitkv (fp32) + (mQ, (self.tile_m, self.tile_hdimv), None), ] ] self.sP_layout = None @@ -325,7 +342,7 @@ def __call__( cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) else cute.size(mCuSeqlensQ.shape[0] - 1), - 1, # num_splits + num_splits, cute.size(mK.shape[0]) if const_expr(mPageTable is None) else mK.shape[0] * mPageTable.shape[1], @@ -341,6 +358,7 @@ def __call__( element_size=self.dtype.width // 8, is_persistent=False, lpt=self.is_causal or self.is_local, + is_split_kv=self.is_split_kv, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) @@ -388,6 +406,7 @@ def __call__( tile_sched_params, TileScheduler, SharedStorage, + num_splits, aux_tensors, fastdiv_mods, ).launch( @@ -434,6 +453,7 @@ def kernel( tile_sched_params: ParamsBase, TileScheduler: cutlass.Constexpr[Callable], SharedStorage: cutlass.Constexpr[Callable], + num_splits: Int32 = Int32(1), aux_tensors=Optional[list[cute.Tensor]], fastdiv_mods=None, ): @@ -538,7 +558,7 @@ def kernel( self.tile_n, self.is_causal, self.is_local, - False, # is_split_kv + self.is_split_kv, window_size_left, window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, @@ -589,6 +609,7 @@ def kernel( block_info, SeqlenInfoCls, TileSchedulerCls, + num_splits, ) else: # Consumer @@ -624,6 +645,7 @@ def kernel( blocksparse_tensors, aux_tensors, fastdiv_mods, + num_splits, ) @cute.jit @@ -647,6 +669,7 @@ def load( block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, + num_splits: Int32 = Int32(1), ): warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 tidx, _, _ = cute.arch.thread_idx() @@ -666,7 +689,7 @@ def load( work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: # if work_tile.is_valid_tile: - m_block, head_idx, batch_idx, _ = work_tile.tile_idx + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] head_idx_kv = ( @@ -754,7 +777,7 @@ def load( ) if const_expr(not self.use_block_sparsity): - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) # if cute.arch.thread_idx()[0] == 0: # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) # Clamp n_block to 0 when n_block_max == 0 (can happen with causal @@ -951,6 +974,7 @@ def mma( blocksparse_tensors: Optional[BlockSparseTensors], aux_tensors: Optional[list], fastdiv_mods=None, + num_splits: Int32 = Int32(1), ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) warp_group_thread_layout = cute.make_layout( @@ -1036,7 +1060,7 @@ def mma( # if work_tile.is_valid_tile: # shape: (atom_v_m * rest_m) - m_block, head_idx, batch_idx, _ = work_tile.tile_idx + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) # Recompute fastdiv_mods if necessary for varlen with aux_tensors @@ -1084,7 +1108,7 @@ def mma( mma_one_n_block = partial( mma_one_n_block_all, seqlen=seqlen, softmax=softmax, score_mod_fn=score_mod_fn ) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) pipeline_q.consumer_wait_w_index_phase(0, q_consumer_phase) # For performance reason, we separate out two kinds of iterations: # those that need masking on S, and those that don't. @@ -1250,6 +1274,7 @@ def mma( m_block, head_idx, batch_idx, + split_idx, ) tile_scheduler.advance_to_next_work() diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index e4b55456ffe..2cecd39e96f 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -685,7 +685,6 @@ def _flash_attn_fwd( has_aux_tensors=aux_tensors is not None, ) elif arch // 10 == 9: - assert not is_split_kv, "SplitKV not supported on SM 9.0" fa_fwd = FlashAttentionForwardSm90( dtype, head_dim, @@ -693,6 +692,7 @@ def _flash_attn_fwd( qhead_per_kvhead, is_causal=causal, is_local=local, + is_split_kv=is_split_kv, pack_gqa=pack_gqa, tile_m=tile_m, tile_n=tile_n, From c701feefb51392ccb46178c942c3ae397349421b Mon Sep 17 00:00:00 2001 From: imbr92 <40306754+imbr92@users.noreply.github.com> Date: Mon, 30 Mar 2026 17:04:34 +0000 Subject: [PATCH 2/7] Fix split kv with empty splits --- flash_attn/cute/flash_fwd_sm90.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/flash_attn/cute/flash_fwd_sm90.py b/flash_attn/cute/flash_fwd_sm90.py index d01ceecdc9f..0c53ece57f1 100644 --- a/flash_attn/cute/flash_fwd_sm90.py +++ b/flash_attn/cute/flash_fwd_sm90.py @@ -1109,6 +1109,7 @@ def mma( mma_one_n_block_all, seqlen=seqlen, softmax=softmax, score_mod_fn=score_mod_fn ) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + n_block_max_orig = n_block_max pipeline_q.consumer_wait_w_index_phase(0, q_consumer_phase) # For performance reason, we separate out two kinds of iterations: # those that need masking on S, and those that don't. @@ -1257,6 +1258,12 @@ def mma( row_scale = softmax.finalize(sink_val=sink_val) softmax.rescale_O(acc_O, row_scale) + # Override empty splits so combine kernel gives zero weight + if const_expr(self.is_split_kv): + if n_block_min >= n_block_max_orig: + acc_O.fill(Float32(0.0)) + softmax.row_sum.fill(-Float32.inf) + # /////////////////////////////////////////////////////////////////////////////// # Epilogue # /////////////////////////////////////////////////////////////////////////////// From 44c3b0c1817fffce0d19c0b41e6233bccac3617c Mon Sep 17 00:00:00 2001 From: imbr92 <40306754+imbr92@users.noreply.github.com> Date: Mon, 30 Mar 2026 18:01:47 +0000 Subject: [PATCH 3/7] Enable split kv for sm90 in tests --- tests/cute/test_flash_attn.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 69e6308fb60..7b1b510d08a 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -34,7 +34,6 @@ # When operating fake tensors, we cannot perform data-dependent operations (e.g., `tensor.max()`). USE_FAKE_TENSOR = int(os.getenv("FLASH_ATTENTION_FAKE_TENSOR", 0)) == 1 DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" -# SplitKV is not supported on SM90 IS_SM90 = torch.cuda.get_device_capability()[0] == 9 IS_SM100 = torch.cuda.get_device_capability()[0] == 10 TEST_BWD_ONLY = False @@ -250,9 +249,6 @@ def test_flash_attn_output( # pack_gqa_vals = [False] num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - # SplitKV not supported on SM90 - skip this iteration - if IS_SM90 and num_splits > 1: - continue if IS_SM100 and (d >= 192 and dv >= 192): # hdim 192 and 256 not support on SM100 continue out, lse = flash_attn_func( @@ -682,9 +678,6 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # SplitKV is not supported for hdim >= 192 num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - # SplitKV not supported on SM90 - skip this iteration - if IS_SM90 and num_splits > 1: - continue out_unpad, lse = flash_attn_varlen_func( q_unpad if unpad_q else q, k_unpad if unpad_kv else k, @@ -1291,9 +1284,6 @@ def test_flash_attn_kvcache( for num_splits, precompute_metadata in itertools.product( num_splits_vals, precompute_metadata_vals ): - # SplitKV not supported on SM90 - skip this iteration - if IS_SM90 and num_splits > 1: - continue # if precompute_metadata: # scheduler_metadata = get_scheduler_metadata( # batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, From 61a9620144e3cd715393d15ded6450a14ab22312 Mon Sep 17 00:00:00 2001 From: imbr92 <40306754+imbr92@users.noreply.github.com> Date: Mon, 30 Mar 2026 20:17:18 +0000 Subject: [PATCH 4/7] Get splitkv working with local attn --- flash_attn/cute/flash_fwd_sm90.py | 6 ++++++ flash_attn/cute/softmax.py | 12 +++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm90.py b/flash_attn/cute/flash_fwd_sm90.py index 0c53ece57f1..48dc91ad6b8 100644 --- a/flash_attn/cute/flash_fwd_sm90.py +++ b/flash_attn/cute/flash_fwd_sm90.py @@ -1253,6 +1253,12 @@ def mma( row = m_block * self.tile_m + tScS_mn[r][0] q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead sink_val[r] = Float32(learnable_sink[q_head_idx]) + if const_expr(self.is_split_kv): + if split_idx > 0: + if const_expr(not self.pack_gqa): + sink_val = -Float32.inf + else: + sink_val.fill(-Float32.inf) # normalize acc_O by row_sum and calculate the lse row_scale = softmax.finalize(sink_val=sink_val) diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index eed55a0b721..bbbc038e483 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -134,9 +134,15 @@ def finalize( if cutlass.const_expr(sink_val is not None): sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r] LOG2_E = math.log2(math.e) - row_sum[r] += cute.math.exp2( - sink_val_cur * LOG2_E - row_max[r] * scale_log2, fastmath=True - ) + # if all scores are masked (row_max=-inf), exp2(sink - (-inf)) overflows + # set row_max/row_sum so the sink is the sole softmax contributor (matching SM100 logic) + if row_max[r] == -Float32.inf: + row_max[r] = sink_val_cur * (LOG2_E / scale_log2) + row_sum[r] = Float32(1.0) + else: + row_sum[r] += cute.math.exp2( + sink_val_cur * LOG2_E - row_max[r] * scale_log2, fastmath=True + ) # if row_sum is zero or nan, set acc_O_mn_row to 1.0 acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r] From cf0d5c863fb87f7eafc4c4039821fd197d7c43b9 Mon Sep 17 00:00:00 2001 From: imbr92 <40306754+imbr92@users.noreply.github.com> Date: Mon, 30 Mar 2026 22:23:57 +0000 Subject: [PATCH 5/7] Support splitkv + packgqa --- flash_attn/cute/flash_fwd.py | 9 +++++++-- flash_attn/cute/pack_gqa.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 35696486862..4db6451f6f2 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -412,8 +412,13 @@ def epilogue( if const_expr(not self.check_hdim_v_oob) or taccOcO[k, m][1] < mO.shape[1]: taccOgO[k, m] = taccOrO[k, m] else: - # TODO: pack_gqa + split_kv needs store_O_splitkv - pack_gqa.store_O_splitkv(mO_cur, acc_O, tiled_mma, tidx, m_block, seqlen.seqlen_q, head_idx) + # mO_gqa is ((qheads_per_kvhead, seqlen_q), d, h_kv) + if const_expr(not seqlen.has_cu_seqlens_q): + mO_gqa = mO[None, None, None, batch_idx, split_idx] + else: + offset = (0, seqlen.offset_q) + mO_gqa = cute.domain_offset((offset, 0, 0), mO[None, None, None, split_idx]) + pack_gqa.store_O_splitkv(mO_gqa, acc_O, tiled_mma, tidx, m_block, seqlen.seqlen_q, head_idx) else: # Non-split: smem -> gmem (TMA or non-TMA) if const_expr(self.use_tma_O): diff --git a/flash_attn/cute/pack_gqa.py b/flash_attn/cute/pack_gqa.py index e87df018671..3d5eb624825 100644 --- a/flash_attn/cute/pack_gqa.py +++ b/flash_attn/cute/pack_gqa.py @@ -261,3 +261,34 @@ def store_O( mO_cur_copy[None, ki], pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, ) + + @cute.jit + def store_O_splitkv( + self, + mO: cute.Tensor, # ((qhd_per_kv, seqlen_q), d, nhead_kv) + acc_O: cute.Tensor, # MMA accumulator, fp32 + tiled_mma: cute.TiledMma, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, # seqlen_q for this batch element + head_idx_kv: cutlass.Int32, + ): + # Writing acc_O directly to gmem for packgqa + splitkv in sm90 + thr_mma = tiled_mma.get_slice(tidx) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + taccOcO = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cO)) + t0accOcO = layout_utils.reshape_acc_to_mn(thr_mma.get_slice(0).partition_C(cO)) + taccOrO = layout_utils.reshape_acc_to_mn(acc_O) + for r in cutlass.range_constexpr(cute.size(taccOrO.shape[0])): + if ( + t0accOcO[r, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - taccOcO[0][0] + ): + row = taccOcO[r, 0][0] + packed_idx = block * self.m_block_size + row + m_idx = packed_idx // self.qhead_per_kvhead + h_idx = packed_idx - m_idx * self.qhead_per_kvhead + for c in cutlass.range_constexpr(cute.size(taccOrO.shape[1])): + col = taccOcO[r, c][1] + if cutlass.const_expr(not self.check_hdim_oob) or col < mO.shape[1]: + mO[(h_idx, m_idx), col, head_idx_kv] = taccOrO[r, c] From 2351ca15332a54ede2dd24b0580b85005de4ea25 Mon Sep 17 00:00:00 2001 From: imbr92 <40306754+imbr92@users.noreply.github.com> Date: Tue, 31 Mar 2026 00:41:44 +0000 Subject: [PATCH 6/7] cleanup --- flash_attn/cute/flash_fwd.py | 17 +++++++++-------- flash_attn/cute/flash_fwd_sm90.py | 19 +++++++------------ flash_attn/cute/pack_gqa.py | 6 +++--- 3 files changed, 19 insertions(+), 23 deletions(-) diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index 4db6451f6f2..33031d0f2cc 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -167,7 +167,6 @@ def can_implement( return False return True - # TODO: fix to allow split kv --> mO type fp32 def _check_type( self, mQ_type: Type[cutlass.Numeric], @@ -179,10 +178,16 @@ def _check_type( mCuSeqlensK_type: Type[cutlass.Numeric] | None, mSeqUsedQ_type: Type[cutlass.Numeric] | None, mSeqUsedK_type: Type[cutlass.Numeric] | None, + is_split_kv: bool = False, ): - # Get the data type and check if it is fp16 or bf16 - if const_expr(not (mQ_type == mK_type == mV_type == mO_type)): - raise TypeError("All tensors must have the same data type") + if is_split_kv: + if const_expr(not (mQ_type == mK_type == mV_type)): + raise TypeError("Q, K, V tensors must have the same data type") + if const_expr(mO_type != Float32): + raise TypeError("O tensor must be Float32 for split_kv") + else: + if const_expr(not (mQ_type == mK_type == mV_type == mO_type)): + raise TypeError("All tensors must have the same data type") if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") if const_expr(mLSE_type not in [None, Float32]): @@ -344,7 +349,6 @@ def epilogue( self.tile_m, self.tile_hdimv, self.check_hdim_v_oob, self.qhead_per_kvhead ) - # Non-split: acc_O -> bf16 -> smem (before LSE write, preserving original order) if const_expr(not self.is_split_kv): rO = cute.make_fragment_like(acc_O, self.dtype) rO.store(acc_O.load().to(self.dtype)) @@ -387,7 +391,6 @@ def epilogue( else: pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q) - # Write O to gmem ragged = self.use_tma_O and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q) if const_expr(self.is_split_kv): mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] @@ -395,7 +398,6 @@ def epilogue( mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=ragged)[None, None, head_idx] if const_expr(self.is_split_kv): - # SplitKV: fp32 acc_O directly from registers to gmem, bypassing smem cute.arch.barrier( barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads ) @@ -420,7 +422,6 @@ def epilogue( mO_gqa = cute.domain_offset((offset, 0, 0), mO[None, None, None, split_idx]) pack_gqa.store_O_splitkv(mO_gqa, acc_O, tiled_mma, tidx, m_block, seqlen.seqlen_q, head_idx) else: - # Non-split: smem -> gmem (TMA or non-TMA) if const_expr(self.use_tma_O): # ensure smem writes are visible to TMA cute.arch.fence_view_async_shared() diff --git a/flash_attn/cute/flash_fwd_sm90.py b/flash_attn/cute/flash_fwd_sm90.py index 48dc91ad6b8..8a4a81b89a1 100644 --- a/flash_attn/cute/flash_fwd_sm90.py +++ b/flash_attn/cute/flash_fwd_sm90.py @@ -183,14 +183,13 @@ def __call__( (batch_size, seqlen_q, num_head, head_dim):(_, _, _, 1) """ - # TODO: do something better... - if const_expr(not self.is_split_kv): - self._check_type( - *( - t.element_type if t is not None else None - for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK) - ) - ) + self._check_type( + *( + t.element_type if t is not None else None + for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK) + ), + is_split_kv=self.is_split_kv, + ) self.varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None @@ -202,11 +201,7 @@ def __call__( O_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] else: - # Fixed-len O: (num_splits, batch, seqlen_q, heads, d) -> (seqlen_q, d, heads, batch, num_splits) - # Varlen O: (num_splits, total_q, heads, d) -> (total_q, d, heads, num_splits) O_layout_transpose = [2, 4, 3, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 3, 2, 0] - # Fixed-len LSE: (num_splits, batch, heads, seqlen_q) -> (seqlen_q, heads, batch, num_splits) - # Varlen LSE: (num_splits, heads, total_q) -> (total_q, heads, num_splits) LSE_layout_transpose = [3, 2, 1, 0] if const_expr(mCuSeqlensQ is None) else [2, 1, 0] num_splits = mO.shape[0] mO = layout_utils.select(mO, O_layout_transpose) diff --git a/flash_attn/cute/pack_gqa.py b/flash_attn/cute/pack_gqa.py index 3d5eb624825..8049b264d3c 100644 --- a/flash_attn/cute/pack_gqa.py +++ b/flash_attn/cute/pack_gqa.py @@ -265,12 +265,12 @@ def store_O( @cute.jit def store_O_splitkv( self, - mO: cute.Tensor, # ((qhd_per_kv, seqlen_q), d, nhead_kv) - acc_O: cute.Tensor, # MMA accumulator, fp32 + mO: cute.Tensor, + acc_O: cute.Tensor, tiled_mma: cute.TiledMma, tidx: cutlass.Int32, block: cutlass.Int32, - seqlen: cutlass.Int32, # seqlen_q for this batch element + seqlen: cutlass.Int32, head_idx_kv: cutlass.Int32, ): # Writing acc_O directly to gmem for packgqa + splitkv in sm90 From 698f05df612872c75a73c0303f095de5cf4fe366 Mon Sep 17 00:00:00 2001 From: imbr92 <40306754+imbr92@users.noreply.github.com> Date: Tue, 31 Mar 2026 13:19:53 +0000 Subject: [PATCH 7/7] ruff on flash_fwd_sm90 --- flash_attn/cute/flash_fwd_sm90.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/flash_attn/cute/flash_fwd_sm90.py b/flash_attn/cute/flash_fwd_sm90.py index 8a4a81b89a1..357414e0051 100644 --- a/flash_attn/cute/flash_fwd_sm90.py +++ b/flash_attn/cute/flash_fwd_sm90.py @@ -201,7 +201,9 @@ def __call__( O_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] else: - O_layout_transpose = [2, 4, 3, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 3, 2, 0] + O_layout_transpose = ( + [2, 4, 3, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 3, 2, 0] + ) LSE_layout_transpose = [3, 2, 1, 0] if const_expr(mCuSeqlensQ is None) else [2, 1, 0] num_splits = mO.shape[0] mO = layout_utils.select(mO, O_layout_transpose) @@ -249,7 +251,7 @@ def __call__( (mK, (self.tile_n, self.tile_hdim), self.num_stages), (mV, (self.tile_n, self.tile_hdimv), self.num_stages), # sO layout dtype possibly different from mO dtype when using splitkv (fp32) - (mQ, (self.tile_m, self.tile_hdimv), None), + (mQ, (self.tile_m, self.tile_hdimv), None), ] ] self.sP_layout = None @@ -772,7 +774,9 @@ def load( ) if const_expr(not self.use_block_sparsity): - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, m_block, split_idx, num_splits + ) # if cute.arch.thread_idx()[0] == 0: # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) # Clamp n_block to 0 when n_block_max == 0 (can happen with causal @@ -1103,7 +1107,9 @@ def mma( mma_one_n_block = partial( mma_one_n_block_all, seqlen=seqlen, softmax=softmax, score_mod_fn=score_mod_fn ) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, m_block, split_idx, num_splits + ) n_block_max_orig = n_block_max pipeline_q.consumer_wait_w_index_phase(0, q_consumer_phase) # For performance reason, we separate out two kinds of iterations: