Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion flash_attn/cute/block_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import cutlass.cute as cute
import torch

from flash_attn.cute.cute_dsl_utils import to_cute_tensor
from flash_attn.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor


def ceildiv(a: int, b: int) -> int:
Expand Down Expand Up @@ -174,6 +174,38 @@ def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool:
return any(t is not None for t in (tensors.full_block_cnt, tensors.mask_block_cnt))


def get_block_sparse_broadcast_pattern(
tensors: BlockSparseTensorsTorch,
) -> Tuple[Tuple[bool, ...], ...] | None:
"""Return broadcast pattern for block sparse tensors by checking actual strides.

Returns a tuple of broadcast patterns (one per tensor) where each pattern
is a tuple of bools indicating which dims have stride=0.
This is used in compile keys to ensure kernels are recompiled when
broadcast patterns change, since CuTe's mark_layout_dynamic() keeps
stride=0 as static.

The tensors should already be expanded/normalized before calling this function.

Returns None if block sparsity is not enabled.
"""
if not is_block_sparsity_enabled(tensors):
return None

patterns = []
for tensor in (
tensors.mask_block_cnt,
tensors.mask_block_idx,
tensors.full_block_cnt,
tensors.full_block_idx,
):
if tensor is not None:
patterns.append(get_broadcast_dims(tensor))
else:
patterns.append(None)
return tuple(patterns)


def to_cute_block_sparse_tensors(
tensors: BlockSparseTensorsTorch, enable_tvm_ffi: bool = True
) -> BlockSparseTensors | None:
Expand Down
10 changes: 10 additions & 0 deletions flash_attn/cute/cute_dsl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,13 @@ def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, ena
if leading_dim == -1:
leading_dim = t.ndim - 1
return tensor.mark_layout_dynamic(leading_dim=leading_dim)


def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]:
"""Return tuple of bools indicating which dims have stride=0 (broadcast).

This is useful for compile keys since CuTe's mark_layout_dynamic() keeps
stride=0 as static, meaning kernels compiled with different broadcast
patterns are not interchangeable.
"""
return tuple(s == 0 for s in tensor.stride())
112 changes: 50 additions & 62 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
normalize_block_sparse_tensors,
get_block_sparse_expected_shapes,
get_block_sparse_expected_shapes_bwd,
get_block_sparse_broadcast_pattern,
)

@lru_cache(maxsize=None)
Expand Down Expand Up @@ -340,6 +341,25 @@ def _flash_attn_fwd(
"Block sparsity is not yet supported with SplitKV. TODO: partition sparse block lists per split."
)

# See get_broadcast_dims for why this is needed in compile key
block_sparse_broadcast_pattern = None
normalized_block_sparse_tensors = None
if block_sparse_tensors is not None:
if seqlen_q is None:
raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).")
expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes(
batch_size, num_head, seqlen_q, seqlen_k,
m_block_size, n_block_size, q_stage,
)
normalized_block_sparse_tensors = normalize_block_sparse_tensors(
block_sparse_tensors,
expected_count_shape=expected_count_shape,
expected_index_shape=expected_index_shape,
)
block_sparse_broadcast_pattern = get_block_sparse_broadcast_pattern(
normalized_block_sparse_tensors
)

compile_key = (
dtype,
head_dim,
Expand All @@ -349,6 +369,7 @@ def _flash_attn_fwd(
score_mod_hash,
mask_mod_hash,
use_block_sparsity,
block_sparse_broadcast_pattern,
len(aux_tensors) if aux_tensors is not None else 0,
lse is None,
cu_seqlens_q is None,
Expand Down Expand Up @@ -397,19 +418,8 @@ def _flash_attn_fwd(
lse_tensor = None

sparse_tensors = None
if block_sparse_tensors is not None:
Comment thread
v0i0 marked this conversation as resolved.
if seqlen_q is None:
raise ValueError("Block sparsity requires fixed-length sequences (seqlen_q must be known).")
expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes(
batch_size, num_head, seqlen_q, seqlen_k,
m_block_size, n_block_size, q_stage,
)
compile_time_normalized = normalize_block_sparse_tensors(
block_sparse_tensors,
expected_count_shape=expected_count_shape,
expected_index_shape=expected_index_shape,
)
sparse_tensors = to_cute_block_sparse_tensors(compile_time_normalized)
if normalized_block_sparse_tensors is not None:
sparse_tensors = to_cute_block_sparse_tensors(normalized_block_sparse_tensors)

cute_aux_tensors = None
if aux_tensors is not None:
Expand Down Expand Up @@ -490,18 +500,6 @@ def _flash_attn_fwd(
options="--enable-tvm-ffi",
)

# Expand block sparse tensors to match actual head count (may be broadcast from 1)
normalized_block_sparse_tensors = None
if block_sparse_tensors is not None:
expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes(
batch_size, num_head, seqlen_q, seqlen_k,
m_block_size, n_block_size, q_stage,
)
normalized_block_sparse_tensors = normalize_block_sparse_tensors(
block_sparse_tensors,
expected_count_shape=expected_count_shape,
expected_index_shape=expected_index_shape,
)
_flash_attn_fwd.compile_cache[compile_key](
q,
k,
Expand Down Expand Up @@ -880,6 +878,28 @@ def _flash_attn_bwd(
if aux_tensors is not None:
cute_aux_tensors = [to_cute_tensor(buf, assumed_align=None, fully_dynamic=True) for buf in aux_tensors]

block_sparse_broadcast_pattern = None
normalized_block_sparse_tensors = None
if block_sparse_tensors is not None:
expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd(
batch_size, num_head, seqlen_q, seqlen_k,
m_block_size, n_block_size, subtile_factor,
)
normalized_block_sparse_tensors = normalize_block_sparse_tensors(
block_sparse_tensors,
expected_count_shape=expected_count_shape,
expected_index_shape=expected_index_shape,
context="_flash_attn_bwd",
hint=lambda: (
f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, and optionally full_q_cnt/full_q_idx). "
f"Regenerate the backward BlockMask with BLOCK_SIZE=({sparse_block_size_q}, {n_block_size}) "
f"(sparse_block_size_q={sparse_block_size_q})."
),
)
block_sparse_broadcast_pattern = get_block_sparse_broadcast_pattern(
normalized_block_sparse_tensors
)

if compute_capability == 9:
compile_key = (
compute_capability,
Expand Down Expand Up @@ -911,6 +931,7 @@ def _flash_attn_bwd(
mask_mod_hash,
num_aux_tensors,
use_block_sparsity,
block_sparse_broadcast_pattern,
)
else:
compile_key = (
Expand All @@ -934,10 +955,11 @@ def _flash_attn_bwd(
mask_mod_hash,
num_aux_tensors,
use_block_sparsity,
block_sparse_broadcast_pattern,
cu_seqlens_q is None,
cu_seqlens_k is None,
seqused_q is None,
seqused_k is None,
seqused_k is None,
)
if compile_key not in _flash_attn_bwd.compile_cache:
q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [
Expand Down Expand Up @@ -1027,23 +1049,8 @@ def _flash_attn_bwd(
# Block sparse tensors for backward use Q-direction indexing (transposed from forward).
# sparse_block_size_q = subtile_factor * tile_m matches BlockMask granularity.
sparse_tensors_compile = None
if block_sparse_tensors is not None:
expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd(
batch_size, num_head, seqlen_q, seqlen_k,
m_block_size, n_block_size, subtile_factor,
)
compile_time_normalized = normalize_block_sparse_tensors(
block_sparse_tensors,
expected_count_shape=expected_count_shape,
expected_index_shape=expected_index_shape,
context="_flash_attn_bwd",
hint=lambda: (
f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, and optionally full_q_cnt/full_q_idx). "
f"Regenerate the backward BlockMask with BLOCK_SIZE=({sparse_block_size_q}, {n_block_size}) "
f"(sparse_block_size_q={sparse_block_size_q})."
),
)
sparse_tensors_compile = to_cute_block_sparse_tensors(compile_time_normalized)
if normalized_block_sparse_tensors is not None:
sparse_tensors_compile = to_cute_block_sparse_tensors(normalized_block_sparse_tensors)

# TODO: check @can_implement
_flash_attn_bwd.compile_cache[compile_key] = cute.compile(
Expand Down Expand Up @@ -1073,25 +1080,6 @@ def _flash_attn_bwd(
sparse_tensors_compile,
options="--enable-tvm-ffi",
)
# Runtime normalization of block sparse tensors for both SM90 and SM100
normalized_block_sparse_tensors = None
if block_sparse_tensors is not None:
expected_count_shape, expected_index_shape = get_block_sparse_expected_shapes_bwd(
batch_size, num_head, seqlen_q, seqlen_k,
m_block_size, n_block_size, subtile_factor,
)
normalized_block_sparse_tensors = normalize_block_sparse_tensors(
block_sparse_tensors,
expected_count_shape=expected_count_shape,
expected_index_shape=expected_index_shape,
context="_flash_attn_bwd",
hint=lambda: (
f"Backward expects Q-direction block-sparse tensors (q_mask_cnt/q_mask_idx, and optionally full_q_cnt/full_q_idx). "
f"Regenerate the backward BlockMask with BLOCK_SIZE=({sparse_block_size_q}, {n_block_size}) "
f"(sparse_block_size_q={sparse_block_size_q})."
),
)

_flash_attn_bwd.compile_cache[compile_key](
q,
k,
Expand Down
96 changes: 96 additions & 0 deletions tests/cute/test_mask_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,5 +893,101 @@ def test_sm90_block_sparse_bwd_mismatched_q_block_granularity_error_message():
)


def test_gqa_block_sparse_broadcast_pattern_recompilation():
"""Test that different block sparse broadcast patterns trigger recompilation.

This is a regression test for a bug where:
1. First call with block_mask H=1 (broadcasts across all query heads)
2. Second call with block_mask H=nheads (no broadcast)
3. Second call incorrectly reused cached kernel from first call

The fix adds block_sparse_broadcast_pattern to the compile key so that
kernels are recompiled when broadcast patterns change. CuTe's
mark_layout_dynamic() keeps stride=0 as static, so different broadcast
patterns require different compiled kernels.
"""
torch.manual_seed(42)

batch_size = 2
nheads = 8
nheads_kv = 2
seqlen = 257
headdim = 64
dtype = torch.bfloat16
tile_m = 128
tile_n = 128

sparse_tile_m = 2 * tile_m if COMPUTE_CAPABILITY == 10 else tile_m

def causal_mask(b, h, q, kv):
return q >= kv

mask_mod_cute, _ = get_mask_pair("causal", seqlen_q=seqlen, seqlen_k=seqlen)

tensors = create_tensors(batch_size, seqlen, seqlen, nheads, nheads_kv, headdim, headdim, dtype)
q, k, v = tensors["q"], tensors["k"], tensors["v"]
grad_out = torch.randn_like(tensors["out"])
softmax_scale = 1.0 / math.sqrt(headdim)

def run_with_block_mask_nheads(block_mask_nheads: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bm = create_block_mask(
causal_mask, batch_size, block_mask_nheads, seqlen, seqlen,
device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n),
)
(
_seq_q, _seq_k,
kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx,
q_mask_cnt, q_mask_idx, full_q_cnt, full_q_idx, *_,
) = bm.as_tuple()

block_sparse_fwd = BlockSparseTensorsTorch(
mask_block_cnt=kv_mask_cnt, mask_block_idx=kv_mask_idx,
full_block_cnt=full_kv_cnt, full_block_idx=full_kv_idx,
)
block_sparse_bwd = BlockSparseTensorsTorch(
mask_block_cnt=q_mask_cnt, mask_block_idx=q_mask_idx,
full_block_cnt=full_q_cnt, full_block_idx=full_q_idx,
)

out = torch.empty_like(tensors["out"])
lse = torch.empty_like(tensors["lse"])

out_tuple = _flash_attn_fwd(
q=q, k=k, v=v, out=out, lse=lse,
softmax_scale=softmax_scale, causal=False,
window_size_left=-1, window_size_right=-1,
m_block_size=tile_m, n_block_size=tile_n, pack_gqa=False,
mask_mod=mask_mod_cute, block_sparse_tensors=block_sparse_fwd,
return_lse=True,
)
out_cute, lse_cute = out_tuple[0], out_tuple[1]

dq, dk, dv = run_cute_mask_bwd(
q, k, v, out_cute, lse_cute, grad_out, mask_mod_cute,
block_sparse_mask_bwd=block_sparse_bwd, tile_m=tile_m, tile_n=tile_n,
)
return dq, dk, dv

flex_block_mask = create_block_mask(
causal_mask, batch_size, nheads, seqlen, seqlen,
device="cuda", BLOCK_SIZE=(tile_m, tile_n),
)
_, dq_ref, dk_ref, dv_ref = run_flex_reference_bwd(q, k, v, flex_block_mask, grad_out, dtype=torch.float32)
dq_ref, dk_ref, dv_ref = dq_ref.to(dtype), dk_ref.to(dtype), dv_ref.to(dtype)

dq_broadcast, dk_broadcast, dv_broadcast = run_with_block_mask_nheads(1)
dq_no_broadcast, dk_no_broadcast, dv_no_broadcast = run_with_block_mask_nheads(nheads)

err_broadcast_dq = (dq_broadcast - dq_ref).abs().max().item()
err_no_broadcast_dq = (dq_no_broadcast - dq_ref).abs().max().item()

print(f"\nGQA block sparse broadcast pattern test:")
print(f" dQ error (H=1 broadcast): {err_broadcast_dq:.2e}")
print(f" dQ error (H={nheads} no broadcast): {err_no_broadcast_dq:.2e}")

assert err_broadcast_dq < 0.1, f"Broadcast dQ error too large: {err_broadcast_dq:.2e}"
assert err_no_broadcast_dq < 0.1, f"No-broadcast dQ error too large: {err_no_broadcast_dq:.2e}"


if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])