Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
8d393a0
rebase to main
reubenconducts Nov 18, 2025
888ad96
varlen support for score mod
reubenconducts Nov 21, 2025
ac4d72a
interface change for varlen score mod
reubenconducts Nov 21, 2025
d20b9a3
implement varlen support for score mod
reubenconducts Nov 24, 2025
0511659
varlen score mod working; updated tests
reubenconducts Nov 25, 2025
366f340
modify varlen score mod to use fastdiv_mods updated per sequence
reubenconducts Nov 25, 2025
42128fa
updated test suite
reubenconducts Nov 25, 2025
c08567a
current working state of varlen score mod
reubenconducts Nov 26, 2025
73d6310
refactor varlen score mod tests
reubenconducts Nov 27, 2025
18e54fa
fix to transpose
reubenconducts Nov 27, 2025
119ba5a
refactor varlen score mod tests; fix bug; clean up varlen score mod a…
reubenconducts Dec 2, 2025
24587de
refactor test_score_mod.py to use external score mod definition file
reubenconducts Dec 2, 2025
d2c078c
update flash_fwd.py for varlen score mod
reubenconducts Dec 2, 2025
4953dc4
sm90 varlen score mod working; test revisions
reubenconducts Dec 2, 2025
2cc5041
enable packgqa for varlen score mod; set up fastdiv_mod recomputation
reubenconducts Dec 3, 2025
ad84a01
update flash_fwd_sm100.py for recomputing fastdiv_mods & format varle…
reubenconducts Dec 3, 2025
3123d6f
Overwrite pack_gqa.py, tile_scheduler.py, and test_flash_attn.py with…
reubenconducts Dec 3, 2025
92f0990
rebase to main
reubenconducts Dec 3, 2025
bebc8ea
fix test rebase artifacts
reubenconducts Dec 3, 2025
633b7f8
fix floor_if_packed redundancy
reubenconducts Dec 3, 2025
263a9e6
correct sm90 divmods mismatch
reubenconducts Dec 3, 2025
23f8c61
revert test_flash_attn to main
reubenconducts Dec 3, 2025
dff4eab
add varlen score mod benchmark script
reubenconducts Dec 3, 2025
83201a0
packgqa for varlen (independent of score mod)
reubenconducts Dec 3, 2025
6e43ac0
rm benchmark from PR
reubenconducts Dec 3, 2025
cdba9f1
move score mod arg wrapping to utils.py
reubenconducts Dec 4, 2025
73c922b
format with ruff
reubenconducts Dec 4, 2025
5b114f4
major refactor: change score_mod signature to accept seqlen_info and …
reubenconducts Dec 4, 2025
9858bf4
reinstate varlen packgqa exclusion checks
reubenconducts Dec 5, 2025
8cdd54a
move fastdiv_mods recomputation out of apply_score_mod in prep for va…
reubenconducts Dec 5, 2025
005fb85
remove duplicate fastdiv_mod recomputation
reubenconducts Dec 5, 2025
ef3b967
[Fix] fastdiv_mods for paged attn and seqused_*
reubenconducts Dec 12, 2025
64d6c31
clean up PR; fix paged_kv varlen for sm90
reubenconducts Dec 15, 2025
241353b
update to varlen score mod test script (paged kv)
reubenconducts Dec 15, 2025
ef4eb95
remove premature seqlen arguments from sm90 apply_mask_mod
reubenconducts Dec 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions flash_attn/cute/flash_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,7 @@ def compute_one_n_block(
batch_idx: cutlass.Int32,
head_idx: cutlass.Int32,
m_block: cutlass.Int32,
seqlen: SeqlenInfoQK,
aux_tensors=None,
fastdiv_mods=None,
mask_fn: Optional[Callable] = None,
Expand Down Expand Up @@ -1105,6 +1106,7 @@ def load_V_next():
m_block,
acc_S,
n_block,
seqlen,
softmax_scale=softmax.softmax_scale,
aux_tensors=aux_tensors,
fastdiv_mods=fastdiv_mods,
Expand Down Expand Up @@ -1502,7 +1504,11 @@ def __call__(
seqlen_q = cute.size(mQ.shape[0]) // (
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
)
seqlen_k = cute.size(mK.shape[0])
seqlen_k = (
cute.size(mK.shape[0])
if const_expr(mPageTable is None)
else mK.shape[0] * mPageTable.shape[1]
)
seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
Expand Down Expand Up @@ -1982,6 +1988,25 @@ def mma(
# shape: (atom_v_m * rest_m)
m_block, head_idx, batch_idx, _ = work_tile.tile_idx
seqlen = SeqlenInfoCls(batch_idx)

# Recompute fastdiv_mods if necessary for varlen with aux_tensors
recompute_fastdiv_mods_q = cutlass.const_expr(
aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)
)
recompute_fastdiv_mods_k = cutlass.const_expr(
aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k)
)
if cutlass.const_expr(fastdiv_mods is not None):
seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
fastdiv_mods = (
seqlen_q_divmod
if not recompute_fastdiv_mods_q
else FastDivmodDivisor(seqlen.seqlen_q),
seqlen_k_divmod
if not recompute_fastdiv_mods_k
else FastDivmodDivisor(seqlen.seqlen_k),
)

mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k)
mask_fn = partial(
mask.apply_mask,
Expand Down Expand Up @@ -2046,6 +2071,7 @@ def mma(
if const_expr(self.intra_wg_overlap):
kv_consumer_state = process_first_half_block(
n_block=n_block_max - 1,
seqlen=seqlen,
kv_consumer_state=kv_consumer_state,
mask_fn=partial(mask_fn, mask_mod=self.mask_mod),
score_mod_fn=score_mod_fn,
Expand All @@ -2058,6 +2084,7 @@ def mma(
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=n_block_max - 1,
seqlen=seqlen,
mma_pv_fn=partial(mma_pv_fn, zero_init=True),
is_first_n_block=True,
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=True),
Expand All @@ -2077,6 +2104,7 @@ def mma(
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=n_block_max - 1 - n_tile,
seqlen=seqlen,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
)
Expand All @@ -2091,6 +2119,7 @@ def mma(
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=n_block_max - 1 - n_tile,
seqlen=seqlen,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
)
Expand All @@ -2102,6 +2131,7 @@ def mma(
kv_consumer_state = mma_one_n_block(
kv_consumer_state,
n_block=n_block_max - 1 - n_tile,
seqlen=seqlen,
mma_pv_fn=partial(mma_pv_fn, zero_init=not O_should_accumulate),
mask_fn=partial(mask_fn, mask_mod=self.mask_mod, mask_seqlen=False),
)
Expand Down Expand Up @@ -2195,6 +2225,7 @@ def first_half_block_overlap(
tOrP: cute.Tensor,
smem_copy_params: SimpleNamespace,
softmax: Softmax,
seqlen: SeqlenInfoQK,
mask_fn: Callable = None,
score_mod_fn: Optional[Callable] = None,
is_first_block: bool = False,
Expand All @@ -2207,7 +2238,7 @@ def first_half_block_overlap(

# Apply score modification if present
if const_expr(score_mod_fn is not None):
score_mod_fn(acc_S, n_block=n_block)
score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)

# Apply mask; mask_seqlen always True for first block
# Caveat: if full block further right than mask block, seqlen masking is redundant;
Expand Down Expand Up @@ -2267,6 +2298,7 @@ def mma_one_n_block(
tOrP: cute.Tensor,
smem_copy_params: SimpleNamespace,
softmax: Softmax,
seqlen: SeqlenInfoQK,
score_mod_fn: Optional[Callable] = None,
mask_fn: Optional[Callable] = None,
is_first_n_block: cutlass.Constexpr = False,
Expand All @@ -2281,7 +2313,7 @@ def mma_one_n_block(

# handle score mods and masking
if const_expr(score_mod_fn is not None):
score_mod_fn(acc_S, n_block=n_block)
score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
if const_expr(mask_fn is not None):
mask_fn(acc_S=acc_S, n_block=n_block)

Expand Down Expand Up @@ -2326,6 +2358,7 @@ def mma_one_n_block_intrawg_overlap(
tOrP: cute.Tensor,
smem_copy_params: SimpleNamespace,
softmax: Softmax,
seqlen: SeqlenInfoQK,
score_mod_fn: Optional[Callable] = None,
mask_fn: Optional[Callable] = None,
check_inf: cutlass.Constexpr = True,
Expand All @@ -2345,7 +2378,7 @@ def mma_one_n_block_intrawg_overlap(

# handle score mods and masking
if const_expr(score_mod_fn is not None):
score_mod_fn(acc_S, n_block=n_block)
score_mod_fn(acc_S, n_block=n_block, seqlen=seqlen)
if const_expr(mask_fn is not None):
mask_fn(acc_S=acc_S, n_block=n_block)
# if cute.arch.thread_idx()[0] == 128: cute.print_tensor(utils.make_acc_tensor_mn_view(acc_S))
Expand Down Expand Up @@ -2392,6 +2425,7 @@ def apply_score_mod(
acc_S,
n_block,
softmax_scale,
seqlen,
aux_tensors: Optional[list] = None,
fastdiv_mods=None,
):
Expand All @@ -2411,6 +2445,7 @@ def apply_score_mod(
self.qk_acc_dtype,
aux_tensors,
fastdiv_mods,
seqlen_info=seqlen,
constant_q_idx=None,
qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1,
)
Expand All @@ -2436,4 +2471,5 @@ def warp_scheduler_barrier_arrive(self):
cute.arch.barrier_arrive(
barrier_id=int(NamedBarrierFwd.WarpSchedulerWG1) + next_wg,
number_of_threads=2 * self.num_threads_per_warp_group,
)
)

31 changes: 29 additions & 2 deletions flash_attn/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,11 @@ class SharedStorage:
seqlen_q = cute.size(mQ.shape[0]) // (
self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1
)
seqlen_k = cute.size(mK.shape[0])
seqlen_k = (
cute.size(mK.shape[0])
if const_expr(mPageTable is None)
else mK.shape[0] * mPageTable.shape[1]
)
seqlen_q_divmod = FastDivmodDivisor(seqlen_q)
seqlen_k_divmod = FastDivmodDivisor(seqlen_k)
fastdiv_mods = (seqlen_q_divmod, seqlen_k_divmod)
Expand Down Expand Up @@ -1624,6 +1628,26 @@ def softmax_loop(
head_idx=head_idx,
aux_tensors=aux_tensors,
)

# Recompute fastdiv_mods if necessary
recompute_fastdiv_mods_q = cutlass.const_expr(
aux_tensors is not None and (seqlen.has_cu_seqlens_q or seqlen.has_seqused_q)
)
recompute_fastdiv_mods_k = cutlass.const_expr(
aux_tensors is not None and (seqlen.has_cu_seqlens_k or seqlen.has_seqused_k)
)

if cutlass.const_expr(fastdiv_mods is not None):
seqlen_q_divmod, seqlen_k_divmod = fastdiv_mods
fastdiv_mods = (
seqlen_q_divmod
if not recompute_fastdiv_mods_q
else FastDivmodDivisor(seqlen.seqlen_q),
seqlen_k_divmod
if not recompute_fastdiv_mods_k
else FastDivmodDivisor(seqlen.seqlen_k),
)

mask_mod = self.mask_mod if const_expr(self.mask_mod is not None) else None
mask_fn = partial(
mask.apply_mask_sm100,
Expand Down Expand Up @@ -1874,6 +1898,7 @@ def softmax_step(
m_block,
n_block,
softmax,
seqlen,
aux_tensors,
fastdiv_mods,
)
Expand Down Expand Up @@ -2369,7 +2394,7 @@ def correction_epilogue(
self.check_hdim_v_oob,
self.qhead_per_kvhead,
)

# load acc O from smem to rmem for wider vectorization
tOrO = cute.make_fragment_like(tOsO, self.o_dtype)
cute.autovec_copy(tOsO, tOrO)
Expand Down Expand Up @@ -2637,6 +2662,7 @@ def apply_score_mod(
m_block,
n_block,
softmax,
seqlen: SeqlenInfoQK,
aux_tensors=None,
fastdiv_mods=(None, None),
):
Expand Down Expand Up @@ -2673,6 +2699,7 @@ def apply_score_mod(
self.qk_acc_dtype,
aux_tensors,
fastdiv_mods,
seqlen_info=seqlen,
constant_q_idx=q_idx_logical,
qhead_per_kvhead=self.qhead_per_kvhead if cutlass.const_expr(self.pack_gqa) else 1,
)
18 changes: 11 additions & 7 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _flash_attn_fwd(
...
score_mod: A callable that takes the attention scores and applies a modification.
mask_mod: A callable that takes token position information and selectively masks
block_sparse_tensors: A tuple of tensors used for block sparsity.
block_sparse_tensors: A tuple of tensors used for block sparsity.
return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate
out: Optional pre-allocated output tensor. If None, will be allocated internally.
lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed.
Expand Down Expand Up @@ -294,6 +294,7 @@ def _flash_attn_fwd(
if compute_capability == 9: # TODO: tune block size according to hdim.
if head_dim == head_dim_v == 128 and not causal and not local and not use_block_sparsity:
n_block_size = 192

if compute_capability == 10:
# TODO: fix the varlen case
if (
Expand Down Expand Up @@ -335,7 +336,7 @@ def _flash_attn_fwd(
elif lse is not None:
lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1)
else:
lse_tensor = None
lse_tensor = None

# hash score and mask mods for compile cache
score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False
Expand All @@ -351,11 +352,6 @@ def _flash_attn_fwd(
or seqused_q is not None
or seqused_k is not None
)
if score_mod is not None:
if is_varlen:
raise NotImplementedError(
"score_mod with aux_tensors is not yet supported for varlen sequences. This will be fixed in a future PR."
)

if mask_mod is not None:
if is_varlen:
Expand Down Expand Up @@ -1154,6 +1150,8 @@ def forward(
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
score_mod: Optional[Callable] = None,
aux_tensors: Optional[list] = None,
):
out, lse = _flash_attn_fwd(
q,
Expand All @@ -1172,6 +1170,8 @@ def forward(
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
score_mod=score_mod,
aux_tensors=aux_tensors,
)
ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
ctx.softmax_scale = softmax_scale
Expand Down Expand Up @@ -1261,6 +1261,8 @@ def flash_attn_varlen_func(
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
score_mod: Optional[Callable] = None,
aux_tensors: Optional[list] = None,
):
return FlashAttnVarlenFunc.apply(
q,
Expand All @@ -1279,6 +1281,8 @@ def flash_attn_varlen_func(
num_splits,
pack_gqa,
deterministic,
score_mod,
aux_tensors,
)


Expand Down
13 changes: 12 additions & 1 deletion flash_attn/cute/seqlen_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class SeqlenInfoQK:
seqlen_k: cutlass.Int32
has_cu_seqlens_q: cutlass.Constexpr[bool]
has_cu_seqlens_k: cutlass.Constexpr[bool]
has_seqused_q: cutlass.Constexpr[bool]
has_seqused_k: cutlass.Constexpr[bool]

@staticmethod
def create(
Expand Down Expand Up @@ -73,8 +75,17 @@ def create(
)
has_cu_seqlens_q: int = mCuSeqlensQ is not None
has_cu_seqlens_k: int = mCuSeqlensK is not None
has_seqused_q: int = mSeqUsedQ is not None
has_seqused_k: int = mSeqUsedK is not None
return SeqlenInfoQK(
offset_q, offset_k, seqlen_q, seqlen_k, has_cu_seqlens_q, has_cu_seqlens_k
offset_q,
offset_k,
seqlen_q,
seqlen_k,
has_cu_seqlens_q,
has_cu_seqlens_k,
has_seqused_q,
has_seqused_k,
)

def offset_batch_Q(self, mQ: cute.Tensor, batch_idx: Int32, dim: int) -> cute.Tensor:
Expand Down
Loading