diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 5ba09a9d518d..80940e3b7e0c 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -628,7 +628,8 @@ def __init__( target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", assistant_model_device: str = "cpu", - target_vocab_size: int = None, + default_AssistantToTargetTranslator + target_vocab_size: Optional[int] = None, filter_value: float = -float("Inf"), suppress_tokens_id: int = -1, ): @@ -642,9 +643,13 @@ def __init__( self.filter_value: float = filter_value self.suppress_tokens_id: int = suppress_tokens_id self._assistant_to_target_input_ids = self._get_assistant_to_target_input_ids() - self.logits_processors: LogitsProcessorList = LogitsProcessorList( - [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] - ) + self._suppress_input_ids: list[int] = self._get_suppress_input_ids() + self.logits_processors: Optional[LogitsProcessorList] = None + if len(self._suppress_input_ids) > 0: + # len(self._suppress_input_ids) = 0 if the assistant vocab is a subset of the target vocab + self.logits_processors = LogitsProcessorList( + [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)] + ) def _get_assistant_to_target_input_ids(self): target_vocab = self._target_tokenizer.get_vocab() @@ -733,7 +738,7 @@ def get_translator( target_tokenizer: "PreTrainedTokenizerBase", assistant_tokenizer: "PreTrainedTokenizerBase", assistant_model_device: str = "cpu", - target_vocab_size: int = None, + target_vocab_size: Optional[int] = None, ) -> AssistantToTargetTranslator: with cls._lock: assistant_dict = cls._cache.get(target_tokenizer) @@ -826,7 +831,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, generation_args["generation_config"].return_dict_in_generate = True # Generate and process outputs using translator - generation_args["logits_processor"] = self._atm_translator.logits_processors + if self._atm_translator.logits_processors is not None: + generation_args["logits_processor"] = self._atm_translator.logits_processors self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args) # Use translator to convert tokens and logits