Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,20 +1118,25 @@ def from_model_config(cls, model_config: PreTrainedConfig | dict) -> "Generation

# Removes all `None` from the model config dict -- this lets the generation config defaults to take hold
config_dict = {key: value for key, value in config_dict.items() if value is not None}

generation_config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)

# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
# generation config (which in turn is defined from the outer attributes of model config).
if not isinstance(model_config, dict):
decoder_config = model_config.get_text_config(decoder=True)
if decoder_config is not model_config:
default_generation_config = GenerationConfig()
decoder_config_dict = decoder_config.to_dict()
for attr in generation_config.to_dict():
is_unset = getattr(generation_config, attr) == getattr(default_generation_config, attr)
if attr in decoder_config_dict and is_unset:
setattr(generation_config, attr, decoder_config_dict[attr])
if isinstance(model_config, dict):
decoder_possible_text_config_names = ("decoder", "generator", "text_config")
for text_config_name in decoder_possible_text_config_names:
if text_config := model_config.get(text_config_name):
model_config = text_config
break
else:
model_config = model_config.get_text_config(decoder=True)
model_config = model_config.to_dict()

default_generation_config = GenerationConfig()
for attr in generation_config.to_dict():
is_unset = getattr(generation_config, attr) == getattr(default_generation_config, attr)
if attr in model_config and is_unset:
setattr(generation_config, attr, model_config[attr])

# If any `output_...` flag is set to `True`, we ensure `return_dict_in_generate` is set to `True`.
if generation_config.return_dict_in_generate is False:
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,9 @@ def adjust_generation_fn(
**repo_loading_kwargs,
)
except OSError:
# `self` already has a generation config created from model config, but model config will
# not contain any generation-specific params. These are popped at config's `__init__`.
# Thus we have to load from `config.json` and create a generation config from it (for BART)
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
Expand All @@ -418,6 +421,7 @@ def adjust_generation_fn(
_from_model_config=True,
**repo_loading_kwargs,
)

# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
if hasattr(self, "load_custom_generate") and trust_remote_code:
try:
Expand Down
21 changes: 19 additions & 2 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
AutoModelForImageTextToText,
AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq,
AutoModelForVision2Seq,
BartForConditionalGeneration,
BartTokenizer,
DataCollatorWithFlattening,
Expand Down Expand Up @@ -4401,7 +4400,7 @@ def test_generate_vision2text_conditioning(self):
"""Test that `decoder_input_ids` can be used to condition the generation in vision-to-text models"""
pixel_values = floats_tensor((2, 3, 30, 30))
conditioning_input = torch.tensor([[10], [10]]) # this should be the 2nd output token, after the BOS token
model = AutoModelForVision2Seq.from_pretrained(
model = AutoModelForImageTextToText.from_pretrained(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks unrelated? Nothing to change just noticed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the mapping was deprecated in favor of AutoModelForImageTextToText, so just making sure we don't use deprecated stuff in tests

"hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2"
)
pixel_values = pixel_values.to(torch_device)
Expand All @@ -4421,6 +4420,24 @@ def test_generate_vision2text_conditioning(self):
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids, output_sequences_input_ids))
self.assertTrue(np.array_equal(output_sequences_decoder_input_ids[:, 1:2], conditioning_input))

@pytest.mark.generate
def test_load_generation_config_from_text_subconfig(self):
"""
Tests that generation config is be loaded from model's `text_config` when not present
in the model repo. We should infer the text config correctly and re-use special tokens
for generation. See https://github.com/huggingface/transformers/issues/42794
"""
model = AutoModelForImageTextToText.from_pretrained(
"hf-internal-testing/tiny-random-LlavaForConditionalGeneration-no-generation-config",
device_map=torch_device,
)
self.assertTrue(model.generation_config.eos_token_id is not None)
self.assertTrue(model.generation_config.bos_token_id is not None)
self.assertTrue(model.generation_config.pad_token_id is not None)

# test that we can generate without inputs, i.e. from BOS
_ = model.generate()

@require_read_token
@slow
@require_torch_accelerator
Expand Down