-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[Spec Decode] Integrate Suffix Decoding from Arctic Inference #25784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 23 commits
2d4180c
db621b4
a15727d
50d6fd3
92fe31e
6755a31
6acc529
72c85e6
da82b85
0751844
9692d9d
4536270
6a24831
6342db2
ecf3efd
062bafe
1e9ac44
9774fb3
213276f
4724b25
dba74a6
0f45554
af6a237
e94166b
802d291
7050bc4
1f3d326
a79107f
2d8e05d
e1d62bc
1d64189
149c907
d85dc3f
2f5b451
387fe61
aae2b61
7ade5e7
ae0beb3
ba82677
cfcfcde
78798c2
d851834
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3303,6 +3303,12 @@ def has_tilelang() -> bool: | |
| return _has_module("tilelang") | ||
|
|
||
|
|
||
| def has_arctic_inference() -> bool: | ||
|
||
| """Whether the optional `arctic_inference` package is available.""" | ||
|
|
||
| return _has_module("arctic_inference") | ||
|
|
||
|
|
||
| def set_process_title( | ||
| name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX | ||
| ) -> None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
aurickq marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| from vllm.config import VllmConfig | ||
| from vllm.v1.worker.gpu_input_batch import InputBatch | ||
|
|
||
|
|
||
| class SuffixDecodingProposer: | ||
| """ | ||
| Speculative decoding proposer for Suffix Decoding (https://arxiv.org/pdf/2411.04975). | ||
| This class imports and uses the official implementation from Arctic Inference | ||
| (https://github.com/snowflakedb/ArcticInference). | ||
| """ | ||
|
|
||
| def __init__(self, vllm_config: VllmConfig): | ||
| config = vllm_config.speculative_config | ||
| self.num_speculative_tokens = config.num_speculative_tokens | ||
| self.max_tree_depth = config.suffix_decoding_max_tree_depth | ||
| self.max_spec_factor = config.suffix_decoding_max_spec_factor | ||
| self.min_token_prob = config.suffix_decoding_min_token_prob | ||
| self.max_model_len = vllm_config.model_config.max_model_len | ||
|
|
||
| # Lazy import to avoid error when Suffix Decoding is not used. | ||
| from arctic_inference.suffix_decoding import SuffixDecodingCache | ||
|
|
||
| # Initialize and empty cache. This object will take care of caching request | ||
| # outputs, evicting old requests, and manages the per-prompt suffix trees. | ||
| self.suffix_cache = SuffixDecodingCache( | ||
| max_tree_depth=config.suffix_decoding_max_tree_depth, | ||
| max_cached_requests=config.suffix_decoding_max_cached_requests, | ||
| ) | ||
|
|
||
| def propose( | ||
| self, | ||
| input_batch: InputBatch, | ||
| sampled_token_ids: list[list[int]], | ||
| ) -> list[list[int]]: | ||
| """ | ||
| Propose speculative tokens for each request in the input batch. Suffix Decoding | ||
| will speculate a dynamic number of tokens for each request every decoding step, | ||
| so each entry in the returned list may have different lengths. | ||
| """ | ||
| draft_token_ids: list[list[int]] = [] | ||
| for i, sampled_ids in enumerate(sampled_token_ids): | ||
| if not sampled_ids: | ||
| # Skip speculative decoding for partial prefills. | ||
| draft_token_ids.append([]) | ||
| continue | ||
|
|
||
| # Skip requests that require sampling parameters that are not | ||
| # supported with speculative decoding. | ||
| req_id = input_batch.req_ids[i] | ||
| if req_id in input_batch.spec_decode_unsupported_reqs: | ||
| draft_token_ids.append([]) | ||
| continue | ||
|
|
||
| num_tokens = input_batch.num_tokens_no_spec[i] | ||
| if num_tokens >= self.max_model_len: | ||
| # Skip requests that have already reached the max model length. | ||
| draft_token_ids.append([]) | ||
| continue | ||
|
Comment on lines
+42
to
+60
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I might forget to flush one of my previous comment. Seems there's quite some duplicated code here from NgramProposer. I'm wondering if we should come up with some ModelFreeProposer class, and put the common logic here. Ideally, that would make the future extensions easier as well.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems currently it's just the three |
||
|
|
||
| index = input_batch.req_id_to_index[req_id] | ||
| if req_id not in self.suffix_cache.active_requests: | ||
| if req_id in self.suffix_cache.cached_requests: | ||
| # Reset the suffix cache for this request. | ||
| self.suffix_cache.evict_cached_response(req_id) | ||
| num_prompt_tokens = input_batch.num_prompt_tokens[index] | ||
| prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens] | ||
| # Start a new request, this will build the suffix tree for that prompt. | ||
| self.suffix_cache.start_request(req_id, prompt_token_ids) | ||
|
|
||
| # Append the newly sampled ids to the suffix cache for this request. | ||
| self.suffix_cache.add_active_response(req_id, sampled_ids) | ||
|
|
||
| start = max(0, num_tokens - self.max_tree_depth) | ||
| pattern = input_batch.token_ids_cpu[i, start:num_tokens] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. n00b question: IIUC, self.max_tree_depth is served similar to max_ngram in ngram spec decoding, right? And we will use all suffix in the pattern from length=1 to length=max_tree_depth to try finding the most frequently matched subtree in suffix_cache, right?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in suffix decoding, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
when the pattern string is very short (1-2 tokens), is the probability of getting the correct predicted token lower? Is it necessary to set a parameter similar to prompt_lookup_min?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah the probability will be lower, but suffix decoding will avoid making that speculation if it's too low, controlled by the there is this underlying |
||
| draft = self.suffix_cache.speculate( | ||
| req_id, | ||
| pattern, | ||
| max_spec_tokens=min( | ||
| self.num_speculative_tokens, self.max_model_len - num_tokens - 1 | ||
| ), | ||
| max_spec_factor=self.max_spec_factor, | ||
| min_token_prob=self.min_token_prob, | ||
| ) | ||
|
|
||
| draft_token_ids.append(draft.token_ids) | ||
|
|
||
| # Stop requests that were not seen in the input batch. | ||
| for req_id in (self.suffix_cache.active_requests - | ||
| input_batch.req_id_to_index.keys()): | ||
| self.suffix_cache.stop_request(req_id) | ||
|
|
||
| return draft_token_ids | ||
|
|
||
| def load_model(self, *args, **kwargs): | ||
| # No model to load. | ||
| pass | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mind to explore the memory usage of the suffix decoding cache? Just curious, if there're other mechanism to limit the cache memory usage in the case that all requests are long context, in that case, the corresponding suffix tree corresponding to 10000 requests could be quite large.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a small clarification that only the responses are stored in the global cache, not prompts. Suffix decoding builds one tree for each prompt which is only used while that request is alive, and deleted after the request finishes.
For other ways to limit the global cache size, perhaps a limit based on number of tokens might be better? Open to suggestions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, make sense. However, there would still be cases with long outputs (e.g. reasoning models). So I think adding some token-level guarding might still be helpful, and we could do that in a relatively simpler way.
WDYT? But I think that could be a followup enhancement, but might worth the flags the memory usages could be linear to max(total tokens in <max_cached_recent_requests> requests) IIUC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, having an additional config to limit the number of tokens makes sense. Created snowflakedb/ArcticInference#215 to track.