Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
0ae26dd
add dynamicpersistentvarlenscheduler to flash_fwd_sm100 and prepare k…
reubenconducts May 5, 2026
e601530
mild refactor to tile scheduler protocol, guard num_m_blocks_ptr for …
reubenconducts May 6, 2026
2c14ab7
rename varlen_batch_idx -> virtual_batch_idx, because it is relevant …
reubenconducts May 7, 2026
79aea14
split out VarlenSchedulerBase to share code between SingleTile and Dy…
reubenconducts May 12, 2026
e4071c0
add benchmark script for varlen dynamic persistent scheduler
reubenconducts May 13, 2026
70efc72
minor clean up
reubenconducts May 13, 2026
f5d0abd
updates to has_work logic, tile scheduler selection, and varlen test …
reubenconducts May 13, 2026
2e9bc0f
fix tile scheduler dispatch logic
reubenconducts May 13, 2026
e76dde1
integrate binary batch search for single tile varlen
reubenconducts May 25, 2026
83328ab
refactor tile scheduler for compositionality
reubenconducts May 25, 2026
5a69571
work PR 2520 into interface and kernels
reubenconducts May 26, 2026
9c68299
Merge branch 'main' into dynamic_metadata
reubenconducts May 26, 2026
5b992d5
fix linter errors
reubenconducts May 26, 2026
e2f42e8
Merge branch 'dynamic_metadata' of https://github.com/reubenconducts/…
reubenconducts May 26, 2026
5391f12
wip: modify scheduler metadata public api
reubenconducts Jun 1, 2026
97b9640
[2026-06-01] pull origin main
reubenconducts Jun 1, 2026
3416514
clean up scheduler metadata API; add docstrings; split out _get_fwd_c…
reubenconducts Jun 1, 2026
3f1f217
address driss' comments
reubenconducts Jun 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
517 changes: 517 additions & 0 deletions benchmarks/benchmark_varlen_sched.py

Large diffs are not rendered by default.

23 changes: 18 additions & 5 deletions flash_attn/cute/block_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@ class BlockInfo:
window_size_left: Optional[Int32] = None
window_size_right: Optional[Int32] = None
qhead_per_kvhead_packgqa: cutlass.Constexpr[int] = 1
num_splits: Int32 = 1
num_splits_dynamic_ptr: Optional[cute.Tensor] = None
num_n_blocks_per_split: Optional[cutlass.Constexpr[Int32]] = None

