diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 88535b44e9c4..8e87ead7fdd5 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -398,7 +398,11 @@ def generate( ) generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length else: # by default let's always generate 10 new tokens - generation_config.max_length = generation_config.max_length + input_ids_seq_length + if generation_config.max_length == GenerationConfig().max_length: + generation_config.max_length = generation_config.max_length + input_ids_seq_length + max_position_embeddings = getattr(self.config, "max_position_embeddings", None) + if max_position_embeddings is not None: + generation_config.max_length = min(generation_config.max_length, max_position_embeddings) if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: raise ValueError( diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6e6d5b8bdce7..53cd2df3a49c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1452,10 +1452,11 @@ def _prepare_generated_length( ): generation_config.max_length -= inputs_tensor.shape[1] elif has_default_max_length: # by default let's always generate 20 new tokens - generation_config.max_length = generation_config.max_length + input_ids_length - max_position_embeddings = getattr(self.config, "max_position_embeddings", None) - if max_position_embeddings is not None: - generation_config.max_length = min(generation_config.max_length, max_position_embeddings) + if generation_config.max_length == GenerationConfig().max_length: + generation_config.max_length = generation_config.max_length + input_ids_length + max_position_embeddings = getattr(self.config, "max_position_embeddings", None) + if max_position_embeddings is not None: + generation_config.max_length = min(generation_config.max_length, max_position_embeddings) # same for min length if generation_config.min_new_tokens is not None: