Skip to content

split out varlen batch search into utils#2556

Merged
jayhshah merged 2 commits into
Dao-AILab:mainfrom
reubenconducts:rstern/search-util
May 14, 2026
Merged

split out varlen batch search into utils#2556
jayhshah merged 2 commits into
Dao-AILab:mainfrom
reubenconducts:rstern/search-util

Conversation

@reubenconducts
Copy link
Copy Markdown
Contributor

@jayhshah jayhshah merged commit 0409f9a into Dao-AILab:main May 14, 2026
ussoewwin added a commit to ussoewwin/flash-attention that referenced this pull request May 21, 2026
horakka5 added a commit to horakka5/flash-attention that referenced this pull request May 28, 2026
…d, use shared utility

Addresses @reubenconducts's May 22 review comments on Dao-AILab#2520:

1. Rename mTileCumsum -> mCuTotalMBlocks across all 9 kernels + scheduler
   + interface for consistency with the convention introduced in Dao-AILab#2224
   (already in main; used by blocksparse, and by Dao-AILab#2559).

2. Drop num_head (and the related pack_gqa arch-conditional remap) from
   the host cumsum. Per-batch cumsum is now pure m_blocks; the scheduler
   handles num_head separately. Removes the SM80/SM120 vs SM90/100/110/MLA
   branching that previously mirrored the pack_gqa_layout reshape behavior.

3. Replace the inline binary search in _varlen_coord_map's cumsum-on
   branch with a call to utils.get_batch_from_cu_tensor (the shared
   utility from Dao-AILab#2556). The existing snap-to-group-boundary + warp-scan
   structure is preserved — the cumsum serves as a hint to skip ahead,
   and the warp-scan refines to the exact batch using _get_num_m_blocks
   (which already handles pack_gqa, q_stage, cluster, etc.). This matches
   the scheduler-side approach in Dao-AILab#2559.

The pack_gqa seqlen multiplier stays in _compute_cu_total_m_blocks so
that per-batch m_block counts match the kernel's _get_num_m_blocks
formula — the snap is forward-only, so under-estimating per-batch counts
is safe but over-estimating (which dropping the multiplier would cause
when pack_gqa is on) would land the snap past the correct batch and the
warp-scan couldn't recover.

Verified on SM100:
- 72 new tests (test_varlen_scheduler_binary_search_correctness{,_bwd}): pass
- existing test_varlen (B=20 slice, 576 cases): pass
- existing test_flash_attn_mla_absorbed_varlen (480 cases): pass
reubenconducts added a commit to reubenconducts/flash-attention that referenced this pull request Jun 2, 2026
* split out varlen batch search into utils

* more descriptive name
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants