[Spec Decode] Enable efficient speculative decoding with FlashInfer-MLA#25984
[Spec Decode] Enable efficient speculative decoding with FlashInfer-MLA#25984benchislett merged 8 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
There was a problem hiding this comment.
Code Review
This pull request refactors the MLA backend to support speculative decoding with FlashInfer, which is a great improvement. The changes are mostly well-structured. However, I found a critical issue in the fallback logic for handling non-uniform query lengths in FlashInferMLAImpl, which could lead to a runtime error. My review includes a suggestion to fix this.
|
Update: the failed baseline is most likely due to an unknown bug in MLA chunked prefill logic. See #26042 |
| # `reorder_batch_threshold > 1`, any decode requests which do not | ||
| # have the same query length as the first decode request will | ||
| # fall back to the prefill kernel. | ||
| supports_nonuniform_decode: ClassVar[bool] = False |
There was a problem hiding this comment.
nit: is this needed if its always set to false? (I think we should set this for FlashAttnMLA since it does support supports_nonuniform_decode)
There was a problem hiding this comment.
I think we maybe can actually just unify supports_spec_as_decode and supports_nonuniform_decode to supports_only_uniform_spec_decode and when thats False we just leave reorder_batch_threshold untouched and require_uniform = False
There was a problem hiding this comment.
@LucasWilkinson I'm pretty sure there can be a full matrix of options here, and that different combinations are useful. For example:
supports_spec_as_decode and supports_nonuniform_decode: FlashAttnMLA, whererequire_uniform=Falseis correct (it can handle varlen), and the longreorder_batch_thresholdallows it to handle spec requests.supports_spec_as_decode and not supports_nonuniform_decode, whererequire_uniform=Trueis required to function correctly, butreorder_batch_thresholdcan be overridden to= 1 + num_spec_tokensto handle spec decoding.not supports_spec_as_decode and not supports_nonuniform_decodeis the default for the backends which require q_len == 1.
There was a problem hiding this comment.
I will update FlashAttnMLA to reflect the correct defaults, but I don't know how to support each of these 3 cases cleanly with only a single flag. Let me know if you would still prefer a different interface.
There was a problem hiding this comment.
I think for the case FlashAttnMLA case the reorder threshold is already high enough we dont need to adjust reorder_batch_threshold when spec-decoding is turned on; my suspicion would be that if a backend supports_nonuniform_decode we should just set the reorder_batch_threshold >= 8ish so that we capture the spec-decode naturally (FlashAttnMLA is really the only example of this currently)
I think if backend that supports_nonuniform_decode but also benefits from dynamically adjusting reorder_batch_threshold comes along then we could add this flag back; but just seems like unnecessary complexity currently (imo)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
LucasWilkinson
left a comment
There was a problem hiding this comment.
Overall looks good to me 👍 left one follow-up: https://github.com/vllm-project/vllm/pull/25984/files#r2408516263
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
…LA (vllm-project#25984) Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
…LA (vllm-project#25984) Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
…LA (vllm-project#25984) Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
…LA (vllm-project#25984) Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
…LA (vllm-project#25984) Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Purpose
This PR refactors the
MLACommonMetadataBuilderto easily support spec decode kernel optimization in MLA implementations. This is used to enable FlashInfer-MLA support using the trtllm-gen kernels which have explicit support for spec-as-decode.Test Plan
I ran a suite of evals over
nvidia/DeepSeek-R1-FP4anddeepseek-ai/DeepSeek-R1-0528on 4xB200 and 8xB200 respectively, usingCutlass-MLAandFlashInfer-MLAbackends. Running MTP with FP4 on B200 requires the fix in #25987.Known issues
TheCutlass-MLAbackend produces incorrect output when using speculative decoding. It is not clear to my why this happens, I have debugged with enforce-eager but did not identify any issues except incorrect model output. I have not verified if this also occurs on H200, but I believeFLASH_ATTN_MLAis also an option on Hopper so it may be sufficient to deprecateCutlass-MLAwhen speculative decoding is enabled.See #26042 for tracking on this correctness issue, which seems to indicate the root cause is MLA chunked prefill.
The fix is in #26063. I will rerun the experiments for a better baseline, but the correctness of this branch for MTP is still valid.
Test Result
4xB200 nvidia/DeepSeek-R1-FP4 FlashInfer-MLA MTP=3
4xB200 nvidia/DeepSeek-R1-FP4 Cutlass-MLA MTP=3
4xB200 nvidia/DeepSeek-R1-FP4 FlashInfer-MLA No-Spec
4xB200 nvidia/DeepSeek-R1-FP4 Cutlass-MLA No-Spec
8xB200 deepseek-ai/DeepSeek-R1-0528 FlashInfer-MLA MTP=3
8xB200 deepseek-ai/DeepSeek-R1-0528 Cutlass-MLA MTP=3
8xB200 deepseek-ai/DeepSeek-R1-0528 FlashInfer-MLA No-Spec
8xB200 deepseek-ai/DeepSeek-R1-0528 Cutlass-MLA No-Spec