From 85d14a049517ec0482cae0b4d3f18178d30622cc Mon Sep 17 00:00:00 2001 From: jmamou Date: Thu, 28 Nov 2024 06:13:57 -0800 Subject: [PATCH 1/2] fix UAG --- src/transformers/generation/candidate_generator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index ccd4b5508410..84931e58ee87 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -453,15 +453,17 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length, vocabulary_size)` containing the logits associated to each candidate. """ + max_new_tokens = int(self.num_assistant_tokens) + if max_new_tokens == 0: # TODO + return input_ids, None + input_ids = input_ids.to(self.assistant_model.device) remove_from_pkv = 0 assistant_input_ids, remove_from_pkv = self._prepare_assistant_input_ids(input_ids) self.prev_assistant_ids = assistant_input_ids - min_new_tokens, max_new_tokens = self._calculate_new_tokens(assistant_input_ids) - if max_new_tokens == 0: - return input_ids, None + min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - assistant_input_ids.shape[-1]), 0) self._update_past_and_masks(assistant_input_ids, remove_from_pkv) generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens) From 145ac612e4b5b9872ce50abbc3481a4e59c0707a Mon Sep 17 00:00:00 2001 From: jmamou Date: Thu, 28 Nov 2024 06:17:03 -0800 Subject: [PATCH 2/2] minor --- src/transformers/generation/candidate_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 84931e58ee87..138aa9144e1d 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -454,7 +454,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, vocabulary_size)` containing the logits associated to each candidate. """ max_new_tokens = int(self.num_assistant_tokens) - if max_new_tokens == 0: # TODO + if max_new_tokens == 0: return input_ids, None input_ids = input_ids.to(self.assistant_model.device)