From bc621cf84a9a45731eba8ebe1de63e37f4c55369 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Tue, 5 Nov 2024 16:14:50 +0100 Subject: [PATCH 1/2] Revert "Revert "Fix Whisper CI" (#34605)" This reverts commit 74d3824cc0725829e7d92e1d43b97be1f18454f8. --- src/transformers/generation/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6e6d5b8bdce7..53cd2df3a49c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1452,10 +1452,11 @@ def _prepare_generated_length( ): generation_config.max_length -= inputs_tensor.shape[1] elif has_default_max_length: # by default let's always generate 20 new tokens - generation_config.max_length = generation_config.max_length + input_ids_length - max_position_embeddings = getattr(self.config, "max_position_embeddings", None) - if max_position_embeddings is not None: - generation_config.max_length = min(generation_config.max_length, max_position_embeddings) + if generation_config.max_length == GenerationConfig().max_length: + generation_config.max_length = generation_config.max_length + input_ids_length + max_position_embeddings = getattr(self.config, "max_position_embeddings", None) + if max_position_embeddings is not None: + generation_config.max_length = min(generation_config.max_length, max_position_embeddings) # same for min length if generation_config.min_new_tokens is not None: From 1f74ddab3c683f9ff9e2b22b2e0fdccb09e63b35 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Tue, 5 Nov 2024 16:42:18 +0100 Subject: [PATCH 2/2] update --- src/transformers/generation/flax_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 88535b44e9c4..8e87ead7fdd5 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -398,7 +398,11 @@ def generate( ) generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length else: # by default let's always generate 10 new tokens - generation_config.max_length = generation_config.max_length + input_ids_seq_length + if generation_config.max_length == GenerationConfig().max_length: + generation_config.max_length = generation_config.max_length + input_ids_seq_length + max_position_embeddings = getattr(self.config, "max_position_embeddings", None) + if max_position_embeddings is not None: + generation_config.max_length = min(generation_config.max_length, max_position_embeddings) if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError(