-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[Spec Decode] Integrate Suffix Decoding from Arctic Inference #25784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
2d4180c
db621b4
a15727d
50d6fd3
92fe31e
6755a31
6acc529
72c85e6
da82b85
0751844
9692d9d
4536270
6a24831
6342db2
ecf3efd
062bafe
1e9ac44
9774fb3
213276f
4724b25
dba74a6
0f45554
af6a237
e94166b
802d291
7050bc4
1f3d326
a79107f
2d8e05d
e1d62bc
1d64189
149c907
d85dc3f
2f5b451
387fe61
aae2b61
7ade5e7
ae0beb3
ba82677
cfcfcde
78798c2
d851834
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,7 @@ | |
| from vllm.config.parallel import ParallelConfig | ||
| from vllm.config.utils import config | ||
| from vllm.logger import init_logger | ||
| from vllm.utils import LazyLoader | ||
| from vllm.utils import LazyLoader, has_arctic_inference | ||
|
|
||
| if TYPE_CHECKING: | ||
| from transformers import PretrainedConfig | ||
|
|
@@ -32,7 +32,7 @@ | |
| SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", | ||
| "mlp_speculator", "draft_model", "deepseek_mtp", | ||
| "ernie_mtp", "qwen3_next_mtp", "mimo_mtp", | ||
| "longcat_flash_mtp", "mtp"] | ||
| "longcat_flash_mtp", "mtp", "suffix"] | ||
| MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp", | ||
| "qwen3_next_mtp", "longcat_flash_mtp") | ||
|
|
||
|
|
@@ -123,6 +123,27 @@ class SpeculativeConfig: | |
| ParallelConfig] = None # type: ignore | ||
| """The parallel configuration for the draft model initialized internal.""" | ||
|
|
||
| # Suffix decoding configuration | ||
| suffix_decoding_max_tree_depth: int = 64 | ||
| """The maximum depth of the suffix decoding global and prompt trees. The | ||
| tree depth limits the sum of the prefix match and speculation lengths.""" | ||
|
|
||
| suffix_decoding_max_cached_requests: int = 10000 | ||
| """The maximum number of requests to cache in the global suffix tree. If | ||
| exceeded, will trigger eviction in FIFO order. If set to 0, the global | ||
| suffix tree is disabled and past responses are not cached (prompt trees | ||
| are still used).""" | ||
|
Comment on lines
+138
to
+142
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mind to explore the memory usage of the suffix decoding cache? Just curious, if there're other mechanism to limit the cache memory usage in the case that all requests are long context, in that case, the corresponding suffix tree corresponding to 10000 requests could be quite large.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a small clarification that only the responses are stored in the global cache, not prompts. Suffix decoding builds one tree for each prompt which is only used while that request is alive, and deleted after the request finishes. For other ways to limit the global cache size, perhaps a limit based on number of tokens might be better? Open to suggestions
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah, make sense. However, there would still be cases with long outputs (e.g. reasoning models). So I think adding some token-level guarding might still be helpful, and we could do that in a relatively simpler way.
WDYT? But I think that could be a followup enhancement, but might worth the flags the memory usages could be linear to max(total tokens in <max_cached_recent_requests> requests) IIUC.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, having an additional config to limit the number of tokens makes sense. Created snowflakedb/ArcticInference#215 to track. |
||
|
|
||
| suffix_decoding_max_spec_factor: float = 1.0 | ||
| """The maximum spec factor for suffix decoding. The spec factor controls | ||
| speculation lengths based on the prefix match length: max_spec_tokens = | ||
| max_spec_factor * prefix_match_length.""" | ||
|
|
||
| suffix_decoding_min_token_prob: float = 0.1 | ||
| """The minimum token probability for suffix decoding. Will only speculate | ||
| tokens with estimated probability (based on frequency counts) greater than | ||
| or equal to this value.""" | ||
|
|
||
| def compute_hash(self) -> str: | ||
| """ | ||
| WARNING: Whenever a new field is added to this config, | ||
|
|
@@ -227,6 +248,8 @@ def __post_init__(self): | |
| self.quantization = self.target_model_config.quantization | ||
| elif self.method in ("ngram", "[ngram]"): | ||
| self.model = "ngram" | ||
| elif self.method == "suffix": | ||
| self.model = "suffix" | ||
| else: | ||
| raise ValueError( | ||
| "num_speculative_tokens was provided but without " | ||
|
|
@@ -271,6 +294,31 @@ def __post_init__(self): | |
| # draft related config as None here. | ||
| self.draft_model_config = self.target_model_config | ||
| self.draft_parallel_config = self.target_parallel_config | ||
| elif self.method == "suffix": | ||
| if not has_arctic_inference(): | ||
| raise ImportError( | ||
| "Arctic Inference is required for suffix decoding. " | ||
| "Please install via `pip install arctic-inference`.") | ||
aurickq marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if self.num_speculative_tokens is None: | ||
| self.num_speculative_tokens = 32 | ||
aurickq marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # Validate values | ||
| if self.suffix_decoding_max_tree_depth < 4: | ||
aurickq marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| raise ValueError( | ||
| f"suffix_decoding_max_tree_depth=" | ||
| f"{self.suffix_decoding_max_tree_depth} must be >= 4") | ||
| if self.suffix_decoding_max_cached_requests < 0: | ||
| raise ValueError( | ||
| f"suffix_decoding_max_cached_requests=" | ||
| f"{self.suffix_decoding_max_cached_requests} must be >= 0") | ||
| if self.suffix_decoding_max_spec_factor < 0: | ||
| raise ValueError( | ||
| f"suffix_decoding_max_spec_factor=" | ||
| f"{self.suffix_decoding_max_spec_factor} must be >= 0") | ||
| if (self.suffix_decoding_min_token_prob < 0 | ||
| or self.suffix_decoding_min_token_prob > 1): | ||
| raise ValueError( | ||
| f"suffix_decoding_min_token_prob=" | ||
| f"{self.suffix_decoding_min_token_prob} must be in [0, 1]") | ||
| else: | ||
| self.prompt_lookup_max = 0 | ||
| self.prompt_lookup_min = 0 | ||
|
|
@@ -557,6 +605,9 @@ def use_eagle(self) -> bool: | |
|
|
||
| def __repr__(self) -> str: | ||
| method = self.method | ||
| model = None if method == "ngram" else self.draft_model_config.model | ||
| if method in ("ngram", "suffix"): | ||
| model = None | ||
| else: | ||
| model = self.draft_model_config.model | ||
| num_spec_tokens = self.num_speculative_tokens | ||
| return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3440,6 +3440,12 @@ def has_tilelang() -> bool: | |
| return _has_module("tilelang") | ||
|
|
||
|
|
||
| def has_arctic_inference() -> bool: | ||
|
||
| """Whether the optional `arctic_inference` package is available.""" | ||
|
|
||
| return _has_module("arctic_inference") | ||
|
|
||
|
|
||
| def set_process_title(name: str, | ||
| suffix: str = "", | ||
| prefix: str = envs.VLLM_PROCESS_NAME_PREFIX) -> None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,99 @@ | ||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||
aurickq marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||
| from vllm.config import VllmConfig | ||||||
| from vllm.v1.worker.gpu_input_batch import InputBatch | ||||||
|
|
||||||
|
|
||||||
| class SuffixDecodingProposer: | ||||||
|
|
||||||
| def __init__(self, vllm_config: VllmConfig): | ||||||
| config = vllm_config.speculative_config | ||||||
| self.num_speculative_tokens = config.num_speculative_tokens | ||||||
| self.max_tree_depth = config.suffix_decoding_max_tree_depth | ||||||
| self.max_spec_factor = config.suffix_decoding_max_spec_factor | ||||||
| self.min_token_prob = config.suffix_decoding_min_token_prob | ||||||
| self.max_model_len = vllm_config.model_config.max_model_len | ||||||
|
|
||||||
| # Lazy import to avoid error when Suffix Decoding is not used. | ||||||
| from arctic_inference.suffix_decoding import SuffixDecodingCache | ||||||
|
|
||||||
| self.suffix_cache = SuffixDecodingCache( | ||||||
| max_tree_depth=config.suffix_decoding_max_tree_depth, | ||||||
| max_cached_requests=config.suffix_decoding_max_cached_requests) | ||||||
|
|
||||||
| def update( | ||||||
| self, | ||||||
| input_batch: InputBatch, | ||||||
| sampled_token_ids: list[list[int]], | ||||||
| ): | ||||||
| seen_req_ids = set() | ||||||
| for i, sampled_ids in enumerate(sampled_token_ids): | ||||||
| req_id = input_batch.req_ids[i] | ||||||
| seen_req_ids.add(req_id) | ||||||
|
|
||||||
| if not sampled_ids: | ||||||
| continue | ||||||
|
|
||||||
| index = input_batch.req_id_to_index[req_id] | ||||||
| if req_id not in self.suffix_cache.active_requests: | ||||||
| if req_id in self.suffix_cache.cached_requests: | ||||||
| # Reset the suffix cache for this request. | ||||||
| self.suffix_cache.evict_cached_response(req_id) | ||||||
| num_prompt_tokens = input_batch.num_prompt_tokens[index] | ||||||
| prompt_token_ids = ( | ||||||
| input_batch.token_ids_cpu[index, :num_prompt_tokens]) | ||||||
| prompt_token_ids = prompt_token_ids.tolist() | ||||||
| self.suffix_cache.start_request(req_id, prompt_token_ids) | ||||||
|
|
||||||
| self.suffix_cache.add_active_response(req_id, sampled_ids) | ||||||
|
|
||||||
| # Stop requests that are not seen | ||||||
| for req_id in list(self.suffix_cache.active_requests): | ||||||
|
||||||
| for req_id in list(self.suffix_cache.active_requests): | |
| for req_id in self.suffix_cache.active_requests: |
It seems superfluous to copy the backing dictionary's keys into a list here (especially since what we do below isn't order-dependent)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the stop_request will change the active_requests during iteration, so casting to list to avoid conflicts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. Then perhaps it might be worth it to expose a SuffixDecodingCache.stop_inactive_requests method that handles this inplace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not immediately clear how much this will help, since the suffix cache needs to rely on the calling code to tell it which requests are active or inactive, which means passing a newly constructed list in the first place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops, my bad - I meant something along the lines def SuffixDecodingCache.stop_all_requests(self, *, req_ids_to_keep: Iterable[...] | None = None): to which we pass seen_req_ids. The method then basically reproduces what this for loop does.
Sure, it's just a single copy (one of many paper cuts given that this is Python) but seeing how update will be called on a hot path (correct me if I'm wrong), we have reason to avoid pessimization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b question: IIUC, self.max_tree_depth is served similar to max_ngram in ngram spec decoding, right? And we will use all suffix in the pattern from length=1 to length=max_tree_depth to try finding the most frequently matched subtree in suffix_cache, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in suffix decoding, max_tree_depth is the upper limit for match_length + spec_length, and yeah it will try all suffixes in that range
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in suffix decoding,
max_tree_depthis the upper limit formatch_length + spec_length, and yeah it will try all suffixes in that range
when the pattern string is very short (1-2 tokens), is the probability of getting the correct predicted token lower? Is it necessary to set a parameter similar to prompt_lookup_min?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah the probability will be lower, but suffix decoding will avoid making that speculation if it's too low, controlled by the suffix_decoding_min_token_prob parameter. basically, it will look at the frequency of all the next tokens in the suffix tree, and will stop speculating if the most frequent token occurs less than suffix_decoding_min_token_prob of the time.
there is this underlying suffix_decoding_max_spec_offset parameter that performs a similar role as prompt_lookup_min but we haven't found any cases where setting that actually helped beyond that existing mechanism, so did not expose it.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this conversion of the numpy array to a list won't have as much overhead as a d2h copy, is there a reason why the SuffixDecodingCache class can't directly support ndarray parameters (also on line 45), particularly if the proposer meant to be the sole consumer of the class?
I'm not very familiar with pybind11, but it seems like the code in arctic_inference relies on the former to automatically marshal CPython lists into STL vectors, necessitating yet another copy internally.
Since the token IDs in both cases (SuffixTree.extend and SuffixTree.speculate) only require an immutable view of the data, it would be beneficial to either:
- Assuming
pybind11supports it, use a primitive to directly wrap thendarrayand pass that to the extension class. - Expose the
memoryviewreturned byndarray.dataand consume that on the extension class side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes we can add a zero-copy api in a future arctic-inference release, tracking in snowflakedb/ArcticInference#203
Uh oh!
There was an error while loading. Please reload this page.