diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 30babc19e732..930a61628c51 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -619,8 +619,24 @@ def _process_assistant_outputs( class AssistantToTargetTranslator: """ - Translate the assistant into the target universe. + Translates token ids and logits between assistant and target model vocabularies. This class is used to handle + vocabulary mismatches when using different tokenizers for the assistant and target models in speculative decoding, + as introduced in the paper "Lossless Speculative Decoding Algorithms for Heterogeneous Vocabularies" + (https://www.arxiv.org/abs/2502.05202). + It maintains mappings between the two vocabularies and handles token/logit conversion. + + Args: + target_tokenizer (`PreTrainedTokenizerBase`): + The tokenizer used by the target (main) model. + assistant_tokenizer (`PreTrainedTokenizerBase`): + The tokenizer used by the assistant model. + assistant_model_device (`str`, defaults to "cpu"): + The device where the assistant model is located. Used for placing tensors. + target_vocab_size (`int`, *optional*): + The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer. """ + FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits. + SUPPRESS_TOKEN_ID: int = -1 # The ID used to mark suppressed tokens in the mapping. def __init__( self, @@ -628,8 +644,6 @@ def __init__( assistant_tokenizer: "PreTrainedTokenizerBase", assistant_model_device: str = "cpu", target_vocab_size: Optional[int] = None, - filter_value: float = -float("Inf"), - suppress_tokens_id: int = -1, ): self._target_tokenizer: "PreTrainedTokenizerBase" = target_tokenizer self._assistant_tokenizer: "PreTrainedTokenizerBase" = assistant_tokenizer @@ -638,8 +652,6 @@ def __init__( self.target_vocab_size: int = len(self._target_tokenizer.get_vocab()) else: self.target_vocab_size: int = target_vocab_size - self.filter_value: float = filter_value - self.suppress_tokens_id: int = suppress_tokens_id self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = ( self._get_assistant_to_target_input_ids() ) @@ -677,7 +689,7 @@ def _get_assistant_to_target_input_ids(self): } max_assistant_index = max(assistant_vocab.values()) - assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.suppress_tokens_id, dtype=int) + assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=int) target_to_assistant_input_ids: Dict[int, int] = {} for tok, assistant_id in assistant_vocab.items(): target_id = target_vocab.get(tok) @@ -690,7 +702,7 @@ def _get_suppress_input_ids(self) -> list[int]: """ Get the input ids that are in the assistant vocab but not in the target vocab. """ - return torch.where(self._assistant_to_target_input_ids == self.suppress_tokens_id)[0] + return torch.where(self._assistant_to_target_input_ids == self.SUPPRESS_TOKEN_ID)[0] def get_target_ids( self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor @@ -714,9 +726,9 @@ def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatT """ target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size) - target_logits: torch.FloatTensor = torch.full(target_shape, self.filter_value).to(self._assistant_model_device) + target_logits: torch.FloatTensor = torch.full(target_shape, self.FILTER_VALUE).to(self._assistant_model_device) # Mask for valid indices - assistant_indices_mask = self._assistant_to_target_input_ids != self.suppress_tokens_id + assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID # Exclude invalid indices target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask] valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]] diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 578946ce8a94..e8048ee4c92a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -855,33 +855,32 @@ def _get_candidate_generator( max_length=generation_config.max_length, ) elif different_tokenizers: - match generation_config.do_sample: - case True: - candidate_generator = UniversalSpeculativeDecodingGenerator( - input_ids=input_ids, - assistant_model=assistant_model, - generation_config=generation_config, - model_kwargs=model_kwargs, - inputs_tensor=inputs_tensor, - logits_processor=logits_processor, - target_tokenizer=target_tokenizer, - assistant_tokenizer=assistant_tokenizer, - # required in the case that self.config.vocab_size is different from the length of target_tokenizer.get_vocab() - target_vocab_size=self.config.vocab_size, - ) - case False: - candidate_generator = AssistedCandidateGeneratorDifferentTokenizers( - input_ids=input_ids, - assistant_model=assistant_model, - generation_config=generation_config, - model_kwargs=model_kwargs, - inputs_tensor=inputs_tensor, - logits_processor=logits_processor, - target_tokenizer=target_tokenizer, - assistant_tokenizer=assistant_tokenizer, - ) - case _: - raise ValueError(f"Invalid value for `do_sample`: {generation_config.do_sample}") + if generation_config.do_sample is True: + candidate_generator = UniversalSpeculativeDecodingGenerator( + input_ids=input_ids, + assistant_model=assistant_model, + generation_config=generation_config, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + logits_processor=logits_processor, + target_tokenizer=target_tokenizer, + assistant_tokenizer=assistant_tokenizer, + # required in the case that self.config.vocab_size is different from the length of target_tokenizer.get_vocab() + target_vocab_size=self.config.vocab_size, + ) + elif generation_config.do_sample is False: + candidate_generator = AssistedCandidateGeneratorDifferentTokenizers( + input_ids=input_ids, + assistant_model=assistant_model, + generation_config=generation_config, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + logits_processor=logits_processor, + target_tokenizer=target_tokenizer, + assistant_tokenizer=assistant_tokenizer, + ) + else: + raise ValueError(f"Invalid value for `do_sample`: expected a boolean, got {type(generation_config.do_sample).__name__}") else: candidate_generator = AssistedCandidateGenerator( input_ids=input_ids, diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py index 6239c9268007..54ce3b3ee1e2 100644 --- a/tests/generation/test_candidate_generator.py +++ b/tests/generation/test_candidate_generator.py @@ -39,7 +39,7 @@ def setUp(self): def test_get_assistant_to_target_input_ids(self): """Test the mapping from assistant tokens to target tokens.""" - expected_mapping = [0, 1, 2, self.translator.suppress_tokens_id, self.translator.suppress_tokens_id] + expected_mapping = [0, 1, 2, self.translator.SUPPRESS_TOKEN_ID, self.translator.SUPPRESS_TOKEN_ID] actual_mapping = self.translator._assistant_to_target_input_ids.tolist() self.assertEqual(actual_mapping, expected_mapping) @@ -56,7 +56,7 @@ def test_get_target_ids(self): assistant_candidate_ids = torch.LongTensor([[0, 1, 2, 4]]) # 'hello world foo baz' in assistant tokenizer expected_target_ids = torch.LongTensor( - [[0, 1, 2, self.translator.suppress_tokens_id]] + [[0, 1, 2, self.translator.SUPPRESS_TOKEN_ID]] ) # 'hello world foo baz' in target tokenizer (baz is mapped to self.translator.suppress_tokens_id since it does not exist in target vocab) actual_target_ids = self.translator.get_target_ids( @@ -67,10 +67,10 @@ def test_get_target_ids(self): def test_get_target_logits(self): """Test the conversion of assistant logits to target logits.""" # Assistant logits for IDs 0, 1, 2 - assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.filter_value]]]) # Shape (1, 1, 5) + assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.FILTER_VALUE]]]) # Shape (1, 1, 5) # Expected target logits (target_vocab_size = 4) - expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.filter_value) + expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.FILTER_VALUE) expected_target_logits[0, 0, 0] = 0.1 # 'hello' expected_target_logits[0, 0, 1] = 0.2 # 'world' expected_target_logits[0, 0, 2] = 0.3 # 'foo'