Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,8 +1440,11 @@ def _prepare_generated_length(
and not self.config.is_encoder_decoder
):
generation_config.max_length -= inputs_tensor.shape[1]
else: # by default let's always generate 10 new tokens
elif has_default_max_length: # by default let's always generate 20 new tokens
generation_config.max_length = generation_config.max_length + input_ids_length
max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
if max_position_embeddings is not None:
generation_config.max_length = min(generation_config.max_length, max_position_embeddings)

# same for min length
if generation_config.min_new_tokens is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,9 @@ def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config

# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
input_ids,
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
max_length=decoder_config.max_length,
)
self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ def check_encoder_decoder_model_generate(

# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
inputs,
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
max_length=decoder_config.max_length,
)
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,9 @@ def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_val

# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
inputs,
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
max_length=decoder_config.max_length,
)
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))

Expand Down Expand Up @@ -873,6 +875,7 @@ def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_val
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
max_length=decoder_config.max_length,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
Expand Down Expand Up @@ -990,6 +993,7 @@ def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_val
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
max_length=decoder_config.max_length,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
Expand Down Expand Up @@ -1107,6 +1111,7 @@ def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_val
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
max_length=decoder_config.max_length,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
Expand Down