diff --git a/flash_attn/cute/flash_fwd.py b/flash_attn/cute/flash_fwd.py index d1a43cfd24..ecca680bc9 100644 --- a/flash_attn/cute/flash_fwd.py +++ b/flash_attn/cute/flash_fwd.py @@ -178,10 +178,16 @@ def _check_type( mCuSeqlensK_type: Type[cutlass.Numeric] | None, mSeqUsedQ_type: Type[cutlass.Numeric] | None, mSeqUsedK_type: Type[cutlass.Numeric] | None, + is_split_kv: bool = False, ): - # Get the data type and check if it is fp16 or bf16 - if const_expr(not (mQ_type == mK_type == mV_type == mO_type)): - raise TypeError("All tensors must have the same data type") + if is_split_kv: + if const_expr(not (mQ_type == mK_type == mV_type)): + raise TypeError("Q, K, V tensors must have the same data type") + if const_expr(mO_type != Float32): + raise TypeError("O tensor must be Float32 for split_kv") + else: + if const_expr(not (mQ_type == mK_type == mV_type == mO_type)): + raise TypeError("All tensors must have the same data type") if const_expr(mQ_type not in [cutlass.Float16, cutlass.BFloat16]): raise TypeError("Only Float16 or BFloat16 is supported") if const_expr(mLSE_type not in [None, Float32]): @@ -336,30 +342,33 @@ def epilogue( m_block: Int32, head_idx: Int32, batch_idx: Int32, + split_idx: Int32 = Int32(0), ): - # store acc_O - rO = cute.make_fragment_like(acc_O, self.dtype) - rO.store(acc_O.load().to(self.dtype)) - # Make sure all threads have finished reading V - cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads - ) - smem_copy_atom_O = utils.get_smem_store_atom(self.arch.major * 10 + self.arch.minor, self.dtype) - smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx) - taccOrO = smem_thr_copy_O.retile(rO) - taccOsO = smem_thr_copy_O.partition_D(sO) - # taccOsO = copy_utils.partition_D_position_independent(smem_thr_copy_O, sO) - # copy acc O from rmem to smem with the smem copy atom - cute.copy(smem_copy_atom_O, taccOrO, taccOsO) - cO = cute.make_identity_tensor((self.tile_m, self.tile_hdimv)) pack_gqa = PackGQA( self.tile_m, self.tile_hdimv, self.check_hdim_v_oob, self.qhead_per_kvhead ) + if const_expr(not self.is_split_kv): + rO = cute.make_fragment_like(acc_O, self.dtype) + rO.store(acc_O.load().to(self.dtype)) + # Make sure all threads have finished reading V + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads + ) + smem_copy_atom_O = utils.get_smem_store_atom(self.arch.major * 10 + self.arch.minor, self.dtype) + smem_thr_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma).get_slice(tidx) + taccOrO = smem_thr_copy_O.retile(rO) + taccOsO = smem_thr_copy_O.partition_D(sO) + # copy acc O from rmem to smem with the smem copy atom + cute.copy(smem_copy_atom_O, taccOrO, taccOsO) + # Write LSE from rmem -> gmem if const_expr(mLSE is not None): - mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2)[None, head_idx] + if const_expr(self.is_split_kv): + mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2)[None, head_idx, split_idx] + else: + mLSE_cur = seqlen.offset_batch_Q(mLSE, batch_idx, dim=2)[None, head_idx] if const_expr(not self.pack_gqa): gLSE = cute.local_tile(mLSE_cur, (self.tile_m,), (m_block,)) gLSE_expanded_layout = cute.append( @@ -383,63 +392,88 @@ def epilogue( pack_gqa.store_LSE(mLSE_cur, lse, tiled_mma, tidx, m_block, seqlen.seqlen_q) ragged = self.use_tma_O and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q) - mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=ragged)[None, None, head_idx] - # thr_mma = tiled_mma.get_slice(tidx) - # taccOgO = thr_mma.partition_C(gO) - # cute.autovec_copy(rO, taccOgO) - # sync to make sure all smem stores are done - if const_expr(self.use_tma_O): - # ensure smem writes are visible to TMA - cute.arch.fence_view_async_shared() - cute.arch.barrier_arrive( - barrier_id=int(NamedBarrierFwd.Epilogue), - number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, - ) - gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) - store_O, _, _ = copy_utils.tma_get_copy_fn( - tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True - ) - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - if warp_idx == 4: - cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.Epilogue), - number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, - ) - store_O() - cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) + if const_expr(self.is_split_kv): + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] else: + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=ragged)[None, None, head_idx] + + if const_expr(self.is_split_kv): cute.arch.barrier( - barrier_id=int(NamedBarrierFwd.Epilogue), - number_of_threads=self.num_epilogue_threads, + barrier_id=int(NamedBarrierFwd.Epilogue), number_of_threads=self.num_epilogue_threads ) - gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) - tOsO = gmem_thr_copy_O.partition_S(sO) - tOrO = cute.make_fragment_like(tOsO, self.dtype) - # load acc O from smem to rmem for wider vectorization - cute.autovec_copy(tOsO, tOrO) if const_expr(not self.pack_gqa): gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) - tOgO = gmem_thr_copy_O.partition_D(gO) - tOcO = gmem_thr_copy_O.partition_S(cO) - t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) - tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) - # copy acc O from rmem to gmem - for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): - if ( - t0OcO[0, rest_m, 0][0] - < seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0] - ): - cute.copy( - gmem_tiled_copy_O, - tOrO[None, rest_m, None], - tOgO[None, rest_m, None], - pred=tOpO[None, rest_m, None] - if const_expr(self.check_hdim_v_oob) - else None, - ) + thr_mma = tiled_mma.get_slice(tidx) + taccOgO = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(gO)) + taccOcO = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cO)) + taccOrO = layout_utils.reshape_acc_to_mn(acc_O) + seqlen_q_limit = seqlen.seqlen_q - m_block * self.tile_m + for k in cutlass.range_constexpr(cute.size(taccOrO.shape[0])): + if taccOcO[k, 0][0] < seqlen_q_limit: + for m in cutlass.range_constexpr(cute.size(taccOrO.shape[1])): + if const_expr(not self.check_hdim_v_oob) or taccOcO[k, m][1] < mO.shape[1]: + taccOgO[k, m] = taccOrO[k, m] else: - pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q) + # mO_gqa is ((qheads_per_kvhead, seqlen_q), d, h_kv) + if const_expr(not seqlen.has_cu_seqlens_q): + mO_gqa = mO[None, None, None, batch_idx, split_idx] + else: + offset = (0, seqlen.offset_q) + mO_gqa = cute.domain_offset((offset, 0, 0), mO[None, None, None, split_idx]) + pack_gqa.store_O_splitkv(mO_gqa, acc_O, tiled_mma, tidx, m_block, seqlen.seqlen_q, head_idx) + else: + if const_expr(self.use_tma_O): + # ensure smem writes are visible to TMA + cute.arch.fence_view_async_shared() + cute.arch.barrier_arrive( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, + ) + gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) + store_O, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_O, 0, cute.make_layout(1), sO, gO, single_stage=True + ) + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + if warp_idx == 4: + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads + cute.arch.WARP_SIZE, + ) + store_O() + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + else: + cute.arch.barrier( + barrier_id=int(NamedBarrierFwd.Epilogue), + number_of_threads=self.num_epilogue_threads, + ) + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + tOrO = cute.make_fragment_like(tOsO, self.dtype) + # load acc O from smem to rmem for wider vectorization + cute.autovec_copy(tOsO, tOrO) + if const_expr(not self.pack_gqa): + gO = cute.local_tile(mO_cur, (self.tile_m, self.tile_hdimv), (m_block, 0)) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOcO = gmem_thr_copy_O.partition_S(cO) + t0OcO = gmem_tiled_copy_O.get_slice(0).partition_S(cO) + tOpO = utils.predicate_k(tOcO, limit=mO.shape[1]) + # copy acc O from rmem to gmem + for rest_m in cutlass.range_constexpr(cute.size(tOrO.shape[1])): + if ( + t0OcO[0, rest_m, 0][0] + < seqlen.seqlen_q - m_block * self.tile_m - tOcO[0][0] + ): + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None], + pred=tOpO[None, rest_m, None] + if const_expr(self.check_hdim_v_oob) + else None, + ) + else: + pack_gqa.store_O(mO_cur, tOrO, gmem_tiled_copy_O, tidx, m_block, seqlen.seqlen_q) @cute.jit def advance_pipeline(self, pipeline_index): diff --git a/flash_attn/cute/flash_fwd_sm90.py b/flash_attn/cute/flash_fwd_sm90.py index 4108ce451f..357414e005 100644 --- a/flash_attn/cute/flash_fwd_sm90.py +++ b/flash_attn/cute/flash_fwd_sm90.py @@ -55,6 +55,7 @@ def __init__( intra_wg_overlap: bool = True, mma_pv_is_rs: bool = True, paged_kv_non_tma: bool = False, + is_split_kv: bool = False, **kwargs, ): super().__init__(*args, **kwargs) @@ -62,6 +63,7 @@ def __init__( self.mma_pv_is_rs = mma_pv_is_rs self.buffer_align_bytes = 1024 self.use_tma_KV = not paged_kv_non_tma + self.is_split_kv = is_split_kv assert self.use_tma_KV or not (self.check_hdim_oob or self.check_hdim_v_oob), ( "Paged KV does not support irregular head dim" ) @@ -185,17 +187,28 @@ def __call__( *( t.element_type if t is not None else None for t in (mQ, mK, mV, mO, mLSE, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK) - ) + ), + is_split_kv=self.is_split_kv, ) self.varlen_q = mCuSeqlensQ is not None or mSeqUsedQ is not None mQ, mK, mV, mO = [assume_tensor_aligned(t) for t in (mQ, mK, mV, mO)] - QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] - mQ, mO = [layout_utils.select(t, QO_layout_transpose) for t in (mQ, mO)] + Q_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + mQ = layout_utils.select(mQ, Q_layout_transpose) + num_splits = Int32(1) + if const_expr(not self.is_split_kv): + O_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + else: + O_layout_transpose = ( + [2, 4, 3, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 3, 2, 0] + ) + LSE_layout_transpose = [3, 2, 1, 0] if const_expr(mCuSeqlensQ is None) else [2, 1, 0] + num_splits = mO.shape[0] + mO = layout_utils.select(mO, O_layout_transpose) KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] mK, mV = [layout_utils.select(t, KV_layout_transpose) for t in (mK, mV)] - LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] mLSE = ( layout_utils.select(mLSE, LSE_layout_transpose) if const_expr(mLSE is not None) @@ -224,7 +237,7 @@ def __call__( self.use_tma_Q = self.arch >= Arch.sm_90 and not ( self.pack_gqa and self.tile_m % self.qhead_per_kvhead != 0 ) - self.use_tma_O = self.use_tma_Q + self.use_tma_O = self.use_tma_Q and not self.is_split_kv # Producer needs more registers when doing cp.async Q or KV loads if const_expr(self.num_wg_mma == 2 and (not self.use_tma_Q or not self.use_tma_KV)): self.num_mma_regs, self.num_producer_regs = 224, 40 @@ -237,7 +250,8 @@ def __call__( (mQ, (self.tile_m, self.tile_hdim), None), (mK, (self.tile_n, self.tile_hdim), self.num_stages), (mV, (self.tile_n, self.tile_hdimv), self.num_stages), - (mO, (self.tile_m, self.tile_hdimv), None), + # sO layout dtype possibly different from mO dtype when using splitkv (fp32) + (mQ, (self.tile_m, self.tile_hdimv), None), ] ] self.sP_layout = None @@ -325,7 +339,7 @@ def __call__( cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) else cute.size(mCuSeqlensQ.shape[0] - 1), - 1, # num_splits + num_splits, cute.size(mK.shape[0]) if const_expr(mPageTable is None) else mK.shape[0] * mPageTable.shape[1], @@ -341,6 +355,7 @@ def __call__( element_size=self.dtype.width // 8, is_persistent=False, lpt=self.is_causal or self.is_local, + is_split_kv=self.is_split_kv, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) @@ -388,6 +403,7 @@ def __call__( tile_sched_params, TileScheduler, SharedStorage, + num_splits, aux_tensors, fastdiv_mods, ).launch( @@ -434,6 +450,7 @@ def kernel( tile_sched_params: ParamsBase, TileScheduler: cutlass.Constexpr[Callable], SharedStorage: cutlass.Constexpr[Callable], + num_splits: Int32 = Int32(1), aux_tensors=Optional[list[cute.Tensor]], fastdiv_mods=None, ): @@ -538,7 +555,7 @@ def kernel( self.tile_n, self.is_causal, self.is_local, - False, # is_split_kv + self.is_split_kv, window_size_left, window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, @@ -589,6 +606,7 @@ def kernel( block_info, SeqlenInfoCls, TileSchedulerCls, + num_splits, ) else: # Consumer @@ -624,6 +642,7 @@ def kernel( blocksparse_tensors, aux_tensors, fastdiv_mods, + num_splits, ) @cute.jit @@ -647,6 +666,7 @@ def load( block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, + num_splits: Int32 = Int32(1), ): warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 tidx, _, _ = cute.arch.thread_idx() @@ -666,7 +686,7 @@ def load( work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: # if work_tile.is_valid_tile: - m_block, head_idx, batch_idx, _ = work_tile.tile_idx + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] head_idx_kv = ( @@ -754,7 +774,9 @@ def load( ) if const_expr(not self.use_block_sparsity): - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, m_block, split_idx, num_splits + ) # if cute.arch.thread_idx()[0] == 0: # cute.printf("m_block = %d, n_block_min: %d, n_block_max: %d", m_block, n_block_min, n_block_max) # Clamp n_block to 0 when n_block_max == 0 (can happen with causal @@ -951,6 +973,7 @@ def mma( blocksparse_tensors: Optional[BlockSparseTensors], aux_tensors: Optional[list], fastdiv_mods=None, + num_splits: Int32 = Int32(1), ): warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) warp_group_thread_layout = cute.make_layout( @@ -1036,7 +1059,7 @@ def mma( # if work_tile.is_valid_tile: # shape: (atom_v_m * rest_m) - m_block, head_idx, batch_idx, _ = work_tile.tile_idx + m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) # Recompute fastdiv_mods if necessary for varlen with aux_tensors @@ -1084,7 +1107,10 @@ def mma( mma_one_n_block = partial( mma_one_n_block_all, seqlen=seqlen, softmax=softmax, score_mod_fn=score_mod_fn ) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block) + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, m_block, split_idx, num_splits + ) + n_block_max_orig = n_block_max pipeline_q.consumer_wait_w_index_phase(0, q_consumer_phase) # For performance reason, we separate out two kinds of iterations: # those that need masking on S, and those that don't. @@ -1228,11 +1254,23 @@ def mma( row = m_block * self.tile_m + tScS_mn[r][0] q_head_idx = row % self.qhead_per_kvhead + head_idx * self.qhead_per_kvhead sink_val[r] = Float32(learnable_sink[q_head_idx]) + if const_expr(self.is_split_kv): + if split_idx > 0: + if const_expr(not self.pack_gqa): + sink_val = -Float32.inf + else: + sink_val.fill(-Float32.inf) # normalize acc_O by row_sum and calculate the lse row_scale = softmax.finalize(sink_val=sink_val) softmax.rescale_O(acc_O, row_scale) + # Override empty splits so combine kernel gives zero weight + if const_expr(self.is_split_kv): + if n_block_min >= n_block_max_orig: + acc_O.fill(Float32(0.0)) + softmax.row_sum.fill(-Float32.inf) + # /////////////////////////////////////////////////////////////////////////////// # Epilogue # /////////////////////////////////////////////////////////////////////////////// @@ -1250,6 +1288,7 @@ def mma( m_block, head_idx, batch_idx, + split_idx, ) tile_scheduler.advance_to_next_work() diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index a11c8debe2..2bf865dbee 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -811,7 +811,6 @@ def _flash_attn_fwd( has_aux_tensors=aux_tensors is not None, ) elif arch // 10 == 9: - assert not is_split_kv, "SplitKV not supported on SM 9.0" fa_fwd = FlashAttentionForwardSm90( dtype, head_dim, @@ -819,6 +818,7 @@ def _flash_attn_fwd( qhead_per_kvhead, is_causal=causal, is_local=local, + is_split_kv=is_split_kv, pack_gqa=pack_gqa, tile_m=tile_m, tile_n=tile_n, diff --git a/flash_attn/cute/pack_gqa.py b/flash_attn/cute/pack_gqa.py index e87df01867..8049b264d3 100644 --- a/flash_attn/cute/pack_gqa.py +++ b/flash_attn/cute/pack_gqa.py @@ -261,3 +261,34 @@ def store_O( mO_cur_copy[None, ki], pred=tOpO[None, m, k] if cutlass.const_expr(self.check_hdim_oob) else None, ) + + @cute.jit + def store_O_splitkv( + self, + mO: cute.Tensor, + acc_O: cute.Tensor, + tiled_mma: cute.TiledMma, + tidx: cutlass.Int32, + block: cutlass.Int32, + seqlen: cutlass.Int32, + head_idx_kv: cutlass.Int32, + ): + # Writing acc_O directly to gmem for packgqa + splitkv in sm90 + thr_mma = tiled_mma.get_slice(tidx) + cO = cute.make_identity_tensor((self.m_block_size, self.head_dim_padded)) + taccOcO = layout_utils.reshape_acc_to_mn(thr_mma.partition_C(cO)) + t0accOcO = layout_utils.reshape_acc_to_mn(thr_mma.get_slice(0).partition_C(cO)) + taccOrO = layout_utils.reshape_acc_to_mn(acc_O) + for r in cutlass.range_constexpr(cute.size(taccOrO.shape[0])): + if ( + t0accOcO[r, 0][0] + < seqlen * self.qhead_per_kvhead - block * self.m_block_size - taccOcO[0][0] + ): + row = taccOcO[r, 0][0] + packed_idx = block * self.m_block_size + row + m_idx = packed_idx // self.qhead_per_kvhead + h_idx = packed_idx - m_idx * self.qhead_per_kvhead + for c in cutlass.range_constexpr(cute.size(taccOrO.shape[1])): + col = taccOcO[r, c][1] + if cutlass.const_expr(not self.check_hdim_oob) or col < mO.shape[1]: + mO[(h_idx, m_idx), col, head_idx_kv] = taccOrO[r, c] diff --git a/flash_attn/cute/softmax.py b/flash_attn/cute/softmax.py index cc9b9d401d..4e8bb7271f 100644 --- a/flash_attn/cute/softmax.py +++ b/flash_attn/cute/softmax.py @@ -134,9 +134,15 @@ def finalize( if cutlass.const_expr(sink_val is not None): sink_val_cur = sink_val if not isinstance(sink_val, cute.Tensor) else sink_val[r] LOG2_E = math.log2(math.e) - row_sum[r] += cute.math.exp2( - sink_val_cur * LOG2_E - row_max[r] * scale_log2, fastmath=True - ) + # if all scores are masked (row_max=-inf), exp2(sink - (-inf)) overflows + # set row_max/row_sum so the sink is the sole softmax contributor (matching SM100 logic) + if row_max[r] == -Float32.inf: + row_max[r] = sink_val_cur * (LOG2_E / scale_log2) + row_sum[r] = Float32(1.0) + else: + row_sum[r] += cute.math.exp2( + sink_val_cur * LOG2_E - row_max[r] * scale_log2, fastmath=True + ) # if row_sum is zero or nan, set acc_O_mn_row to 1.0 acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r] diff --git a/tests/cute/test_flash_attn.py b/tests/cute/test_flash_attn.py index 4446c053b5..f7f2251047 100644 --- a/tests/cute/test_flash_attn.py +++ b/tests/cute/test_flash_attn.py @@ -56,7 +56,6 @@ def wrapper(*args, **kwargs): # When operating fake tensors, we cannot perform data-dependent operations (e.g., `tensor.max()`). USE_FAKE_TENSOR = int(os.getenv("FLASH_ATTENTION_FAKE_TENSOR", 0)) == 1 DISABLE_SPLIT = os.getenv("FLASH_ATTENTION_DISABLE_SPLIT", "FALSE") == "TRUE" -# SplitKV is not supported on SM90 IS_SM90 = torch.cuda.get_device_capability()[0] == 9 IS_SM100 = torch.cuda.get_device_capability()[0] == 10 TEST_BWD_ONLY = False @@ -300,9 +299,7 @@ def test_flash_attn_output( # pack_gqa_vals = [False] num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY and not has_qv else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - # SplitKV not supported on SM90 - skip this iteration - if IS_SM90 and num_splits > 1: - continue + if IS_SM100 and (d >= 192 and dv >= 192) and not (d == 256 and dv == 256): continue # TODO(wangsiyu): SM100 head_dim=256 2CTA kernel does not support pack_gqa yet. @@ -769,9 +766,6 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # SplitKV is not supported for hdim >= 192 num_splits_vals = [1, 3] if d < 192 and not DISABLE_SPLIT and not TEST_BWD_ONLY else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): - # SplitKV not supported on SM90 - skip this iteration - if IS_SM90 and num_splits > 1: - continue # TODO(wangsiyu): SM100 head_dim=256 2CTA kernel does not support pack_gqa yet. # pack_gqa=None means auto-enable for GQA/MQA (qhead_per_kvhead > 1) # Remove this when support is added. @@ -1391,9 +1385,6 @@ def test_flash_attn_kvcache( for num_splits, precompute_metadata in itertools.product( num_splits_vals, precompute_metadata_vals ): - # SplitKV not supported on SM90 - skip this iteration - if IS_SM90 and num_splits > 1: - continue # if precompute_metadata: # scheduler_metadata = get_scheduler_metadata( # batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d,