Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 9 additions & 35 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,9 @@ def __init__(
self.sep_token_id = sep_token_id
self.decoder_start_token_id = decoder_start_token_id

# Retrocompatibility: Parameters for sequence generation. While we will keep the ability to load these
# parameters, saving them will be deprecated. In a distant future, we won't need to load them.
for parameter_name, default_value in self._get_global_generation_defaults().items():
setattr(self, parameter_name, kwargs.pop(parameter_name, default_value))
# Parameters for sequence generation saved in the config are popped instead of loading them.
for parameter_name in self._get_global_generation_defaults().keys():
kwargs.pop(parameter_name, None)

# Name or path to the pretrained checkpoint
self._name_or_path = str(kwargs.pop("name_or_path", ""))
Expand Down Expand Up @@ -445,14 +444,11 @@ def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool =

non_default_generation_parameters = self._get_non_default_generation_parameters()
if len(non_default_generation_parameters) > 0:
# TODO (joao): this should be an exception if the user has modified the loaded config. See #33886
warnings.warn(
raise ValueError(
"Some non-default generation parameters are set in the model config. These should go into either a) "
"`model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file "
"(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model)."
"This warning will become an exception in the future."
f"\nNon-default generation parameters: {str(non_default_generation_parameters)}",
UserWarning,
)

os.makedirs(save_directory, exist_ok=True)
Expand Down Expand Up @@ -1101,40 +1097,18 @@ def _get_non_default_generation_parameters(self) -> dict[str, Any]:
non_default_generation_parameters = {}
decoder_attribute_name = None

# Some composite models don't have a default config, use their decoder config as a fallback for default values
# If no known pattern is matched, then `default_config = None` -> check against the global generation defaults
if not self.has_no_defaults_at_init:
default_config = self.__class__()
else:
decoder_config = self.get_text_config(decoder=True)
if decoder_config is not self:
default_config = decoder_config.__class__()
else:
default_config = None

# If it is a composite model, we want to check the subconfig that will be used for generation
self_decoder_config = self if decoder_attribute_name is None else getattr(self, decoder_attribute_name)

for parameter_name, default_global_value in self._get_global_generation_defaults().items():
if hasattr(self_decoder_config, parameter_name):
is_default_in_config = is_default_generation_value = None
parameter_value = getattr(self_decoder_config, parameter_name)
# Three cases in which is okay for the model config to hold generation config parameters:
parameter_value = getattr(self_decoder_config, parameter_name, None)
# Two cases in which is okay for the model config to hold generation config parameters:
# 1. The parameter is set to `None`, effectively delegating its value to the generation config
if parameter_value is None:
# 2. The parameter is set the global generation defaults
if parameter_value is None or parameter_value == default_global_value:
continue
# 2. If we have a default config, then the instance should hold the same generation defaults
if default_config is not None:
is_default_in_config = parameter_value == getattr(default_config, parameter_name)
# 3. if we don't have a default config, then the instance should hold the global generation defaults
else:
is_default_generation_value = parameter_value == default_global_value

is_non_default = (is_default_in_config is False) or (
is_default_in_config is None and is_default_generation_value is False
)
if is_non_default:
non_default_generation_parameters[parameter_name] = getattr(self_decoder_config, parameter_name)
non_default_generation_parameters[parameter_name] = getattr(self_decoder_config, parameter_name)

return non_default_generation_parameters

Expand Down
27 changes: 15 additions & 12 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,9 @@ def from_pretrained(
else:
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")

if kwargs.get("return_unused_kwargs") is True:
if kwargs.get("_from_model_config", False):
return cls.from_model_config(config_dict)
elif kwargs.get("return_unused_kwargs") is True:
config, unused_kwargs = cls.from_dict(config_dict, **kwargs)
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
return config, unused_kwargs
Expand Down Expand Up @@ -1084,19 +1086,19 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool =
writer.write(self.to_json_string(use_diff=use_diff))

@classmethod
def from_model_config(cls, model_config: PreTrainedConfig) -> "GenerationConfig":
def from_model_config(cls, model_config: PreTrainedConfig | dict) -> "GenerationConfig":
"""
Instantiates a [`GenerationConfig`] from a [`PreTrainedConfig`]. This function is useful to convert legacy
[`PreTrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`].

Args:
model_config (`PreTrainedConfig`):
model_config (`PreTrainedConfig | dict`):
The model config that will be used to instantiate the generation config.

Returns:
[`GenerationConfig`]: The configuration object instantiated from those parameters.
"""
config_dict = model_config.to_dict()
config_dict = model_config.to_dict() if not isinstance(model_config, dict) else model_config
config_dict.pop("_from_model_config", None)

# Removes all `None` from the model config dict -- this lets the generation config defaults to take hold
Expand All @@ -1106,14 +1108,15 @@ def from_model_config(cls, model_config: PreTrainedConfig) -> "GenerationConfig"

# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
# generation config (which in turn is defined from the outer attributes of model config).
decoder_config = model_config.get_text_config(decoder=True)
if decoder_config is not model_config:
default_generation_config = GenerationConfig()
decoder_config_dict = decoder_config.to_dict()
for attr in generation_config.to_dict():
is_unset = getattr(generation_config, attr) == getattr(default_generation_config, attr)
if attr in decoder_config_dict and is_unset:
setattr(generation_config, attr, decoder_config_dict[attr])
if not isinstance(model_config, dict):
decoder_config = model_config.get_text_config(decoder=True)
if decoder_config is not model_config:
default_generation_config = GenerationConfig()
decoder_config_dict = decoder_config.to_dict()
for attr in generation_config.to_dict():
is_unset = getattr(generation_config, attr) == getattr(default_generation_config, attr)
if attr in decoder_config_dict and is_unset:
setattr(generation_config, attr, decoder_config_dict[attr])

# If any `output_...` flag is set to `True`, we ensure `return_dict_in_generate` is set to `True`.
if generation_config.return_dict_in_generate is False:
Expand Down
16 changes: 11 additions & 5 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,14 @@ def adjust_generation_fn(
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
self.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
config_file_name="config.json",
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
_from_model_config=True,
**repo_loading_kwargs,
)
# Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
if hasattr(self, "load_custom_generate") and trust_remote_code:
try:
Expand Down Expand Up @@ -1778,14 +1786,12 @@ def _prepare_generation_config(
):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config: # 4)
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed in v5."
raise ValueError(
"You have modified the pretrained model configuration to control generation."
" This strategy to control generation is not supported anymore. "
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",
UserWarning,
)
self.generation_config = new_generation_config

generation_config = self.generation_config
using_model_generation_config = True
Expand Down
17 changes: 2 additions & 15 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1456,6 +1456,8 @@ def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):
self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
self.config._attn_implementation, is_init_check=True
)
if self.can_generate():
self.generation_config = GenerationConfig.from_model_config(config)

# for initialization of the loss
loss_type = self.__class__.__name__
Expand All @@ -1470,8 +1472,6 @@ def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):

