Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
495ef79
[ROCm Windows] fix build failed (#2519)
Apophis3158 May 6, 2026
c263382
don't disable 2cta due to cuda 12 in bwd (#2543)
reubenconducts May 6, 2026
9192248
[CuTe,Bwd] guard softcap for varlen backward (#2544)
reubenconducts May 7, 2026
25b451e
[CuTe,Flex] varlen blocksparsity (#2224)
reubenconducts May 7, 2026
09aa322
[FA4][hd256] Fix layout of non-contiguous qkv in backward kernel (#2545)
wangsiyu May 7, 2026
ab66326
fix incorrect calculation of n_block global max for bwd deterministic…
jayhshah May 8, 2026
9bad4be
fix varlen w/ paging split kv bug (#2550)
liangel-02 May 12, 2026
484b981
Fix ZeroDivisionError in num_splits_heuristic for empty Q workloads (…
shivam2199 May 13, 2026
9cee95f
[Cute, flex, sm90] fix sm90 flex (#2563)
geruome May 13, 2026
0409f9a
split out varlen batch search into utils (#2556)
reubenconducts May 14, 2026
8a8b2f1
allow for zero length sequences in hdim 256 sm100 kernels (#2568)
jayhshah May 16, 2026
4178915
Enable split-kv for blocksparse tensors (#2536)
drisspg May 19, 2026
0cb66b4
Wrap mask contruction in a function for mask subclassing (#2584)
sryap May 22, 2026
3da76cd
Build Fix: Update abi3 tag to cp310 and minimum python version to 3.…
aw920h May 22, 2026
fe5fb1b
[Cute,Flex,Sm100] vectorized mask_mod (#2261)
reubenconducts May 24, 2026
2d5d5a1
Update architecture assertion for SM 10.x and 11.x (#2572)
ocss884 May 24, 2026
59cf537
Include sm_110 in Blackwell-family arch gating (follow-up to #2572) (…
Johnsonms May 26, 2026
6c4f74f
Use is_family_of for sm_90 and sm_103 arch checks (#2589)
Johnsonms May 26, 2026
59f01d6
Bump AITER submodule to commit 3b2e6f4 (#2540)
sstamenk May 27, 2026
0bbb25a
Clamp kv_stage to avoid SMEM overflow for small head_dims on SM100 (#…
Johnsonms May 28, 2026
4569a8a
Merge remote-tracking branch 'upstream/main' into sync_upstream
MatthewBonanni May 28, 2026
316617c
Fix pre-commit
MatthewBonanni May 30, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
set_params_splitkv(params, batch_size, num_heads, head_size,
max_seqlen_k, max_seqlen_q, head_size_rounded,
p_dropout, num_splits, get_num_sm(get_current_device()), opts);
} else if (paged_KV) {
TORCH_CHECK(num_splits <= 1, "num_splits > 1 is not supported for varlen paged KV");
params.num_splits = num_splits;
}

if (leftpad_k_.has_value()) {
Expand Down
6 changes: 3 additions & 3 deletions flash_attn/cute/block_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,14 @@ def get_n_block_max_for_m_block(
self,
seqlen_info: SeqlenInfoQK,
m_block: Int32,
n_block_global_max: Int32,
) -> Int32:
n_block_max = cute.ceil_div(seqlen_info.seqlen_k, self.tile_n)
if const_expr(self.is_causal or self.window_size_right is not None):
m_idx_max = (m_block + 1) * self.tile_m
if const_expr(self.qhead_per_kvhead_packgqa > 1):
m_idx_max = cute.ceil_div(m_idx_max, self.qhead_per_kvhead_packgqa)
n_idx_right = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q
if const_expr(self.window_size_right is not None):
n_idx_right += self.window_size_right
return min(n_block_global_max, cute.ceil_div(n_idx_right, self.tile_n))
return n_block_global_max
n_block_max = min(n_block_max, cute.ceil_div(n_idx_right, self.tile_n))
return n_block_max
Loading
Loading