diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index ccd4b5508410..138aa9144e1d 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -453,15 +453,17 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length, vocabulary_size)` containing the logits associated to each candidate. """ + max_new_tokens = int(self.num_assistant_tokens) + if max_new_tokens == 0: + return input_ids, None + input_ids = input_ids.to(self.assistant_model.device) remove_from_pkv = 0 assistant_input_ids, remove_from_pkv = self._prepare_assistant_input_ids(input_ids) self.prev_assistant_ids = assistant_input_ids - min_new_tokens, max_new_tokens = self._calculate_new_tokens(assistant_input_ids) - if max_new_tokens == 0: - return input_ids, None + min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - assistant_input_ids.shape[-1]), 0) self._update_past_and_masks(assistant_input_ids, remove_from_pkv) generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens)