diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index 1607a8b80b5..59b0c017f3a 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -7,7 +7,7 @@ import cutlass.cute as cute import torch -from flash_attn.cute.cute_dsl_utils import to_cute_tensor +from flash_attn.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor def ceildiv(a: int, b: int) -> int: @@ -174,6 +174,38 @@ def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool: return any(t is not None for t in (tensors.full_block_cnt, tensors.mask_block_cnt)) +def get_block_sparse_broadcast_pattern( + tensors: BlockSparseTensorsTorch, +) -> Tuple[Tuple[bool, ...], ...] | None: + """Return broadcast pattern for block sparse tensors by checking actual strides. + + Returns a tuple of broadcast patterns (one per tensor) where each pattern + is a tuple of bools indicating which dims have stride=0. + This is used in compile keys to ensure kernels are recompiled when + broadcast patterns change, since CuTe's mark_layout_dynamic() keeps + stride=0 as static. + + The tensors should already be expanded/normalized before calling this function. + + Returns None if block sparsity is not enabled. + """ + if not is_block_sparsity_enabled(tensors): + return None + + patterns = [] + for tensor in ( + tensors.mask_block_cnt, + tensors.mask_block_idx, + tensors.full_block_cnt, + tensors.full_block_idx, + ): + if tensor is not None: + patterns.append(get_broadcast_dims(tensor)) + else: + patterns.append(None) + return tuple(patterns) + + def to_cute_block_sparse_tensors( tensors: BlockSparseTensorsTorch, enable_tvm_ffi: bool = True ) -> BlockSparseTensors | None: diff --git a/flash_attn/cute/cute_dsl_utils.py b/flash_attn/cute/cute_dsl_utils.py index 9d6ee345d00..14723872b85 100644 --- a/flash_attn/cute/cute_dsl_utils.py +++ b/flash_attn/cute/cute_dsl_utils.py @@ -132,3 +132,13 @@ def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, ena if leading_dim == -1: leading_dim = t.ndim - 1 return tensor.mark_layout_dynamic(leading_dim=leading_dim) + + +def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]: + """Return tuple of bools indicating which dims have stride=0 (broadcast). + + This is useful for compile keys since CuTe's mark_layout_dynamic() keeps + stride=0 as static, meaning kernels compiled with different broadcast + patterns are not interchangeable. + """ + return tuple(s == 0 for s in tensor.stride()) diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index 9d5b25b25e0..8d240698ce9 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -48,6 +48,7 @@ normalize_block_sparse_tensors, get_block_sparse_expected_shapes, get_block_sparse_expected_shapes_bwd, + get_block_sparse_broadcast_pattern, ) @lru_cache(maxsize=None) @@ -340,6 +341,25 @@ def _flash_attn_fwd( "Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split." ) + # See get_broadcast_dims for why this is needed in compile key + block_sparse_broadcast_pattern = None + normalized_block_sparse_tensors = None + if block_sparse_tensors is not None: + if seqlen_q is None: + raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).") + expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes( + batch_size, num_head, seqlen_q, seqlen_k, + m_block_size, n_block_size, q_stage, + ) + normalized_block_sparse_tensors = normalize_block_sparse_tensors( + block_sparse_tensors, + expected_count_shape=expected_count_shape, + expected_index_shape=expected_index_shape, + ) + block_sparse_broadcast_pattern = get_block_sparse_broadcast_pattern( + normalized_block_sparse_tensors + ) + compile_key = ( dtype, head_dim, @@ -349,6 +369,7 @@ def _flash_attn_fwd( score_mod_hash, mask_mod_hash, use_block_sparsity, + block_sparse_broadcast_pattern, len(aux_tensors) if aux_tensors is not None else 0, lse is None, cu_seqlens_q is None, @@ -397,19 +418,8 @@ def _flash_attn_fwd( lse_tensor = None sparse_tensors = None - if block_sparse_tensors is not None: - if seqlen_q is None: - raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).") - expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes( - batch_size, num_head, seqlen_q, seqlen_k, - m_block_size, n_block_size, q_stage, - ) - compile_time_normalized = normalize_block_sparse_tensors( - block_sparse_tensors, - expected_count_shape=expected_count_shape, - expected_index_shape=expected_index_shape, - ) - sparse_tensors = to_cute_block_sparse_tensors(compile_time_normalized) + if normalized_block_sparse_tensors is not None: + sparse_tensors = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) cute_aux_tensors = None if aux_tensors is not None: @@ -490,18 +500,6 @@ def _flash_attn_fwd( options="--enable-tvm-ffi", ) - # Expand block sparse tensors to match actual head count (may be broadcast from 1) - normalized_block_sparse_tensors = None - if block_sparse_tensors is not None: - expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes( - batch_size, num_head, seqlen_q, seqlen_k, - m_block_size, n_block_size, q_stage, - ) - normalized_block_sparse_tensors = normalize_block_sparse_tensors( - block_sparse_tensors, - expected_count_shape=expected_count_shape, - expected_index_shape=expected_index_shape, - ) _flash_attn_fwd.compile_cache[compile_key]( q, k, @@ -880,6 +878,28 @@ def _flash_attn_bwd( if aux_tensors is not None: cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors] + block_sparse_broadcast_pattern = None + normalized_block_sparse_tensors = None + if block_sparse_tensors is not None: + expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd( + batch_size, num_head, seqlen_q, seqlen_k, + m_block_size, n_block_size, subtile_factor, + ) + normalized_block_sparse_tensors = normalize_block_sparse_tensors( + block_sparse_tensors, + expected_count_shape=expected_count_shape, + expected_index_shape=expected_index_shape, + context="_flash_attn_bwd", + hint=lambda: ( + f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, and optionally full_q_cnt/full_q_idx). " + f"Regenerate the backward BlockMask with BLOCK_SIZE=({sparse_block_size_q}, {n_block_size}) " + f"(sparse_block_size_q={sparse_block_size_q})." + ), + ) + block_sparse_broadcast_pattern = get_block_sparse_broadcast_pattern( + normalized_block_sparse_tensors + ) + if compute_capability == 9: compile_key = ( compute_capability, @@ -911,6 +931,7 @@ def _flash_attn_bwd( mask_mod_hash, num_aux_tensors, use_block_sparsity, + block_sparse_broadcast_pattern, ) else: compile_key = ( @@ -934,10 +955,11 @@ def _flash_attn_bwd( mask_mod_hash, num_aux_tensors, use_block_sparsity, + block_sparse_broadcast_pattern, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, - seqused_k is None, + seqused_k is None, ) if compile_key not in _flash_attn_bwd.compile_cache: q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ @@ -1027,23 +1049,8 @@ def _flash_attn_bwd( # Block sparse tensors for backward use Q-direction indexing (transposed from forward). # sparse_block_size_q = subtile_factor * tile_m matches BlockMask granularity. sparse_tensors_compile = None - if block_sparse_tensors is not None: - expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd( - batch_size, num_head, seqlen_q, seqlen_k, - m_block_size, n_block_size, subtile_factor, - ) - compile_time_normalized = normalize_block_sparse_tensors( - block_sparse_tensors, - expected_count_shape=expected_count_shape, - expected_index_shape=expected_index_shape, - context="_flash_attn_bwd", - hint=lambda: ( - f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, and optionally full_q_cnt/full_q_idx). " - f"Regenerate the backward BlockMask with BLOCK_SIZE=({sparse_block_size_q}, {n_block_size}) " - f"(sparse_block_size_q={sparse_block_size_q})." - ), - ) - sparse_tensors_compile = to_cute_block_sparse_tensors(compile_time_normalized) + if normalized_block_sparse_tensors is not None: + sparse_tensors_compile = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) # TODO: check @can_implement _flash_attn_bwd.compile_cache[compile_key] = cute.compile( @@ -1073,25 +1080,6 @@ def _flash_attn_bwd( sparse_tensors_compile, options="--enable-tvm-ffi", ) - # Runtime normalization of block sparse tensors for both SM90 and SM100 - normalized_block_sparse_tensors = None - if block_sparse_tensors is not None: - expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd( - batch_size, num_head, seqlen_q, seqlen_k, - m_block_size, n_block_size, subtile_factor, - ) - normalized_block_sparse_tensors = normalize_block_sparse_tensors( - block_sparse_tensors, - expected_count_shape=expected_count_shape, - expected_index_shape=expected_index_shape, - context="_flash_attn_bwd", - hint=lambda: ( - f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, and optionally full_q_cnt/full_q_idx). " - f"Regenerate the backward BlockMask with BLOCK_SIZE=({sparse_block_size_q}, {n_block_size}) " - f"(sparse_block_size_q={sparse_block_size_q})." - ), - ) - _flash_attn_bwd.compile_cache[compile_key]( q, k, diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index 96e051c5655..a4b5bf27107 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -893,5 +893,101 @@ def test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message(): ) +def test_gqa_block_sparse_broadcast_pattern_recompilation(): + """Test that different block sparse broadcast patterns trigger recompilation. + + This is a regression test for a bug where: + 1. First call with block_mask H=1 (broadcasts across all query heads) + 2. Second call with block_mask H=nheads (no broadcast) + 3. Second call incorrectly reused cached kernel from first call + + The fix adds block_sparse_broadcast_pattern to the compile key so that + kernels are recompiled when broadcast patterns change. CuTe's + mark_layout_dynamic() keeps stride=0 as static, so different broadcast + patterns require different compiled kernels. + """ + torch.manual_seed(42) + + batch_size = 2 + nheads = 8 + nheads_kv = 2 + seqlen = 257 + headdim = 64 + dtype = torch.bfloat16 + tile_m = 128 + tile_n = 128 + + sparse_tile_m = 2 * tile_m if COMPUTE_CAPABILITY == 10 else tile_m + + def causal_mask(b, h, q, kv): + return q >= kv + + mask_mod_cute, _ = get_mask_pair("causal", seqlen_q=seqlen, seqlen_k=seqlen) + + tensors = create_tensors(batch_size, seqlen, seqlen, nheads, nheads_kv, headdim, headdim, dtype) + q, k, v = tensors["q"], tensors["k"], tensors["v"] + grad_out = torch.randn_like(tensors["out"]) + softmax_scale = 1.0 / math.sqrt(headdim) + + def run_with_block_mask_nheads(block_mask_nheads: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bm = create_block_mask( + causal_mask, batch_size, block_mask_nheads, seqlen, seqlen, + device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n), + ) + ( + _seq_q, _seq_k, + kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, + q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, *_, + ) = bm.as_tuple() + + block_sparse_fwd = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, full_block_idx=full_kv_idx, + ) + block_sparse_bwd = BlockSparseTensorsTorch( + mask_block_cnt=q_mask_cnt, mask_block_idx=q_mask_idx, + full_block_cnt=full_q_cnt, full_block_idx=full_q_idx, + ) + + out = torch.empty_like(tensors["out"]) + lse = torch.empty_like(tensors["lse"]) + + out_tuple = _flash_attn_fwd( + q=q, k=k, v=v, out=out, lse=lse, + softmax_scale=softmax_scale, causal=False, + window_size_left=-1, window_size_right=-1, + m_block_size=tile_m, n_block_size=tile_n, pack_gqa=False, + mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_fwd, + return_lse=True, + ) + out_cute, lse_cute = out_tuple[0], out_tuple[1] + + dq, dk, dv = run_cute_mask_bwd( + q, k, v, out_cute, lse_cute, grad_out, mask_mod_cute, + block_sparse_mask_bwd=block_sparse_bwd, tile_m=tile_m, tile_n=tile_n, + ) + return dq, dk, dv + + flex_block_mask = create_block_mask( + causal_mask, batch_size, nheads, seqlen, seqlen, + device="cuda", BLOCK_SIZE=(tile_m, tile_n), + ) + _, dq_ref, dk_ref, dv_ref = run_flex_reference_bwd(q, k, v, flex_block_mask, grad_out, dtype=torch.float32) + dq_ref, dk_ref, dv_ref = dq_ref.to(dtype), dk_ref.to(dtype), dv_ref.to(dtype) + + dq_broadcast, dk_broadcast, dv_broadcast = run_with_block_mask_nheads(1) + dq_no_broadcast, dk_no_broadcast, dv_no_broadcast = run_with_block_mask_nheads(nheads) + + err_broadcast_dq = (dq_broadcast - dq_ref).abs().max().item() + err_no_broadcast_dq = (dq_no_broadcast - dq_ref).abs().max().item() + + print(f"\nGQA block sparse broadcast pattern test:") + print(f" dQ error (H=1 broadcast): {err_broadcast_dq:.2e}") + print(f" dQ error (H={nheads} no broadcast): {err_no_broadcast_dq:.2e}") + + assert err_broadcast_dq < 0.1, f"Broadcast dQ error too large: {err_broadcast_dq:.2e}" + assert err_no_broadcast_dq < 0.1, f"No-broadcast dQ error too large: {err_no_broadcast_dq:.2e}" + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"])