From 5cd350119453d009a4756f389fba7f89f8b32c4e Mon Sep 17 00:00:00 2001 From: drisspg Date: Sun, 8 Feb 2026 01:12:00 +0000 Subject: [PATCH] Deterministic backward for blocksparse impl stack-info: PR: https://github.com/Dao-AILab/flash-attention/pull/2253, branch: drisspg/stack/16 --- flash_attn/cute/block_info.py | 17 + flash_attn/cute/block_sparse_utils.py | 22 +- flash_attn/cute/block_sparsity.py | 213 ++++++++- flash_attn/cute/cache_utils.py | 1 - flash_attn/cute/compute_block_sparsity.py | 11 +- flash_attn/cute/flash_bwd_sm100.py | 171 ++++--- flash_attn/cute/flash_fwd_sm100.py | 1 - flash_attn/cute/interface.py | 52 ++- tests/cute/test_mask_mod.py | 525 +++++++++++++++++++++- 9 files changed, 913 insertions(+), 100 deletions(-) diff --git a/flash_attn/cute/block_info.py b/flash_attn/cute/block_info.py index f21013891b4..422da2b66a0 100644 --- a/flash_attn/cute/block_info.py +++ b/flash_attn/cute/block_info.py @@ -137,3 +137,20 @@ def get_n_block_min_before_local_mask( n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q n_idx_left = n_idx - self.window_size_left return cutlass.max(n_block_min, cute.ceil_div(n_idx_left, self.tile_n)) + + @cute.jit + def get_n_block_max_for_m_block( + self, + seqlen_info: SeqlenInfoQK, + m_block: Int32, + n_block_global_max: Int32, + ) -> Int32: + if const_expr(self.is_causal or self.window_size_right is not None): + m_idx_max = (m_block + 1) * self.tile_m + if const_expr(self.qhead_per_kvhead_packgqa > 1): + m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa) + n_idx_right = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q + if const_expr(self.window_size_right is not None): + n_idx_right += self.window_size_right + return min(n_block_global_max, cute.ceil_div(n_idx_right, self.tile_n)) + return n_block_global_max diff --git a/flash_attn/cute/block_sparse_utils.py b/flash_attn/cute/block_sparse_utils.py index 0f81f863673..d664b16dc64 100644 --- a/flash_attn/cute/block_sparse_utils.py +++ b/flash_attn/cute/block_sparse_utils.py @@ -188,7 +188,7 @@ def produce_block_sparse_loads( must be converted to unpacked for sparse tensor indexing. """ - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) @@ -332,7 +332,7 @@ def consume_block_sparse_loads( must be converted to unpacked for sparse tensor indexing. """ - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) @@ -552,7 +552,7 @@ def produce_block_sparse_loads_sm100( """ m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] @@ -629,7 +629,7 @@ def get_total_block_count( ): m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors if const_expr(full_block_cnt is not None): return ( mask_block_cnt[batch_idx, head_idx, m_block_sparse] @@ -780,7 +780,7 @@ def softmax_block_sparse_sm100( warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 m_block_sparse = sparse_tensor_m_block(m_block, qhead_per_kvhead, q_subtile_factor) - mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors + mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx, *_ = blocksparse_tensors curr_mask_block_cnt = mask_block_cnt[batch_idx, head_idx, m_block_sparse] curr_mask_block_idx = mask_block_idx[batch_idx, head_idx, m_block_sparse, None] @@ -795,8 +795,6 @@ def softmax_block_sparse_sm100( total_block_cnt = curr_mask_block_cnt + curr_full_block_cnt if total_block_cnt == 0: - # See NOTE [SM100 block-sparse empty tiles: mbarrier contract]. - # pipeline_sm_stats.producer_commit_w_index(stage_idx) sm_stats_barrier.arrive_w_index(index=stage_idx * 4 + warp_idx) else: if curr_mask_block_cnt > 0: @@ -907,7 +905,7 @@ def get_total_q_block_count_bwd( m_block_max: int = 0, ): """Count total tile iterations for given n_block (KV tile) in backward.""" - q_block_cnt, _, full_block_cnt, _ = blocksparse_tensors + q_block_cnt, _, full_block_cnt, _, *_ = blocksparse_tensors total = q_block_cnt[batch_idx, head_idx, n_block] if const_expr(full_block_cnt is not None): total = total + full_block_cnt[batch_idx, head_idx, n_block] @@ -1051,7 +1049,7 @@ def get_block_sparse_iteration_info_bwd( Returns (curr_q_cnt, curr_q_idx, curr_full_cnt, curr_full_idx, total_count). """ - q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + q_cnt, q_idx, full_cnt, full_idx, *_ = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] @@ -1175,7 +1173,7 @@ def produce_block_sparse_q_loads_bwd_sm90( Returns updated (producer_state_Q, producer_state_dO). """ - q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + q_cnt, q_idx, full_cnt, full_idx, *_ = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] @@ -1270,7 +1268,7 @@ def consume_block_sparse_mma_bwd_sm90( Returns updated (consumer_state_Q, consumer_state_dO). """ - q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + q_cnt, q_idx, full_cnt, full_idx, *_ = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] @@ -1396,7 +1394,7 @@ def dQaccum_store_block_sparse_bwd_sm90( Iterates partial blocks first, then full blocks, matching producer/consumer order. """ - q_cnt, q_idx, full_cnt, full_idx = blocksparse_tensors + q_cnt, q_idx, full_cnt, full_idx, *_ = blocksparse_tensors curr_q_cnt = q_cnt[batch_idx, head_idx, n_block] curr_q_idx = q_idx[batch_idx, head_idx, n_block, None] diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index 3fad8c9f491..4a5726b7493 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -19,9 +19,13 @@ class BlockSparseTensors(NamedTuple): mask_block_idx: cute.Tensor full_block_cnt: cute.Tensor | None full_block_idx: cute.Tensor | None + dq_write_order: cute.Tensor | None = None + dq_write_order_full: cute.Tensor | None = None def __new_from_mlir_values__(self, values): if len(values) == 2: + values = (*values, None, None, None, None) + elif len(values) == 4: values = (*values, None, None) return BlockSparseTensors(*values) @@ -32,6 +36,138 @@ class BlockSparseTensorsTorch(NamedTuple): full_block_cnt: torch.Tensor | None = None full_block_idx: torch.Tensor | None = None block_size: tuple[int, int] | None = None + dq_write_order: torch.Tensor | None = None + dq_write_order_full: torch.Tensor | None = None + spt: bool | None = None + + +def _ordered_to_dense_simple( + num_blocks: torch.Tensor, + indices: torch.Tensor, + num_cols: int, +) -> torch.Tensor: + """Convert ordered sparse representation to dense binary matrix. + + Args: + num_blocks: [B, H, num_rows] count of valid entries per row + indices: [B, H, num_rows, max_entries] column indices (valid entries packed left) + num_cols: total number of columns + + Returns: + dense: [B, H, num_rows, num_cols] binary int32 matrix + """ + B, H, num_rows, max_entries = indices.shape + device = indices.device + dense = torch.zeros(B, H, num_rows, num_cols + 1, dtype=torch.int32, device=device) + col_range = torch.arange(max_entries, device=device) + valid = col_range[None, None, None, :] < num_blocks[:, :, :, None] + safe_indices = torch.where(valid, indices.long(), num_cols) + row_idx = torch.arange(num_rows, device=device)[None, None, :, None].expand_as(indices) + b_idx = torch.arange(B, device=device)[:, None, None, None].expand_as(indices) + h_idx = torch.arange(H, device=device)[None, :, None, None].expand_as(indices) + dense[b_idx, h_idx, row_idx, safe_indices] = 1 + return dense[:, :, :, :num_cols] + + +def compute_dq_write_order( + fwd_mask_cnt: torch.Tensor, + fwd_mask_idx: torch.Tensor, + fwd_full_cnt: torch.Tensor | None, + fwd_full_idx: torch.Tensor | None, + bwd_mask_cnt: torch.Tensor, + bwd_mask_idx: torch.Tensor, + bwd_full_cnt: torch.Tensor | None, + bwd_full_idx: torch.Tensor | None, + spt: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """Compute dQ write-order metadata for deterministic block-sparse backward. + + For each (n_block, i) in the backward iteration, computes the semaphore + lock value: the rank of n_block in the combined (partial + full) sorted + contributor list for the target m_block. + + Lock values are assigned in ascending n_block order (or descending if spt=True) + to guarantee deadlock-freedom with the CTA scheduling order. + + Args: + fwd_mask_cnt: [B, H, num_m_blocks] partial contributor counts per m_block + fwd_mask_idx: [B, H, num_m_blocks, max_kv] partial contributor n_block indices (ascending) + fwd_full_cnt: [B, H, num_m_blocks] full contributor counts per m_block (optional) + fwd_full_idx: [B, H, num_m_blocks, max_kv] full contributor n_block indices (optional) + bwd_mask_cnt: [B, H, num_n_blocks] partial iteration counts per n_block + bwd_mask_idx: [B, H, num_n_blocks, max_q] partial iteration m_block indices + bwd_full_cnt: [B, H, num_n_blocks] full iteration counts per n_block (optional) + bwd_full_idx: [B, H, num_n_blocks, max_q] full iteration m_block indices (optional) + spt: if True, reverse ordering (highest n_block gets lock_value=0) + + Returns: + (dq_write_order, dq_write_order_full): tensors parallel to bwd_mask_idx + and bwd_full_idx respectively, containing lock values. + """ + device = fwd_mask_idx.device + B, H, num_m, max_kv_partial = fwd_mask_idx.shape + _, _, num_n, max_q_partial = bwd_mask_idx.shape + + has_full = fwd_full_cnt is not None and fwd_full_idx is not None + + dense_partial = _ordered_to_dense_simple(fwd_mask_cnt, fwd_mask_idx, num_n) + if has_full: + dense_full = _ordered_to_dense_simple(fwd_full_cnt, fwd_full_idx, num_n) + dense = (dense_partial + dense_full).clamp(max=1) + else: + dense = dense_partial + + cumsum = dense.cumsum(dim=-1) + rank_table = (cumsum - dense).to(torch.int32) + + if spt: + total_per_m = cumsum[:, :, :, -1:] + rank_table = (total_per_m - 1 - rank_table).to(torch.int32) + + def _gather_write_order(bwd_idx, bwd_cnt): + b_i = torch.arange(B, device=device)[:, None, None, None].expand_as(bwd_idx) + h_i = torch.arange(H, device=device)[None, :, None, None].expand_as(bwd_idx) + n_i = torch.arange(bwd_idx.shape[2], device=device)[None, None, :, None].expand_as(bwd_idx) + m_vals = bwd_idx.long().clamp(0, num_m - 1) + return rank_table[b_i, h_i, m_vals, n_i].to(torch.int32) + + dq_write_order = _gather_write_order(bwd_mask_idx, bwd_mask_cnt) + + dq_write_order_full = None + if has_full and bwd_full_cnt is not None and bwd_full_idx is not None: + dq_write_order_full = _gather_write_order(bwd_full_idx, bwd_full_cnt) + + return dq_write_order, dq_write_order_full + + +def compute_dq_write_order_from_block_mask( + block_mask, + spt: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None]: + ( + _seq_q, + _seq_k, + kv_mask_cnt, + kv_mask_idx, + full_kv_cnt, + full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = block_mask.as_tuple() + return compute_dq_write_order( + kv_mask_cnt, + kv_mask_idx, + full_kv_cnt, + full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + spt=spt, + ) def get_sparse_q_block_size( @@ -110,6 +246,25 @@ def _check_and_expand_block( return expanded_cnt, expanded_idx +def _check_and_expand_metadata_tensor( + name: str, + tensor: torch.Tensor | None, + expected_shape: Tuple[int, ...], + context: str | None, + hint: str | Callable[[], str] | None, + device: torch.device, +) -> torch.Tensor | None: + if tensor is None: + return None + if tensor.dtype != torch.int32: + raise ValueError(f"{name} must have dtype torch.int32") + if tensor.device != device: + raise ValueError(f"{name} must be on the same device as block sparse tensors") + if not tensor.is_cuda: + raise ValueError(f"{name} must live on CUDA") + return _expand_sparsity_tensor(tensor, expected_shape, name, context, hint) + + def get_block_sparse_expected_shapes( batch_size: int, num_head: int, @@ -279,12 +434,37 @@ def normalize_block_sparse_tensors( if full_cnt is not None and mask_cnt.device != full_cnt.device: raise ValueError("All block sparse tensors must be on the same device") + dq_write_order = _check_and_expand_metadata_tensor( + "dq_write_order", + tensors.dq_write_order, + tuple(mask_idx.shape), + context, + hint, + mask_cnt.device, + ) + dq_write_order_full = _check_and_expand_metadata_tensor( + "dq_write_order_full", + tensors.dq_write_order_full, + tuple(full_idx.shape) if full_idx is not None else expected_index_shape, + context, + hint, + mask_cnt.device, + ) + spt = tensors.spt + if spt is not None and not isinstance(spt, bool): + raise ValueError("spt must be a bool when provided") + if spt is not None and dq_write_order is None: + raise ValueError("spt requires dq_write_order to be provided") + return BlockSparseTensorsTorch( mask_block_cnt=mask_cnt, mask_block_idx=mask_idx, full_block_cnt=full_cnt, full_block_idx=full_idx, block_size=tensors.block_size, + dq_write_order=dq_write_order, + dq_write_order_full=dq_write_order_full, + spt=spt, ) @@ -316,6 +496,8 @@ def get_block_sparse_broadcast_pattern( tensors.mask_block_idx, tensors.full_block_cnt, tensors.full_block_idx, + tensors.dq_write_order, + tensors.dq_write_order_full, ): if tensor is not None: patterns.append(get_broadcast_dims(tensor)) @@ -423,30 +605,21 @@ def to_cute_block_sparse_tensors( """Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi""" if not is_block_sparsity_enabled(tensors): return None - - ( - mask_block_cnt, - mask_block_idx, - full_block_cnt, - full_block_idx, - *_, - ) = tensors - - ( - mask_block_cnt_tensor, - mask_block_idx_tensor, - ) = [ + mask_block_cnt_tensor, mask_block_idx_tensor = [ to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi) - for t in (mask_block_cnt, mask_block_idx) + for t in (tensors.mask_block_cnt, tensors.mask_block_idx) ] - ( - full_block_cnt_tensor, - full_block_idx_tensor, - ) = [ + full_block_cnt_tensor, full_block_idx_tensor = [ + to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi) + if t is not None + else None + for t in (tensors.full_block_cnt, tensors.full_block_idx) + ] + dq_write_order_tensor, dq_write_order_full_tensor = [ to_cute_tensor(t, assumed_align=4, leading_dim=-1, enable_tvm_ffi=enable_tvm_ffi) if t is not None else None - for t in (full_block_cnt, full_block_idx) + for t in (tensors.dq_write_order, tensors.dq_write_order_full) ] return BlockSparseTensors( @@ -454,6 +627,8 @@ def to_cute_block_sparse_tensors( mask_block_idx_tensor, full_block_cnt_tensor, full_block_idx_tensor, + dq_write_order_tensor, + dq_write_order_full_tensor, ) diff --git a/flash_attn/cute/cache_utils.py b/flash_attn/cute/cache_utils.py index f1b59700448..658a8d5b656 100644 --- a/flash_attn/cute/cache_utils.py +++ b/flash_attn/cute/cache_utils.py @@ -30,7 +30,6 @@ CompileKeyType: TypeAlias = tuple[Hashable, ...] CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function - # Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1` CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1" diff --git a/flash_attn/cute/compute_block_sparsity.py b/flash_attn/cute/compute_block_sparsity.py index a2dd98e41d2..69e8309a028 100644 --- a/flash_attn/cute/compute_block_sparsity.py +++ b/flash_attn/cute/compute_block_sparsity.py @@ -54,7 +54,7 @@ def __call__( seqlen_k: Int32, aux_tensors: Optional[list] = None, ): - self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx = blocksparse_tensors + self.mask_cnt, self.mask_idx, self.full_cnt, self.full_idx, *_ = blocksparse_tensors if const_expr(self.compute_full_blocks): assert self.full_cnt is not None and self.full_idx is not None, ( @@ -366,7 +366,14 @@ def compute_block_sparsity( ) compute_block_sparsity.compile_cache[compile_key]( - blocksparse_tensors_torch[:4], + ( + blocksparse_tensors_torch.mask_block_cnt, + blocksparse_tensors_torch.mask_block_idx, + blocksparse_tensors_torch.full_block_cnt, + blocksparse_tensors_torch.full_block_idx, + blocksparse_tensors_torch.dq_write_order, + blocksparse_tensors_torch.dq_write_order_full, + ), seqlen_q, seqlen_k, aux_tensors, diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 4b4083eda9e..9184ddeb029 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -58,6 +58,7 @@ def __init__( tile_n: int = 128, is_persistent: bool = False, deterministic: bool = False, + spt: Optional[bool] = None, cluster_size: int = 1, use_2cta_instrs: bool = False, score_mod: cutlass.Constexpr | None = None, @@ -116,6 +117,7 @@ def __init__( self.qhead_per_kvhead = qhead_per_kvhead self.pack_gqa = False self.deterministic = deterministic + self.spt_override = spt # Score mod and mask mod support self.score_mod = score_mod @@ -705,7 +707,11 @@ def __call__( TileScheduler = SingleTileLPTBwdScheduler else: TileScheduler = SingleTileScheduler - self.spt = (self.is_causal or self.is_local) and self.deterministic + if const_expr(self.spt_override is None): + self.spt = (self.is_causal or self.is_local) and self.deterministic + else: + assert self.spt_override is not None + self.spt = self.spt_override and self.deterministic tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mK.shape[0]), self.cta_tiler[0]), # num_blocks cute.size(mQ.shape[2]), # num_heads = num_query_heads @@ -2989,8 +2995,17 @@ def compute_loop( # prefetch_LSE = not self.is_causal prefetch_LSE = False - # some tiles might be empty due to block sparsity + + curr_q_cnt = Int32(0) + curr_q_idx = None + curr_full_cnt = Int32(0) + curr_full_idx = None + loop_count = m_block_max - m_block_min + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) or m_block_min < m_block_max + ) if const_expr(self.use_block_sparsity): + assert blocksparse_tensors is not None ( curr_q_cnt, curr_q_idx, @@ -3006,17 +3021,14 @@ def compute_loop( m_block_max=m_block_max, ) process_tile = loop_count > Int32(0) - else: - process_tile = ( - const_expr(not self.is_local and not self.is_varlen_q) - or m_block_min < m_block_max - ) - loop_count = m_block_max - m_block_min # Mainloop # Block sparsity: iterate over sparse m_block count and derive actual m_block # from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly. for iter_idx in cutlass.range(loop_count, unroll=1): + m_block = m_block_min + iter_idx + m_block_oob = False + is_full_block = False if const_expr(self.use_block_sparsity): m_block, is_full_block = get_m_block_from_iter_bwd( iter_idx, @@ -3028,10 +3040,6 @@ def compute_loop( m_block_max=m_block_max, ) m_block_oob = m_block >= m_block_max - else: - m_block = m_block_min + iter_idx - m_block_oob = False - is_full_block = False # Prefetch 1 stage of LSE pipeline_LSE.consumer_wait(consumer_state_LSE) tSrLSE_s2r = cute.make_fragment(tScS_t2r[None, 0, 0, 0].shape, Float32) @@ -3412,6 +3420,38 @@ def compute_loop( tile_scheduler.advance_to_next_work() work_tile = tile_scheduler.get_current_work() + @cute.jit + def _dq_semaphore_lock_value( + self, + iter_idx: Int32, + curr_q_cnt: Int32, + curr_dq_write_order: Optional[cute.Tensor], + curr_dq_write_order_full: Optional[cute.Tensor], + blocksparse_tensors: Optional[BlockSparseTensors], + block_info: BlockInfo, + seqlen, + m_block: Int32, + n_block: Int32, + n_block_global_max: Int32, + ) -> Int32: + lock_value = n_block + if const_expr(self.spt): + n_block_max_for_m_block = block_info.get_n_block_max_for_m_block( + seqlen, m_block, n_block_global_max + ) + lock_value = n_block_max_for_m_block - 1 - n_block + if const_expr(self.use_block_sparsity): + assert blocksparse_tensors is not None + if const_expr(blocksparse_tensors.dq_write_order is not None): + sparse_iter = iter_idx // self.subtile_factor + if sparse_iter < curr_q_cnt: + assert curr_dq_write_order is not None + lock_value = curr_dq_write_order[sparse_iter] + else: + assert curr_dq_write_order_full is not None + lock_value = curr_dq_write_order_full[sparse_iter - curr_q_cnt] + return lock_value + @cute.jit def dQacc_reduce( self, @@ -3484,13 +3524,24 @@ def dQacc_reduce( ) if const_expr(self.deterministic): + assert mdQ_semaphore is not None mdQ_semaphore_cur = mdQ_semaphore[None, None, head_idx, batch_idx] - # delay_semaphore_release = self.is_causal and not self.tile_hdim == 192 - delay_semaphore_release = not self.tile_hdim == 192 + delay_semaphore_release = not self.tile_hdim == 192 and not self.use_block_sparsity + n_block_global_max = cute.ceil_div(seqlen.seqlen_k, self.tile_n) - # some tiles might be empty due to block sparsity + curr_q_cnt = Int32(0) + curr_q_idx = None + curr_full_cnt = Int32(0) + curr_full_idx = None + curr_dq_write_order = None + curr_dq_write_order_full = None + loop_count = m_block_max - m_block_min + process_tile = ( + const_expr(not self.is_local and not self.is_varlen_q) or m_block_min < m_block_max + ) if const_expr(self.use_block_sparsity): + assert blocksparse_tensors is not None ( curr_q_cnt, curr_q_idx, @@ -3506,17 +3557,25 @@ def dQacc_reduce( m_block_max=m_block_max, ) process_tile = loop_count > Int32(0) - else: - process_tile = ( - const_expr(not self.is_local and not self.is_varlen_q) - or m_block_min < m_block_max - ) - loop_count = m_block_max - m_block_min + if const_expr(self.deterministic and self.use_block_sparsity): + assert blocksparse_tensors is not None + if const_expr(blocksparse_tensors.dq_write_order is not None): + assert blocksparse_tensors.dq_write_order is not None + curr_dq_write_order = blocksparse_tensors.dq_write_order[ + batch_idx, head_idx, n_block, None + ] + if const_expr(blocksparse_tensors.dq_write_order_full is not None): + assert blocksparse_tensors.dq_write_order_full is not None + curr_dq_write_order_full = blocksparse_tensors.dq_write_order_full[ + batch_idx, head_idx, n_block, None + ] # dQacc_reduce mainloop # Block sparsity: iterate over sparse m_block count and derive actual m_block # from Q_IDX/FULL_Q_IDX tensors. Dense: iterate m_block_min..m_block_max directly. for iter_idx in cutlass.range(loop_count, unroll=1): + m_block = m_block_min + iter_idx + m_block_oob_upper = False if const_expr(self.use_block_sparsity): m_block, _ = get_m_block_from_iter_bwd( iter_idx, @@ -3527,10 +3586,7 @@ def dQacc_reduce( subtile_factor=self.subtile_factor, m_block_max=m_block_max, ) - if m_block_max > 0: - m_block = cutlass.min(m_block, m_block_max - 1) - else: - m_block = m_block_min + iter_idx + m_block_oob_upper = m_block >= m_block_max pipeline_dQ.consumer_wait(dQ_consumer_state) # TMEM -> RMEM tdQrdQ_t2r = cute.make_fragment(tdQrdQ_t2r_shape, Float32) @@ -3541,6 +3597,8 @@ def dQacc_reduce( pipeline_dQ.consumer_release(dQ_consumer_state) dQ_consumer_state.advance() + if m_block_max > 0: + m_block = cutlass.min(m_block, m_block_max - 1) gdQaccum_cur = gdQaccum[None, None, m_block] tdQrdQ_shape = ( @@ -3558,22 +3616,28 @@ def dQacc_reduce( cute.arch.fence_view_async_shared() # semaphore acquire if const_expr(self.deterministic and stage == 0): - if const_expr(self.spt): - _, n_block_max_for_m_block = block_info.get_n_block_min_max( - seqlen, m_block + if not m_block_oob_upper: + lock_value = self._dq_semaphore_lock_value( + iter_idx, + curr_q_cnt, + curr_dq_write_order, + curr_dq_write_order_full, + blocksparse_tensors, + block_info, + seqlen, + m_block, + n_block_cta_group, + n_block_global_max, + ) + barrier.wait_eq( + mdQ_semaphore_cur[(m_block, None)].iterator, + tidx, + cta_rank_in_cluster, + lock_value, ) - lock_value = n_block_max_for_m_block - 1 - n_block_cta_group - else: - lock_value = n_block_cta_group - barrier.wait_eq( - mdQ_semaphore_cur[(m_block, None)].iterator, - tidx, - cta_rank_in_cluster, - lock_value, - ) self.reduce_sync_barrier.arrive_and_wait() # Copy from shared memory to global memory - if is_tma_warp: + if is_tma_warp and not m_block_oob_upper: with cute.arch.elect_one(): copy_utils.cpasync_reduce_bulk_add_f32( sdQaccum[None, smem_idx].iterator, @@ -3582,20 +3646,12 @@ def dQacc_reduce( ) cute.arch.cp_async_bulk_commit_group() cute.arch.cp_async_bulk_wait_group(self.sdQaccum_stage - 1, read=read_flag) + elif is_tma_warp: + # Drain pending TMA stores so SMEM buffers are safe to reuse + cute.arch.cp_async_bulk_wait_group(0, read=read_flag) self.reduce_sync_barrier.arrive_and_wait() dQ_tma_store_producer_state.advance() - # Directly add to gmem, much slower - # tdQgdQ = thr_copy_dQaccum_r2s.partition_D(gdQaccum[None, stage, m_block]) - # assert cute.size(tdQrdQ_r2s) == cute.size(tdQgdQ) - # for i in cutlass.range(cute.size(tdQrdQ_r2s) // 4, unroll_full=True): - # copy_utils.atomic_add_fp32x4( - # tdQrdQ_r2s[4 * i], - # tdQrdQ_r2s[4 * i + 1], - # tdQrdQ_r2s[4 * i + 2], - # tdQrdQ_r2s[4 * i + 3], - # utils.elem_pointer(tdQgdQ, 4 * i), - # ) - # semaphore release for prior m_block + if const_expr(self.deterministic and stage == 0 and delay_semaphore_release): if m_block > m_block_min: barrier.arrive_inc( @@ -3617,12 +3673,13 @@ def dQacc_reduce( # NOTE: arrive_inc calls red_release which issues membar if const_expr(self.deterministic and not delay_semaphore_release): if const_expr(self.sdQaccum_stage > 1 and not self.tile_hdim == 192): - if is_tma_warp: + if is_tma_warp and not m_block_oob_upper: cute.arch.cp_async_bulk_wait_group(0, read=read_flag) self.reduce_sync_barrier.arrive_and_wait() - barrier.arrive_inc( - mdQ_semaphore_cur[m_block, None].iterator, tidx, cta_rank_in_cluster, 1 - ) + if not m_block_oob_upper: + barrier.arrive_inc( + mdQ_semaphore_cur[m_block, None].iterator, tidx, cta_rank_in_cluster, 1 + ) if process_tile: if is_tma_warp: @@ -3638,7 +3695,10 @@ def dQacc_reduce( ) if const_expr( - self.deterministic and not self.spt and block_info.window_size_left is not None + self.deterministic + and not self.spt + and not self.use_block_sparsity + and block_info.window_size_left is not None ): m_block_global_max = cute.ceil_div(seqlen.seqlen_q, self.tile_m) for m_block in cutlass.range(m_block_max, m_block_global_max, unroll=1): @@ -3863,6 +3923,7 @@ def epilogue_dK_or_dV_tma( deterministic_KV = self.deterministic and self.qhead_per_kvhead > 1 if const_expr(deterministic_KV): + assert mdKV_semaphore is not None mdKV_semaphore_cur = mdKV_semaphore[n_block, None, head_idx_kv, batch_idx] if const_expr(not self.dKV_postprocess): diff --git a/flash_attn/cute/flash_fwd_sm100.py b/flash_attn/cute/flash_fwd_sm100.py index 9027da189ac..42acbeaec86 100644 --- a/flash_attn/cute/flash_fwd_sm100.py +++ b/flash_attn/cute/flash_fwd_sm100.py @@ -2036,7 +2036,6 @@ def softmax_loop( ) if const_expr(self.use_block_sparsity) or has_work: - # See block_sparse_utils.py NOTE [SM100 block-sparse empty tiles: mbarrier contract]. pipeline_sm_stats.producer_acquire_w_index_phase(stage, sm_stats_producer_phase) sm_stats_producer_phase ^= 1 diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index a11c8debe2b..38c6b707b61 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -964,10 +964,6 @@ def _flash_attn_fwd( compile_args.insert(-3, descale_tensors_tensor) _flash_attn_fwd.compile_cache[compile_key] = cute.compile(*compile_args, options="--enable-tvm-ffi") - # In "fake mode", we will take torch fake tensors as input and the expected behaviors are: - # - Use those fake metadata to populate compilation cache - # - Return "fake" output tensors, which could be needed in follow-up fake operations - # Thus, we skip the actual kernel invocation here. if not is_fake_mode(): q_call, k_call, v_call = q.detach(), k.detach(), v.detach() qv_call = qv.detach() if qv is not None else None @@ -1021,7 +1017,16 @@ def _flash_attn_fwd( if arch // 10 in [10, 11]: call_args.append(descale_tensors) call_args.extend([ - normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, + ( + normalized_block_sparse_tensors.mask_block_cnt, + normalized_block_sparse_tensors.mask_block_idx, + normalized_block_sparse_tensors.full_block_cnt, + normalized_block_sparse_tensors.full_block_idx, + normalized_block_sparse_tensors.dq_write_order, + normalized_block_sparse_tensors.dq_write_order_full, + ) + if normalized_block_sparse_tensors is not None + else None, aux_tensors, ]) _flash_attn_fwd.compile_cache[compile_key](*call_args) @@ -1549,6 +1554,30 @@ def _flash_attn_bwd( block_size=(m_block_size, n_block_size), subtile_factor=subtile_factor, ) + if deterministic: + if normalized_block_sparse_tensors.dq_write_order is None: + raise ValueError( + "deterministic block-sparse backward requires dq_write_order in block_sparse_tensors" + ) + if ( + normalized_block_sparse_tensors.full_block_cnt is not None + and normalized_block_sparse_tensors.dq_write_order_full is None + ): + raise ValueError( + "deterministic block-sparse backward requires dq_write_order_full when full blocks are present" + ) + if normalized_block_sparse_tensors.spt is None: + raise ValueError( + "deterministic block-sparse backward requires block_sparse_tensors.spt " + "to match dq_write_order direction" + ) + if ( + normalized_block_sparse_tensors is not None + and normalized_block_sparse_tensors.spt is not None + ): + spt = normalized_block_sparse_tensors.spt and deterministic + else: + spt = (causal or local) and deterministic if arch // 10 in [8, 9, 12]: compile_key = ( @@ -1610,6 +1639,7 @@ def _flash_attn_bwd( cluster_size, use_2cta_instrs, deterministic, + spt, score_mod_hash, score_mod_bwd_hash, mask_mod_hash, @@ -1745,6 +1775,7 @@ def _flash_attn_bwd( cluster_size=cluster_size, use_2cta_instrs=use_2cta_instrs, deterministic=deterministic, + spt=spt, score_mod=score_mod, score_mod_bwd=score_mod_bwd, mask_mod=mask_mod, @@ -1808,7 +1839,16 @@ def _flash_attn_bwd( dK_semaphore, dV_semaphore, aux_tensors, - normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, + ( + normalized_block_sparse_tensors.mask_block_cnt, + normalized_block_sparse_tensors.mask_block_idx, + normalized_block_sparse_tensors.full_block_cnt, + normalized_block_sparse_tensors.full_block_idx, + normalized_block_sparse_tensors.dq_write_order, + normalized_block_sparse_tensors.dq_write_order_full, + ) + if normalized_block_sparse_tensors is not None + else None, ) # Postprocess: convert dq_accum from float32 to dq in bf16/fp16 # hd=256 2CTA backward has its own internal postprocess, skip here. diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 484f5191725..ceef6500b97 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -27,6 +27,8 @@ BlockSparseTensorsTorch, fast_sampling, normalize_block_sparse_config, + compute_dq_write_order, + compute_dq_write_order_from_block_mask, ) from flash_attn.cute.cache_utils import get_jit_cache from flash_attn.cute import utils @@ -677,8 +679,12 @@ def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): full_block_idx=full_q_idx, block_size=(sparse_tile_m, tile_n), ) - - + dq_write_order = compute_dq_write_order_from_block_mask(bm, spt=False) + block_sparse_mask_bwd = block_sparse_mask_bwd._replace( + dq_write_order=dq_write_order[0], + dq_write_order_full=dq_write_order[1], + spt=False, + ) out_tuple = _flash_attn_fwd( q=q, k=k, v=v, out=out, lse=lse, cu_seqlens_q=None, cu_seqlens_k=None, @@ -1119,7 +1125,7 @@ def wrapped_normalize(*args, **kwargs): def run_cute_mask_bwd( q, k, v, out, lse, grad_out, mask_mod_cute, block_sparse_mask_bwd=None, tile_m=128, tile_n=128, - aux_tensors=None, + aux_tensors=None, deterministic=False, causal=False, window_size_left=None, window_size_right=None, ): """Run flash attention backward with mask_mod. @@ -1132,6 +1138,7 @@ def run_cute_mask_bwd( block_sparse_mask_bwd: Block sparse tensors for backward pass tile_m, tile_n: Tile sizes aux_tensors: Auxiliary tensors for mask_mod (e.g., doc_ids for document masking) + deterministic: Whether to enable deterministic backward Returns (dq, dk, dv) all in BSHD format. """ @@ -1142,12 +1149,15 @@ def run_cute_mask_bwd( out=out, dout=grad_out, lse=lse, - causal=False, + causal=causal, m_block_size=tile_m, n_block_size=tile_n, + window_size_left=window_size_left, + window_size_right=window_size_right, mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_mask_bwd, aux_tensors=aux_tensors, + deterministic=deterministic, ) return dq, dk, dv @@ -1739,6 +1749,513 @@ def test_persistent_blocksparse_empty_tiles(): assert not out.isnan().any() +def _build_dense_from_ordered(num_blocks, indices, num_cols): + """Build dense binary matrix from ordered sparse representation (test helper).""" + B, H, num_rows, max_entries = indices.shape + batch_is_broadcast = B == 1 or (indices.stride(0) == 0 and num_blocks.stride(0) == 0) + head_is_broadcast = H == 1 or (indices.stride(1) == 0 and num_blocks.stride(1) == 0) + batch_size = 1 if batch_is_broadcast else B + head_size = 1 if head_is_broadcast else H + indices_view = indices[:batch_size, :head_size] + num_blocks_view = num_blocks[:batch_size, :head_size] + dense = torch.zeros( + batch_size, + head_size, + num_rows, + num_cols + 1, + dtype=torch.int32, + device=indices.device, + ) + valid = ( + torch.arange(max_entries, device=indices.device)[None, None, None, :] + < num_blocks_view[:, :, :, None] + ) + safe_indices = torch.where(valid, indices_view.long(), num_cols) + dense.scatter_(-1, safe_indices, valid.to(torch.int32)) + dense = dense[:, :, :, :num_cols] + if batch_size != B or head_size != H: + return dense.expand(B, H, num_rows, num_cols) + return dense + + +def _verify_deadlock_freedom( + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, + dq_wo, dq_wo_full, spt=False, +): + """Verify the critical deadlock-freedom invariant for all m_blocks. + + For non-spt: the lowest n_block contributor to each m_block must have lock_value=0. + For spt: the highest n_block contributor must have lock_value=0. + """ + B, H, num_m = kv_mask_cnt.shape + num_n = kv_mask_idx.shape[-1] + + dense = _build_dense_from_ordered(kv_mask_cnt, kv_mask_idx, num_n) + if full_kv_cnt is not None: + dense = dense | _build_dense_from_ordered(full_kv_cnt, full_kv_idx, num_n) + + for b in range(B): + for h in range(H): + for m in range(num_m): + contributors = dense[b, h, m].nonzero(as_tuple=True)[0] + if len(contributors) == 0: + continue + target_n = contributors[-1].item() if spt else contributors[0].item() + + found = False + cnt_partial = q_mask_cnt[b, h, target_n].item() + for i in range(cnt_partial): + if q_mask_idx[b, h, target_n, i].item() == m: + assert dq_wo[b, h, target_n, i].item() == 0, ( + f"n_block={target_n} should get lock_value=0 for m_block={m} (spt={spt})" + ) + found = True + break + if not found and full_q_cnt is not None: + cnt_full = full_q_cnt[b, h, target_n].item() + for i in range(cnt_full): + if full_q_idx[b, h, target_n, i].item() == m: + assert dq_wo_full[b, h, target_n, i].item() == 0, ( + f"n_block={target_n} (full) should get lock_value=0 for m_block={m} (spt={spt})" + ) + found = True + break + assert found, f"target n_block={target_n} not found in backward lists for m_block={m}" + + +def _verify_unique_ranks_per_m_block( + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, + dq_wo, dq_wo_full, +): + """Verify that for each m_block, the lock values form a contiguous 0..N-1 range.""" + B, H, num_m = kv_mask_cnt.shape + num_n = kv_mask_idx.shape[-1] + + dense = _build_dense_from_ordered(kv_mask_cnt, kv_mask_idx, num_n) + if full_kv_cnt is not None: + dense = dense | _build_dense_from_ordered(full_kv_cnt, full_kv_idx, num_n) + + for b in range(B): + for h in range(H): + for m in range(num_m): + contributors = dense[b, h, m].nonzero(as_tuple=True)[0] + total = len(contributors) + if total == 0: + continue + lock_vals = set() + for n in contributors.tolist(): + cnt_p = q_mask_cnt[b, h, n].item() + for i in range(cnt_p): + if q_mask_idx[b, h, n, i].item() == m: + lock_vals.add(dq_wo[b, h, n, i].item()) + if full_q_cnt is not None: + cnt_f = full_q_cnt[b, h, n].item() + for i in range(cnt_f): + if full_q_idx[b, h, n, i].item() == m: + lock_vals.add(dq_wo_full[b, h, n, i].item()) + assert lock_vals == set(range(total)), ( + f"m_block={m}: expected ranks {{0..{total-1}}}, got {lock_vals}" + ) + + +def _run_write_order_test(mask_mod_flex, seqlen_q, seqlen_k, block_size, B=1, H=4, spt=False): + bm = create_block_mask( + mask_mod_flex, B, H, seqlen_q, seqlen_k, + device="cuda", BLOCK_SIZE=(block_size, block_size), + ) + ( + _, _, + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, *_, + ) = bm.as_tuple() + + dq_wo, dq_wo_full = compute_dq_write_order( + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, + spt=spt, + ) + + _verify_deadlock_freedom( + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, + dq_wo, dq_wo_full, spt=spt, + ) + if not spt: + _verify_unique_ranks_per_m_block( + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, + dq_wo, dq_wo_full, + ) + + +def _build_block_sparse_masks_for_bwd( + mask_mod_flex, + batch_size, + nheads, + seqlen_q, + seqlen_k, + tile_m, + tile_n, + spt, +): + sparse_tile_m = 2 * tile_m if COMPUTE_CAPABILITY == 10 else tile_m + bm = create_block_mask( + mask_mod_flex, + batch_size, + nheads, + seqlen_q, + seqlen_k, + device="cuda", + BLOCK_SIZE=(sparse_tile_m, tile_n), + ) + ( + _seq_q, + _seq_k, + kv_mask_cnt, + kv_mask_idx, + full_kv_cnt, + full_kv_idx, + q_mask_cnt, + q_mask_idx, + full_q_cnt, + full_q_idx, + *_, + ) = bm.as_tuple() + + block_sparse_mask_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + block_size=(sparse_tile_m, tile_n), + ) + block_sparse_mask_bwd = BlockSparseTensorsTorch( + mask_block_cnt=q_mask_cnt, + mask_block_idx=q_mask_idx, + full_block_cnt=full_q_cnt, + full_block_idx=full_q_idx, + block_size=(sparse_tile_m, tile_n), + ) + dq_write_order = compute_dq_write_order_from_block_mask(bm, spt=spt) + return block_sparse_mask_fwd, block_sparse_mask_bwd._replace( + dq_write_order=dq_write_order[0], + dq_write_order_full=dq_write_order[1], + spt=spt, + ) + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="deterministic bwd only supported on sm100/sm110") +@pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 256), (512, 512), (383, 769)]) +@pytest.mark.parametrize( + "mask_name,window_size", + [ + ("block_diagonal", None), + ("causal", None), + ("sliding_window", 256), + ("document", None), + ], +) +@pytest.mark.parametrize("spt", [False, True]) +@pytest.mark.parametrize("kv_mode", ["mha", "gqa"]) +def test_block_sparse_bwd_deterministic(seqlen_q, seqlen_k, mask_name, window_size, spt, kv_mode): + torch.manual_seed(42) + if mask_name == "sliding_window" and seqlen_q > seqlen_k: + pytest.skip("sliding_window requires seqlen_q <= seqlen_k") + if spt and mask_name not in ("sliding_window", "causal"): + pytest.skip("spt path is only exercised for sliding_window and causal in this test") + + batch_size = 1 + nheads = 4 + nheads_kv = 1 if kv_mode == "gqa" else nheads + pack_gqa = nheads != nheads_kv + headdim = 128 + tile_m = 128 + tile_n = 128 + dtype = torch.bfloat16 + + mask_mod_cute, mask_mod_flex = get_mask_pair( + mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size + ) + + aux_tensors_arg = None + if mask_name == "document": + doc_ids = random_doc_id_tensor(nheads, batch_size, max(seqlen_q, seqlen_k), device="cuda").to( + torch.int32 + ) + original_flex_mask = mask_mod_flex + + def mask_mod_flex(b, h, q_idx, kv_idx, doc_ids=doc_ids): + return original_flex_mask(b, h, q_idx, kv_idx, doc_ids) + + aux_tensors_arg = [doc_ids] + + tensors = create_tensors(batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, headdim, headdim, dtype) + q = tensors["q"] + k = tensors["k"] + v = tensors["v"] + block_mask_nheads = 1 if pack_gqa else nheads + block_sparse_mask_fwd, block_sparse_mask_bwd = _build_block_sparse_masks_for_bwd( + mask_mod_flex=mask_mod_flex, + batch_size=batch_size, + nheads=block_mask_nheads, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + tile_m=tile_m, + tile_n=tile_n, + spt=spt, + ) + causal_arg = spt and mask_name == "causal" + window_size_left_arg = window_size if spt and mask_name == "sliding_window" else None + window_size_right_arg = 0 if spt and mask_name == "sliding_window" else None + mask_mod_arg = mask_mod_cute if not spt else None + + out_cute, lse_cute = _flash_attn_fwd( + q=q, + k=k, + v=v, + out=torch.empty(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype), + lse=torch.empty(batch_size, nheads, seqlen_q, device="cuda", dtype=torch.float32), + cu_seqlens_q=None, + cu_seqlens_k=None, + seqused_q=None, + seqused_k=None, + page_table=None, + softmax_scale=1.0 / math.sqrt(headdim), + causal=causal_arg, + softcap=None, + window_size_left=window_size_left_arg, + window_size_right=window_size_right_arg, + learnable_sink=None, + tile_mn=(tile_m, tile_n), + pack_gqa=pack_gqa, + _arch=None, + score_mod=None, + mask_mod=mask_mod_arg, + block_sparse_tensors=block_sparse_mask_fwd, + return_lse=True, + aux_tensors=aux_tensors_arg, + ) + + grad_out = torch.randn_like(out_cute) + dq0, dk0, dv0 = run_cute_mask_bwd( + q, + k, + v, + out_cute, + lse_cute, + grad_out, + mask_mod_arg, + block_sparse_mask_bwd=block_sparse_mask_bwd, + tile_m=tile_m, + tile_n=tile_n, + aux_tensors=aux_tensors_arg, + deterministic=True, + causal=causal_arg, + window_size_left=window_size_left_arg, + window_size_right=window_size_right_arg, + ) + + num_repeats = 3 if spt else 50 + for _ in range(num_repeats): + dq, dk, dv = run_cute_mask_bwd( + q, + k, + v, + out_cute, + lse_cute, + grad_out, + mask_mod_arg, + block_sparse_mask_bwd=block_sparse_mask_bwd, + tile_m=tile_m, + tile_n=tile_n, + aux_tensors=aux_tensors_arg, + deterministic=True, + causal=causal_arg, + window_size_left=window_size_left_arg, + window_size_right=window_size_right_arg, + ) + assert torch.equal(dq, dq0) + assert torch.equal(dk, dk0) + assert torch.equal(dv, dv0) + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="deterministic bwd only supported on sm100/sm110") +def _setup_block_sparse_deterministic_validation_case(): + torch.manual_seed(42) + batch_size = 1 + nheads = 4 + seqlen_q = 256 + seqlen_k = 256 + headdim = 128 + tile_m = 128 + tile_n = 128 + dtype = torch.bfloat16 + + _, mask_mod_flex = get_mask_pair( + "block_diagonal", seqlen_q=seqlen_q, seqlen_k=seqlen_k + ) + + q = torch.randn(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype) + k = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) + v = torch.randn(batch_size, seqlen_k, nheads, headdim, device="cuda", dtype=dtype) + block_sparse_mask_fwd, block_sparse_mask_bwd = _build_block_sparse_masks_for_bwd( + mask_mod_flex=mask_mod_flex, + batch_size=batch_size, + nheads=nheads, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + tile_m=tile_m, + tile_n=tile_n, + spt=False, + ) + out_cute, lse_cute = _flash_attn_fwd( + q=q, + k=k, + v=v, + out=torch.empty(batch_size, seqlen_q, nheads, headdim, device="cuda", dtype=dtype), + lse=torch.empty(batch_size, nheads, seqlen_q, device="cuda", dtype=torch.float32), + softmax_scale=1.0 / math.sqrt(headdim), + tile_mn=(tile_m, tile_n), + block_sparse_tensors=block_sparse_mask_fwd, + return_lse=True, + ) + + return q, k, v, out_cute, lse_cute, torch.randn_like(out_cute), block_sparse_mask_bwd, tile_m, tile_n + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="deterministic bwd only supported on sm100/sm110") +def test_block_sparse_bwd_deterministic_missing_dq_write_order_raises(): + q, k, v, out_cute, lse_cute, grad_out, block_sparse_mask_bwd, tile_m, tile_n = ( + _setup_block_sparse_deterministic_validation_case() + ) + block_sparse_mask_bwd_no_dq_write_order = block_sparse_mask_bwd._replace( + dq_write_order=None, + dq_write_order_full=None, + spt=None, + ) + + with pytest.raises(ValueError, match="requires dq_write_order in block_sparse_tensors"): + run_cute_mask_bwd( + q, + k, + v, + out_cute, + lse_cute, + grad_out, + None, + block_sparse_mask_bwd=block_sparse_mask_bwd_no_dq_write_order, + tile_m=tile_m, + tile_n=tile_n, + deterministic=True, + ) + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="deterministic bwd only supported on sm100/sm110") +def test_block_sparse_bwd_deterministic_missing_dq_write_order_full_raises(): + q, k, v, out_cute, lse_cute, grad_out, block_sparse_mask_bwd, tile_m, tile_n = ( + _setup_block_sparse_deterministic_validation_case() + ) + block_sparse_mask_bwd_no_dq_write_order_full = block_sparse_mask_bwd._replace( + full_block_cnt=torch.zeros_like(block_sparse_mask_bwd.mask_block_cnt), + full_block_idx=torch.zeros_like(block_sparse_mask_bwd.mask_block_idx), + dq_write_order_full=None, + spt=False, + ) + + with pytest.raises(ValueError, match="requires dq_write_order_full when full blocks are present"): + run_cute_mask_bwd( + q, + k, + v, + out_cute, + lse_cute, + grad_out, + None, + block_sparse_mask_bwd=block_sparse_mask_bwd_no_dq_write_order_full, + tile_m=tile_m, + tile_n=tile_n, + deterministic=True, + ) + + +@pytest.mark.skipif(COMPUTE_CAPABILITY not in (10, 11), reason="deterministic bwd only supported on sm100/sm110") +def test_block_sparse_bwd_deterministic_missing_spt_raises(): + q, k, v, out_cute, lse_cute, grad_out, block_sparse_mask_bwd, tile_m, tile_n = ( + _setup_block_sparse_deterministic_validation_case() + ) + block_sparse_mask_bwd_no_spt = block_sparse_mask_bwd._replace(spt=None) + + with pytest.raises(ValueError, match="requires block_sparse_tensors.spt"): + run_cute_mask_bwd( + q, + k, + v, + out_cute, + lse_cute, + grad_out, + None, + block_sparse_mask_bwd=block_sparse_mask_bwd_no_spt, + tile_m=tile_m, + tile_n=tile_n, + deterministic=True, + ) + + +WRITE_ORDER_SEQLENS = [ + (256, 256), + (512, 512), + (1024, 1024), + (2048, 2048), + (4096, 4096), + (512, 1024), + (1024, 512), + (384, 768), +] + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", WRITE_ORDER_SEQLENS) +@pytest.mark.parametrize("mask_name", ["block_diagonal", "mini_causal", "prefix_lm", "dilated_sliding_window"]) +@pytest.mark.parametrize("spt", [False, True]) +def test_dq_write_order_static_masks(seqlen_q, seqlen_k, mask_name, spt): + torch.manual_seed(42) + _, mask_mod_flex = get_mask_pair(mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k) + _run_write_order_test(mask_mod_flex, seqlen_q, seqlen_k, block_size=128, spt=spt) + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", WRITE_ORDER_SEQLENS) +@pytest.mark.parametrize( + "mask_name,window_size", + [ + ("causal", None), + ("block_causal", None), + ("sliding_window", 128), + ("sliding_window", 256), + ("sliding_window", 512), + ], +) +@pytest.mark.parametrize("spt", [False, True]) +def test_dq_write_order_parameterized_masks(seqlen_q, seqlen_k, mask_name, window_size, spt): + torch.manual_seed(42) + if mask_name == "sliding_window" and seqlen_q > seqlen_k: + pytest.skip("sliding_window requires seqlen_q <= seqlen_k") + _, mask_mod_flex = get_mask_pair(mask_name, seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=window_size) + _run_write_order_test(mask_mod_flex, seqlen_q, seqlen_k, block_size=128, spt=spt) + + +@pytest.mark.parametrize("seqlen_q,seqlen_k", [(512, 512), (1024, 1024), (2048, 2048)]) +@pytest.mark.parametrize("spt", [False, True]) +def test_dq_write_order_document_mask(seqlen_q, seqlen_k, spt): + torch.manual_seed(42) + B, H = 1, 4 + doc_ids = random_doc_id_tensor(H, B, max(seqlen_q, seqlen_k), device="cuda").to(torch.int32) + + def doc_mask(b, h, q_idx, kv_idx): + return doc_ids[b, h, q_idx] == doc_ids[b, h, kv_idx] + + _run_write_order_test(doc_mask, seqlen_q, seqlen_k, block_size=128, B=B, H=H, spt=spt) + def test_compact_block_sparse_indices(): """Test that compact block sparse index tensors (idx.shape[3] < n_blocks) work correctly.