[Bugfix] Fix sparse MLA metadata building#33579
Conversation
There was a problem hiding this comment.
Code Review
This pull request addresses a bug in sparse Multi-Head Latent Attention (MLA) by correctly handling metadata for sparse backends. The changes introduce a separate logic path for sparse implementations, bypassing the prefill/decode split that caused issues. For sparse backends, all tokens are now correctly routed through the forward_mqa path, which aligns with their design. The refactoring is clear and effectively resolves the bug. The changes look good.
|
Hi @MatthewBonanni, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
cc @zou3519 as this failure will probably show up in vLLM benchmark run on PyTorch CI until this PR is merged. Here is an example failure https://github.com/pytorch/pytorch-integration-testing/actions/runs/21584614705/job/62189623833#step:19:28492 |
| if is_sparse_impl: | ||
| has_decode = True | ||
| has_prefill = False | ||
| num_decode_tokens = q.size(0) |
There was a problem hiding this comment.
I’m new to this area, so I have a possibly naive question. Why is q.size(0) equal to num_decode_tokens?
Where should I start tracing this?
There was a problem hiding this comment.
This is because currently the MLA sparse implementation uses purely the MQA pathway for both prefill and decode, i.e. q.size(0) (memory bandwidth optimized, this is only used for decodes for dense MLA)
Sorry the naming here is a bit confusing; basically the MQA pathway is more memory efficient and the MHA pathway is more memory bandwidth efficient. So it makes sense to use MQA for decode where attention is memory bond and MHA for prefill that is more compute bound. However, sparsity changes this calculus and currently we only have an MQA pathway partly due to kernel support and partly because with sparsity and longer contexts it can make sense to use MQA for the prefill too. Hence, the renaming from forward_decode -> forward_mqa and forward_prefill -> forward_mha to relax the associations with prefill/decode. There is still some legacy naming here that will likely need to be refactored in future PRs.
There was a problem hiding this comment.
@MatthewBonanni to avoid confusion can we get rid of has_decode, has_prefill, num_decode_tokens and instead do num_mqa_tokens and num_mha_tokens then do
if num_mha_tokens > 0:
...
if num_mqa_tokens > 0:
...
I think this may cause less confusion
LucasWilkinson
left a comment
There was a problem hiding this comment.
thanks for fixing this, overall makes sense to me but I think we should consider: https://github.com/vllm-project/vllm/pull/33579/changes#r2757360444
LucasWilkinson
left a comment
There was a problem hiding this comment.
LGTM thanks for the cleanups!
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: felix01.yu <felix01.yu@vipshop.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Purpose
Fix #33546
#33284 broke sparse MLA by moving logic from the backend to the layer without properly accounting for sparse backends.
Test Plan
Test Result
Main: crashes during startup
PR:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.