Skip to content

[CuTe,Flex] varlen blocksparsity#2224

Merged
drisspg merged 2 commits into
Dao-AILab:mainfrom
reubenconducts:rstern/varlen-blocksparsity
May 7, 2026
Merged

[CuTe,Flex] varlen blocksparsity#2224
drisspg merged 2 commits into
Dao-AILab:mainfrom
reubenconducts:rstern/varlen-blocksparsity

Conversation

@reubenconducts
Copy link
Copy Markdown
Contributor

@reubenconducts reubenconducts commented Feb 2, 2026

This PR extends blocksparsity to the variable sequence length case. Whereas in batched blocksparsity the metadata tensors take the shapes

*_block_cnt: [batch_size, num_heads, num_m_blocks]
*_block_idx: [batch_size, num_heads, num_m_blocks, num_n_blocks],

in varlen blocksparsity, we pack our metadata tensors to take the shapes

*_block_cnt: [num_heads, total_m_blocks]
*_block_idx: [num_heads, total_n_blocks]

where total_m_blocks is the sum of all m blocks per head (equiv. number of work tiles per head) and total_n_blocks is the total of all n blocks potentially processed per head across all sequences in the batch. For example, consider a varlen batch with sequences contained in seqlens_q and seqlens_k. At batch index b, we let

num_m(b) = ceildiv(seqlens_q[b], tile_m)
num_n(b) = ceildiv(seqlens_k[b], tile_n)

and define

total_m_blocks = sum_{b \in B} num_m(b)
total_n_blocks = sum_{b \in B} num_m(b) * num_n(b)

To properly index into the blocksparsity tensors, we use auxiliary mCuTotalMBlocks and mBlockIdxOffsets tensors, which can be prepared on host.

cc @drisspg @v0i0

NOT INTENDED FOR THIS PR:

  • bwd support

@SeanLi-OI
Copy link
Copy Markdown

Hi there, @reubenconducts ! Thank you so much for your draft.

Since I also need this feature eagerly, I tried to continue development based on your branch which fixes some grammar issue (SeanLi-OI@7ccfc5e). Though it can run, but returns wrong results when batch_size > 1.

I completely understand you may be busy with other priorities. If you have a moment, I’d be truly grateful for any guidance:
Are my modifications heading in the right direction? Or do you plan to continue updating this PR?

@reubenconducts
Copy link
Copy Markdown
Contributor Author

@SeanLi-OI Yes, I will be continuing this, but not until next week.

@reubenconducts reubenconducts force-pushed the rstern/varlen-blocksparsity branch from a4f3021 to bc15c46 Compare February 15, 2026 19:53
@reubenconducts reubenconducts force-pushed the rstern/varlen-blocksparsity branch from bc15c46 to 03f2f92 Compare February 25, 2026 19:03
@reubenconducts reubenconducts changed the title [WIP] varlen blocksparsity [CuTe,Flex] varlen blocksparsity Feb 25, 2026
@reubenconducts reubenconducts marked this pull request as ready for review February 25, 2026 19:05
@reubenconducts reubenconducts marked this pull request as draft February 25, 2026 19:17
@reubenconducts reubenconducts force-pushed the rstern/varlen-blocksparsity branch from 03f2f92 to 04d3016 Compare February 26, 2026 16:42
@reubenconducts reubenconducts marked this pull request as ready for review February 26, 2026 16:42
Comment thread flash_attn/cute/block_sparse_utils.py Outdated
Comment thread flash_attn/cute/block_sparse_utils.py
@reubenconducts reubenconducts force-pushed the rstern/varlen-blocksparsity branch from ab9bbeb to 9c4370b Compare March 16, 2026 17:41
@wqwqazwsxedc
Copy link
Copy Markdown

Hi @drisspg, @reubenconducts, just checking in on this PR. Are there any remaining blockers or changes needed before it can be merged?
I think this feature could be useful to implement something akin to PrefixGrouper for varlen sequences, which would be useful for RL training. I'd be glad to help if necessary.

@reubenconducts reubenconducts force-pushed the rstern/varlen-blocksparsity branch from 6b85348 to 6f736e9 Compare April 10, 2026 17:39
@reubenconducts reubenconducts force-pushed the rstern/varlen-blocksparsity branch 2 times, most recently from 0c21cc0 to 19fbcd1 Compare April 29, 2026 19:09
normalized_block_sparse_tensors,
block_sparse_broadcast_pattern,
q_subtile_factor,
) = normalize_block_sparse_config(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we jam somemore of this into the normalize to keep the interface cleaner?

Comment thread flash_attn/cute/interface.py Outdated
score_mod_bwd: Optional[Callable] = None,
mask_mod: Optional[Callable] = None,
block_sparse_tensors: Optional[list] = None,
cu_total_m_blocks: Optional[torch.Tensor] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this would be attributes of the BlcokSparseTensor tuple?

@drisspg
Copy link
Copy Markdown
Collaborator

drisspg commented Apr 29, 2026

total_m_blocks = sum_b ceil(seqlen_q[b] / tile_m) whats the flow you envision for people producing this? Trying to minimize HtD / DtoH syncs I guess this could live in the dataloader right? Or you can overlaunch: batch_size * ceil(max_seqlen_q / tile_m) and then mask

Comment thread flash_attn/cute/block_sparse_utils.py Outdated
blocksparse_tensors: BlockSparseTensors,
seqlen_info: SeqlenInfoQK,
) -> Tuple[cutlass.Int32, cute.Tensor, cutlass.Int32, Optional[cute.Tensor]]:
mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: I think just having a varlen speicifc sub func could make this even cleaner.. I hate cutedsl huge if/else constexpr blocks