@cute.jit
def get_n_block_min_max(
self,
seqlen_info: SeqlenInfoQK,
m_block: Int32,
split_idx: Int32 = 0,
batch_idx: Int32 = 0,
num_splits: Int32 = 1,
) -> Tuple[Int32, Int32]:
n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n)
Expand All @@ -45,11 +49,20 @@ def get_n_block_min_max(
n_idx_left = n_idx - self.window_size_left
n_block_min = cutlass.max(n_idx_left // self.tile_n, 0)
if cutlass.const_expr(self.is_split_kv):
num_n_blocks_per_split = (
Int32(0)
if n_block_max <= n_block_min
else (n_block_max - n_block_min + num_splits - 1) // num_splits
)
if const_expr(self.num_splits_dynamic_ptr is not None):
# Unpack num_splits from top 16 bits of split_idx (packed by scheduler)
num_splits = split_idx >> 16
split_idx = split_idx & 0xFFFF
else:
num_splits = self.num_splits
if const_expr(self.num_n_blocks_per_split is not None):
num_n_blocks_per_split = self.num_n_blocks_per_split
else:
num_n_blocks_per_split = (
Int32(0)
if n_block_max <= n_block_min
else (n_block_max - n_block_min + num_splits - 1) // num_splits
)
n_block_min = n_block_min + split_idx * num_n_blocks_per_split
n_block_max = cutlass.min(n_block_min + num_n_blocks_per_split, n_block_max)
return n_block_min, n_block_max
Expand Down
2 changes: 2 additions & 0 deletions flash_attn/cute/flash_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def __call__(
mdV_semaphore: Optional[cute.Tensor] = None,
aux_tensors: Optional[list] = None,
blocksparse_tensors: Optional[BlockSparseTensors] = None,
mCuTotalMBlocks: Optional[cute.Tensor] = None,
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
stream: cuda.CUstream = None,
):
Expand Down Expand Up @@ -429,6 +430,7 @@ def __call__(
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1,
mCuSeqlensQ=mCuSeqlensK,
mSeqUsedQ=mSeqUsedK,
cu_total_m_blocks_ptr=mCuTotalMBlocks,
)

tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
Expand Down
2 changes: 2 additions & 0 deletions flash_attn/cute/flash_bwd_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def __call__(
scale: cutlass.Float32,
mCuSeqlensQ: Optional[cute.Tensor],
mSeqUsedQ: Optional[cute.Tensor],
mCuTotalMBlocks: Optional[cute.Tensor] = None,
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
stream: cuda.CUstream = None,
):
Expand Down Expand Up @@ -258,6 +259,7 @@ def __call__(
tile_shape_mn=(self.tile_m, 1),
mCuSeqlensQ=mCuSeqlensQ,
mSeqUsedQ=mSeqUsedQ,
cu_total_m_blocks_ptr=mCuTotalMBlocks,
)

tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
Expand Down
2 changes: 2 additions & 0 deletions flash_attn/cute/flash_bwd_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __call__(
mCuSeqlensQ: Optional[cute.Tensor], # (batch + 1,)
mSeqUsedQ: Optional[cute.Tensor], # (batch,)
mdLSE: Optional[cute.Tensor], # (batch, nheads, seqlen) or (nheads, total_q)
mCuTotalMBlocks: Optional[cute.Tensor] = None,
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
stream: cuda.CUstream = None,
):
Expand Down Expand Up @@ -193,6 +194,7 @@ def __call__(
tile_shape_mn=(self.tile_m, 1),
mCuSeqlensQ=mCuSeqlensQ,
mSeqUsedQ=mSeqUsedQ,
cu_total_m_blocks_ptr=mCuTotalMBlocks,
)

tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
Expand Down
2 changes: 2 additions & 0 deletions flash_attn/cute/flash_bwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ def __call__(
aux_tensors: Optional[list] = None,
# Block-sparse tensors (Q direction - for iterating m_blocks per n_block):
blocksparse_tensors: Optional[BlockSparseTensors] = None,
mCuTotalMBlocks: Optional[cute.Tensor] = None,
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
stream: cuda.CUstream = None,
):
Expand Down Expand Up @@ -732,6 +733,7 @@ def __call__(
qhead_per_kvhead_packgqa=1, # pack_gqa disabled for bwd
element_size=self.k_dtype.width // 8,
is_persistent=self.is_persistent, # persistent mode not tested
cu_total_m_blocks_ptr=mCuTotalMBlocks,
lpt=self.spt,
head_swizzle=self.deterministic,
)
Expand Down
2 changes: 2 additions & 0 deletions flash_attn/cute/flash_bwd_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def __call__(
mdV_semaphore: Optional[cute.Tensor] = None,
aux_tensors: Optional[list] = None,
blocksparse_tensors: Optional[BlockSparseTensors] = None,
mCuTotalMBlocks: Optional[cute.Tensor] = None,
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
stream: cuda.CUstream = None,
):
Expand Down Expand Up @@ -536,6 +537,7 @@ def _qkv_transpose(t):
is_persistent=False,
lpt=self.spt,
head_swizzle=self.deterministic,
cu_total_m_blocks_ptr=mCuTotalMBlocks,
)

tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
Expand Down
4 changes: 4 additions & 0 deletions flash_attn/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,8 @@ def __call__(
learnable_sink: Optional[cute.Tensor] = None,
blocksparse_tensors: Optional[BlockSparseTensors] = None,
aux_tensors=None,
mCuTotalMBlocks: Optional[cute.Tensor] = None,
Copy link
Copy Markdown
Collaborator

@drisspg drisspg Jun 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we turn this into a namedTuple for keepign them colocated?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit tricky. This is what SchedulerMetadataTensorsTorch is for, but this PR doesn't add general scheduler metadata to the non-sm100 kernels.

mCuTotalSplitsMBlocks: Optional[cute.Tensor] = None,
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
stream: cuda.CUstream = None,
):
Expand Down Expand Up @@ -698,6 +700,8 @@ def __call__(
qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
mCuSeqlensQ=mCuSeqlensQ,
mSeqUsedQ=mSeqUsedQ,
cu_total_m_blocks_ptr=mCuTotalMBlocks,
cu_total_splits_m_blocks_ptr=mCuTotalSplitsMBlocks,
)
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
grid_dim = TileScheduler.get_grid_shape(tile_sched_params)
Expand Down
15 changes: 8 additions & 7 deletions flash_attn/cute/flash_fwd_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __call__(
cu_seqlens: Optional[cute.Tensor] = None,
seqused: Optional[cute.Tensor] = None,
num_splits_dynamic_ptr: Optional[cute.Tensor] = None,
varlen_batch_idx: Optional[cute.Tensor] = None,
virtual_batch_idx: Optional[cute.Tensor] = None,
semaphore_to_reset: Optional[cute.Tensor] = None,
# Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI).
stream: cuda.CUstream = None,
Expand Down Expand Up @@ -301,7 +301,7 @@ class SharedStorage:
cu_seqlens,
seqused,
num_splits_dynamic_ptr,
varlen_batch_idx,
virtual_batch_idx,
semaphore_to_reset,
SharedStorage,
self.smem_layout_lse,
Expand Down Expand Up @@ -330,7 +330,7 @@ def kernel(
cu_seqlens: Optional[cute.Tensor],
seqused: Optional[cute.Tensor],
num_splits_dynamic_ptr: Optional[cute.Tensor],
varlen_batch_idx: Optional[cute.Tensor],
virtual_batch_idx: Optional[cute.Tensor],
semaphore_to_reset: Optional[cute.Tensor],
SharedStorage: cutlass.Constexpr,
smem_layout_lse: cute.Layout | cute.ComposedLayout,
Expand All @@ -349,8 +349,8 @@ def kernel(

# Map virtual batch index to real batch index (for persistent tile schedulers)
batch_idx = (
varlen_batch_idx[maybe_virtual_batch]
if const_expr(varlen_batch_idx is not None)
virtual_batch_idx[maybe_virtual_batch]
if const_expr(virtual_batch_idx is not None)
else maybe_virtual_batch
)

Expand Down Expand Up @@ -394,8 +394,9 @@ def kernel(
num_head = mO_partial.shape[3]
max_idx = seqlen * num_head

# Early exit for single split if dynamic
if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 1) and (
# TODO: early exit for single split if dynamic — for now always merge so the
# num_splits_dynamic == 1 case still writes mO from mO_partial[0].
if (const_expr(num_splits_dynamic_ptr is None) or num_splits > 0) and (
const_expr(not varlen) or m_block * self.tile_m < max_idx
):
# Wait for dependent grids (e.g., the main attention kernel that produces O_partial/LSE_partial)
Expand Down
4 changes: 2 additions & 2 deletions flash_attn/cute/flash_fwd_mla_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import flash_attn.cute.blackwell_helpers as fa_sm100_utils
from flash_attn.cute.softmax import SoftmaxSm100
from flash_attn.cute.tile_scheduler import (
ClcState,
SchedulerState,
SchedulingMode,
TileSchedulerArguments,
TileSchedulerProtocol,
Expand Down Expand Up @@ -993,7 +993,7 @@ def make_pipeline(cls, mbar_ptr, num_stages, producer, consumer, tx_count=None):
clc_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, cute.arch.WARP_SIZE * num_clc_consumer_warps
)
clc = ClcState.create(
clc = SchedulerState.create_clc(
hw_scheduler=ClcDynamicPersistentTileScheduler.create(
self.tile_scheduler_cls.clc_problem_shape(tile_sched_params),
cute.arch.block_idx(),
Expand Down
Loading
Loading