diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 212f560ff433..1da9e21f012e 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -201,45 +201,15 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, vocabulary_size)` containing the logits associated to each candidate. """ input_ids = input_ids.to(self.assistant_model.device) - - # Don't generate more than `max_length - 1` candidates since the target model generates one extra token. - new_cur_len = input_ids.shape[-1] - max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1) - min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0) + # Calculate new tokens to generate + min_new_tokens, max_new_tokens = self._calculate_new_tokens(input_ids) if max_new_tokens == 0: return input_ids, None - - # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length - # (which implicitly contains the number of accepted candidates from the previous round) - has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None - if has_past_key_values: - new_cache_size = new_cur_len - 1 - self.assistant_kwargs["past_key_values"] = _crop_past_key_values( - self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 - ) # the assistant does not have the token after the last match, hence the -1 - - self.assistant_kwargs = _prepare_attention_mask( - self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder - ) - self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len) - - # 2. Forecast next N tokens using the assistant model. - assistant_generation_kwargs = { - self.input_ids_key: input_ids, - "min_new_tokens": min_new_tokens, - "max_new_tokens": max_new_tokens, - "generation_config": self.generation_config, - "logits_processor": self.logits_processor, - } - - assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs) - - # 3. Update variables for the next round of candidate generation - self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values - - # 4. Prepare variables for output - candidate_logits = torch.stack(assistant_output.scores, dim=1) - candidate_ids = assistant_output.sequences + # Update past key values and masks + self._update_past_and_masks(input_ids) + # Generate candidates + generation_args = self._prepare_generation_args(input_ids, min_new_tokens, max_new_tokens) + candidate_ids, candidate_logits = self._generate_candidates(generation_args) return candidate_ids, candidate_logits def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): @@ -268,6 +238,45 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F else: self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0) + def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> Tuple[int, int]: + """Calculate the minimum and maximum number of new tokens to generate.""" + new_cur_len = input_ids.shape[-1] + max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1) + min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0) + return min_new_tokens, max_new_tokens + + def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool: + """Update past key values and attention masks for subsequent generation rounds.""" + has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None + if has_past_key_values: + new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv + self.assistant_kwargs["past_key_values"] = _crop_past_key_values( + self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 + ) + self.assistant_kwargs = _prepare_attention_mask( + self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder + ) + self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1]) + return has_past_key_values + + def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict: + """Prepare arguments for the generation call.""" + return { + self.input_ids_key: input_ids, + "min_new_tokens": min_new_tokens, + "max_new_tokens": max_new_tokens, + "generation_config": self.generation_config, + "logits_processor": self.logits_processor, + } + + def _generate_candidates(self, generation_args: Dict) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: + """Generate candidate sequences using the assistant model.""" + assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs) + self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values + candidate_logits = torch.stack(assistant_output.scores, dim=1) + candidate_ids = assistant_output.sequences + return candidate_ids, candidate_logits + class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator): """ @@ -321,6 +330,7 @@ def __init__( self.prev_assistant_ids = None self.target_lookbehind = 10 self.assistant_lookbehind = 10 + self.prev_target_ids = None @staticmethod def _get_longest_diag_dict(input_matrix, nonzero_idx): @@ -443,20 +453,41 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length, vocabulary_size)` containing the logits associated to each candidate. """ - max_new_tokens = int(self.num_assistant_tokens) + input_ids = input_ids.to(self.assistant_model.device) + remove_from_pkv = 0 + + assistant_input_ids, remove_from_pkv = self._prepare_assistant_input_ids(input_ids) + self.prev_assistant_ids = assistant_input_ids + + min_new_tokens, max_new_tokens = self._calculate_new_tokens(assistant_input_ids) if max_new_tokens == 0: return input_ids, None - input_ids = input_ids.to(self.assistant_model.device) + self._update_past_and_masks(assistant_input_ids, remove_from_pkv) + generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) + self.assistant_kwargs.pop("attention_mask", None) + + assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs) + new_target_ids = self._process_assistant_outputs(input_ids, assistant_output.sequences, assistant_input_ids) + + # Update state + self.prev_target_ids = input_ids + self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values + self.prev_tokens = assistant_output.sequences + + if input_ids.shape[1] >= new_target_ids.shape[1]: + return input_ids, None + + return new_target_ids, None + + def _prepare_assistant_input_ids(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, int]: + """Converts target input IDs to assistant input IDs, handling discrepancies.""" convert_kwargs = { "source_tokenizer": self.target_tokenizer, "destination_tokenizer": self.assistant_tokenizer, } remove_from_pkv = 0 - # Since re-encoding the tokens may result in tokenization discrepancies, we use 2 look behind values - # (one for each conversion) which mark where to start looking for the overlap between the - # source and target encodings, to ensure the new tokens include the correct prompt suffix. if self.prev_tokens is not None and self.prev_target_ids.shape[1] > self.target_lookbehind: # input_ids contains all target prompt input ids and some new target input ids start_index_in_target_window = self.prev_target_ids.shape[1] - self.target_lookbehind @@ -467,8 +498,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, prompt_use_length = new_assistant_ids.shape[1] prompt_use = self.prev_assistant_ids[:, -prompt_use_length:] - discrepancy_length, new_tokens_only, discrepancy_only = ( - AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt_use, new_assistant_ids) + discrepancy_length, new_tokens_only, discrepancy_only = self._get_tokens_diag( + prompt_use, new_assistant_ids ) assistant_input_ids = self.prev_assistant_ids @@ -489,47 +520,21 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, else: # edge case: in case of no intersection between prompt and new_assistant_ids assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1) - else: assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs) self.prev_target_ids = input_ids - self.prev_assistant_ids = assistant_input_ids - new_cur_len = assistant_input_ids.shape[-1] - min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0) - - # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length - # (which implicitly contains the number of accepted candidates from the previous round) - has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None - if has_past_key_values: - new_cache_size = new_cur_len - 1 - remove_from_pkv - self.assistant_kwargs["past_key_values"] = _crop_past_key_values( - self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 - ) # the assistant does not have the token after the last match, hence the -1 - - self.assistant_kwargs = _prepare_attention_mask( - self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder - ) - self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len) - - # 2. Forecast next N tokens using the assistant model. - assistant_generation_kwargs = { - self.input_ids_key: assistant_input_ids, - "min_new_tokens": min_new_tokens, - "max_new_tokens": max_new_tokens, - "generation_config": self.generation_config, - "logits_processor": self.logits_processor, - } - - self.assistant_kwargs.pop("attention_mask", None) - - assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs) + return assistant_input_ids, remove_from_pkv + def _process_assistant_outputs( + self, input_ids: torch.LongTensor, assistant_sequences: torch.LongTensor, assistant_input_ids: torch.LongTensor + ) -> torch.LongTensor: + """Processes assistant outputs to obtain target input IDs.""" num_prev_assistant = self.prev_assistant_ids.shape[1] start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind new_target_ids_from_window = self.convert_source_tokens_to_target_tokens( - assistant_output.sequences[:, start_assistant_look_index:], + assistant_sequences[:, start_assistant_look_index:], source_tokenizer=self.assistant_tokenizer, destination_tokenizer=self.target_tokenizer, ) @@ -537,9 +542,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, target_prompt_use = input_ids[:, -target_prompt_use_length:] - _, target_new_tokens_only, _ = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( - target_prompt_use, new_target_ids_from_window - ) + _, target_new_tokens_only, _ = self._get_tokens_diag(target_prompt_use, new_target_ids_from_window) new_target_ids = input_ids @@ -550,20 +553,10 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, # edge case: in case of no intersection between prompt and new_target_ids new_target_ids = torch.cat([new_target_ids, new_target_ids_from_window], dim=-1) - self.prev_target_ids = input_ids - if hasattr(self.generation_config, "max_length"): new_target_ids = new_target_ids[:, : self.generation_config.max_length] - # 3. Update variables for the next round of candidate generation - self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values - self.prev_tokens = assistant_output.sequences - - # 4. Prepare variables for output - if input_ids.shape[1] >= new_target_ids.shape[1]: - return input_ids, None - - return new_target_ids, None + return new_target_ids class AssistantToTargetTranslator: @@ -687,28 +680,7 @@ def cleanup(cls): class UniversalSpeculativeDecodingGenerator(AssistedCandidateGeneratorDifferentTokenizers): """ `CandidateGenerator` class to be used for Universal Speculative Decoding (USD): speculative decoding with different tokenizers - for the assistant and main models. This class generates candidates through the use of a smaller - model. - - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) - assistant_model (`PreTrainedModel`): - The model to be used for generating candidates. This model should be smaller than the main model. - target_tokenizer (`PreTrainedTokenizerBase`): - The tokenizer used for the target model. - assistant_tokenizer (`PreTrainedTokenizerBase`): - The tokenizer used for the assistant model. - generation_config (`~generation.GenerationConfig`, *optional*): - The generation configuration to be used as base parametrization for the generation call. - logits_processor (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - model_kwargs (`Dict`): - The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant - model as well. - inputs_tensor (`torch.Tensor`, *optional*): - The model input tensor. In encoder-decoder models, this is the encoder input. + for the assistant and main models. This class generates candidates through the use of a smaller model. """ def __init__( @@ -722,6 +694,7 @@ def __init__( inputs_tensor: Optional[torch.Tensor] = None, logits_processor: "LogitsProcessorList" = None, ): + # Initialize translator before parent class self._atm_translator = AssistantVocabTranslatorCache.get_translator(target_tokenizer, assistant_tokenizer) super().__init__( input_ids, @@ -733,99 +706,75 @@ def __init__( inputs_tensor, logits_processor, ) + # Track sequence lengths and previous assistant IDs self._prev_target_seq_len: int = 0 - self._prev_assistant_ids: torch.LongTensor | None = None + self._prev_assistant_ids: Optional[torch.LongTensor] = None def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: """ - Fetches the candidates to be tried for the current input. - - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) - - Return: - `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be - assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length, - vocabulary_size)` containing the logits associated to each candidate. + Simplified version of get_candidates that uses the translator cache for token conversion. """ - has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None - - def get_assistant_input_ids(target_input_ids: torch.LongTensor) -> torch.LongTensor: - nonlocal has_past_key_values - target_seq_len = target_input_ids.shape[-1] - target_new_ids = target_input_ids[:, -(target_seq_len - self._prev_target_seq_len) :] - self._prev_target_seq_len = target_seq_len - # Convert target_new_ids to string - target_new_toks = self.target_tokenizer.batch_decode( - target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True - ) - # Convert the string to assistant_new_ids - assistant_new_ids = self.assistant_tokenizer( - target_new_toks, add_special_tokens=False, return_tensors="pt" - )["input_ids"] - if self._prev_assistant_ids is None: - self._prev_assistant_ids = assistant_new_ids - else: - self._prev_assistant_ids = torch.cat(self._prev_assistant_ids, assistant_new_ids, dim=-1) - return self._prev_assistant_ids.to(self.assistant_model.device) - + input_ids = input_ids.to(self.assistant_model.device) target_input_ids = input_ids.clone() - input_ids = get_assistant_input_ids(input_ids) + assistant_input_ids = self._prepare_assistant_input_ids(target_input_ids) - # Don't generate more than `max_length - 1` candidates since the target model generates one extra token. - new_cur_len = input_ids.shape[-1] - max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1) - # TODO: Debug - # min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0) + # Standard generation steps + min_new_tokens, max_new_tokens = self._calculate_new_tokens(assistant_input_ids) if max_new_tokens == 0: return input_ids, None - # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length - # (which implicitly contains the number of accepted candidates from the previous round) - # has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None - if has_past_key_values: - new_cache_size = new_cur_len - 1 - self.assistant_kwargs["past_key_values"] = _crop_past_key_values( - self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 - ) # the assistant does not have the token after the last match, hence the -1 - self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len) + self._update_past_and_masks(assistant_input_ids) + generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) - # we need to update the attention mask to reflect the new input_ids length - self.assistant_kwargs = _prepare_attention_mask( - self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder - ) - - # 2. Forecast next N tokens using the assistant model. - assistant_generation_kwargs = { - self.input_ids_key: input_ids, - # "min_new_tokens": min_new_tokens, - # "max_new_tokens": max_new_tokens, - "min_new_tokens": 100, - "max_new_tokens": 100, - "generation_config": self.generation_config, - "logits_processor": self.logits_processor, - } - - assistant_output = self.assistant_model.generate( - **assistant_generation_kwargs, **self.assistant_kwargs, output_logits=True - ) + # Ensure scores are returned + generation_args["generation_config"].output_scores = True + generation_args["generation_config"].return_dict_in_generate = True - # 3. Update variables for the next round of candidate generation + # Generate and process outputs using translator + assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs) self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values - # 4. Prepare variables for output - candidate_logits = torch.stack(assistant_output.logits, dim=1) - if not candidate_logits.shape[1] > 1: - msg = f"Since we set min_new_tokens to {assistant_generation_kwargs['min_new_tokens']} and max_new_tokens to {assistant_generation_kwargs['max_new_tokens']}, we expect at least 2 candidates, but seems like we got {candidate_logits.shape[1]} candidates." - raise Exception(msg) + candidate_logits = torch.stack(assistant_output.scores, dim=1) + if candidate_logits.shape[1] <= 1: + raise ValueError( + f"Expected at least 2 candidate tokens, but got {candidate_logits.shape[1]}. " + f"min_new_tokens: {generation_args['min_new_tokens']}, max_new_tokens: {generation_args['max_new_tokens']}." + ) + + # Use translator to convert tokens and logits candidate_ids = assistant_output.sequences candidate_logits = self._atm_translator.logits_processors(input_ids=candidate_ids, scores=candidate_logits) - target_ids = self._atm_translator.get_target_ids(input_ids, target_input_ids, candidate_ids) - + target_ids = self._atm_translator.get_target_ids(assistant_input_ids, target_input_ids, candidate_ids) target_logits = self._atm_translator.get_target_logits(candidate_logits) + return target_ids, target_logits + def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> torch.LongTensor: + """ + Simplified token conversion that only processes new tokens. + """ + # Calculate new tokens since last call + target_seq_len = target_input_ids.shape[-1] + new_token_count = target_seq_len - self._prev_target_seq_len + target_new_ids = target_input_ids[:, -new_token_count:] + self._prev_target_seq_len = target_seq_len + + # Convert only the new tokens + target_new_text = self.target_tokenizer.batch_decode( + target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + assistant_new_ids = self.assistant_tokenizer( + target_new_text, add_special_tokens=False, return_tensors="pt" + )["input_ids"].to(self.assistant_model.device) + + # Update or initialize assistant IDs + if self._prev_assistant_ids is None: + self._prev_assistant_ids = assistant_new_ids + else: + self._prev_assistant_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1) + + return self._prev_assistant_ids + class PromptLookupCandidateGenerator(CandidateGenerator): """