Allow compact block sparse index tensors#2417
Merged
drisspg merged 2 commits intoDao-AILab:mainfrom Apr 1, 2026
Merged
Conversation
drisspg
reviewed
Mar 31, 2026
drisspg
reviewed
Mar 31, 2026
Collaborator
drisspg
left a comment
There was a problem hiding this comment.
Can you add 1 test in test_mask_mod
Relax validation in block_sparsity.py to allow idx.shape[3] <= expected_n_blocks instead of requiring exact equality. FA4 only accesses indices 0..cnt-1 per query tile, so the index tensor's last dimension does not need to be as large as ceil(seqlen_k / block_size_n). This enables memory-efficient compact index tensors that avoid O(N^2) memory at long sequence lengths (e.g., 1M+ tokens for sparse attention / NSA workloads). Changes: - _check_and_expand_block: accept compact n-block dimension and expand only the batch/head/m-block dimensions - infer_block_sparse_expected_shapes: change strict equality check to upper-bound check (error only when n-blocks exceeds expected, not when smaller) Backward compatible: existing code that passes full-sized tensors is unaffected.
de84a3a to
04d5eca
Compare
Verify that truncating block sparse index tensors to idx.shape[3] = max(cnt) (instead of the full ceil(seqlen_k / block_size_n)) produces bit-identical output to full-sized tensors. This validates the relaxed validation from the previous commit.
drisspg
approved these changes
Apr 1, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Relax validation in block_sparsity.py to allow idx.shape[3] <= expected_n_blocks instead of requiring exact equality.
FA4 only accesses indices 0..cnt-1 per query tile, so the index tensor's last dimension does not need to be as large as ceil(seqlen_k / block_size_n). This enables memory-efficient compact index tensors that avoid O(N^2) memory at long sequence lengths (e.g., sparse attention).
Changes:
See: [Note] Allow Compact block sparse indices
Backward compatible: existing code that passes full-sized tensors is unaffected.
This is a followup on PR #2085