[Attention] Fix FlashMLA metadata builder arguments for q_len > 1#27368
[Attention] Fix FlashMLA metadata builder arguments for q_len > 1#27368LucasWilkinson merged 2 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
There was a problem hiding this comment.
Code Review
This pull request correctly fixes a bug in the FlashMLA metadata builder for decode scenarios with a query length greater than one. The change properly calculates num_q_tokens_per_head_k and passes it to get_mla_metadata, which resolves the performance degradation and crashes noted in the description. The provided benchmarks clearly demonstrate the significant speedup achieved by this fix. The implementation is correct and well-targeted. Overall, this is an excellent and important bug fix.
mgoin
left a comment
There was a problem hiding this comment.
Is there an eval we can run to validate this? I assume we could do deepseek with mtp enabled
LucasWilkinson
left a comment
There was a problem hiding this comment.
LGTM; thanks for tracking this down!
nit: can you make a small note that we use the max but all the query lens should be the same
|
@mgoin will do! |
LucasWilkinson
left a comment
There was a problem hiding this comment.
LGTM (assuming evals path; dont merge till then; but I dont see any reason the wont)
|
@mgoin @LucasWilkinson confirmed evals look good: |
…lm-project#27368) Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
…lm-project#27368) Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
…lm-project#27368) Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
…lm-project#27368) Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
…lm-project#27368) Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Purpose
As of #26541, FlashMLA now supports
q_len > 1in the decode pipeline. Theget_mla_metadatacall was not updated, however, leading to poor performance (and potentially, crashes) in these cases. This PR is a simple bug fix achieving a substantial speedup, especially at small batch sizes.Note: uses the benchmarks in #26835 (not yet merged)
cc @LucasWilkinson
Test Plan
python benchmarks/attention_benchmarks/benchmark.py --config benchmarks/attention_benchmarks/configs/flashmla_bugfix_demo.yamlTest Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.