Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions flash_attn/cute/block_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def _check_and_expand_block(
expanded_cnt = _expand_sparsity_tensor(
cnt, expected_count_shape, f"{name}_block_cnt", context, hint
)
# [Note] Allow Compact block sparse indices
# Allow the last dimension (n_blocks) of idx to be <= expected, since
# FA4 only accesses indices 0..cnt-1 per query tile. This enables compact
# index tensors that avoid O(N^2) memory at long sequence lengths.
if idx.ndim == 4 and idx.shape[3] <= expected_index_shape[3]:
expected_index_shape = (*expected_index_shape[:3], idx.shape[3])
expanded_idx = _expand_sparsity_tensor(
idx, expected_index_shape, f"{name}_block_idx", context, hint
)
Expand Down Expand Up @@ -200,9 +206,11 @@ def infer_block_sparse_expected_shapes(
raise ValueError(f"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.")
if mask_block_cnt.shape[2] != mask_block_idx.shape[2]:
raise ValueError(f"Block sparse tensors{context} must share the same m-block dimension.")
if mask_block_idx.shape[3] != expected_n_blocks:
# [Note] Allow Compact block sparse indices: FA4 only accesses indices 0..cnt-1
# per query tile, so idx.shape[3] can be <= expected_n_blocks.
if mask_block_idx.shape[3] > expected_n_blocks:
raise ValueError(
f"Block sparse tensors{context} n-block dimension must be {expected_n_blocks}."
f"Block sparse tensors{context} n-block dimension must be <= {expected_n_blocks}."
)
if expected_m_blocks != num_m_blocks:
raise ValueError(
Expand Down
86 changes: 86 additions & 0 deletions tests/cute/test_mask_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -1712,5 +1712,91 @@ def test_persistent_blocksparse_empty_tiles():



def test_compact_block_sparse_indices():
"""Test that compact block sparse index tensors (idx.shape[3] < n_blocks) work correctly.

FA4 only accesses indices 0..cnt-1 per query tile, so the index tensor's last
dimension does not need to be as large as ceil(seqlen_k / block_size_n). This
test verifies that truncated (compact) index tensors produce identical output
to full-sized ones.
"""
torch.manual_seed(42)
batch_size = 1
nheads = 4
seqlen_q = 1024
seqlen_k = 1024
headdim = 128
tile_m = 128
tile_n = 128
dtype = torch.bfloat16

sparse_tile_m = 2 * tile_m if COMPUTE_CAPABILITY == 10 else tile_m

mask_mod_cute, mask_mod_flex = get_mask_pair(
"block_diagonal", seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=None
)
tensors = create_tensors(
batch_size, seqlen_q, seqlen_k, nheads, nheads, headdim, headdim, dtype
)

bm = create_block_mask(
mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k,
device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n),
)
(_, _, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, *_) = bm.as_tuple()

# Determine the max count across all query tiles — this is the compact last dim
max_mask_k = kv_mask_cnt.max().item() if kv_mask_cnt is not None else 0
max_full_k = full_kv_cnt.max().item() if full_kv_cnt is not None else 0
max_k = max(max_mask_k, max_full_k, 1)

# Truncate index tensors to compact size
kv_mask_idx_compact = kv_mask_idx[:, :, :, :max_k].contiguous()
full_kv_idx_compact = full_kv_idx[:, :, :, :max_k].contiguous() if full_kv_idx is not None else None

block_sparse_compact = BlockSparseTensorsTorch(
mask_block_cnt=kv_mask_cnt,
mask_block_idx=kv_mask_idx_compact,
full_block_cnt=full_kv_cnt,
full_block_idx=full_kv_idx_compact,
block_size=(sparse_tile_m, tile_n),
)

out_compact, _ = _flash_attn_fwd(
q=tensors["q"], k=tensors["k"], v=tensors["v"],
out=tensors["out"].clone(), lse=tensors["lse"].clone(),
softmax_scale=1.0 / math.sqrt(headdim),
causal=False, mask_mod=mask_mod_cute,
block_sparse_tensors=block_sparse_compact,
return_lse=True,
)

# Reference: use full-sized index tensors
block_sparse_full = BlockSparseTensorsTorch(
mask_block_cnt=kv_mask_cnt,
mask_block_idx=kv_mask_idx,
full_block_cnt=full_kv_cnt,
full_block_idx=full_kv_idx,
block_size=(sparse_tile_m, tile_n),
)

out_full, _ = _flash_attn_fwd(
q=tensors["q"], k=tensors["k"], v=tensors["v"],
out=tensors["out"].clone(), lse=tensors["lse"].clone(),
softmax_scale=1.0 / math.sqrt(headdim),
causal=False, mask_mod=mask_mod_cute,
block_sparse_tensors=block_sparse_full,
return_lse=True,
)

assert not torch.isnan(out_compact).any(), "Compact output has NaN"
assert torch.isfinite(out_compact).all(), "Compact output has Inf"
# Compact and full should produce bit-identical results
assert torch.equal(out_compact, out_full), (
f"Compact and full block sparse outputs differ: "
f"max diff = {(out_compact - out_full).abs().max().item():.2e}"
)


if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])