diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 7cab88a4bc2e..d85860ad5dc4 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -194,45 +194,15 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, vocabulary_size)` containing the logits associated to each candidate. """ input_ids = input_ids.to(self.assistant_model.device) - - # Don't generate more than `max_length - 1` candidates since the target model generates one extra token. - new_cur_len = input_ids.shape[-1] - max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1) - min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0) + # Calculate new tokens to generate + min_new_tokens, max_new_tokens = self._calculate_new_tokens(input_ids) if max_new_tokens == 0: return input_ids, None - - # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length - # (which implicitly contains the number of accepted candidates from the previous round) - has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None - if has_past_key_values: - new_cache_size = new_cur_len - 1 - self.assistant_kwargs["past_key_values"] = _crop_past_key_values( - self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 - ) # the assistant does not have the token after the last match, hence the -1 - - self.assistant_kwargs = _prepare_attention_mask( - self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder - ) - self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len) - - # 2. Forecast next N tokens using the assistant model. - assistant_generation_kwargs = { - self.input_ids_key: input_ids, - "min_new_tokens": min_new_tokens, - "max_new_tokens": max_new_tokens, - "generation_config": self.generation_config, - "logits_processor": self.logits_processor, - } - - assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs) - - # 3. Update variables for the next round of candidate generation - self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values - - # 4. Prepare variables for output - candidate_logits = torch.stack(assistant_output.scores, dim=1) - candidate_ids = assistant_output.sequences + # Update past key values and masks + self._update_past_and_masks(input_ids) + # Generate candidates + generation_args = self._prepare_generation_args(input_ids, min_new_tokens, max_new_tokens) + candidate_ids, candidate_logits = self._generate_candidates(generation_args) return candidate_ids, candidate_logits def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): @@ -261,6 +231,45 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F else: self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0) + def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> Tuple[int, int]: + """Calculate the minimum and maximum number of new tokens to generate.""" + new_cur_len = input_ids.shape[-1] + max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1) + min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0) + return min_new_tokens, max_new_tokens + + def _update_past_and_masks(self, input_ids: torch.LongTensor, remove_from_pkv: int = 0) -> bool: + """Update past key values and attention masks for subsequent generation rounds.""" + has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None + if has_past_key_values: + new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv + self.assistant_kwargs["past_key_values"] = _crop_past_key_values( + self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 + ) + self.assistant_kwargs = _prepare_attention_mask( + self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder + ) + self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1]) + return has_past_key_values + + def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> Dict: + """Prepare arguments for the generation call.""" + return { + self.input_ids_key: input_ids, + "min_new_tokens": min_new_tokens, + "max_new_tokens": max_new_tokens, + "generation_config": self.generation_config, + "logits_processor": self.logits_processor, + } + + def _generate_candidates(self, generation_args: Dict) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: + """Generate candidate sequences using the assistant model.""" + assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs) + self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values + candidate_logits = torch.stack(assistant_output.scores, dim=1) + candidate_ids = assistant_output.sequences + return candidate_ids, candidate_logits + class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator): """ @@ -310,6 +319,8 @@ def __init__( self.target_tokenizer = target_tokenizer self.assistant_tokenizer = assistant_tokenizer + self.prev_target_ids = None + self.prev_tokens = None self.prev_assistant_ids = None self.target_lookbehind = assistant_model.generation_config.target_lookbehind self.assistant_lookbehind = assistant_model.generation_config.assistant_lookbehind @@ -440,18 +451,41 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, return input_ids, None input_ids = input_ids.to(self.assistant_model.device) + remove_from_pkv = 0 + + assistant_input_ids, remove_from_pkv = self._prepare_assistant_input_ids(input_ids) + self.prev_assistant_ids = assistant_input_ids + + min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - assistant_input_ids.shape[-1]), 0) + + self._update_past_and_masks(assistant_input_ids, remove_from_pkv) + generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) + self.assistant_kwargs.pop("attention_mask", None) + + assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs) + new_target_ids = self._process_assistant_outputs(input_ids, assistant_output.sequences, assistant_input_ids) + + # Update state + self.prev_target_ids = input_ids + self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values + self.prev_tokens = assistant_output.sequences + + if input_ids.shape[1] >= new_target_ids.shape[1]: + return input_ids, None + + return new_target_ids, None + + def _prepare_assistant_input_ids(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, int]: + """Converts target input IDs to assistant input IDs, handling discrepancies.""" convert_kwargs = { "source_tokenizer": self.target_tokenizer, "destination_tokenizer": self.assistant_tokenizer, } remove_from_pkv = 0 - # Since re-encoding the tokens may result in tokenization discrepancies, we use 2 look behind values - # (one for each conversion) which mark where to start looking for the overlap between the - # source and target encodings, to ensure the new tokens include the correct prompt suffix. - if self.prev_assistant_ids is not None and input_ids.shape[1] > self.target_lookbehind: + if self.prev_tokens is not None and self.prev_target_ids.shape[1] > self.target_lookbehind: # input_ids contains all target prompt input ids and some new target input ids - start_index_in_target_window = input_ids.shape[1] - self.target_lookbehind + start_index_in_target_window = self.prev_target_ids.shape[1] - self.target_lookbehind new_assistant_ids = self.convert_source_tokens_to_target_tokens( input_ids[:, start_index_in_target_window:], **convert_kwargs @@ -459,8 +493,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, prompt_use_length = new_assistant_ids.shape[1] prompt_use = self.prev_assistant_ids[:, -prompt_use_length:] - discrepancy_length, new_tokens_only, discrepancy_only = ( - AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt_use, new_assistant_ids) + discrepancy_length, new_tokens_only, discrepancy_only = self._get_tokens_diag( + prompt_use, new_assistant_ids ) assistant_input_ids = self.prev_assistant_ids @@ -481,48 +515,21 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, else: # edge case: in case of no intersection between prompt and new_assistant_ids assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1) - else: assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs) + self.prev_target_ids = input_ids - self.prev_assistant_ids = assistant_input_ids - new_cur_len = assistant_input_ids.shape[-1] - min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0) - - # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length - # (which implicitly contains the number of accepted candidates from the previous round) - has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None - if has_past_key_values: - new_cache_size = new_cur_len - 1 - remove_from_pkv - self.assistant_kwargs["past_key_values"] = _crop_past_key_values( - self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 - ) # the assistant does not have the token after the last match, hence the -1 - - self.assistant_kwargs = _prepare_attention_mask( - self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder - ) - self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len) - - # 2. Forecast next N tokens using the assistant model. - assistant_generation_kwargs = { - self.input_ids_key: assistant_input_ids, - "min_new_tokens": min_new_tokens, - "max_new_tokens": max_new_tokens, - "generation_config": self.generation_config, - "logits_processor": self.logits_processor, - } - - self.assistant_kwargs.pop("attention_mask", None) - - assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs) + return assistant_input_ids, remove_from_pkv + def _process_assistant_outputs( + self, input_ids: torch.LongTensor, assistant_sequences: torch.LongTensor, assistant_input_ids: torch.LongTensor + ) -> torch.LongTensor: + """Processes assistant outputs to obtain target input IDs.""" num_prev_assistant = self.prev_assistant_ids.shape[1] start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind - if start_assistant_look_index < 0: - start_assistant_look_index = 0 new_target_ids_from_window = self.convert_source_tokens_to_target_tokens( - assistant_output.sequences[:, start_assistant_look_index:], + assistant_sequences[:, start_assistant_look_index:], source_tokenizer=self.assistant_tokenizer, destination_tokenizer=self.target_tokenizer, ) @@ -530,9 +537,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, target_prompt_use = input_ids[:, -target_prompt_use_length:] - _, target_new_tokens_only, _ = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( - target_prompt_use, new_target_ids_from_window - ) + _, target_new_tokens_only, _ = self._get_tokens_diag(target_prompt_use, new_target_ids_from_window) new_target_ids = input_ids @@ -546,14 +551,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, if hasattr(self.generation_config, "max_length"): new_target_ids = new_target_ids[:, : self.generation_config.max_length] - # 3. Update variables for the next round of candidate generation - self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values - - # 4. Prepare variables for output - if input_ids.shape[1] >= new_target_ids.shape[1]: - return input_ids, None - - return new_target_ids, None + return new_target_ids class PromptLookupCandidateGenerator(CandidateGenerator): diff --git a/tests/generation/test_candidate_generator.py b/tests/generation/test_candidate_generator.py new file mode 100644 index 000000000000..03fd51324b02 --- /dev/null +++ b/tests/generation/test_candidate_generator.py @@ -0,0 +1,43 @@ +import unittest + +import numpy as np + +from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers + + +class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase): + def test_no_intersection(self): + prompt = np.array([[1, 2, 3]]) + prompt_plus_new_tokens = np.array([[4, 5, 6]]) + result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens) + self.assertEqual(result, (None, None, None)) + + def test_complete_overlap(self): + prompt = np.array([[1, 2, 3]]) + prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]]) + discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( + prompt, prompt_plus_new_tokens + ) + self.assertEqual(discrep_length, 0) + np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]])) + np.testing.assert_array_equal(discrep_only, np.array([[]])) + + def test_partial_overlap(self): + prompt = np.array([[1, 2, 3]]) + prompt_plus_new_tokens = np.array([[2, 3, 4, 5]]) + discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( + prompt, prompt_plus_new_tokens + ) + self.assertEqual(discrep_length, 0) + np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]])) + np.testing.assert_array_equal(discrep_only, np.array([[]])) + + def test_no_new_tokens(self): + prompt = np.array([[1, 2, 3]]) + prompt_plus_new_tokens = np.array([[1, 2, 3]]) + discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( + prompt, prompt_plus_new_tokens + ) + self.assertEqual(discrep_length, 0) + np.testing.assert_array_equal(new_tokens_only, np.array([[]])) + np.testing.assert_array_equal(discrep_only, np.array([[]])) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 063e9a3da8fd..86d7c0b198c0 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -92,7 +92,6 @@ WatermarkDetector, WatermarkingConfig, ) - from transformers.generation.candidate_generator import AssistedCandidateGeneratorDifferentTokenizers from transformers.generation.utils import _speculative_sampling @@ -4274,41 +4273,3 @@ def test_generate_from_inputs_embeds_with_bos_token_id_is_none(self): # bos_token_id is required when no input ids nor inputs_embeds is passed with self.assertRaises(ValueError): model.generate(max_length=20, bos_token_id=None) - - -class TestAssistedCandidateGeneratorDifferentTokenizers(unittest.TestCase): - def test_no_intersection(self): - prompt = np.array([[1, 2, 3]]) - prompt_plus_new_tokens = np.array([[4, 5, 6]]) - result = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt, prompt_plus_new_tokens) - self.assertEqual(result, (None, None, None)) - - def test_complete_overlap(self): - prompt = np.array([[1, 2, 3]]) - prompt_plus_new_tokens = np.array([[1, 2, 3, 4, 5]]) - discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( - prompt, prompt_plus_new_tokens - ) - self.assertEqual(discrep_length, 0) - np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]])) - np.testing.assert_array_equal(discrep_only, np.array([[]])) - - def test_partial_overlap(self): - prompt = np.array([[1, 2, 3]]) - prompt_plus_new_tokens = np.array([[2, 3, 4, 5]]) - discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( - prompt, prompt_plus_new_tokens - ) - self.assertEqual(discrep_length, 0) - np.testing.assert_array_equal(new_tokens_only, np.array([[4, 5]])) - np.testing.assert_array_equal(discrep_only, np.array([[]])) - - def test_no_new_tokens(self): - prompt = np.array([[1, 2, 3]]) - prompt_plus_new_tokens = np.array([[1, 2, 3]]) - discrep_length, new_tokens_only, discrep_only = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag( - prompt, prompt_plus_new_tokens - ) - self.assertEqual(discrep_length, 0) - np.testing.assert_array_equal(new_tokens_only, np.array([[]])) - np.testing.assert_array_equal(discrep_only, np.array([[]]))