Skip to content

fix varlen w/ paging split kv bug#2550

Merged
drisspg merged 1 commit into
Dao-AILab:mainfrom
liangel-02:varlen_splitkv_paging
May 12, 2026
Merged

fix varlen w/ paging split kv bug#2550
drisspg merged 1 commit into
Dao-AILab:mainfrom
liangel-02:varlen_splitkv_paging

Conversation

@liangel-02
Copy link
Copy Markdown
Contributor

@liangel-02 liangel-02 commented May 9, 2026

Paged KV always runs with split KV kernel because mha_varlen_fwd calls run_mha_fwd(..., force_split_kernel = paged KV) but because set_params_splitkv was never called before #2448, num_splits stayed 0 → the splitting heuristic never ran so even though the split KV kernel always runs with paged KV, it still did only a single pass.

Before #2448, set_params_splitkv (which sets the oaccum + runs the heuristic) only ran when seqlen_ngroup_swapped = True, which also reshapes the varlen input to have a batch dimension.

Today, the condition for calling set_params_splitkv is seqlen_ngroup_swapped = True OR paged because we wanted to give users the option to pass in num_splits=1 as a way to force batch invariance between paged (split kernel) vs not paged (regular kernel), but then this caused the heuristic to run, potentially setting num_splits > 1. So then now the split KV kernel could actually run with real splitting which causes the varlen o_batch_stride bug (see #2542)

This PR fixes the bug by restoring original behaviour. we move this params.num_splits line out (since that's the only reason why we needed to call set_params_splitkv), and error if user provides anything != 1.

#2542 will be landed as a follow up PR to truly enable varlen split KV.

@liangel-02 liangel-02 force-pushed the varlen_splitkv_paging branch from c746642 to d6c4c89 Compare May 10, 2026 22:21
@liangel-02 liangel-02 marked this pull request as ready for review May 10, 2026 22:23
Comment thread csrc/flash_attn/flash_api.cpp
Comment thread tests/test_flash_attn.py Outdated
@liangel-02 liangel-02 force-pushed the varlen_splitkv_paging branch from d6c4c89 to eb61bc2 Compare May 11, 2026 22:08
Comment thread tests/test_flash_attn.py
@liangel-02 liangel-02 force-pushed the varlen_splitkv_paging branch from eb61bc2 to 88c8bef Compare May 12, 2026 20:50
@drisspg drisspg merged commit 9bad4be into Dao-AILab:main May 12, 2026
reubenconducts pushed a commit to reubenconducts/flash-attention that referenced this pull request Jun 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants