diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index e814d6aa458..bc8d2e79049 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -814,3 +814,254 @@ def softmax_block_sparse_sm100( s0_s1_sequence_phase, total_block_cnt == 0, ) + + +# ============================================================================= +# Backward-specific block-sparse helpers (SM100) +# ============================================================================= +# +# In backward, iteration is transposed compared to forward: +# - Forward: outer loop over m_blocks (Q tiles), inner loop over n_blocks (KV tiles) +# - Backward: outer loop over n_blocks (KV tiles), inner loop over m_blocks (Q tiles) +# +# The backward block-sparse tensors use "Q direction" indexing: +# - q_block_cnt[batch, head, n_block] → count of m_blocks to process for this KV tile +# - q_block_idx[batch, head, n_block, :] → indices of m_blocks to process +# + + +@cute.jit +def get_total_q_block_count_bwd( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + n_block, + subtile_factor: cutlass.Constexpr = 1, + m_block_max: int = 0, +): + """Count total tile iterations for given n_block (KV tile) in backward. + + Args: + m_block_max: Maximum m_block index from causal/local masking constraints. + Computed by block_info.get_m_block_min_max() based on sequence lengths + and attention mask type. When > 0, caps the result to ensure we don't + count sparse blocks that fall outside the valid causal/local window. + + Returns min(sparse_block_count * subtile_factor, m_block_max) when m_block_max > 0. + """ + q_block_cnt, _, full_q_block_cnt, _ = blocksparse_tensors + total = q_block_cnt[batch_idx, head_idx, n_block] + if const_expr(full_q_block_cnt is not None): + total = total + full_q_block_cnt[batch_idx, head_idx, n_block] + result = total * subtile_factor + if m_block_max > 0: + result = cutlass.min(result, m_block_max) + return result + + +@cute.jit +def produce_block_sparse_q_loads_bwd_sm100( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + n_block, + # Pipeline states (will be returned after advancing) + producer_state_Q_LSE, + producer_state_dO_dPsum, + # Pipelines + pipeline_Q, + pipeline_LSE, + pipeline_dO, + pipeline_dPsum, + # Load functions + load_K, + load_V, + load_Q, + load_dO, + copy_stats, + # Global tensors for LSE/dPsum + gLSE, + sLSE, + gdPsum, + sdPsum, + # TMA copy bytes for extra_tx_count + tma_copy_bytes_K, + tma_copy_bytes_V, + # Flags for which loads to perform + should_load_Q: cutlass.Constexpr, + should_load_dO: cutlass.Constexpr, + # Subtiling factor and bounds + subtile_factor: cutlass.Constexpr = 1, + m_block_max: int = 0, +): + """SM100 backward block sparse loading with subtiling. + + Returns updated (producer_state_Q_LSE, producer_state_dO_dPsum). + First iteration loads K/V alongside Q/dO; subsequent iterations load only Q/dO. + """ + ( + curr_q_cnt, + curr_q_idx, + curr_full_cnt, + curr_full_idx, + loop_count, + ) = get_block_sparse_iteration_info_bwd( + blocksparse_tensors, batch_idx, head_idx, n_block, subtile_factor, m_block_max + ) + + for iter_idx in cutlass.range(loop_count, unroll=1): + m_block, _ = get_m_block_from_iter_bwd( + iter_idx, + curr_q_cnt, + curr_q_idx, + curr_full_cnt, + curr_full_idx, + subtile_factor, + ) + + if iter_idx == 0: + # First block: load K/V alongside Q/dO + if const_expr(should_load_Q): + pipeline_Q.producer_acquire(producer_state_Q_LSE, extra_tx_count=tma_copy_bytes_K) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) + load_Q(m_block, producer_state=producer_state_Q_LSE) + pipeline_Q.producer_commit(producer_state_Q_LSE) + pipeline_LSE.producer_acquire(producer_state_Q_LSE) + with cute.arch.elect_one(): + copy_stats( + gLSE[None, m_block], + sLSE[None, producer_state_Q_LSE.index], + mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), + ) + producer_state_Q_LSE.advance() + if const_expr(should_load_dO): + pipeline_dO.producer_acquire( + producer_state_dO_dPsum, extra_tx_count=tma_copy_bytes_V + ) + load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum)) + load_dO(m_block, producer_state=producer_state_dO_dPsum) + pipeline_dO.producer_commit(producer_state_dO_dPsum) + pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) + with cute.arch.elect_one(): + copy_stats( + gdPsum[None, m_block], + sdPsum[None, producer_state_dO_dPsum.index], + mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), + ) + producer_state_dO_dPsum.advance() + else: + # Subsequent blocks: just load Q/dO (K/V already loaded) + if const_expr(should_load_Q): + pipeline_Q.producer_acquire(producer_state_Q_LSE) + load_Q(m_block, producer_state=producer_state_Q_LSE) + pipeline_Q.producer_commit(producer_state_Q_LSE) + pipeline_LSE.producer_acquire(producer_state_Q_LSE) + with cute.arch.elect_one(): + copy_stats( + gLSE[None, m_block], + sLSE[None, producer_state_Q_LSE.index], + mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), + ) + producer_state_Q_LSE.advance() + if const_expr(should_load_dO): + pipeline_dO.producer_acquire(producer_state_dO_dPsum) + load_dO(m_block, producer_state=producer_state_dO_dPsum) + pipeline_dO.producer_commit(producer_state_dO_dPsum) + pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) + with cute.arch.elect_one(): + copy_stats( + gdPsum[None, m_block], + sdPsum[None, producer_state_dO_dPsum.index], + mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), + ) + producer_state_dO_dPsum.advance() + + return producer_state_Q_LSE, producer_state_dO_dPsum + + +@cute.jit +def get_block_sparse_iteration_info_bwd( + blocksparse_tensors: BlockSparseTensors, + batch_idx, + head_idx, + n_block, + subtile_factor: cutlass.Constexpr = 1, + m_block_max: int = 0, +): + """Extract block-sparse iteration info for backward pass. + + Args: + m_block_max: Maximum m_block index from causal/local masking constraints. + Computed by block_info.get_m_block_min_max() based on sequence lengths + and attention mask type. When > 0, caps total_count to ensure we don't + process sparse blocks that fall outside the valid causal/local window. + This combines block sparsity with causal/local masking. + + Returns (curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count). + """ + q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] + curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] + + if const_expr(full_cnt is not None): + curr_full_cnt = full_cnt[batch_idx, head_idx, n_block] + curr_full_idx = full_idx[batch_idx, head_idx, n_block, None] + else: + curr_full_cnt = Int32(0) + curr_full_idx = None + + sparse_block_count = curr_q_cnt + if const_expr(full_cnt is not None): + sparse_block_count = sparse_block_count + curr_full_cnt + + total_count = sparse_block_count * subtile_factor + if m_block_max > 0: + total_count = cutlass.min(total_count, m_block_max) + + return curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count + + +@cute.jit +def get_m_block_from_iter_bwd( + iter_idx, + curr_q_cnt, + curr_q_idx: cute.Tensor, + curr_full_cnt, + curr_full_idx: Optional[cute.Tensor], + subtile_factor: cutlass.Constexpr = 1, +): + """Derive m_block index and is_full_block flag from iteration index. + + In backward, we iterate in FORWARD order: masked blocks first (low to high), + then full blocks (low to high). This ensures that when loop_count is capped + to m_block_max, we skip the high (potentially out-of-bounds) m_blocks at the + end of iteration rather than in the middle. + + With subtiling (subtile_factor > 1): + - sparse_iter_idx = iter_idx // subtile_factor (which sparse block) + - subtile_offset = iter_idx % subtile_factor (which subtile within sparse block) + - m_block = sparse_m_block * subtile_factor + subtile_offset + + Returns (m_block, is_full_block): + - m_block: The actual Q-tile block index (after subtiling) + - is_full_block: True if this is a full block (no mask_mod needed) + Note: All subtiles within a sparse block share the same is_full_block status + """ + sparse_iter_idx = iter_idx // subtile_factor + subtile_offset = iter_idx % subtile_factor + + sparse_m_block = Int32(0) + is_full_block = False + + # Forward order: process low sparse block indices first + if sparse_iter_idx < curr_q_cnt: + sparse_m_block = curr_q_idx[sparse_iter_idx] + is_full_block = False + else: + full_iter = sparse_iter_idx - curr_q_cnt + sparse_m_block = curr_full_idx[full_iter] + is_full_block = True + + m_block = sparse_m_block * subtile_factor + subtile_offset + + return m_block, is_full_block diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index 48cd3a9010a..d90548f2e1b 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -102,6 +102,29 @@ def get_block_sparse_expected_shapes( return expected_count_shape, expected_index_shape +def get_block_sparse_expected_shapes_bwd( + batch_size: int, + num_head: int, + seqlen_q: int, + seqlen_k: int, + m_block_size: int, + n_block_size: int, + subtile_factor: int, +) -> Tuple[Tuple[int, int, int], Tuple[int, int, int, int]]: + """Return (expected_count_shape, expected_index_shape) for backward block sparse normalization. + + Backward uses Q-direction indexing (transposed from forward), where shapes are + indexed by N-blocks first, then M-blocks. The sparse_block_size_q is determined + by subtile_factor * m_block_size. + """ + sparse_block_size_q = subtile_factor * m_block_size + expected_m_blocks = ceildiv(seqlen_q, sparse_block_size_q) + expected_n_blocks = ceildiv(seqlen_k, n_block_size) + expected_count_shape = (batch_size, num_head, expected_n_blocks) + expected_index_shape = (batch_size, num_head, expected_n_blocks, expected_m_blocks) + return expected_count_shape, expected_index_shape + + def normalize_block_sparse_tensors( tensors: BlockSparseTensorsTorch, *, diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 4f7640c5bad..f7044f2958c 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -31,6 +31,13 @@ from flash_attn.cute import barrier from flash_attn.cute.named_barrier import NamedBarrierBwdSm100 from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( + get_total_q_block_count_bwd, + get_block_sparse_iteration_info_bwd, + get_m_block_from_iter_bwd, + produce_block_sparse_q_loads_bwd_sm100, +) class FlashAttentionBackwardSm100: @@ -50,7 +57,9 @@ def __init__( cluster_size: int = 1, score_mod: cutlass.Constexpr | None = None, score_mod_bwd: cutlass.Constexpr | None = None, + mask_mod: cutlass.Constexpr | None = None, has_aux_tensors: cutlass.Constexpr = False, + subtile_factor: cutlass.Constexpr[int] = 1, ): # padding head_dim to a multiple of 16 as k_block_size hdim_multiple_of = 16 @@ -93,10 +102,12 @@ def __init__( self.use_tma_store = True self.deterministic = deterministic - # Score mod support + # Score mod and mask mod support self.score_mod = score_mod self.score_mod_bwd = score_mod_bwd + self.mask_mod = mask_mod self.has_aux_tensors = has_aux_tensors + self.subtile_factor = subtile_factor # For score_mod, use vec_size=1 (like forward) to handle per-element indices if cutlass.const_expr(has_aux_tensors): self.vec_size: cutlass.Constexpr = 1 @@ -377,6 +388,8 @@ def __call__( mdK_semaphore: Optional[cute.Tensor] = None, mdV_semaphore: Optional[cute.Tensor] = None, aux_tensors: Optional[list] = None, + # Block-sparse tensors (Q direction - for iterating m_blocks per n_block): + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), ( "Variable sequence length is not supported yet in FlashAttentionBackwardSm100" @@ -703,6 +716,7 @@ class SharedStorage: seqlen_q_divmod = FastDivmodDivisor(seqlen_q) seqlen_k_divmod = FastDivmodDivisor(seqlen_k) fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod) + self.use_block_sparsity = cutlass.const_expr(blocksparse_tensors is not None) self.kernel( tma_tensor_Q, @@ -753,6 +767,7 @@ class SharedStorage: tile_sched_params, aux_tensors, fastdiv_mods, + blocksparse_tensors, ).launch( grid=grid_dim, block=[self.threads_per_cta, 1, 1], @@ -813,6 +828,7 @@ def kernel( tile_sched_params: ParamsBase, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) @@ -1097,6 +1113,7 @@ def kernel( block_info, SeqlenInfoCls, TileSchedulerCls, + blocksparse_tensors, should_load_Q=True, should_load_dO=True, ) @@ -1143,6 +1160,7 @@ def kernel( block_info, SeqlenInfoCls, TileSchedulerCls, + blocksparse_tensors, ) cute.arch.relinquish_tmem_alloc_permit() tmem_ptr = cute.arch.retrieve_tmem_ptr( @@ -1194,6 +1212,7 @@ def kernel( mdV_semaphore, aux_tensors, fastdiv_mods, + blocksparse_tensors, ) cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) @@ -1211,6 +1230,7 @@ def kernel( SeqlenInfoCls, TileSchedulerCls, mdQ_semaphore, + blocksparse_tensors, ) return @@ -1245,6 +1265,7 @@ def load( block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors] = None, should_load_Q: bool = True, should_load_dO: bool = True, ): @@ -1330,68 +1351,83 @@ def load( # gdPsum = cute.logical_divide(gdPsum, (64,))[(None, block_in_cluster_coord_vmnk[1]), None] # copy_stats = partial(cute.copy, copy_atom_stats, mcast_mask=q_do_mcast_mask) - if const_expr(not self.is_local) or m_block_min < m_block_max: - # First iteration: load K together w Q & LSE, then V together w dO & dPsum - if const_expr(should_load_Q): - # K & Q - pipeline_Q.producer_acquire( - producer_state_Q_LSE, extra_tx_count=self.tma_copy_bytes["K"] - ) - load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) - load_Q(m_block_min, producer_state=producer_state_Q_LSE) - pipeline_Q.producer_commit(producer_state_Q_LSE) - # LSE - pipeline_LSE.producer_acquire(producer_state_Q_LSE) - with cute.arch.elect_one(): - copy_stats( - gLSE[None, m_block_min], - sLSE[None, producer_state_Q_LSE.index], - mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), + # some tiles might be empty due to block sparsity + if const_expr(self.use_block_sparsity): + total_m_block_cnt = get_total_q_block_count_bwd( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + ) + process_tile = total_m_block_cnt > Int32(0) + else: + process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + + if process_tile: + if const_expr(self.use_block_sparsity): + producer_state_Q_LSE, producer_state_dO_dPsum = ( + produce_block_sparse_q_loads_bwd_sm100( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + producer_state_Q_LSE, + producer_state_dO_dPsum, + pipeline_Q, + pipeline_LSE, + pipeline_dO, + pipeline_dPsum, + load_K, + load_V, + load_Q, + load_dO, + copy_stats, + gLSE, + sLSE, + gdPsum, + sdPsum, + self.tma_copy_bytes["K"], + self.tma_copy_bytes["V"], + should_load_Q=should_load_Q, + should_load_dO=should_load_dO, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, ) - producer_state_Q_LSE.advance() - if const_expr(should_load_dO): - # V & dO - pipeline_dO.producer_acquire( - producer_state_dO_dPsum, extra_tx_count=self.tma_copy_bytes["V"] ) - load_V(tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum)) - load_dO(m_block_min, producer_state=producer_state_dO_dPsum) - pipeline_dO.producer_commit(producer_state_dO_dPsum) - # dPsum - pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) - with cute.arch.elect_one(): - copy_stats( - gdPsum[None, m_block_min], - sdPsum[None, producer_state_dO_dPsum.index], - mbar_ptr=pipeline_dPsum.producer_get_barrier(producer_state_dO_dPsum), - ) - producer_state_dO_dPsum.advance() + else: + first_m_block = m_block_min - for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + # First iteration: load K together w Q & LSE, then V together w dO & dPsum if const_expr(should_load_Q): - # Q - pipeline_Q.producer_acquire(producer_state_Q_LSE) - load_Q(m_block, producer_state=producer_state_Q_LSE) + pipeline_Q.producer_acquire( + producer_state_Q_LSE, extra_tx_count=self.tma_copy_bytes["K"] + ) + load_K(tma_bar_ptr=pipeline_Q.producer_get_barrier(producer_state_Q_LSE)) + load_Q(first_m_block, producer_state=producer_state_Q_LSE) pipeline_Q.producer_commit(producer_state_Q_LSE) - # LSE pipeline_LSE.producer_acquire(producer_state_Q_LSE) with cute.arch.elect_one(): copy_stats( - gLSE[None, m_block], + gLSE[None, first_m_block], sLSE[None, producer_state_Q_LSE.index], mbar_ptr=pipeline_LSE.producer_get_barrier(producer_state_Q_LSE), ) producer_state_Q_LSE.advance() if const_expr(should_load_dO): - # dO - pipeline_dO.producer_acquire(producer_state_dO_dPsum) - load_dO(m_block, producer_state=producer_state_dO_dPsum) + pipeline_dO.producer_acquire( + producer_state_dO_dPsum, extra_tx_count=self.tma_copy_bytes["V"] + ) + load_V( + tma_bar_ptr=pipeline_dO.producer_get_barrier(producer_state_dO_dPsum) + ) + load_dO(first_m_block, producer_state=producer_state_dO_dPsum) pipeline_dO.producer_commit(producer_state_dO_dPsum) - # dPsum pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) with cute.arch.elect_one(): copy_stats( - gdPsum[None, m_block], + gdPsum[None, first_m_block], sdPsum[None, producer_state_dO_dPsum.index], mbar_ptr=pipeline_dPsum.producer_get_barrier( producer_state_dO_dPsum @@ -1399,6 +1435,37 @@ def load( ) producer_state_dO_dPsum.advance() + # Dense path: iterate from m_block_min+1 to m_block_max + for m_block in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + if const_expr(should_load_Q): + pipeline_Q.producer_acquire(producer_state_Q_LSE) + load_Q(m_block, producer_state=producer_state_Q_LSE) + pipeline_Q.producer_commit(producer_state_Q_LSE) + pipeline_LSE.producer_acquire(producer_state_Q_LSE) + with cute.arch.elect_one(): + copy_stats( + gLSE[None, m_block], + sLSE[None, producer_state_Q_LSE.index], + mbar_ptr=pipeline_LSE.producer_get_barrier( + producer_state_Q_LSE + ), + ) + producer_state_Q_LSE.advance() + if const_expr(should_load_dO): + pipeline_dO.producer_acquire(producer_state_dO_dPsum) + load_dO(m_block, producer_state=producer_state_dO_dPsum) + pipeline_dO.producer_commit(producer_state_dO_dPsum) + pipeline_dPsum.producer_acquire(producer_state_dO_dPsum) + with cute.arch.elect_one(): + copy_stats( + gdPsum[None, m_block], + sdPsum[None, producer_state_dO_dPsum.index], + mbar_ptr=pipeline_dPsum.producer_get_barrier( + producer_state_dO_dPsum + ), + ) + producer_state_dO_dPsum.advance() + if const_expr(should_load_Q): pipeline_Q.producer_tail( producer_state_Q_LSE.clone() @@ -1446,6 +1513,7 @@ def mma( block_info: BlockInfo, SeqlenInfoCls: Callable, TileSchedulerCls: Callable, + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): # [2025-10-21] For reasons I don't understand, putting these partitioning in the main # kernel (before warp specialization) is a lot slower tha putting them here. @@ -1535,7 +1603,22 @@ def mma( m_block_min, m_block_max = block_info.get_m_block_min_max( seqlen, n_block // self.cluster_shape_mnk[0] ) - if const_expr(not self.is_local) or m_block_min < m_block_max: + + if const_expr(self.use_block_sparsity): + block_iter_count = get_total_q_block_count_bwd( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + ) + process_tile = block_iter_count > Int32(0) + else: + block_iter_count = m_block_max - m_block_min + process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + + if process_tile: accumulate_dK = False # ----------------------------------------------------------- ###### Prologue @@ -1575,7 +1658,14 @@ def mma( # 4. dP = V @ dO.T # 5. dV = P.T @ dO - for _ in cutlass.range(m_block_min + 1, m_block_max, unroll=1): + # For block sparsity, we use block_iter_count; for dense, use m_block range + # MMA doesn't need actual m_block indices, just the iteration count + main_loop_iters = ( + block_iter_count - 1 + if const_expr(self.use_block_sparsity) + else m_block_max - m_block_min - 1 + ) + for _ in cutlass.range(main_loop_iters, unroll=1): # 1) S = K @ Q_i handle_Q_next = pipeline_Q_consumer.wait_and_advance() # Don't need to wait for S, as P must have been ready ealier, i.e., S is ready @@ -1820,6 +1910,7 @@ def compute_loop( mdV_semaphore: Optional[cute.Tensor], aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): sLSE_2D = cute.make_tensor( sLSE.iterator, @@ -1936,13 +2027,53 @@ def compute_loop( mask_seqlen=True, mask_causal=self.is_causal, mask_local=self.is_local, + mask_mod=self.mask_mod, + batch_idx=batch_idx, + head_idx=head_idx, + aux_tensors=aux_tensors, + fastdiv_mods=fastdiv_mods, ) # prefetch_LSE = not self.is_causal prefetch_LSE = False + # some tiles might be empty due to block sparsity + if const_expr(self.use_block_sparsity): + ( + curr_q_cnt, + curr_q_idx, + curr_full_cnt, + curr_full_idx, + loop_count, + ) = get_block_sparse_iteration_info_bwd( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + ) + process_tile = loop_count > Int32(0) + else: + process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + loop_count = m_block_max - m_block_min + # Mainloop - for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): + # Block sparsity: iterate over sparse m_block count and derive actual m_block + # from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly. + for iter_idx in cutlass.range(loop_count, unroll=1): + if const_expr(self.use_block_sparsity): + m_block, is_full_block = get_m_block_from_iter_bwd( + iter_idx, + curr_q_cnt, + curr_q_idx, + curr_full_cnt, + curr_full_idx, + subtile_factor=self.subtile_factor, + ) + else: + m_block = m_block_min + iter_idx + is_full_block = False # Prefetch 1 stage of LSE pipeline_LSE.consumer_wait(consumer_state_LSE) tSrLSE_s2r = cute.make_fragment(tScS_t2r[None, 0, 0, 0].shape, Float32) @@ -1956,14 +2087,11 @@ def compute_loop( cute.copy(thr_copy_t2r, tStS_t2r, tSrS_t2r) if const_expr(self.score_mod is not None): - # Preserve unscaled S for backward score_mod BEFORE masking + # Preserve unscaled S for backward score_mod BEFORE any modification tSrS_pre = cute.make_fragment_like(tSrS_t2r) cute.autovec_copy(tSrS_t2r, tSrS_pre) - #### APPLY MASK - mask_fn(tSrS_t2r, m_block=m_block) - - if const_expr(self.score_mod is not None): + # Apply score_mod FIRST -> matches forward self.apply_score_mod( tSrS_t2r, thr_copy_t2r, @@ -1978,6 +2106,15 @@ def compute_loop( fastdiv_mods, ) + #### APPLY MASK (after score_mod, matching forward pass order) + check_m_boundary = (m_block + 1) * self.tile_m > seqlen.seqlen_q + mask_fn( + tSrS_t2r, + m_block=m_block, + is_full_block=is_full_block, + check_m_boundary=check_m_boundary, + ) + num_stages = cute.size(tScS_t2r, mode=[1]) # --------------------------------------------- @@ -2123,7 +2260,8 @@ def compute_loop( producer_state_dS.advance() # Epilogue - if const_expr(not self.is_local) or m_block_min < m_block_max: + # Run epilogue if we processed any m_blocks for this n_block + if process_tile: if const_expr(not self.use_tma_store): consumer_state_dKV = self.epilogue_dKV( dp_idx, @@ -2179,10 +2317,18 @@ def compute_loop( int(NamedBarrierBwdSm100.EpilogueWG1), # barrier_id mdK_semaphore, ) - if const_expr(self.qhead_per_kvhead == 1 and self.is_local): - if m_block_min >= m_block_max: - # if tidx == 0: - # cute.printf("m_block_min = {}, m_block_max = {}", m_block_min, m_block_max) + # Zero dK/dV for empty tiles (local attention or block sparsity) + # When total_m_block_cnt == 0 for block sparsity, no Q tiles contribute to this KV tile + if const_expr(self.qhead_per_kvhead == 1): + should_zero_dKV = False + if const_expr(self.is_local): + should_zero_dKV = m_block_min >= m_block_max + if const_expr(self.use_block_sparsity): + # For block sparsity, zero when no m_blocks contribute to this n_block + if not process_tile: + should_zero_dKV = True + + if should_zero_dKV: # like other epis, currently assumes hdim == hdimv gmem_tiled_copy_zero_dKV = copy_utils.tiled_copy_2d( self.dk_dtype, @@ -2228,6 +2374,7 @@ def dQacc_reduce( SeqlenInfoCls: Callable, TileSchedulerCls: Callable, mdQ_semaphore: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors] = None, ): num_reduce_threads = cute.arch.WARP_SIZE * len(self.reduce_warp_ids) tidx = cute.arch.thread_idx()[0] % num_reduce_threads @@ -2279,7 +2426,42 @@ def dQacc_reduce( delay_semaphore_release = self.is_causal n_block_global_max = cute.ceil_div(seqlen.seqlen_k, self.tile_n) - for m_block in cutlass.range(m_block_min, m_block_max, unroll=1): + # some tiles might be empty due to block sparsity + if const_expr(self.use_block_sparsity): + ( + curr_q_cnt, + curr_q_idx, + curr_full_cnt, + curr_full_idx, + loop_count, + ) = get_block_sparse_iteration_info_bwd( + blocksparse_tensors, + batch_idx, + head_idx, + n_block, + subtile_factor=self.subtile_factor, + m_block_max=m_block_max, + ) + process_tile = loop_count > Int32(0) + else: + process_tile = const_expr(not self.is_local) or m_block_min < m_block_max + loop_count = m_block_max - m_block_min + + # dQacc_reduce mainloop + # Block sparsity: iterate over sparse m_block count and derive actual m_block + # from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly. + for iter_idx in cutlass.range(loop_count, unroll=1): + if const_expr(self.use_block_sparsity): + m_block, _ = get_m_block_from_iter_bwd( + iter_idx, + curr_q_cnt, + curr_q_idx, + curr_full_cnt, + curr_full_idx, + subtile_factor=self.subtile_factor, + ) + else: + m_block = m_block_min + iter_idx pipeline_dQ.consumer_wait(dQ_consumer_state) # TMEM -> RMEM tdQrdQ_t2r = cute.make_fragment(tdQrdQ_t2r_shape, Float32) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 383d317038c..103eb55f5a0 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -52,6 +52,7 @@ def _get_device_capability(): to_cute_block_sparse_tensors, normalize_block_sparse_tensors, get_block_sparse_expected_shapes, + get_block_sparse_expected_shapes_bwd, ) def maybe_contiguous(x): @@ -575,7 +576,9 @@ def _flash_attn_bwd( dv: Optional[torch.Tensor] = None, score_mod: Optional[Callable] = None, score_mod_bwd: Optional[Callable] = None, + mask_mod: Optional[Callable] = None, aux_tensors: Optional[list[torch.Tensor]] = None, + block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: compute_capability = _get_device_capability() assert compute_capability in [9, 10], "Unsupported compute capability. Supported: 9.x, 10.x" @@ -637,6 +640,8 @@ def _flash_attn_bwd( else: causal, local = False, True + use_block_sparsity = block_sparse_tensors is not None + if cu_seqlens_k is None: assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) @@ -699,6 +704,9 @@ def _flash_attn_bwd( device = q.device out_torch_dtype = q.dtype + # nb: this could be derived from the block_sparse_tensors but for now we hardcode it to 2 + subtile_factor = 2 + if dq is None: dq = torch.empty_like(q) else: @@ -869,6 +877,7 @@ def _flash_attn_bwd( # Hash callables for compile key score_mod_hash = utils.hash_callable(score_mod) if score_mod else False score_mod_bwd_hash = utils.hash_callable(score_mod_bwd) if score_mod_bwd else False + mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod else False num_aux_tensors = len(aux_tensors) if aux_tensors else 0 # Convert aux_tensors to cute tensors cute_aux_tensors = None @@ -892,7 +901,9 @@ def _flash_attn_bwd( deterministic, score_mod_hash, score_mod_bwd_hash, + mask_mod_hash, num_aux_tensors, + use_block_sparsity, ) num_threads = 384 if compile_key not in _flash_attn_bwd.compile_cache: @@ -970,8 +981,26 @@ def _flash_attn_bwd( deterministic=deterministic, score_mod=score_mod, score_mod_bwd=score_mod_bwd, + mask_mod=mask_mod, has_aux_tensors=aux_tensors is not None and len(aux_tensors) > 0, + subtile_factor=subtile_factor, ) + + # Block sparse tensors for backward use Q-direction indexing (transposed from forward). + # sparse_block_size_q = 2*tile_m matches forward's q_stage=2 pipelining. + sparse_tensors_compile = None + if block_sparse_tensors is not None and compute_capability == 10: + expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd( + batch_size, num_head, seqlen_q, seqlen_k, + m_block_size, n_block_size, subtile_factor, + ) + compile_time_normalized = normalize_block_sparse_tensors( + block_sparse_tensors, + expected_count_shape=expected_count_shape, + expected_index_shape=expected_index_shape, + ) + sparse_tensors_compile = to_cute_block_sparse_tensors(compile_time_normalized) + # TODO: check @can_implement _flash_attn_bwd.compile_cache[compile_key] = cute.compile( fa_bwd_obj, @@ -997,8 +1026,21 @@ def _flash_attn_bwd( dK_semaphore_tensor, dV_semaphore_tensor, cute_aux_tensors, + sparse_tensors_compile, options="--enable-tvm-ffi", ) + normalized_block_sparse_tensors = None + if block_sparse_tensors is not None and compute_capability == 10: + expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd( + batch_size, num_head, seqlen_q, seqlen_k, + m_block_size, n_block_size, subtile_factor, + ) + normalized_block_sparse_tensors = normalize_block_sparse_tensors( + block_sparse_tensors, + expected_count_shape=expected_count_shape, + expected_index_shape=expected_index_shape, + ) + _flash_attn_bwd.compile_cache[compile_key]( q, k, @@ -1022,6 +1064,7 @@ def _flash_attn_bwd( dK_semaphore, dV_semaphore, aux_tensors, + normalized_block_sparse_tensors, ) num_threads = 256 if compute_capability == 9 else 128 diff --git a/flash_attn/cute/mask.py b/flash_attn/cute/mask.py index 430c7d26fc5..385e208cbe5 100644 --- a/flash_attn/cute/mask.py +++ b/flash_attn/cute/mask.py @@ -440,9 +440,24 @@ def apply_mask_sm100_transposed( mask_seqlen: cutlass.Constexpr, mask_causal: cutlass.Constexpr, mask_local: cutlass.Constexpr, + mask_mod: cutlass.Constexpr[Optional[Callable]] = None, + batch_idx: Int32 = None, + head_idx: Int32 = None, + aux_tensors: Optional[list] = None, + fastdiv_mods=(None, None), + is_full_block: bool = False, + check_m_boundary: bool = True, ) -> None: """ Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q. + + Coordinate conventio: + - ROW corresponds to Q (m_block) + - COL corresponds to KV (n_block) + + is_full_block: If True, skip mask_mod (all elements valid). Only apply seqlen masking. + check_m_boundary: If False, skip seqlen_q boundary check (optimization for non-boundary m_blocks). + When iterating m_blocks in forward order, only the last m_block may be partial. """ assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" ROW = 0 if const_expr(not self.swap_AB) else 1 @@ -450,7 +465,81 @@ def apply_mask_sm100_transposed( assert t0ScS_t2r[0][COL] == 0, "col0 == 0" thr_col_offset = tScS_t2r[0][COL] seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset - if const_expr(not mask_causal and not mask_local): + + if const_expr(not mask_causal and not mask_local and mask_mod is not None): + # Block sparse case with mask_mod (backward) + # + # Coordinate convention: ROW → Q (m_block), COL → KV (n_block). + # These already account for swap_AB. + # + # FULL blocks: mask_mod returns True for all elements, so skip it. + # Still need seqlen bounds check (elements may be OOB on last m_block). + # PARTIAL blocks: apply mask_mod element-wise, then seqlen bounds. + if is_full_block: + if const_expr(mask_seqlen): + if seqlenk_col_limit <= 0: + # Entire tile is OOB for K + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + acc_S[i] = -cutlass.Float32.inf + elif check_m_boundary: + # Last m_block: check Q and K boundaries + ncol = const_expr(cute.size(tScS_t2r.shape)) + for i in cutlass.range_constexpr(ncol): + row_coord = tScS_t2r[i][ROW] + col_coord = tScS_t2r[i][COL] + global_q = row_coord + m_block * self.tile_m + global_kv = col_coord + n_block * self.tile_n + q_out_of_bounds = global_q >= self.seqlen_q + kv_out_of_bounds = global_kv >= self.seqlen_k + out_of_bounds = q_out_of_bounds or kv_out_of_bounds + acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i] + else: + # Partial block + has_fastdiv = const_expr( + fastdiv_mods is not None + and fastdiv_mods[0] is not None + and fastdiv_mods[1] is not None + ) + wrap_aux_indices = const_expr( + has_fastdiv and mask_seqlen and const_expr(aux_tensors is not None) + ) + batch_idx_ssa = utils.scalar_to_ssa(batch_idx, cutlass.Int32) + head_idx_ssa = utils.scalar_to_ssa(head_idx, cutlass.Int32) + + ncol = const_expr(cute.size(tScS_t2r.shape)) + for i in cutlass.range_constexpr(ncol): + row_coord = tScS_t2r[i][ROW] + col_coord = tScS_t2r[i][COL] + global_q = row_coord + m_block * self.tile_m + global_kv = col_coord + n_block * self.tile_n + + q_idx_for_mod = global_q + kv_idx_for_mod = global_kv + if const_expr(wrap_aux_indices): + _, q_idx_for_mod = divmod(global_q, fastdiv_mods[0]) + _, kv_idx_for_mod = divmod(global_kv, fastdiv_mods[1]) + + q_idx_ssa = utils.scalar_to_ssa(q_idx_for_mod, cutlass.Int32) + kv_idx_ssa = utils.scalar_to_ssa(kv_idx_for_mod, cutlass.Int32) + + mask_value = mask_mod( + batch_idx_ssa, + head_idx_ssa, + q_idx_ssa, + kv_idx_ssa, + aux_tensors, + ) + cond = cutlass.Boolean(utils.ssa_to_scalar(mask_value)) + acc_S[i] = acc_S[i] if cond else -cutlass.Float32.inf + + if const_expr(mask_seqlen): + # check_m_boundary=False skips q check for non-boundary m_blocks + q_out_of_bounds = check_m_boundary and (global_q >= self.seqlen_q) + kv_out_of_bounds = global_kv >= self.seqlen_k + out_of_bounds = q_out_of_bounds or kv_out_of_bounds + acc_S[i] = -cutlass.Float32.inf if out_of_bounds else acc_S[i] + + elif const_expr(not mask_causal and not mask_local): if const_expr(mask_seqlen): if seqlenk_col_limit <= 0: for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 9c2db48f22b..f43a9c6dd9e 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -20,14 +20,9 @@ from torch.nn.attention.flex_attention import create_block_mask, flex_attention import torch.nn.functional as F -from flash_attn.cute.interface import _flash_attn_fwd +from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch -from flash_attn.cute.mask_definitions import ( - get_mask_pair, - STATIC_MASKS, - random_doc_id_tensor, -) -from flash_attn.cute.testing import attention_ref +from flash_attn.cute.mask_definitions import get_mask_pair, random_doc_id_tensor COMPUTE_CAPABILITY = torch.cuda.get_device_capability()[0] @@ -59,35 +54,14 @@ def create_tensors( lse = torch.empty(batch_size, nheads, seqlen_q, device=device, dtype=torch.float32) return { - "q": q.contiguous(), - "k": k.contiguous(), - "v": v.contiguous(), - "out": out.contiguous(), - "lse": lse.contiguous(), + "q": q, + "k": k, + "v": v, + "out": out, + "lse": lse, } -def compute_reference_flash_attn(tensors, causal, window_size, dtype_ref, upcast=True): - """Compute reference using FlashAttention's attention_ref function""" - q = tensors["q"].to(dtype_ref) - k = tensors["k"].to(dtype_ref) - v = tensors["v"].to(dtype_ref) - - out_ref, attn_ref = attention_ref( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - causal=causal, - window_size=window_size, - upcast=upcast, - reorder_ops=False, - ) - - return out_ref - - def compute_reference_flex_attn(tensors, mask_mod_flex, block_size: Optional[tuple[int, int]] = None): """Compute reference using flex_attention for custom mask_mods""" batch_size, seqlen_q, nheads, headdim = tensors["q"].shape @@ -172,6 +146,7 @@ def _run_mask_test( tile_m, tile_n, use_block_sparsity, + needs_backward=False, ): torch.manual_seed(42) @@ -230,7 +205,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim_v, dtype ) - # Compute block sparsity for mask_mod + # SM100 uses sparse_tile_m = 2*tile_m to match forward q_stage=2 pipelining if COMPUTE_CAPABILITY == 10: sparse_tile_m = 2 * tile_m else: @@ -245,29 +220,35 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n), ) - _, _, mask_cnt, mask_idx, full_cnt, full_idx, *_ = bm.as_tuple() + ( + _seq_q, + _seq_k, + kv_mask_cnt, + kv_mask_idx, + full_kv_cnt, + full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = bm.as_tuple() softmax_scale = 1.0 / math.sqrt(headdim) - # if full_cnt is not None: - # print(f"Block sparsity info for {mask_name}:") - # print(f" full_cnt shape: {full_cnt.shape}") - # print(f" full_idx shape: {full_idx.shape}") - # print(f" mask_cnt shape: {mask_cnt.shape}") - # print(f" mask_idx shape: {mask_idx.shape}") - # print(f" full_cnt: {full_cnt}") - # print(f" full_idx: {full_idx}") - # print(f" mask_cnt: {mask_cnt}") - # print(f" mask_idx: {mask_idx}") - # if full_cnt[0,0,0] > 0: - # print(f" First Q block - full indices: {full_idx[0,0,0,:full_cnt[0,0,0].item()]}") - # if mask_cnt[0,0,0] > 0: - # print(f" First Q block - mask indices: {mask_idx[0,0,0,:mask_cnt[0,0,0].item()]}") - block_sparse_mask = BlockSparseTensorsTorch( - mask_block_cnt=mask_cnt, - mask_block_idx=mask_idx, - full_block_cnt=full_cnt, - full_block_idx=full_idx, + block_sparse_mask_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + ) if use_block_sparsity else None + + # Backward uses Q-direction (transposed) sparse tensors + block_sparse_mask_bwd = BlockSparseTensorsTorch( + mask_block_cnt=q_mask_cnt, + mask_block_idx=q_mask_idx, + full_block_cnt=full_q_cnt, + full_block_idx=full_q_idx, ) if use_block_sparsity else None out_tuple = _flash_attn_fwd( @@ -294,12 +275,13 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): _compute_capability=None, score_mod=None, mask_mod=mask_mod_cute, - block_sparse_tensors=block_sparse_mask, + block_sparse_tensors=block_sparse_mask_fwd, return_lse=True, aux_tensors=aux_tensors_arg, ) out_cute = out_tuple[0] + lse_cute = out_tuple[1] tensors_fp32 = { k: v.float() if v.dtype in [torch.float16, torch.bfloat16] else v for k, v in tensors.items() @@ -356,6 +338,65 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias): f"Kernel error {cute_error:.2e} exceeds {rtol}x PyTorch error {pt_error:.2e} + {fwd_atol:.2e}" ) + # Backward pass (SM100 only) + if needs_backward and COMPUTE_CAPABILITY == 10 and kv_mode == "mha": + q = tensors["q"] + k = tensors["k"] + v = tensors["v"] + + # Create grad_out once and reuse + grad_out = torch.randn_like(out_cute) + + # Create block_mask for flex reference + flex_block_mask = create_block_mask( + mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k, + device="cuda", BLOCK_SIZE=(tile_m, tile_n), + ) + + dq_cute, dk_cute, dv_cute = run_cute_mask_bwd( + q, k, v, out_cute, lse_cute, grad_out, mask_mod_cute, + block_sparse_mask_bwd=block_sparse_mask_bwd, tile_m=tile_m, tile_n=tile_n, + aux_tensors=aux_tensors_arg, + ) + _, dq_ref_fp32, dk_ref_fp32, dv_ref_fp32 = run_flex_reference_bwd( + q, k, v, flex_block_mask, grad_out, dtype=torch.float32 + ) + _, dq_pt, dk_pt, dv_pt = run_flex_reference_bwd( + q, k, v, flex_block_mask, grad_out + ) + + # Check for invalid values + assert not torch.isnan(dq_cute).any(), "dQ contains NaN" + assert not torch.isnan(dk_cute).any(), "dK contains NaN" + assert not torch.isnan(dv_cute).any(), "dV contains NaN" + + bwd_rtol = 2 + bwd_atol_floor = 1e-5 + dq_atol = max(bwd_atol_floor, 2 * (dq_ref_fp32 + 0.3 - 0.3 - dq_ref_fp32).abs().max().item()) + dk_atol = max(bwd_atol_floor, 2 * (dk_ref_fp32 + 0.3 - 0.3 - dk_ref_fp32).abs().max().item()) + dv_atol = max(bwd_atol_floor, 2 * (dv_ref_fp32 + 0.3 - 0.3 - dv_ref_fp32).abs().max().item()) + + dq_ref = dq_ref_fp32.to(dtype) + dk_ref = dk_ref_fp32.to(dtype) + dv_ref = dv_ref_fp32.to(dtype) + + pt_dq_err = (dq_pt - dq_ref).abs().max().item() + pt_dk_err = (dk_pt - dk_ref).abs().max().item() + pt_dv_err = (dv_pt - dv_ref).abs().max().item() + + cute_dq_err = (dq_cute - dq_ref).abs().max().item() + cute_dk_err = (dk_cute - dk_ref).abs().max().item() + cute_dv_err = (dv_cute - dv_ref).abs().max().item() + + print(" Backward comparison:") + print(f" dQ: PT err={pt_dq_err:.2e}, CuTE err={cute_dq_err:.2e}, atol={dq_atol:.2e}") + print(f" dK: PT err={pt_dk_err:.2e}, CuTE err={cute_dk_err:.2e}, atol={dk_atol:.2e}") + print(f" dV: PT err={pt_dv_err:.2e}, CuTE err={cute_dv_err:.2e}, atol={dv_atol:.2e}") + + assert cute_dq_err <= bwd_rtol * pt_dq_err + dq_atol, f"dQ error too large: {cute_dq_err:.2e}" + assert cute_dk_err <= bwd_rtol * pt_dk_err + dk_atol, f"dK error too large: {cute_dk_err:.2e}" + assert cute_dv_err <= bwd_rtol * pt_dv_err + dv_atol, f"dV error too large: {cute_dv_err:.2e}" + def test_mask_mod_ima_partial_block(): _run_mask_test( @@ -372,6 +413,59 @@ def test_mask_mod_ima_partial_block(): tile_m=128, tile_n=128, use_block_sparsity=True, + needs_backward=True, + ) + + +# Q boundary seqlens: NOT multiples of tile_m (128) +# These exercise the fix for is_full_block tiles not masking OOB Q rows in backward +Q_BOUNDARY_SEQLEN_PAIRS = [ + (200, 200), # Last m_block: rows 128-199 valid, 200-255 should be masked + (300, 300), # Last m_block: rows 256-299 valid, 300-383 should be masked + (129, 129), # Just 1 element into second tile + (255, 255), # Just 1 element short of 2 full tiles + (500, 512), # Q boundary only (K aligned) + (512, 500), # K boundary only (Q aligned) + (333, 444), # Both non-aligned +] + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", Q_BOUNDARY_SEQLEN_PAIRS) +@pytest.mark.parametrize("mask_name", ["block_diagonal", "document"]) +def test_q_boundary_masking_block_sparse_bwd(seqlen_q, seqlen_k, mask_name): + """Test Q boundary masking for block-sparse backward pass. + + This test specifically exercises the fix for the bug where Q rows beyond seqlen_q + were not masked in backward pass for is_full_block=True tiles. + + The bug occurred because: + - In forward, apply_mask_sm100 always checks both Q and K bounds + - In backward, apply_mask_sm100_transposed with is_full_block=True only checked K bounds + - Result: partial last m_blocks had unmasked garbage Q rows contributing to gradients + + Key conditions: + - seqlen_q NOT a multiple of tile_m (128): creates partial last m_block + - Block-sparse with mask_mod: exercises is_full_block=True path + - Backward pass: where the bug manifested + """ + if COMPUTE_CAPABILITY != 10: + pytest.skip("SM100-only backward test") + + _run_mask_test( + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + nheads=4, + kv_mode="mha", + headdim=128, + dtype=torch.bfloat16, + mask_name=mask_name, + window_size=None, + window_left=None, + window_right=None, + tile_m=128, + tile_n=128, + use_block_sparsity=True, + needs_backward=True, ) @@ -412,6 +506,7 @@ def test_static_masks( tile_m=tile_m, tile_n=tile_n, use_block_sparsity=use_block_sparsity, + needs_backward=True, ) @@ -462,6 +557,7 @@ def test_parameterized_masks( tile_m=tile_m, tile_n=tile_n, use_block_sparsity=use_block_sparsity, + needs_backward=True, ) @@ -510,6 +606,83 @@ def test_sm100_block_sparse_sink_all_masked(): assert torch.allclose(lse, expected, atol=0.0, rtol=0.0) +# ============================================================================= +# Backward Helper Functions +# ============================================================================= + +def run_cute_mask_bwd( + q, k, v, out, lse, grad_out, mask_mod_cute, + block_sparse_mask_bwd=None, tile_m=128, tile_n=128, + aux_tensors=None, +): + """Run flash attention backward with mask_mod. + + Args: + q, k, v: Input tensors in BSHD format + out: Forward output tensor + lse: Log-sum-exp from forward pass + grad_out: Gradient of output + mask_mod_cute: CuTE mask modification function + block_sparse_mask_bwd: Block sparse tensors for backward pass + tile_m, tile_n: Tile sizes + aux_tensors: Auxiliary tensors for mask_mod (e.g., doc_ids for document masking) + + Returns (dq, dk, dv) all in BSHD format. + """ + dq, dk, dv = _flash_attn_bwd( + q=q, + k=k, + v=v, + out=out, + dout=grad_out, + lse=lse, + causal=False, + m_block_size=tile_m, + n_block_size=tile_n, + mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_mask_bwd, + aux_tensors=aux_tensors, + ) + + return dq, dk, dv + + +def run_flex_reference_bwd(q, k, v, block_mask, grad_out, dtype=None): + """Run flex_attention forward + backward for reference. + + Args: + q, k, v: Input tensors in BSHD format + block_mask: Pre-created block mask for flex_attention + grad_out: Gradient of output in BSHD format + dtype: Optional dtype to cast inputs to (e.g., torch.float32 for reference) + + Returns (out, dq, dk, dv) all in BSHD format. + """ + # Transpose to BHSD for flex_attention + if dtype is not None: + q_ref = q.transpose(1, 2).to(dtype).requires_grad_(True) + k_ref = k.transpose(1, 2).to(dtype).requires_grad_(True) + v_ref = v.transpose(1, 2).to(dtype).requires_grad_(True) + grad_out_ref = grad_out.transpose(1, 2).to(dtype) + else: + q_ref = q.transpose(1, 2).requires_grad_(True) + k_ref = k.transpose(1, 2).requires_grad_(True) + v_ref = v.transpose(1, 2).requires_grad_(True) + grad_out_ref = grad_out.transpose(1, 2) + + # Use flex_attention directly without torch.compile for backward tests + # torch.compile can hang on certain mask patterns (e.g., mini_causal with float32) + out_ref = flex_attention(q_ref, k_ref, v_ref, block_mask=block_mask) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), grad_out_ref) + + # Transpose back to BSHD + return ( + out_ref.transpose(1, 2), + dq_ref.transpose(1, 2), + dk_ref.transpose(1, 2), + dv_ref.transpose(1, 2), + ) + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"]) - \ No newline at end of file diff --git a/tests/cute/test_score_mod.py b/tests/cute/test_score_mod.py index d354f93ffc8..26cdecde431 100644 --- a/tests/cute/test_score_mod.py +++ b/tests/cute/test_score_mod.py @@ -614,6 +614,16 @@ def score_mod_bwd_identity(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info return grad +@cute.jit +def score_mod_bwd_causal(grad, score, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): + """Backward for causal masking: d(where(mask, score, -inf))/d(score) = where(mask, 1, 0). + + At unmasked positions (q_idx >= kv_idx), grad passes through. + At masked positions (q_idx < kv_idx), the kernel already zeros grad because P=0. + """ + return grad + + @cute.jit def score_mod_squared(tSrS_ssa, b_idx, h_idx, q_idx, kv_idx, seqlen_info, aux_tensors): """Forward: score ** 2.""" @@ -634,6 +644,7 @@ def score_squared_eager(score, b, h, q_idx, kv_idx): (score_mod_5, score_mod_bwd_5, times_two_eager), (score_mod_3, score_mod_bwd_3, relative_bias_eager), (score_mod_squared, score_mod_bwd_squared, score_squared_eager), + (score_mod_2, score_mod_bwd_causal, causal_mask_eager), ] BWD_TEST_PAIRS_WITH_AUX = [