diff --git a/flash_attn/cute/block_sparsity.py b/flash_attn/cute/block_sparsity.py index 6c72f00f728..3fad8c9f491 100644 --- a/flash_attn/cute/block_sparsity.py +++ b/flash_attn/cute/block_sparsity.py @@ -98,6 +98,12 @@ def _check_and_expand_block( expanded_cnt = _expand_sparsity_tensor( cnt, expected_count_shape, f"{name}_block_cnt", context, hint ) + # [Note] Allow Compact block sparse indices + # Allow the last dimension (n_blocks) of idx to be <= expected, since + # FA4 only accesses indices 0..cnt-1 per query tile. This enables compact + # index tensors that avoid O(N^2) memory at long sequence lengths. + if idx.ndim == 4 and idx.shape[3] <= expected_index_shape[3]: + expected_index_shape = (*expected_index_shape[:3], idx.shape[3]) expanded_idx = _expand_sparsity_tensor( idx, expected_index_shape, f"{name}_block_idx", context, hint ) @@ -200,9 +206,11 @@ def infer_block_sparse_expected_shapes( raise ValueError(f"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.") if mask_block_cnt.shape[2] != mask_block_idx.shape[2]: raise ValueError(f"Block sparse tensors{context} must share the same m-block dimension.") - if mask_block_idx.shape[3] != expected_n_blocks: + # [Note] Allow Compact block sparse indices: FA4 only accesses indices 0..cnt-1 + # per query tile, so idx.shape[3] can be <= expected_n_blocks. + if mask_block_idx.shape[3] > expected_n_blocks: raise ValueError( - f"Block sparse tensors{context} n-block dimension must be {expected_n_blocks}." + f"Block sparse tensors{context} n-block dimension must be <= {expected_n_blocks}." ) if expected_m_blocks != num_m_blocks: raise ValueError( diff --git a/tests/cute/test_mask_mod.py b/tests/cute/test_mask_mod.py index d9ca8c68df5..26e0a5e1353 100644 --- a/tests/cute/test_mask_mod.py +++ b/tests/cute/test_mask_mod.py @@ -1712,5 +1712,91 @@ def test_persistent_blocksparse_empty_tiles(): +def test_compact_block_sparse_indices(): + """Test that compact block sparse index tensors (idx.shape[3] < n_blocks) work correctly. + + FA4 only accesses indices 0..cnt-1 per query tile, so the index tensor's last + dimension does not need to be as large as ceil(seqlen_k / block_size_n). This + test verifies that truncated (compact) index tensors produce identical output + to full-sized ones. + """ + torch.manual_seed(42) + batch_size = 1 + nheads = 4 + seqlen_q = 1024 + seqlen_k = 1024 + headdim = 128 + tile_m = 128 + tile_n = 128 + dtype = torch.bfloat16 + + sparse_tile_m = 2 * tile_m if COMPUTE_CAPABILITY == 10 else tile_m + + mask_mod_cute, mask_mod_flex = get_mask_pair( + "block_diagonal", seqlen_q=seqlen_q, seqlen_k=seqlen_k, window_size=None + ) + tensors = create_tensors( + batch_size, seqlen_q, seqlen_k, nheads, nheads, headdim, headdim, dtype + ) + + bm = create_block_mask( + mask_mod_flex, batch_size, nheads, seqlen_q, seqlen_k, + device="cuda", BLOCK_SIZE=(sparse_tile_m, tile_n), + ) + (_, _, kv_mask_cnt, kv_mask_idx, full_kv_cnt, full_kv_idx, *_) = bm.as_tuple() + + # Determine the max count across all query tiles — this is the compact last dim + max_mask_k = kv_mask_cnt.max().item() if kv_mask_cnt is not None else 0 + max_full_k = full_kv_cnt.max().item() if full_kv_cnt is not None else 0 + max_k = max(max_mask_k, max_full_k, 1) + + # Truncate index tensors to compact size + kv_mask_idx_compact = kv_mask_idx[:, :, :, :max_k].contiguous() + full_kv_idx_compact = full_kv_idx[:, :, :, :max_k].contiguous() if full_kv_idx is not None else None + + block_sparse_compact = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx_compact, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx_compact, + block_size=(sparse_tile_m, tile_n), + ) + + out_compact, _ = _flash_attn_fwd( + q=tensors["q"], k=tensors["k"], v=tensors["v"], + out=tensors["out"].clone(), lse=tensors["lse"].clone(), + softmax_scale=1.0 / math.sqrt(headdim), + causal=False, mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_compact, + return_lse=True, + ) + + # Reference: use full-sized index tensors + block_sparse_full = BlockSparseTensorsTorch( + mask_block_cnt=kv_mask_cnt, + mask_block_idx=kv_mask_idx, + full_block_cnt=full_kv_cnt, + full_block_idx=full_kv_idx, + block_size=(sparse_tile_m, tile_n), + ) + + out_full, _ = _flash_attn_fwd( + q=tensors["q"], k=tensors["k"], v=tensors["v"], + out=tensors["out"].clone(), lse=tensors["lse"].clone(), + softmax_scale=1.0 / math.sqrt(headdim), + causal=False, mask_mod=mask_mod_cute, + block_sparse_tensors=block_sparse_full, + return_lse=True, + ) + + assert not torch.isnan(out_compact).any(), "Compact output has NaN" + assert torch.isfinite(out_compact).all(), "Compact output has Inf" + # Compact and full should produce bit-identical results + assert torch.equal(out_compact, out_full), ( + f"Compact and full block sparse outputs differ: " + f"max diff = {(out_compact - out_full).abs().max().item():.2e}" + ) + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"])