Skip to content

Allow compact block sparse index tensors#2417

Merged
drisspg merged 2 commits intoDao-AILab:mainfrom
jduprat:compact-block-sparse-indices
Apr 1, 2026
Merged

Allow compact block sparse index tensors#2417
drisspg merged 2 commits intoDao-AILab:mainfrom
jduprat:compact-block-sparse-indices

Conversation

@jduprat
Copy link
Copy Markdown
Contributor

@jduprat jduprat commented Mar 31, 2026

Relax validation in block_sparsity.py to allow idx.shape[3] <= expected_n_blocks instead of requiring exact equality.

FA4 only accesses indices 0..cnt-1 per query tile, so the index tensor's last dimension does not need to be as large as ceil(seqlen_k / block_size_n). This enables memory-efficient compact index tensors that avoid O(N^2) memory at long sequence lengths (e.g., sparse attention).

Changes:

  • _check_and_expand_block: accept compact n-block dimension and expand only the batch/head/m-block dimensions
  • infer_block_sparse_expected_shapes: change strict equality check to upper-bound check (error only when n-blocks exceeds expected, not when smaller)
    See: [Note] Allow Compact block sparse indices

Backward compatible: existing code that passes full-sized tensors is unaffected.
This is a followup on PR #2085

Comment thread flash_attn/cute/block_sparsity.py Outdated
Copy link
Copy Markdown
Collaborator

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add 1 test in test_mask_mod

Relax validation in block_sparsity.py to allow idx.shape[3] <= expected_n_blocks
instead of requiring exact equality.

FA4 only accesses indices 0..cnt-1 per query tile, so the index tensor's last
dimension does not need to be as large as ceil(seqlen_k / block_size_n). This
enables memory-efficient compact index tensors that avoid O(N^2) memory at long
sequence lengths (e.g., 1M+ tokens for sparse attention / NSA workloads).

Changes:
- _check_and_expand_block: accept compact n-block dimension and expand only the
  batch/head/m-block dimensions
- infer_block_sparse_expected_shapes: change strict equality check to upper-bound
  check (error only when n-blocks exceeds expected, not when smaller)

Backward compatible: existing code that passes full-sized tensors is unaffected.
@jduprat jduprat force-pushed the compact-block-sparse-indices branch from de84a3a to 04d5eca Compare March 31, 2026 23:06
Verify that truncating block sparse index tensors to idx.shape[3] = max(cnt)
(instead of the full ceil(seqlen_k / block_size_n)) produces bit-identical
output to full-sized tensors. This validates the relaxed validation from
the previous commit.
@drisspg drisspg merged commit f6a16e1 into Dao-AILab:main Apr 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants