[Spec Decode] Utilities and refactor to support qlen>1 decode kernels for spec decode#25183
[Spec Decode] Utilities and refactor to support qlen>1 decode kernels for spec decode#25183benchislett wants to merge 4 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces several utilities and refactors to support speculative decoding with query lengths greater than one. The changes include making reorder_batch_threshold an instance variable for dynamic configuration, adding a 'uniform' mode for batch splitting, and providing helper functions for tensor reshaping. The new logic for batch splitting is well-tested and appears correct. The refactoring improves code clarity and prepares the codebase for new speculative decoding backends. I have one suggestion to improve an assertion in a new helper function for clarity and correctness.
| @@ -766,6 +798,40 @@ def reorder_batch_to_split_decodes_and_prefills( | |||
| return modified_batch | |||
|
|
|||
|
|
|||
| def reshape_query_for_spec_decode(query: torch.Tensor, | |||
There was a problem hiding this comment.
Are these not used yet? should we just include them in the follow-up once they are actually used? or maybe we should add FlashMLA support in this PR? Just so everything is used (and tested since we can do a FlashMLA + MTP lm_eval run)
There was a problem hiding this comment.
I think it is worth committing now and using in subsequent PRs mostly because it will be used by FlashMLA and also FlashInfer-MLA and maybe more. Merging here as a helper means that all the downstream PRs can reuse the same code from main instead of duplicating it in each.
But I don't feel particularly strongly about this, and can remove if you think it's better to add separately.
Co-authored-by: lhsjohn <huashuoli@tencent.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
6eae35e to
cfa3273
Compare
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
|
This pull request has merge conflicts that must be resolved before it can be |
|
Closing, see #25196 |
Purpose
This PR makes common changes needed to enable FlashInfer, FlashInferMLA, FlashMLA, and other new backends for speculative decoding. This PR does not add explicit support for any one of these.
Included in this PR is:
reorder_batch_threshold, making it no longer aClassVar. This is because is can now be specialized at initialization time depending on whether or not speculative decoding is enabled: when it is, we can set it tonum_speculative_tokens + 1so that all spec-verify can be classified as decodes. A helper function is also included to facilitate thisTest Plan
See
tests/v1/attention/test_attention_splitting.pyTest Result
All passing locally
Essential Elements of an Effective PR Description Checklist