fix varlen w/ paging split kv bug#2550
Merged
Merged
Conversation
c746642 to
d6c4c89
Compare
drisspg
reviewed
May 11, 2026
drisspg
reviewed
May 11, 2026
d6c4c89 to
eb61bc2
Compare
drisspg
reviewed
May 12, 2026
eb61bc2 to
88c8bef
Compare
drisspg
approved these changes
May 12, 2026
reubenconducts
pushed a commit
to reubenconducts/flash-attention
that referenced
this pull request
Jun 2, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Paged KV always runs with split KV kernel because mha_varlen_fwd calls run_mha_fwd(..., force_split_kernel = paged KV) but because set_params_splitkv was never called before #2448, num_splits stayed 0 → the splitting heuristic never ran so even though the split KV kernel always runs with paged KV, it still did only a single pass.
Before #2448, set_params_splitkv (which sets the oaccum + runs the heuristic) only ran when seqlen_ngroup_swapped = True, which also reshapes the varlen input to have a batch dimension.
Today, the condition for calling set_params_splitkv is seqlen_ngroup_swapped = True OR paged because we wanted to give users the option to pass in num_splits=1 as a way to force batch invariance between paged (split kernel) vs not paged (regular kernel), but then this caused the heuristic to run, potentially setting num_splits > 1. So then now the split KV kernel could actually run with real splitting which causes the varlen o_batch_stride bug (see #2542)
This PR fixes the bug by restoring original behaviour. we move this params.num_splits line out (since that's the only reason why we needed to call set_params_splitkv), and error if user provides anything != 1.
#2542 will be landed as a follow up PR to truly enable varlen split KV.