From 13948bcec8a06946264e5c8bbf8369a757c4d46c Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 4 Nov 2025 21:28:17 +0000 Subject: [PATCH] [Cute] Add block-sparsity support to SM100 - Implement block-sparse attention in flash_fwd_sm100.py - Update interface.py to handle SM100 block size calculations (2x multiplier for m_block_size since 1 CTA handles 2*tile_m rows) - Add mask_mod parameter support in mask.py for block-sparse masking - Add SM100 test fixtures and tile size handling in test_mask_mod.py This enables block-sparsity on SM 10.0 architecture, including mask_mod support and proper block size accounting. --- flash_attn/cute/block_sparse_utils.py | 381 ++++++++++++++++++- flash_attn/cute/compute_block_sparsity.py | 11 +- flash_attn/cute/flash_bwd_sm100.py | 52 ++- flash_attn/cute/flash_fwd_sm100.py | 438 ++++++++++++++-------- flash_attn/cute/interface.py | 32 +- flash_attn/cute/mask.py | 36 +- tests/cute/test_mask_mod.py | 71 +++- 7 files changed, 819 insertions(+), 202 deletions(-) diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index d1cb95e18ed..f117498fd2c 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -7,12 +7,14 @@ from typing import Callable from functools import partial +import math import cutlass import cutlass.cute as cute -from cutlass import const_expr +from cutlass import Float32, Int32, const_expr # Import data structures from block_sparsity from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute import utils @cute.jit @@ -143,8 +145,13 @@ def produce_block_sparse_loads( curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] - curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] - curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + + if const_expr(full_block_cnt is not None): + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + else: + curr_full_block_cnt = Int32(0) + curr_full_block_idx = None mask_empty = curr_mask_block_cnt == 0 full_empty = curr_full_block_cnt == 0 @@ -417,3 +424,371 @@ def consume_block_sparse_loads( O_should_accumulate = True return kv_consumer_state, O_should_accumulate, processed_any + + +@cute.jit +def load_block_list_sm100( + block_indices: cute.Tensor, + block_count, + load_q_with_first: cutlass.Constexpr, + m_block, + q_stage: cutlass.Constexpr, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_kv, +): + """SM100 version of load_block_list (no intra_wg_overlap, no extra_tx_count).""" + if block_count > 0: + # First iteration: load Q alongside K if requested + n_block_first = block_indices[block_count - 1] + + if const_expr(load_q_with_first): + # SM100 loads Q0 and optionally Q1 + load_Q(block=q_stage * m_block + 0, stage=0) + if const_expr(q_stage == 2): + load_Q(block=q_stage * m_block + 1, stage=1) + + # SM100 doesn't use producer_acquire for pipeline_kv in load path + # The pipeline barriers are handled inside load_KV + load_K(block=n_block_first, producer_state=kv_producer_state, page_idx=None) + kv_producer_state.advance() + load_V(block=n_block_first, producer_state=kv_producer_state, page_idx=None) + kv_producer_state.advance() + + # Remaining blocks + for offset in cutlass.range(1, block_count): + n_block = block_indices[block_count - 1 - offset] + load_K(block=n_block, producer_state=kv_producer_state, page_idx=None) + kv_producer_state.advance() + load_V(block=n_block, producer_state=kv_producer_state, page_idx=None) + kv_producer_state.advance() + + return kv_producer_state + + +# SM100-specific tile processor using SM100 helpers +@cute.jit +def produce_block_sparse_loads_sm100( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_kv, + q_stage: cutlass.Constexpr, + q_producer_phase: Int32, +): + """SM100 entry point for sparse block iteration. + + SM100 uses PipelineTmaUmma which doesn't support extra_tx_count, so we use + simplified block processing that just calls producer_acquire without extras. + """ + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + + if const_expr(full_block_cnt is not None): + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + else: + curr_full_block_cnt = Int32(0) + curr_full_block_idx = None + + mask_empty = curr_mask_block_cnt == 0 + full_empty = curr_full_block_cnt == 0 + + q_phase_flipped = False + + if mask_empty: + # No masked blocks: process full list with Q loading + kv_producer_state = load_block_list_sm100( + curr_full_block_idx, + curr_full_block_cnt, + load_q_with_first=True, + m_block=m_block, + q_stage=q_stage, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_kv=pipeline_kv, + ) + q_phase_flipped = not full_empty + else: + # Process masked blocks with Q loading + kv_producer_state = load_block_list_sm100( + curr_mask_block_idx, + curr_mask_block_cnt, + load_q_with_first=True, + m_block=m_block, + q_stage=q_stage, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_kv=pipeline_kv, + ) + q_phase_flipped = True + + if not full_empty: + # Process full blocks without Q loading + kv_producer_state = load_block_list_sm100( + curr_full_block_idx, + curr_full_block_cnt, + load_q_with_first=False, + m_block=m_block, + q_stage=q_stage, + kv_producer_state=kv_producer_state, + load_Q=load_Q, + load_K=load_K, + load_V=load_V, + pipeline_kv=pipeline_kv, + ) + + if q_phase_flipped: + q_producer_phase ^= 1 + + return kv_producer_state, q_producer_phase + + +@cute.jit +def get_total_block_count( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, +): + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + if const_expr(full_block_cnt is not None): + return ( + mask_block_cnt[batch_idx, head_idx, m_block] + + full_block_cnt[batch_idx, head_idx, m_block] + ) + else: + return mask_block_cnt[batch_idx, head_idx, m_block] + + +@cute.jit +def handle_block_sparse_empty_tile_correction_sm100( + tidx: Int32, + q_stage: cutlass.Constexpr, + m_block_size: cutlass.Constexpr, + qhead_per_kvhead, + pack_gqa: cutlass.Constexpr, + is_split_kv: cutlass.Constexpr, + learnable_sink, + mLSE, + seqlen, + m_block: Int32, + head_idx: Int32, + batch_idx: Int32, + split_idx: Int32, + sScale: cute.Tensor, + stats: list, + correction_epilogue: Callable, + thr_mma_pv: cute.core.ThrMma, + tOtOs: tuple[cute.Tensor], + sO: cute.Tensor, + mbar_ptr, + mbar_softmax_corr_full_offset: Int32, + mbar_softmax_corr_empty_offset: Int32, + mbar_P_full_O_rescaled_offset: Int32, + mbar_P_full_2_offset: Int32, + mbar_corr_epi_full_offset: Int32, + mbar_corr_epi_empty_offset: Int32, + softmax_corr_consumer_phase: Int32, + o_corr_consumer_phase: Int32, + corr_epi_producer_phase: Int32, + softmax_scale_log2: Float32, +): + """Handle the block-sparse case where a tile is fully masked: + * zero staged results + * seed stats + * satisfy the usual barrier protocol so downstream warps continue to make progress. + """ + LOG2_E = Float32(math.log2(math.e)) + + for stage in cutlass.range_constexpr(q_stage): + row_sum_value = Float32(1.0) + row_max_value = ( + -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None + ) + if const_expr(learnable_sink is not None): + sink_val = -Float32.inf + if const_expr(not pack_gqa): + sink_val = Float32(learnable_sink[head_idx]) + elif tidx < m_block_size: + q_head_idx = ( + (q_stage * m_block + stage) * m_block_size + tidx + ) % qhead_per_kvhead + head_idx * qhead_per_kvhead + sink_val = Float32(learnable_sink[q_head_idx]) + if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0): + if row_max_value == -Float32.inf: + row_max_value = sink_val * (LOG2_E / softmax_scale_log2) + row_sum_value = Float32(1.0) + else: + row_sum_value = row_sum_value + utils.exp2f( + sink_val * LOG2_E - row_max_value * softmax_scale_log2 + ) + if tidx < m_block_size: + scale_row_idx = tidx + stage * m_block_size + sScale[scale_row_idx] = row_sum_value + if const_expr(mLSE is not None or learnable_sink is not None): + sScale[scale_row_idx + m_block_size * 2] = row_max_value + acc_flag = row_sum_value == Float32(0.0) or row_sum_value != row_sum_value + stats[stage] = (row_sum_value, row_max_value, acc_flag) + + cute.arch.mbarrier_wait( + mbar_ptr + mbar_softmax_corr_full_offset + stage, + softmax_corr_consumer_phase, + ) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage) + + cute.arch.mbarrier_wait( + mbar_ptr + mbar_corr_epi_empty_offset + stage, + corr_epi_producer_phase, + ) + correction_epilogue( + thr_mma_pv, + tOtOs[stage], + tidx, + Float32(0.0), # zero scale ensures empty tile writes zeros into staged outputs + sO[None, None, stage], + ) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_corr_epi_full_offset + stage) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage) + + softmax_corr_consumer_phase ^= 1 + o_corr_consumer_phase ^= 1 + corr_epi_producer_phase ^= 1 + + return ( + softmax_corr_consumer_phase, + o_corr_consumer_phase, + corr_epi_producer_phase, + ) + + +@cute.jit +def softmax_block_sparse_sm100( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + m_block, + softmax_step: Callable, + mask_fn: Callable, + mask_fn_none: Callable, + mma_si_consumer_phase: Int32, + si_corr_producer_phase: Int32, + s0_s1_sequence_phase: Int32, + mbar_ptr, + mbar_softmax_corr_full_offset: Int32, + mbar_softmax_corr_empty_offset: Int32, + mbar_P_full_O_rescaled_offset: Int32, + mbar_P_full_2_offset: Int32, + q_stage: cutlass.Constexpr, + stage_idx: Int32, +): + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + + curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block] + curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block, None] + + if const_expr(full_block_cnt is not None): + curr_full_block_cnt = full_block_cnt[batch_idx, head_idx, m_block] + curr_full_block_idx = full_block_idx[batch_idx, head_idx, m_block, None] + else: + curr_full_block_cnt = Int32(0) + curr_full_block_idx = None + + total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt + + if total_block_cnt == 0: + cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_full_offset + stage_idx) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_O_rescaled_offset + stage_idx) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_P_full_2_offset + stage_idx) + cute.arch.mbarrier_arrive(mbar_ptr + mbar_softmax_corr_empty_offset + stage_idx) + else: + if curr_mask_block_cnt > 0: + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1] + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + mask_n_block, + is_first=True, + mask_fn=partial(mask_fn, mask_seqlen=True), # last block could oob + ) + for i in cutlass.range(1, curr_mask_block_cnt): + mask_n_block = curr_mask_block_idx[curr_mask_block_cnt - 1 - i] + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + mask_n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + + if curr_full_block_cnt > 0: + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1] + if curr_mask_block_cnt == 0: + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + full_n_block, + is_first=True, + mask_fn=partial(mask_fn_none, mask_seqlen=True), + ) + else: + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + full_n_block, + is_first=False, + mask_fn=partial(mask_fn_none, mask_seqlen=False), + ) + for i in cutlass.range(1, curr_full_block_cnt): + full_n_block = curr_full_block_idx[curr_full_block_cnt - 1 - i] + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + ) = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + full_n_block, + mask_fn=partial(mask_fn_none, mask_seqlen=False), + ) + + return ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + total_block_cnt == 0, + ) diff --git a/flash_attn/cute/compute_block_sparsity.py b/flash_attn/cute/compute_block_sparsity.py index bec6fe5701f..acaeac794c5 100644 --- a/flash_attn/cute/compute_block_sparsity.py +++ b/flash_attn/cute/compute_block_sparsity.py @@ -1,11 +1,8 @@ from functools import partial -import math -import operator -from typing import Callable, Optional, Tuple, Type +from typing import Callable, Optional, Tuple -import cuda.bindings.driver as cuda import cutlass -from cutlass import Boolean, Constexpr, Float32, Int32, Int8, const_expr +from cutlass import Boolean, Int32, Int8, const_expr import cutlass.cute as cute from cutlass.cute.runtime import from_dlpack import torch @@ -276,11 +273,11 @@ def compute_block_sparsity( batch_size: The batch size. num_heads: The number of heads. seqlen_q: The sequence length for the query. - seqlen_k: The sequence length for the key. + seqlen_k: The sequence length for the key. mask_mod: The `mask_mod` callable to use. aux_tensors: A list of auxiliary tensors. device: The device to use. - compute_full_blocks: Whether to compute full blocks. If False, only partially-masked blocks are computed. + compute_full_blocks: Whether to compute full blocks. If False, only partially-masked blocks are computed. use_fast_sampling: Whether to use 5-point sampling (4 corners + center). This is much faster, but only suitable for masks where this check is sufficient. Returns: diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 3b9aa00cb33..0a29ce462a8 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -315,7 +315,7 @@ def _setup_smem_layout(self): ) self.sdKV_epi_tile = ( self.tile_n, - 128 // (self.dk_dtype.width // 8), # 64 or 32 + 128 // (self.dk_dtype.width // 8), # 64 or 32 ) # subtiles mma_tiler_dsq[:2] = mma_tiler_pdo[:2] self.num_epi_stages = (self.tile_hdim // 2) // self.sdKV_epi_tile[1] self.sdKV_flat_epi_tile = self.tile_n * (self.tile_hdim // 2) // self.num_epi_stages @@ -326,12 +326,10 @@ def _setup_smem_layout(self): self.dk_dtype, LayoutEnum.ROW_MAJOR, self.sdKV_epi_tile, - 2, # num compute wgs + 2, # num compute wgs ) else: - self.sdKV_layout = cute.make_layout( - (self.tile_n * self.dK_reduce_ncol, 2) - ) + self.sdKV_layout = cute.make_layout((self.tile_n * self.dK_reduce_ncol, 2)) @cute.jit def __call__( @@ -389,9 +387,7 @@ def __call__( ] layout_transpose = [1, 3, 2, 0] # (b, s, n, h) --> (s, h, n, b) - mQ, mK, mV, mdO = [ - utils.select(t, mode=layout_transpose) for t in (mQ, mK, mV, mdO) - ] + mQ, mK, mV, mdO = [utils.select(t, mode=layout_transpose) for t in (mQ, mK, mV, mdO)] LSE_dPsum_dQaccum_transpose = [2, 1, 0] # (b, n, s) --> (s, n, b) mLSE, mdPsum, mdQaccum = [ utils.select(t, mode=LSE_dPsum_dQaccum_transpose) for t in (mLSE, mdPsum, mdQaccum) @@ -400,10 +396,8 @@ def __call__( layout_dKV_transpose = layout_transpose else: layout_dKV_transpose = LSE_dPsum_dQaccum_transpose - mdK, mdV = [ - utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV) - ] - dO_transpose = [1, 0, 2, 3] # (s, h, n, b) --> (h, s, n, h) + mdK, mdV = [utils.select(t, mode=layout_dKV_transpose) for t in (mdK, mdV)] + dO_transpose = [1, 0, 2, 3] # (s, h, n, b) --> (h, s, n, h) mdO = utils.select(mdO, mode=dO_transpose) semaphore_transpose = [2, 3, 1, 0] # (b, n, block, stage) -> (block, stage, n, b) @@ -451,7 +445,7 @@ def __call__( raise RuntimeError("The layout of mdK is wrong") if const_expr(dV_major_mode != tcgen05.OperandMajorMode.K): raise RuntimeError("The layout of mdV is wrong") - + if const_expr(self.use_tma_store and self.qhead_per_kvhead == 1): tma_copy_op_dKV = cpasync.CopyBulkTensorTileS2GOp() tma_atom_dK, mdK_tma_tensor = cpasync.make_tiled_tma_atom( @@ -2253,32 +2247,32 @@ def epilogue_dK_or_dV_tma( if const_expr(self.qhead_per_kvhead == 1): sdKV = sdKV[None, None, wg_idx] # (tile_n, 64) for bf16 else: - sdKV = sdKV[None, wg_idx] # (tile_n * 32) for fp32 - + sdKV = sdKV[None, wg_idx] # (tile_n * 32) for fp32 + # (8, tile_n / 128, 64 / 8) = (8, 1, 8) or (4, tile_n * 32 / (128 * 4)) = (4, 8) tdKVsdKV_r2s = thr_copy_r2s_dKV.partition_D(sdKV) head_idx_kv = head_idx // self.qhead_per_kvhead if const_expr(self.qhead_per_kvhead == 1): - mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] # (seqlen, hdim) + mdKV_cur = mdKV[None, None, head_idx_kv, batch_idx] # (seqlen, hdim) gdKV_p = cute.local_tile( mdKV_cur, (self.tile_n, self.tile_hdim), (n_block, 0) - ) # (tile_n, hdim) - gdKV = self.split_wg(gdKV_p, wg_idx, num_wg) # (tile_n, hdim / 2) + ) # (tile_n, hdim) + gdKV = self.split_wg(gdKV_p, wg_idx, num_wg) # (tile_n, hdim / 2) gdKV_epi = cute.local_tile( gdKV, self.sdKV_epi_tile, (0, None) - ) # (tile_n, 64, epi_stage = (hdim / 2) / 64) + ) # (tile_n, 64, epi_stage = (hdim / 2) / 64) else: - mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim) + mdKV_cur = mdKV[None, head_idx_kv, batch_idx] # (seqlen * hdim) gdKV_p = cute.local_tile( - mdKV_cur, (self.tile_n * self.tile_hdim, ), (n_block, ) - ) # (tile_n * hdim) - gdKV = cute.logical_divide( - gdKV_p, (self.tile_n * self.tile_hdim // num_wg, ) - )[((None, wg_idx), )] # (tile_n * hdim / 2) + mdKV_cur, (self.tile_n * self.tile_hdim,), (n_block,) + ) # (tile_n * hdim) + gdKV = cute.logical_divide(gdKV_p, (self.tile_n * self.tile_hdim // num_wg,))[ + ((None, wg_idx),) + ] # (tile_n * hdim / 2) gdKV_epi = cute.flat_divide( - gdKV, (self.sdKV_flat_epi_tile, ) - ) # (tile_n * hdim / 2 / epi_stage, epi_stage) + gdKV, (self.sdKV_flat_epi_tile,) + ) # (tile_n * hdim / 2 / epi_stage, epi_stage) if const_expr(self.deterministic and self.qhead_per_kvhead > 1): mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx] @@ -2290,7 +2284,7 @@ def epilogue_dK_or_dV_tma( cute.make_layout(1), cute.group_modes(sdKV, 0, 2), cute.group_modes(gdKV_epi, 0, 2), - ) # (TMA) and (TMA, EPI_STAGE) + ) # (TMA) and (TMA, EPI_STAGE) assert len(tdKVsdKV.shape) == 1, "Wrong rank for SMEM fragment tdKVsdKV" assert len(tdKVgdKV.shape) == 2, "Wrong rank for GMEM fragment tdKVgdKV" num_epi_stages = cute.size(tdKVgdKV.shape[1]) @@ -2344,7 +2338,7 @@ def epilogue_dK_or_dV_tma( tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1] = utils.mul_packed_f32x2( (tdKVrdKV_t2r[2 * i], tdKVrdKV_t2r[2 * i + 1]), (scale, scale) ) - tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, self.dv_dtype) # (32 columns) + tdKVrdKV = cute.make_fragment(tdKVrdKV_t2r.shape, self.dv_dtype) # (32 columns) tdKVrdKV.store(tdKVrdKV_t2r.load().to(self.dv_dtype)) # RMEM -> SMEM -- copy, fence and barrier diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 915315d461b..521e1325a8f 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -36,6 +36,12 @@ from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( + get_total_block_count, + produce_block_sparse_loads_sm100, + softmax_block_sparse_sm100, + handle_block_sparse_empty_tile_correction_sm100, +) from flash_attn.cute.pack_gqa import PackGQA from flash_attn.cute import mma_sm100_desc as sm100_desc from flash_attn.cute import blackwell_helpers as sm100_utils @@ -76,6 +82,7 @@ def __init__( n_block_size: int = 128, is_persistent: bool = True, score_mod: cutlass.Constexpr | None = None, + mask_mod: cutlass.Constexpr | None = None, has_aux_tensors: cutlass.Constexpr = False, paged_kv_non_tma: bool = False, ): @@ -116,6 +123,7 @@ def __init__( "SplitKV is not supported for hdim >= 192" ) self.score_mod = score_mod + self.mask_mod = mask_mod if cutlass.const_expr(has_aux_tensors): self.vec_size: cutlass.Constexpr = 1 else: @@ -652,6 +660,10 @@ class SharedStorage: seqlen_k_divmod = FastDivmod.create(seqlen_k) fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) + if cutlass.const_expr(self.use_block_sparsity and mPageTable is not None): + raise NotImplementedError("Block sparsity + paged KV not supported on SM100") + # Launch the kernel synchronously self.kernel( mQ, @@ -673,6 +685,7 @@ class SharedStorage: window_size_left, window_size_right, learnable_sink, + blocksparse_tensors, sQ_layout, sK_layout, tP_layout, @@ -717,6 +730,7 @@ def kernel( window_size_left: Optional[Int32], window_size_right: Optional[Int32], learnable_sink: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, tP_layout: cute.ComposedLayout, @@ -941,6 +955,7 @@ def kernel( num_splits, SeqlenInfoCls, TileSchedulerCls, + blocksparse_tensors, ) # /////////////////////////////////////////////////////////////////////////////// @@ -970,6 +985,7 @@ def kernel( num_splits, SeqlenInfoCls, TileSchedulerCls, + blocksparse_tensors, ) # if warp_idx == self.mma_warp_id: @@ -1024,6 +1040,7 @@ def kernel( TileSchedulerCls=TileSchedulerCls, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, + blocksparse_tensors=blocksparse_tensors, ) if const_expr(not self.s0_s1_barrier): @@ -1070,6 +1087,7 @@ def kernel( num_splits, SeqlenInfoCls, TileSchedulerCls, + blocksparse_tensors, ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_tmem_dealloc_offset) @@ -1096,6 +1114,7 @@ def load( num_splits: Int32, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors], ): num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE tidx = cute.arch.thread_idx()[0] % num_load_threads @@ -1207,40 +1226,58 @@ def load( K_or_V="V", ) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) - - if const_expr(not self.is_split_kv) or n_block_min < n_block_max: - if const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE: - load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 - n_block_first = n_block_max - 1 if n_block_max > 0 else 0 - page_idx = ( - mPageTable[batch_idx, n_block_first] - if const_expr(mPageTable is not None and self.use_tma_KV) - else None + if const_expr(not self.use_block_sparsity): + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, m_block, split_idx, num_splits ) - if const_expr(not self.use_tma_KV): - paged_kv_manager.load_page_table(n_block_first) - load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 - kv_producer_state.advance() - if const_expr(self.q_stage == 2) and (const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE): - load_Q(block=self.q_stage * m_block + 1, stage=1) # Q1 - q_producer_phase ^= 1 - load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 - kv_producer_state.advance() - for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): - n_block = n_block_max - 2 - i + if const_expr(not self.is_split_kv) or n_block_min < n_block_max: + if const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE: + load_Q(block=self.q_stage * m_block + 0, stage=0) # Q0 + n_block_first = n_block_max - 1 if n_block_max > 0 else 0 page_idx = ( - mPageTable[batch_idx, n_block] + mPageTable[batch_idx, n_block_first] if const_expr(mPageTable is not None and self.use_tma_KV) else None ) if const_expr(not self.use_tma_KV): - paged_kv_manager.load_page_table(n_block) - # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) - load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki + paged_kv_manager.load_page_table(n_block_first) + load_K(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # K0 kv_producer_state.advance() - load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi + if const_expr(self.q_stage == 2) and (const_expr(self.use_tma_KV) or tidx < cute.arch.WARP_SIZE): + load_Q(block=self.q_stage * m_block + 1, stage=1) # Q1 + q_producer_phase ^= 1 + load_V(block=n_block_max - 1, producer_state=kv_producer_state, page_idx=page_idx) # V0 kv_producer_state.advance() + for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + n_block = n_block_max - 2 - i + page_idx = ( + mPageTable[batch_idx, n_block] + if const_expr(mPageTable is not None and self.use_tma_KV) + else None + ) + if const_expr(not self.use_tma_KV): + paged_kv_manager.load_page_table(n_block) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("n_block = {}, page_idx = {}", n_block, page_idx) + load_K(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Ki + kv_producer_state.advance() + load_V(block=n_block, producer_state=kv_producer_state, page_idx=page_idx) # Vi + kv_producer_state.advance() + + else: + kv_producer_state, q_producer_phase = produce_block_sparse_loads_sm100( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + kv_producer_state, + load_Q, + load_K, + load_V, + pipeline_kv, + self.q_stage, + q_producer_phase, + ) + tile_scheduler.prefetch_next_work() tile_scheduler.advance_to_next_work() @@ -1264,6 +1301,7 @@ def mma( num_splits: Int32, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors], ): tSrQ = tiled_mma_qk.make_fragment_A(sQ) tSrK = tiled_mma_qk.make_fragment_B(sK) @@ -1308,15 +1346,28 @@ def mma( while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) - n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) - if const_expr(not self.is_split_kv) or n_block_min < n_block_max: + block_iter_count = Int32(0) + process_tile = False + + if const_expr(self.use_block_sparsity): + block_iter_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block) + process_tile = block_iter_count > Int32(0) + else: + n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + block_iter_count = n_block_max - n_block_min + if const_expr(not self.is_split_kv): + process_tile = True + else: + process_tile = n_block_min < n_block_max + + if process_tile: for stage in cutlass.range_constexpr(self.q_stage): # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) # 1. wait for Q0 / Q1 cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase - ) + mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase + ) # 2. wait for K0 if const_expr(stage == 0): pipeline_kv.consumer_wait(mma_kv_consumer_state) @@ -1345,8 +1396,9 @@ def mma( # so we need to release them after the seqlen_kv loop # O hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate + block_loop_count = block_iter_count - 1 O_should_accumulate = False - for i in cutlass.range(n_block_max - 1 - n_block_min, unroll=1): + for i in cutlass.range(block_loop_count, unroll=1): # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop # 1. wait for V0 pipeline_kv.consumer_wait(mma_kv_consumer_state) @@ -1444,7 +1496,7 @@ def mma( ) # 4. release accumulated O0_partial # We do need O_full here since for the last tile, by the time the softmax warp - # has signaled to the correction warp, the softmax warp has just finished compute + # has signaled to the correction warps, the softmax warp has just finished compute # the row sum of the current tile. It does not guarantee that the 1st tile # of the next work tile has been computed yet. with cute.arch.elect_one(): @@ -1461,6 +1513,7 @@ def mma( work_tile = tile_scheduler.get_current_work() # End of persistent scheduler loop + # for both softmax0 and softmax1 warp group @cute.jit def softmax_loop( @@ -1481,6 +1534,7 @@ def softmax_loop( TileSchedulerCls: Callable, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): """Compute softmax on attention scores from QK matrix multiplication. @@ -1548,115 +1602,173 @@ def softmax_loop( seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) - if const_expr(not self.is_split_kv) or n_block_min < n_block_max: - mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) - mask_fn = partial( + mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) + shared_mask_kwargs = dict( + m_block=self.q_stage * m_block + stage, + thr_mma=thr_mma_qk, + thr_tmem_load=thr_tmem_load, + mask_causal=self.is_causal, + mask_local=self.is_local, + batch_idx=batch_idx, + head_idx=head_idx, + aux_tensors=aux_tensors, + ) + block_mask_mod = self.mask_mod if const_expr(self.use_block_sparsity) else None + mask_fn = partial( + mask.apply_mask_sm100, + mask_mod=block_mask_mod, + **shared_mask_kwargs, + ) + if const_expr(self.use_block_sparsity): + # Full blocks dont need mask_mod + mask_fn_none = partial( mask.apply_mask_sm100, - m_block=self.q_stage * m_block + stage, - thr_mma=thr_mma_qk, - thr_tmem_load=thr_tmem_load, - mask_causal=self.is_causal, - mask_local=self.is_local, - ) - softmax = SoftmaxSm100.create( - softmax_scale_log2, - rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, - softmax_scale=softmax_scale, - ) - softmax.reset() - - softmax_step = partial( - self.softmax_step, - softmax=softmax, - mbar_ptr=mbar_ptr, - mbar_s0_s1_sequence_offset=mbar_s0_s1_sequence_offset, - thr_mma_qk=thr_mma_qk, - thr_tmem_load=thr_tmem_load, - thr_tmem_store=thr_tmem_store, - thr_tmem_store_scale=thr_tmem_store_scale, - tStS_t2r=tStS_t2r, - tStScale_r2t=tStScale_r2t, - tStP_r2t=tStP_r2t, - sScale=sScale, - stage=stage, - batch_idx=batch_idx, - head_idx=head_idx, - m_block=self.q_stage * m_block + stage, - seqlen=seqlen, - aux_tensors=aux_tensors, - fastdiv_mods=fastdiv_mods, + mask_mod=None, + **shared_mask_kwargs, ) + else: + mask_fn_none = None + + softmax = SoftmaxSm100.create( + softmax_scale_log2, + rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, + softmax_scale=softmax_scale, + ) + softmax.reset() + + if const_expr(self.use_block_sparsity): + tile_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block) + has_work = tile_block_count > Int32(0) + else: + tile_block_count = n_block_max - n_block_min + has_work = const_expr(not self.is_split_kv) or tile_block_count > Int32(0) + + softmax_step = partial( + self.softmax_step, + softmax=softmax, + mbar_ptr=mbar_ptr, + mbar_s0_s1_sequence_offset=mbar_s0_s1_sequence_offset, + thr_mma_qk=thr_mma_qk, + thr_tmem_load=thr_tmem_load, + thr_tmem_store=thr_tmem_store, + thr_tmem_store_scale=thr_tmem_store_scale, + tStS_t2r=tStS_t2r, + tStScale_r2t=tStScale_r2t, + tStP_r2t=tStP_r2t, + sScale=sScale, + stage=stage, + batch_idx=batch_idx, + head_idx=head_idx, + m_block=self.q_stage * m_block + stage, + seqlen=seqlen, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + if has_work: + # Softmax acts as the producer: wait until correction signals the stage is empty cute.arch.mbarrier_wait( mbar_ptr + self.mbar_softmax_corr_empty_offset + stage, si_corr_producer_phase ) si_corr_producer_phase ^= 1 - # 1 masking iter - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + + # Block sparse or dense iteration + if const_expr(self.use_block_sparsity): + ( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + empty_tile, + ) = softmax_block_sparse_sm100( + blocksparse_tensors, + batch_idx, + head_idx, + m_block, + softmax_step, + mask_fn, + mask_fn_none, mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, - n_block_max - 1, - is_first=True, - mask_fn=partial(mask_fn, mask_seqlen=True), + mbar_ptr, + self.mbar_softmax_corr_full_offset, + self.mbar_softmax_corr_empty_offset, + self.mbar_P_full_O_rescaled_offset, + self.mbar_P_full_2_offset, + self.q_stage, + Int32(stage), ) - n_block_max -= 1 - # Next couple of iterations with causal masking - if const_expr(self.is_causal or self.is_local): - n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( - seqlen, m_block, n_block_min + if not empty_tile: + sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] + if const_expr(mLSE is not None or learnable_sink is not None): + sScale[ + tidx + stage * self.m_block_size + self.m_block_size * 2 + ] = softmax.row_max[0] + # if tidx == 0: + # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) + # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) + else: + if const_expr(not self.is_split_kv) or tile_block_count > Int32(0): + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block_max - 1, + is_first=True, + mask_fn=partial(mask_fn, mask_seqlen=True), ) - for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): - n_block = n_block_max - 1 - n_tile - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( - softmax_step( - mma_si_consumer_phase, - si_corr_producer_phase, - s0_s1_sequence_phase, - n_block, - mask_fn=partial(mask_fn, mask_seqlen=False), - ) + n_block_max -= 1 + # Next couple of iterations with causal masking + if const_expr(self.is_causal or self.is_local): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, m_block, n_block_min ) - n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) - # The remaining iterations have no masking - n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( - seqlen, m_block, n_block_min - ) - for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): - n_block = n_block_max - n_tile - 1 - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block - ) - # Separate iterations with local masking on the left - if const_expr(self.is_local and block_info.window_size_left is not None): - n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) - for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1): - n_block = n_block_max - 1 - n_tile - mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( - softmax_step( - mma_si_consumer_phase, - si_corr_producer_phase, - s0_s1_sequence_phase, - n_block, - mask_fn=partial(mask_fn, mask_seqlen=False), + for n_tile in cutlass.range(n_block_max - n_block_min_causal_local_mask, unroll=1): + n_block = n_block_max - 1 - n_tile + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) ) - ) - # Now that we no longer already have the 1st iteration, need mask_seqlen=True here - - # tSrScale_r2t_shape = thr_tmem_store_scale.partition_S(tScScale).shape - # tSrScale_r2t = cute.make_fragment(tSrScale_r2t_shape, Float32) - # tSrScale_r2t[0] = softmax.row_sum[0] - # cute.copy(thr_tmem_store_scale, tSrScale_r2t, tStScale_r2t) - # cute.arch.fence_view_async_tmem_store() - sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] - if const_expr(mLSE is not None or learnable_sink is not None): - sScale[tidx + stage * self.m_block_size + self.m_block_size * 2] = softmax.row_max[ - 0 - ] - # if tidx == 0: - # cute.printf("softmax row sum stage %d: %f, row_max = %f\n", stage, softmax.row_sum[0], softmax.row_max[0]) - cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) - # if tidx == 0: cute.printf("softmax row sum stage %d: %f\n", stage, softmax.row_sum[0]) + n_block_max = cutlass.min(n_block_max, n_block_min_causal_local_mask) + # The remaining iterations have no masking + n_block_min_before_local_mask = block_info.get_n_block_min_before_local_mask( + seqlen, m_block, n_block_min + ) + for n_tile in cutlass.range(n_block_max - n_block_min_before_local_mask, unroll=1): + n_block = n_block_max - n_tile - 1 + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = softmax_step( + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase, n_block + ) + # Separate iterations with local masking on the left + if const_expr(self.is_local and block_info.window_size_left is not None): + n_block_max = cutlass.min(n_block_max, n_block_min_before_local_mask) + for n_tile in cutlass.range(0, n_block_max - n_block_min, unroll=1): + n_block = n_block_max - 1 - n_tile + mma_si_consumer_phase, si_corr_producer_phase, s0_s1_sequence_phase = ( + softmax_step( + mma_si_consumer_phase, + si_corr_producer_phase, + s0_s1_sequence_phase, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + ) + # Now that we no longer already have the 1st iteration, need mask_seqlen=True here + + # Dense path always writes scale / signals + sScale[tidx + stage * self.m_block_size] = softmax.row_sum[0] + if const_expr(mLSE is not None or learnable_sink is not None): + sScale[ + tidx + stage * self.m_block_size + self.m_block_size * 2 + ] = softmax.row_max[0] + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_full_offset + stage) # # Write LSE to gmem # if const_expr(mLSE is not None): @@ -1826,6 +1938,7 @@ def correction_loop( num_splits: Int32, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) tScS = thr_mma_qk.partition_C(cute.make_identity_tensor(self.mma_tiler_qk[:2])) @@ -1862,7 +1975,14 @@ def correction_loop( # Default LSE to -inf for invalid split_idx tiles stats = [(0.0, -Float32.inf if const_expr(mLSE is not None or learnable_sink is not None) else None, True)] * self.q_stage - if const_expr(not self.is_split_kv) or n_block_min < n_block_max: + if const_expr(self.use_block_sparsity): + total_block_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block) + has_work = total_block_count > Int32(0) + else: + total_block_count = n_block_max - n_block_min + has_work = const_expr(not self.is_split_kv) or total_block_count > Int32(0) + + if has_work: # Ignore first signal from softmax as no correction is required cute.arch.mbarrier_wait( mbar_ptr + self.mbar_softmax_corr_full_offset + 0, softmax_corr_consumer_phase @@ -1874,7 +1994,7 @@ def correction_loop( softmax_corr_consumer_phase ^= 1 tSrScale_t2r = cute.make_fragment(tSrScale_t2r_shape, Float32) - for i in cutlass.range(n_block_max - n_block_min - 1, unroll=1): + for i in cutlass.range(total_block_count - 1, unroll=1): for stage in cutlass.range_constexpr(2): # wait for S0 / S1 cute.arch.mbarrier_wait( @@ -1969,6 +2089,44 @@ def correction_loop( o_corr_consumer_phase ^= 1 softmax_corr_consumer_phase ^= 1 corr_epi_producer_phase ^= 1 + else: + if const_expr(self.use_block_sparsity): + ( + softmax_corr_consumer_phase, + o_corr_consumer_phase, + corr_epi_producer_phase, + ) = handle_block_sparse_empty_tile_correction_sm100( + tidx, + self.q_stage, + self.m_block_size, + self.qhead_per_kvhead, + self.pack_gqa, + self.is_split_kv, + learnable_sink, + mLSE, + seqlen, + m_block, + head_idx, + batch_idx, + split_idx, + sScale, + stats, + self.correction_epilogue, + thr_mma_pv, + tOtOs, + sO, + mbar_ptr, + self.mbar_softmax_corr_full_offset, + self.mbar_softmax_corr_empty_offset, + self.mbar_P_full_O_rescaled_offset, + self.mbar_P_full_2_offset, + self.mbar_corr_epi_full_offset, + self.mbar_corr_epi_empty_offset, + softmax_corr_consumer_phase, + o_corr_consumer_phase, + corr_epi_producer_phase, + softmax_scale_log2, + ) if const_expr(mLSE is not None): if const_expr(not seqlen.has_cu_seqlens_q): @@ -2006,28 +2164,6 @@ def correction_loop( # This actually just works with PackGQA too gLSE[tidx] = lse - # gO_qdhb = cute.local_tile(mO, cute.select(self.mma_tiler_pv, mode=[0, 1]), (None, 0, None, None)) - # gO = gO_qdhb[None, None, None, head_idx, batch_idx] - # tOsO, tOgO = cpasync.tma_partition( - # tma_atom_O, - # 0, - # cute.make_layout(1), - # cute.group_modes(sO, 0, 2), - # cute.group_modes(gO, 0, 2), - # ) - # warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 - # stage = warp_idx_in_wg - # if stage < self.q_stage: - # # wait from corr, issue tma store on smem - # # 1. wait for O0 / O1 final - # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_corr_epi_full_offset + stage, corr_epi_producer_phase) - # # 2. copy O0 / O1 to gmem - # cute.copy(tma_atom_O, tOsO[None, stage], tOgO[None, self.q_stage * m_block + stage]) - # cute.arch.cp_async_bulk_commit_group() - # # Ensure O0 / O1 buffer is ready to be released - # cute.arch.cp_async_bulk_wait_group(0, read=True) - # cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_corr_epi_empty_offset + stage) - # Advance to next tile tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index fb36bfd492b..db7930de537 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -259,11 +259,25 @@ def _flash_attn_fwd( if page_table is not None else None ) + compute_capability = ( + torch.cuda.get_device_capability()[0] + if _compute_capability is None + else _compute_capability + ) + + assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" + + sparse_tensors = None if block_sparse_tensors is not None: if seqlen_q is None: raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).") - expected_m_blocks = (seqlen_q + m_block_size - 1) // m_block_size + m_block_size_block = m_block_size + if compute_capability == 10: + # TODO: This multiplier should really be q_stage, wire up in later PR + # 1 cta handles 2*tile_m row + m_block_size_block = 2 * m_block_size + expected_m_blocks = (seqlen_q + m_block_size_block - 1) // m_block_size_block expected_n_blocks = (seqlen_k + n_block_size - 1) // n_block_size block_sparse_tensors = normalize_block_sparse_tensors( block_sparse_tensors, @@ -286,12 +300,6 @@ def _flash_attn_fwd( else: causal, local = False, False - compute_capability = ( - torch.cuda.get_device_capability()[0] - if _compute_capability is None - else _compute_capability - ) - assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) if compute_capability == 9: # TODO: tune block size according to hdim. @@ -383,6 +391,10 @@ def _flash_attn_fwd( raise NotImplementedError( "Block sparsity is not yet supported with pack_gqa=True. This will be fixed in a future PR." ) + if is_split_kv: + raise NotImplementedError( + "Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split." + ) cute_aux_tensors = None if aux_tensors is not None: @@ -415,7 +427,6 @@ def _flash_attn_fwd( compute_capability, page_size not in [None, 128], # paged KV non-TMA ) - if compile_key not in _flash_attn_fwd.compile_cache: if compute_capability == 9: assert page_table is None, "paged KV not supported on SM 9.0" @@ -442,8 +453,6 @@ def _flash_attn_fwd( has_aux_tensors=aux_tensors is not None, ) elif compute_capability == 10: - if sparse_tensors is not None: - raise NotImplementedError("BlockSparsity not yet supported on SM 10.0") fa_fwd = FlashAttentionForwardSm100( head_dim, head_dim_v, @@ -452,12 +461,15 @@ def _flash_attn_fwd( is_local=local, is_split_kv=is_split_kv, pack_gqa=pack_gqa, + m_block_size=m_block_size, + n_block_size=n_block_size, is_persistent=not causal and not local and cu_seqlens_q is None and seqused_q is None and not is_split_kv, score_mod=score_mod, + mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None, paged_kv_non_tma=page_size not in [None, 128], ) diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index aa18566cb23..c5e0a7fe2bf 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -298,6 +298,10 @@ def apply_mask_sm100( mask_seqlen: cutlass.Constexpr[bool], mask_causal: cutlass.Constexpr[bool], mask_local: cutlass.Constexpr[bool] = False, + mask_mod: cutlass.Constexpr[Optional[Callable]] = None, + batch_idx: Int32 = None, + head_idx: Int32 = None, + aux_tensors: Optional[list] = None, ) -> None: assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" acc_shape = (self.tile_m, self.tile_n) @@ -311,7 +315,7 @@ def apply_mask_sm100( n_block = 0 seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n r2p = True - if const_expr(not mask_causal and not mask_local): + if const_expr(not mask_causal and not mask_local and mask_mod is None): if const_expr(mask_seqlen): if const_expr(not r2p): for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): @@ -321,6 +325,36 @@ def apply_mask_sm100( acc_S[i] = -Float32.inf if tScS_t2r[i][1] >= seqlenk_col_limit else acc_S[i] else: mask_r2p(acc_S, seqlenk_col_limit, arch=100, rank1=True) + + elif const_expr(not mask_causal and not mask_local and mask_mod is not None): + # Block sparse case w/ mask_mod + batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) + head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32) + row_coord_first = tScS_t2r[0][0] + global_row = row_coord_first + m_block * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa != 1): + mask_row = global_row // self.qhead_per_kvhead_packgqa + else: + mask_row = global_row + mask_row_ssa = utils.scalar_to_ssa(mask_row, cutlass.Int32) + + ncol = const_expr(cute.size(tScS_t2r.shape)) + for i in cutlass.range_constexpr(ncol): + col_coord = tScS_t2r[i][1] if not self.swap_AB else tScS_t2r[i][0] + global_col = col_coord + n_block * self.tile_n + mask_value = mask_mod( + batch_idx_ssa, + head_idx_ssa, + mask_row_ssa, + utils.scalar_to_ssa(global_col, cutlass.Int32), + aux_tensors, + ) + cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) + acc_S[i] = acc_S[i] if cond else -Float32.inf + if const_expr(mask_seqlen): + out_of_bounds = (global_row >= self.seqlen_q) or (global_col >= self.seqlen_k) + acc_S[i] = -Float32.inf if out_of_bounds else acc_S[i] + else: # Causal or local causal_row_offset = 1 + self.seqlen_k - n_block * self.tile_n - self.seqlen_q row_idx = tScS_t2r[0][0] + m_block * self.tile_m diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 07e63e2bc7f..4c68fad0eba 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -28,8 +28,20 @@ random_doc_id_tensor, ) from flash_attn.cute.testing import attention_ref +COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] +@pytest.fixture(autouse=True) +def reset_torch_state(): + """Reset torch dynamo/compile state between tests to avoid state pollution.""" + torch._dynamo.reset() + torch.cuda.empty_cache() + + yield + + torch._dynamo.reset() + torch.cuda.empty_cache() + def create_tensors( batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim_v, dtype ): @@ -142,6 +154,7 @@ def compute_reference_flex_attn(tensors, mask_mod_flex, block_size: Optional[tup (256, 256), (113, 203), (1024, 1024), + (128, 8192) ] @@ -208,6 +221,11 @@ def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): ) # Compute block sparsity for mask_mod + if COMPUTE_CAPABILITY == 10: + sparse_tile_m = 2 * tile_m + else: + sparse_tile_m = tile_m + bm = create_block_mask( mask_mod_flex, batch_size, @@ -215,7 +233,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): seqlen_q, seqlen_k, device="cuda", - BLOCK_SIZE=(tile_m, tile_n), + BLOCK_SIZE=(sparse_tile_m, tile_n), ) _, _, mask_cnt, mask_idx, full_cnt, full_idx, *_ = bm.as_tuple() @@ -348,6 +366,9 @@ def test_static_masks( - block_diagonal: Masks by 64-element diagonal blocks - mini_causal: Local causal within 128-element tiles """ + if COMPUTE_CAPABILITY == 10 and (tile_m, tile_n) != (128, 128): + pytest.skip("TODO: Non-128x128 tiles currently not supported on SM 10.0. due to TMEM") + _run_mask_test( seqlen_q=seqlen_q, seqlen_k=seqlen_k, @@ -393,6 +414,9 @@ def test_parameterized_masks( - sliding_window: Requires window size and offset parameters - document: Slower to check """ + if COMPUTE_CAPABILITY == 10 and (tile_m, tile_n) != (128, 128): + pytest.skip("TODO: Non-128x128 tiles currently not supported on SM 10.0. due to TMEM") + _run_mask_test( seqlen_q=seqlen_q, seqlen_k=seqlen_k, @@ -409,5 +433,50 @@ def test_parameterized_masks( ) +def test_sm100_block_sparse_sink_all_masked(): + """Block-sparse regression for the sink path""" + if torch.cuda.get_device_capability()[0] != 10: + pytest.skip("SM100-only test") + device = "cuda" + dtype = torch.bfloat16 + batch_size = 1 + seqlen_q = 256 + seqlen_k = 128 + nheads = 8 + headdim = 128 + q = torch.randn(batch_size, seqlen_q, nheads, headdim, dtype=dtype, device=device) + k = torch.randn(batch_size, seqlen_k, nheads, headdim, dtype=dtype, device=device) + v = torch.randn(batch_size, seqlen_k, nheads, headdim, dtype=dtype, device=device) + learnable_sink = torch.full((nheads,), 0.5, dtype=torch.bfloat16, device=device) + zero_cnt = torch.zeros((batch_size, nheads, 1), dtype=torch.int32, device=device) + zero_idx = torch.zeros((batch_size, nheads, 1, 1), dtype=torch.int32, device=device) + sparse = BlockSparseTensorsTorch( + mask_block_cnt=zero_cnt, + mask_block_idx=zero_idx, + full_block_cnt=zero_cnt, + full_block_idx=zero_idx, + ) + softmax_scale = 1.0 / math.sqrt(headdim) + _, lse = _flash_attn_fwd( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + causal=False, + window_size_left=None, + window_size_right=None, + learnable_sink=learnable_sink, + m_block_size=128, + n_block_size=128, + num_threads=384, + pack_gqa=False, + block_sparse_tensors=sparse, + return_lse=True, + ) + # Fully masked tile ⇒ probability mass sits entirely on the sink, so LSE equals sink logit. + expected = learnable_sink.float()[None, :, None].expand_as(lse) + assert torch.allclose(lse, expected, atol=0.0, rtol=0.0) + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"])