-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Universal Speculative Decoding CandidateGenerator
#35029
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 81 commits
aa7e01a
f6b7f20
0ded37c
d48b69b
b47e33a
8a99129
4649bd2
f199c94
19c0057
acf5a4b
3712117
63f2f46
6ac33f1
6938311
5a0db3b
92f8ad3
7c8708e
880d0ae
d9b5e74
25974d5
b8636ab
1ef46b7
f8e94eb
439db84
643901d
77097ff
d08b4f0
f242dc1
ede1176
511ee96
5e47945
25a4349
95fe744
0ad88b2
bc5fa61
41a5670
44f7ba7
57aafcc
6f95c33
078f763
faac2fc
76a2dd3
43e96e7
e63cb9d
8aa6020
2169973
c6da827
2cf9e8e
a1c0d05
d830091
19d0cce
5b8217d
9b0126a
578d0b3
7db2695
fb69900
e40c775
b5ce873
bfccdea
d34d7ea
9d4d9f9
4e92e9c
3fe2d31
a350b1c
e047adf
011f595
2652490
701edbb
7088978
3b89341
e43dba8
a529795
9025751
25cd5da
77edae2
a2a2882
a556947
407d898
a24b193
1afdaa3
88f6877
4e3660a
c162c88
d0798a0
d18d090
ae2f16f
02dba31
49a228f
7f76fec
32335a5
1a79647
78a2a2c
751a099
bfb636d
00e325d
8a39f5b
64c95fe
7661fc9
503ece9
fb7187d
dedcf98
7e3f3dc
c9fc5a6
94e8a31
4e23470
eae175c
6784931
d20f07b
be79a15
9cb0a3a
b0e7a16
683bbee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | |||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,6 +14,8 @@ | ||||||||||||||||||||||||||||||||||
| # limitations under the License. | |||||||||||||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||||||||||||
| import copy | |||||||||||||||||||||||||||||||||||
| import threading | |||||||||||||||||||||||||||||||||||
| import weakref | |||||||||||||||||||||||||||||||||||
| from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple | |||||||||||||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||||||||||||
| import numpy as np | |||||||||||||||||||||||||||||||||||
|
|
@@ -27,7 +29,11 @@ | ||||||||||||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||||||||||||
| from ..cache_utils import DynamicCache | |||||||||||||||||||||||||||||||||||
| from ..pytorch_utils import isin_mps_friendly | |||||||||||||||||||||||||||||||||||
| from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor | |||||||||||||||||||||||||||||||||||
| from .logits_process import ( | |||||||||||||||||||||||||||||||||||
| LogitsProcessorList, | |||||||||||||||||||||||||||||||||||
| MinLengthLogitsProcessor, | |||||||||||||||||||||||||||||||||||
| SuppressTokensLogitsProcessor, | |||||||||||||||||||||||||||||||||||
| ) | |||||||||||||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||||||||||||
| if TYPE_CHECKING: | |||||||||||||||||||||||||||||||||||
|
|
@@ -284,18 +290,21 @@ def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> Tuple[int, int]: | ||||||||||||||||||||||||||||||||||
| min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0) | |||||||||||||||||||||||||||||||||||
| return min_new_tokens, max_new_tokens | |||||||||||||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||||||||||||
| def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool: | |||||||||||||||||||||||||||||||||||
| def _update_past_and_masks( | |||||||||||||||||||||||||||||||||||
| self, input_ids: torch.LongTensor, remove_from_pkv: int = 0, num_added_tokens: int = 1 | |||||||||||||||||||||||||||||||||||
| ) -> bool: | |||||||||||||||||||||||||||||||||||
| """Update past key values and attention masks for subsequent generation rounds.""" | |||||||||||||||||||||||||||||||||||
| has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None | |||||||||||||||||||||||||||||||||||
| if has_past_key_values: | |||||||||||||||||||||||||||||||||||
| new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv | |||||||||||||||||||||||||||||||||||
| self.assistant_kwargs["past_key_values"] = _crop_past_key_values( | |||||||||||||||||||||||||||||||||||
| self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 | |||||||||||||||||||||||||||||||||||
| self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - num_added_tokens | |||||||||||||||||||||||||||||||||||
| ) | |||||||||||||||||||||||||||||||||||
| self.assistant_kwargs = _prepare_attention_mask( | |||||||||||||||||||||||||||||||||||
| self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder | |||||||||||||||||||||||||||||||||||
| ) | |||||||||||||||||||||||||||||||||||
| self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1]) | |||||||||||||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||||||||||||
| return has_past_key_values | |||||||||||||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||||||||||||
| def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict: | |||||||||||||||||||||||||||||||||||
|
|
@@ -609,6 +618,286 @@ def _process_assistant_outputs( | ||||||||||||||||||||||||||||||||||
| return new_target_ids | |||||||||||||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||||||||||||
| class AssistantToTargetTranslator: | |||||||||||||||||||||||||||||||||||
| """ | |||||||||||||||||||||||||||||||||||
| Translate the assistant into the target universe. | |||||||||||||||||||||||||||||||||||
keyboardAnt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
|||||||||||||||||||||||||||||||||||
| """ | |||||||||||||||||||||||||||||||||||
|
|
|||||||||||||||||||||||||||||||||||
| def __init__( | |||||||||||||||||||||||||||||||||||
| self, | |||||||||||||||||||||||||||||||||||
| target_tokenizer: "PreTrainedTokenizerBase", | |||||||||||||||||||||||||||||||||||
| assistant_tokenizer: "PreTrainedTokenizerBase", | |||||||||||||||||||||||||||||||||||
| assistant_model_device: str = "cpu", | |||||||||||||||||||||||||||||||||||
| target_vocab_size: Optional[int] = None, | |||||||||||||||||||||||||||||||||||
| filter_value: float = -float("Inf"), | |||||||||||||||||||||||||||||||||||
| suppress_tokens_id: int = -1, | |||||||||||||||||||||||||||||||||||
keyboardAnt marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
|||||||||||||||||||||||||||||||||||
| ): | |||||||||||||||||||||||||||||||||||
| self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer | |||||||||||||||||||||||||||||||||||
| self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer | |||||||||||||||||||||||||||||||||||
| self._assistant_model_device: str = assistant_model_device | |||||||||||||||||||||||||||||||||||
| if target_vocab_size: | |||||||||||||||||||||||||||||||||||
| self.target_vocab_size: int = target_vocab_size | |||||||||||||||||||||||||||||||||||
| else: | |||||||||||||||||||||||||||||||||||
| self.target_vocab_size: int = len(self._target_tokenizer.get_vocab()) | |||||||||||||||||||||||||||||||||||
|
|||||||||||||||||||||||||||||||||||
| Model | len(tokenizer.get_vocab()) | config.vocab_size |
|---|---|---|
| Qwen/Qwen2.5-1.5B-Instruct | 151665 | 151936 |
| Qwen/Qwen2.5-0.5B-Instruct | 151665 | 151936 |
| Qwen/Qwen2.5-0.5B | 151665 | 151936 |
| microsoft/Phi-3.5-mini-instruct | 32011 | 32064 |
| microsoft/Phi-3-medium-128k-instruct | 32011 | 32064 |
| deepseek-ai/DeepSeek-R1-Distill-Qwen-32B | 151665 | 152064 |
| deepseek-ai/DeepSeek-R1-Distill-Qwen-14B | 151665 | 152064 |
| deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B | 151665 | 151936 |
| deepseek-ai/DeepSeek-R1 | 128815 | 129280 |
| deepseek-ai/DeepSeek-R1-Distill-Qwen-7B | 151665 | 152064 |
Anyway logits shape corresponds to config.vocab_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, it would actually be better to use config.vocab_size! If you check the tokenizer for e.g. Qwen/Qwen2.5-0.5B, you'll see that the last token as the id 151664, which means we don't want to use tokens in the [151665, 151935] interval.
The extra ids in the tokenizer are used to pad the embedding layer, so as to make it a multiple of a power of 2 (different optimizations may use different values for the power of 2).
(Imo the value in the two should be the same, corresponding to the unpadded value 🤔 )
Uh oh!
There was an error while loading. Please reload this page.