self.name_or_path = config.name_or_path
self.warnings_issued = {}
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None

# Overwrite the class attribute to make it an instance attribute, so models like
# `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
# when a different component (e.g. language_model) is used.
Expand Down Expand Up @@ -3174,19 +3174,6 @@ def save_pretrained(
# Save the config
if is_main_process:
if not _hf_peft_config_loaded:
# If the model config has set attributes that should be in the generation config, move them there.
misplaced_generation_parameters = model_to_save.config._get_non_default_generation_parameters()
if self.can_generate() and len(misplaced_generation_parameters) > 0:
warnings.warn(
"Moving the following attributes in the config to the generation config: "
f"{misplaced_generation_parameters}. You are seeing this warning because you've set "
"generation parameters in the model config, as opposed to in the generation config.",
UserWarning,
)
for param_name, param_value in misplaced_generation_parameters.items():
setattr(model_to_save.generation_config, param_name, param_value)
setattr(model_to_save.config, param_name, None)

model_to_save.config.save_pretrained(save_directory)
if self.can_generate():
model_to_save.generation_config.save_pretrained(save_directory)
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/bart/configuration_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,10 @@ def __init__(
)
self.tie_encoder_decoder = True
# ensure backward compatibility for BART CNN models
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
if kwargs.get("force_bos_token_to_be_generated", False):
self.forced_bos_token_id = self.bos_token_id
warnings.warn(
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
"The config can simply be saved and uploaded again to be fixed."
f"Please make sure the generation config includes `forced_bos_token_id={self.bos_token_id}`. "
)


Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/mvp/configuration_mvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,10 @@ def __init__(
**kwargs,
)

if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
if kwargs.get("force_bos_token_to_be_generated", False):
self.forced_bos_token_id = self.bos_token_id
warnings.warn(
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
f"Please make sure the generated config includes `forced_bos_token_id={self.bos_token_id}` . "
"The config can simply be saved and uploaded again to be fixed."
)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/rag/configuration_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(

self.use_cache = use_cache

if self.forced_eos_token_id is None:
if forced_eos_token_id is None:
self.forced_eos_token_id = getattr(self.generator, "forced_eos_token_id", None)

@classmethod
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/udop/modeling_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,6 @@ def __init__(self, config):
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
self.embed_patches = UdopPatchEmbeddings(config)
self.is_decoder = config.is_decoder
self._max_length = config.max_length
self.num_layers = config.num_layers

self.block = nn.ModuleList(
Expand Down
4 changes: 2 additions & 2 deletions tests/models/bart/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,9 +975,9 @@ def test_xsum_summarization_same_as_fairseq(self):
self.assertEqual(EXPECTED_SUMMARY, decoded[0])

def test_xsum_config_generation_params(self):
config = BartConfig.from_pretrained("facebook/bart-large-xsum")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-xsum")
expected_params = {"num_beams": 6, "do_sample": False, "early_stopping": True, "length_penalty": 1.0}
config_params = {k: getattr(config, k, "MISSING") for k, v in expected_params.items()}
config_params = {k: getattr(model.generation_config, k, "MISSING") for k, v in expected_params.items()}
self.assertDictEqual(expected_params, config_params)

@slow
Expand Down
4 changes: 2 additions & 2 deletions tests/models/encoder_decoder/test_modeling_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,9 @@ def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config
generated_output = enc_dec_model.generate(
input_ids,
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
max_length=decoder_config.max_length,
max_length=enc_dec_model.generation_config.max_length,
)
self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (enc_dec_model.generation_config.max_length,))

def create_and_check_encoder_decoder_shared_weights(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,9 @@ def check_encoder_decoder_model_generate(
generated_output = enc_dec_model.generate(
inputs,
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
max_length=decoder_config.max_length,
max_length=enc_dec_model.generation_config.max_length,
)
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (enc_dec_model.generation_config.max_length,))

def test_encoder_decoder_model(self):
input_ids_dict = self.prepare_config_and_inputs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,9 @@ def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_val
generated_output = enc_dec_model.generate(
inputs,
decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id,
max_length=decoder_config.max_length,
max_length=enc_dec_model.generation_config.max_length,
)
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,))
self.assertEqual(generated_output.shape, (inputs.shape[0],) + (enc_dec_model.generation_config.max_length,))

def test_encoder_decoder_model(self):
input_ids_dict = self.prepare_config_and_inputs()
Expand Down Expand Up @@ -883,10 +883,12 @@ def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_val
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
max_length=decoder_config.max_length,
max_length=enc_dec_model.generation_config.max_length,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
self.assertEqual(
generated_output.shape, (pixel_values.shape[0],) + (enc_dec_model.generation_config.max_length,)
)

@unittest.skip(reason="There are no published pretrained TrOCR checkpoints for now")
def test_real_model_save_load_from_pretrained(self):
Expand Down Expand Up @@ -994,10 +996,12 @@ def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_val
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
max_length=decoder_config.max_length,
max_length=enc_dec_model.generation_config.max_length,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
self.assertEqual(
generated_output.shape, (pixel_values.shape[0],) + (enc_dec_model.generation_config.max_length,)
)

@unittest.skip(reason="VIT2GPT2 also has an integration test for testinf save-load")
def test_real_model_save_load_from_pretrained(self):
Expand Down Expand Up @@ -1105,10 +1109,12 @@ def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_val
generated_output = enc_dec_model.generate(
pixel_values=pixel_values,
decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id,
max_length=decoder_config.max_length,
max_length=enc_dec_model.generation_config.max_length,
**kwargs,
)
self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))
self.assertEqual(
generated_output.shape, (pixel_values.shape[0],) + (enc_dec_model.generation_config.max_length,)
)

@unittest.skip(reason="Donut has an Integration test for that")
def test_real_model_save_load_from_pretrained(self):
Expand Down
Loading