Comment thread flash_attn/cute/interface.py Outdated
seqused_k_tensor,
learnable_sink_tensor,
cu_total_m_blocks_tensor,
cu_total_n_blocks_tensor,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: cu_total_n_blocks is a bit confusing to me. The name sounds like a cumulative sum of N_b, but semantically it is a cumulative sum of M_b * N_b the offset into the packed *_block_idx tensors. Maybe something like cu_block_idx_offsets ?

@drisspg
Copy link
Copy Markdown
Collaborator

drisspg commented Apr 29, 2026

Overall, sorry this took so long to review. I'm pretty on board. I did a bunch of codexing / experiments on the pr, and I think that we could build transposy-like metadata for the backward decently efficiently. I still don't have a full sense of all the different var-length sub-features this would need to interact with, but I think this provides a decent amount of value and is actually pretty well contained

We should spell out somewhere very explicitly what the mask_mod semantics are for this feature

@reubenconducts reubenconducts force-pushed the rstern/varlen-blocksparsity branch from 5816ec0 to 1b8a0fa Compare May 2, 2026 12:37
Comment thread flash_attn/cute/block_sparsity.py Outdated
f"Varlen block sparsity requires sparse_block_size[0]={base_m_block} "
f"(= q_stage * tile_m); got {sparse_block_size_q}."
)
total_m_blocks = tensors.cu_total_m_blocks[-1].item()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these items expected?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think they are / need to be also rebase :)

Squashed forward-path varlen support: extends BlockSparseTensors usage to
[num_heads, total_m_blocks] / [num_heads, total_n_blocks] layouts, threads
cu_seqlens / cu_total_m_blocks / cu_total_n_blocks through the kernel and
compute_block_sparsity, and routes through get_curr_blocksparse_tensors and
get_total_block_count for shape-aware indexing.
…cks/cu_block_idx_offsets into BlockSparseTensors instead of threading them as standalone parameters; drop the two <tensor>[-1].item() syncs in normalize_block_sparse_config
@reubenconducts reubenconducts force-pushed the rstern/varlen-blocksparsity branch from 1b8a0fa to a8b8251 Compare May 6, 2026 03:52
@drisspg
Copy link
Copy Markdown
Collaborator

drisspg commented May 7, 2026

let me know when ready to merge

@drisspg drisspg merged commit 25b451e into Dao-AILab:main May 7, 2026
horakka5 added a commit to horakka5/flash-attention that referenced this pull request May 28, 2026
…d, use shared utility

Addresses @reubenconducts's May 22 review comments on Dao-AILab#2520:

1. Rename mTileCumsum -> mCuTotalMBlocks across all 9 kernels + scheduler
   + interface for consistency with the convention introduced in Dao-AILab#2224
   (already in main; used by blocksparse, and by Dao-AILab#2559).

2. Drop num_head (and the related pack_gqa arch-conditional remap) from
   the host cumsum. Per-batch cumsum is now pure m_blocks; the scheduler
   handles num_head separately. Removes the SM80/SM120 vs SM90/100/110/MLA
   branching that previously mirrored the pack_gqa_layout reshape behavior.

3. Replace the inline binary search in _varlen_coord_map's cumsum-on
   branch with a call to utils.get_batch_from_cu_tensor (the shared
   utility from Dao-AILab#2556). The existing snap-to-group-boundary + warp-scan
   structure is preserved — the cumsum serves as a hint to skip ahead,
   and the warp-scan refines to the exact batch using _get_num_m_blocks
   (which already handles pack_gqa, q_stage, cluster, etc.). This matches
   the scheduler-side approach in Dao-AILab#2559.

The pack_gqa seqlen multiplier stays in _compute_cu_total_m_blocks so
that per-batch m_block counts match the kernel's _get_num_m_blocks
formula — the snap is forward-only, so under-estimating per-batch counts
is safe but over-estimating (which dropping the multiplier would cause
when pack_gqa is on) would land the snap past the correct batch and the
warp-scan couldn't recover.

Verified on SM100:
- 72 new tests (test_varlen_scheduler_binary_search_correctness{,_bwd}): pass
- existing test_varlen (B=20 slice, 576 cases): pass
- existing test_flash_attn_mla_absorbed_varlen (480 cases): pass
reubenconducts added a commit to reubenconducts/flash-attention that referenced this pull request Jun 2, 2026
* varlen block-sparsity for forward

Squashed forward-path varlen support: extends BlockSparseTensors usage to
[num_heads, total_m_blocks] / [num_heads, total_n_blocks] layouts, threads
cu_seqlens / cu_total_m_blocks / cu_total_n_blocks through the kernel and
compute_block_sparsity, and routes through get_curr_blocksparse_tensors and
get_total_block_count for shape-aware indexing.

* rename cu_total_n_blocks to cu_block_idx_offsets; move cu_total_m_blocks/cu_block_idx_offsets into BlockSparseTensors instead of threading them as standalone parameters; drop the two <tensor>[-1].item() syncs in normalize_block_sparse_config
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants