[CuTe,Flex] varlen blocksparsity#2224
Conversation
|
Hi there, @reubenconducts ! Thank you so much for your draft. Since I also need this feature eagerly, I tried to continue development based on your branch which fixes some grammar issue (SeanLi-OI@7ccfc5e). Though it can run, but returns wrong results when batch_size > 1. I completely understand you may be busy with other priorities. If you have a moment, I’d be truly grateful for any guidance: |
|
@SeanLi-OI Yes, I will be continuing this, but not until next week. |
a4f3021 to
bc15c46
Compare
bc15c46 to
03f2f92
Compare
03f2f92 to
04d3016
Compare
ab9bbeb to
9c4370b
Compare
|
Hi @drisspg, @reubenconducts, just checking in on this PR. Are there any remaining blockers or changes needed before it can be merged? |
6b85348 to
6f736e9
Compare
0c21cc0 to
19fbcd1
Compare
| normalized_block_sparse_tensors, | ||
| block_sparse_broadcast_pattern, | ||
| q_subtile_factor, | ||
| ) = normalize_block_sparse_config( |
There was a problem hiding this comment.
nit: could we jam somemore of this into the normalize to keep the interface cleaner?
| score_mod_bwd: Optional[Callable] = None, | ||
| mask_mod: Optional[Callable] = None, | ||
| block_sparse_tensors: Optional[list] = None, | ||
| cu_total_m_blocks: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
I thought this would be attributes of the BlcokSparseTensor tuple?
|
total_m_blocks = sum_b ceil(seqlen_q[b] / tile_m) whats the flow you envision for people producing this? Trying to minimize HtD / DtoH syncs I guess this could live in the dataloader right? Or you can overlaunch: batch_size * ceil(max_seqlen_q / tile_m) and then mask |
| blocksparse_tensors: BlockSparseTensors, | ||
| seqlen_info: SeqlenInfoQK, | ||
| ) -> Tuple[cutlass.Int32, cute.Tensor, cutlass.Int32, Optional[cute.Tensor]]: | ||
| mask_block_cnt, mask_block_idx, full_block_cnt, full_block_idx = blocksparse_tensors |
There was a problem hiding this comment.
super nit: I think just having a varlen speicifc sub func could make this even cleaner.. I hate cutedsl huge if/else constexpr blocks
| seqused_k_tensor, | ||
| learnable_sink_tensor, | ||
| cu_total_m_blocks_tensor, | ||
| cu_total_n_blocks_tensor, |
There was a problem hiding this comment.
Nit: cu_total_n_blocks is a bit confusing to me. The name sounds like a cumulative sum of N_b, but semantically it is a cumulative sum of M_b * N_b the offset into the packed *_block_idx tensors. Maybe something like cu_block_idx_offsets ?
|
Overall, sorry this took so long to review. I'm pretty on board. I did a bunch of codexing / experiments on the pr, and I think that we could build transposy-like metadata for the backward decently efficiently. I still don't have a full sense of all the different var-length sub-features this would need to interact with, but I think this provides a decent amount of value and is actually pretty well contained We should spell out somewhere very explicitly what the mask_mod semantics are for this feature |
5816ec0 to
1b8a0fa
Compare
| f"Varlen block sparsity requires sparse_block_size[0]={base_m_block} " | ||
| f"(= q_stage * tile_m); got {sparse_block_size_q}." | ||
| ) | ||
| total_m_blocks = tensors.cu_total_m_blocks[-1].item() |
There was a problem hiding this comment.
I dont think they are / need to be also rebase :)
Squashed forward-path varlen support: extends BlockSparseTensors usage to [num_heads, total_m_blocks] / [num_heads, total_n_blocks] layouts, threads cu_seqlens / cu_total_m_blocks / cu_total_n_blocks through the kernel and compute_block_sparsity, and routes through get_curr_blocksparse_tensors and get_total_block_count for shape-aware indexing.
…cks/cu_block_idx_offsets into BlockSparseTensors instead of threading them as standalone parameters; drop the two <tensor>[-1].item() syncs in normalize_block_sparse_config
1b8a0fa to
a8b8251
Compare
|
let me know when ready to merge |
…d, use shared utility Addresses @reubenconducts's May 22 review comments on Dao-AILab#2520: 1. Rename mTileCumsum -> mCuTotalMBlocks across all 9 kernels + scheduler + interface for consistency with the convention introduced in Dao-AILab#2224 (already in main; used by blocksparse, and by Dao-AILab#2559). 2. Drop num_head (and the related pack_gqa arch-conditional remap) from the host cumsum. Per-batch cumsum is now pure m_blocks; the scheduler handles num_head separately. Removes the SM80/SM120 vs SM90/100/110/MLA branching that previously mirrored the pack_gqa_layout reshape behavior. 3. Replace the inline binary search in _varlen_coord_map's cumsum-on branch with a call to utils.get_batch_from_cu_tensor (the shared utility from Dao-AILab#2556). The existing snap-to-group-boundary + warp-scan structure is preserved — the cumsum serves as a hint to skip ahead, and the warp-scan refines to the exact batch using _get_num_m_blocks (which already handles pack_gqa, q_stage, cluster, etc.). This matches the scheduler-side approach in Dao-AILab#2559. The pack_gqa seqlen multiplier stays in _compute_cu_total_m_blocks so that per-batch m_block counts match the kernel's _get_num_m_blocks formula — the snap is forward-only, so under-estimating per-batch counts is safe but over-estimating (which dropping the multiplier would cause when pack_gqa is on) would land the snap past the correct batch and the warp-scan couldn't recover. Verified on SM100: - 72 new tests (test_varlen_scheduler_binary_search_correctness{,_bwd}): pass - existing test_varlen (B=20 slice, 576 cases): pass - existing test_flash_attn_mla_absorbed_varlen (480 cases): pass
* varlen block-sparsity for forward Squashed forward-path varlen support: extends BlockSparseTensors usage to [num_heads, total_m_blocks] / [num_heads, total_n_blocks] layouts, threads cu_seqlens / cu_total_m_blocks / cu_total_n_blocks through the kernel and compute_block_sparsity, and routes through get_curr_blocksparse_tensors and get_total_block_count for shape-aware indexing. * rename cu_total_n_blocks to cu_block_idx_offsets; move cu_total_m_blocks/cu_block_idx_offsets into BlockSparseTensors instead of threading them as standalone parameters; drop the two <tensor>[-1].item() syncs in normalize_block_sparse_config
This PR extends blocksparsity to the variable sequence length case. Whereas in batched blocksparsity the metadata tensors take the shapes
in varlen blocksparsity, we pack our metadata tensors to take the shapes
where
total_m_blocksis the sum of allmblocks per head (equiv. number of work tiles per head) andtotal_n_blocksis the total of allnblocks potentially processed per head across all sequences in the batch. For example, consider a varlen batch with sequences contained inseqlens_qandseqlens_k. At batch indexb, we letand define
total_m_blocks = sum_{b \in B} num_m(b) total_n_blocks = sum_{b \in B} num_m(b) * num_n(b)To properly index into the blocksparsity tensors, we use auxiliary
mCuTotalMBlocksandmBlockIdxOffsetstensors, which can be prepared on host.cc @drisspg @v0i0
NOT INTENDED FOR THIS PR: