Skip to content

[CuTe,Sm100] Varlen Dynamic Persistent scheduler and metadata#2559

Open
reubenconducts wants to merge 18 commits into
Dao-AILab:mainfrom
reubenconducts:dynamic_metadata
Open

[CuTe,Sm100] Varlen Dynamic Persistent scheduler and metadata#2559
reubenconducts wants to merge 18 commits into
Dao-AILab:mainfrom
reubenconducts:dynamic_metadata

Conversation

@reubenconducts
Copy link
Copy Markdown
Contributor

@reubenconducts reubenconducts commented May 13, 2026

This PR adds the VarlenDynamicPersistentScheduler, seen in FA3, to FA4.

  • Accepts scheduler metadata tensors, which can be prepared in the prepare_scheduler_metadata.py FlashPrepareScheduler kernel:
    • num_m_blocks_ptr to precompute number of m blocks for the lpt sort (unused on SM100)
    • num_splits_dynamic_ptr: holds num_splits for each sequence in a batch, which can vary in mixed workloads
    • virtual_batch_idx_ptr: unused currently; used to permute sequences according to "virtual" batch indices, e.g. when sorting a batch for load balancing (to be utilized in a subsequent PR)
    • num_nheads_in_l2_ptr: used for head swizzle computation in the tile scheduler
    • tile_count_semaphore: zeroed out here, used in main kernel for dynamic persistence
  • Threads metadata through interface with options to disable, call FlashPrepareKernel, or pass in pre-computed cached metadata to amortize cost across layers.
  • Refactors tile_scheduler.py so that dynamic persistent schedulers (CLC and "traditional", like that in this PR) share methods and so that varlen schedulers (SingleTile and DynamicPersistent) reuse common methods.

Additionally, adds in a mCuTotalMBlocks tensor to be used by the SingleTileVarlenScheduler to perform a binary search for the current batch. This is due to @horakka5 in #2520, with an additional mCuTotalSplitsMBlocks pointer used jointly to deduce dynamic num splits, when appropriate.

Comprehensive performance numbers are attached. The regressions seen (e.g. with large batch, small seqlen) are attributable mainly to the fact that the combine kernel does not early-exit for 1-split sequences. We see that "traditional" dynamic persistent outperforms CLC (to which I've also wired up the appropriate metadata tensors) almost always. It's worth noting that in the few tests where prepare kernel latency appears extreme, a large proportion of that latency comes from torch.empty (known issue on Grace CPUs; ostensibly fixed with cuda 13.x, but I ran on 12.x).

varlen_dynamic_scheduler_perf.txt

To-do in follow-up PRs:

  • Currently the combine kernel does not early-exit for 1-split sequences; this explodes in many cases from over-launching zero-work CTAs. Solution is to use a persistent scheduler.
  • Implement lpt batch sort in the FlashPrepareKernel. This is proven to have an enormous impact on mixed prefill/decode workloads (including with CLC, from my preliminary testing). (In addition to this varlen case, batch sort will be helpful for load balancing with blocksparsity.)

@reubenconducts reubenconducts marked this pull request as ready for review May 13, 2026 23:25
horakka5 added a commit to horakka5/flash-attention that referenced this pull request May 28, 2026
…d, use shared utility

Addresses @reubenconducts's May 22 review comments on Dao-AILab#2520:

1. Rename mTileCumsum -> mCuTotalMBlocks across all 9 kernels + scheduler
   + interface for consistency with the convention introduced in Dao-AILab#2224
   (already in main; used by blocksparse, and by Dao-AILab#2559).

2. Drop num_head (and the related pack_gqa arch-conditional remap) from
   the host cumsum. Per-batch cumsum is now pure m_blocks; the scheduler
   handles num_head separately. Removes the SM80/SM120 vs SM90/100/110/MLA
   branching that previously mirrored the pack_gqa_layout reshape behavior.

3. Replace the inline binary search in _varlen_coord_map's cumsum-on
   branch with a call to utils.get_batch_from_cu_tensor (the shared
   utility from Dao-AILab#2556). The existing snap-to-group-boundary + warp-scan
   structure is preserved — the cumsum serves as a hint to skip ahead,
   and the warp-scan refines to the exact batch using _get_num_m_blocks
   (which already handles pack_gqa, q_stage, cluster, etc.). This matches
   the scheduler-side approach in Dao-AILab#2559.

The pack_gqa seqlen multiplier stays in _compute_cu_total_m_blocks so
that per-batch m_block counts match the kernel's _get_num_m_blocks
formula — the snap is forward-only, so under-estimating per-batch counts
is safe but over-estimating (which dropping the multiplier would cause
when pack_gqa is on) would land the snap past the correct batch and the
warp-scan couldn't recover.

Verified on SM100:
- 72 new tests (test_varlen_scheduler_binary_search_correctness{,_bwd}): pass
- existing test_varlen (B=20 slice, 576 cases): pass
- existing test_flash_attn_mla_absorbed_varlen (480 cases): pass
…onfig method; remove cluster_size==1 restriction; guard architectures against unused scheduler metadata args
@reubenconducts reubenconducts force-pushed the dynamic_metadata branch 2 times, most recently from 2e55d75 to 3416514 Compare June 2, 2026 00:20
learnable_sink: Optional[cute.Tensor] = None,
blocksparse_tensors: Optional[BlockSparseTensors] = None,
aux_tensors=None,
mCuTotalMBlocks: Optional[cute.Tensor] = None,
Copy link
Copy Markdown
Collaborator

@drisspg drisspg Jun 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we turn this into a namedTuple for keepign them colocated?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit tricky. This is what SchedulerMetadataTensorsTorch is for, but this PR doesn't add general scheduler metadata to the non-sm100 kernels.

Comment thread flash_attn/cute/flash_fwd_sm100.py Outdated
Comment thread flash_attn/cute/flash_fwd_sm100.py
Comment thread flash_attn/cute/flash_fwd_sm100.py
Comment thread flash_attn/cute/flash_fwd_sm100.py
Comment thread flash_attn/cute/flash_fwd_sm100.py Outdated
Comment thread flash_attn/cute/interface.py Outdated
return out, lse


def _get_scheduler_metadata(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you think some of this should live inteh scheduler file?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's best left in interface, for example to hint that tile size selection needs to agree.

Comment thread flash_attn/cute/tile_scheduler.py Outdated
)
)
nheads_in_l2 = min(nheads_in_l2, self.num_head)
mh_in_l2 = nheads_in_l2 * num_m_blocks
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how different is this swizziling (havent read) should it be shared between LPT or at least parts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants