-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[Spec Decode] Integrate Suffix Decoding from Arctic Inference #25784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request integrates Suffix Decoding from Arctic Inference as a new speculative decoding method. The changes are well-structured, adding new configuration options, validation, and the core logic for proposing draft tokens and managing the suffix cache. My review identifies a potential type inconsistency in the token sequences passed to the arctic-inference library, which could lead to runtime errors. I've suggested a fix to ensure consistency.
|
@codex review |
|
note to reviewers:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting
|
@aurickq, thanks for your awesome contribution, the results look good! Suffix decoding outperforms n-gram at |
The out=1024 and out=256 are also two different datasets, so might not be very comparable. Other than that, when the concurrency is high and the number of output tokens is low (e.g. 256), the request completion time becomes dominated by mixed-prefill batches that drag up the mean TPOT metric. So it makes sense for these cases the performance of suffix and ngram will approach each other. As for why suffix becomes a little worse than ngram for spec_bench out=256 and concurrency=64, here is my guess: the SpecBench dataset is more open-ended (higher entropy, less repetition) than refactor-benchmark, so we should already would expect suffix/ngram to perform worse on it. The benchmark is also small (400-500 examples), so suffix decoding might not have built a sufficiently large cache to accurately predict the next tokens. From the benchmarks, the performance of suffix decoding actually is better when this cache is disabled in this setting. I have some ideas for solving this latter issue when the cached data is sparse, which I might later implement and contribute as a "suffix v2" method, if it works. |
|
Thanks a lot for the contribution @aurickq ! A few questions.
|
|
| draft_token_ids: list[list[int]] = [] | ||
| for i, sampled_ids in enumerate(sampled_token_ids): | ||
| if not sampled_ids: | ||
| # Skip speculative decoding for partial prefills. | ||
| draft_token_ids.append([]) | ||
| continue | ||
|
|
||
| # Skip requests that require sampling parameters that are not | ||
| # supported with speculative decoding. | ||
| req_id = input_batch.req_ids[i] | ||
| if req_id in input_batch.spec_decode_unsupported_reqs: | ||
| draft_token_ids.append([]) | ||
| continue | ||
|
|
||
| num_tokens = input_batch.num_tokens_no_spec[i] | ||
| if num_tokens >= self.max_model_len: | ||
| # Skip requests that have already reached the max model length. | ||
| draft_token_ids.append([]) | ||
| continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I might forget to flush one of my previous comment. Seems there's quite some duplicated code here from NgramProposer. I'm wondering if we should come up with some ModelFreeProposer class, and put the common logic here.
Ideally, that would make the future extensions easier as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems currently it's just the three continue statements that overlap, which is a pretty small part of the NgramProposer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, as the logic is showing promising results and quite self-contained which doesn't affect majority of use cases.
CC @houseroad for the potential final review
|
This pull request has merge conflicts that must be resolved before it can be |
|
Rebased. Could someone help trigger CI? |
@aurickq Could you try to address the DCO and doc build failure first? |
fixed the doc failure. for dco in the past i've avoided addressing this since it leaks my personal email publicly :) (not sure if this part changed) |
…roject#25784) Co-authored-by: Aurick Qiao <[email protected]>
…roject#25784) Co-authored-by: Aurick Qiao <[email protected]>
|
@aurickq Why would you use this parameter --no-enable-prefix-caching ? |
…roject#25784) Co-authored-by: Aurick Qiao <[email protected]>
Purpose
This PR adds Suffix Decoding (https://arxiv.org/abs/2411.04975) as a new speculative decoding method in vLLM. Suffix Decoding is a dynamic n-gram matching method that:
Test Plan
Test Result
Benchmarks on Specbench and Blazedit are below (on H200). Suffix Decoding beats ngram in pretty much all cases. In practice, we have seen larger speedups for real user interactions and agentic requests, since they tend to exhibit more output repetition than these benchmark datasets.
Script for benchmark reproduction: benchmark.sh
Specbench
Time per output token (ms)
Total drafted tokens
Total accepted tokens
Blazedit
Time per output token (ms)
Total drafted tokens
Total accepted tokens
Older Results (before optimizing)
refactor-bench (out=1024)
Results are mean TPOT (ms)
spec-bench (out=256)
Results are mean TPOT (ms)