Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,16 @@ def __init__(self, **kwargs):
self._commit_hash = kwargs.pop("_commit_hash", None)
self.transformers_version = kwargs.pop("transformers_version", __version__)

# Additional attributes without default values
if not self._from_model_config:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add a comment here explaining why we need this if, otherwise we may be like "wtf?" in the future

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, thanks for the comment!

# we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a model's default configuration file
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err

def __eq__(self, other):
self_dict = self.__dict__.copy()
other_dict = other.__dict__.copy()
Expand Down Expand Up @@ -537,7 +547,9 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig":
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
kwargs["_commit_hash"] = config_dict["_commit_hash"]

config = cls(**config_dict)
# remove all the arguments that are in the config_dict

config = cls(**config_dict, **kwargs)
unused_kwargs = config.update(**kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line now only exists to obtain unused_kwargs, as the kwargs get written to the config in the line above, correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well yes, for example when _from_model_config is set to True, then it is still in the kwargs, think I saw something like this.


logger.info(f"Generate config {config}")
Expand All @@ -546,6 +558,18 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig":
else:
return config

def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
"""
Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
string, which can then be stored in the json format.
"""
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
for value in d.values():
if isinstance(value, dict):
self.dict_torch_dtype_to_str(value)

def to_diff_dict(self) -> Dict[str, Any]:
"""
Removes all attributes from config which correspond to the default config attributes for better readability and
Expand All @@ -566,6 +590,7 @@ def to_diff_dict(self) -> Dict[str, Any]:
if key not in default_config_dict or key == "transformers_version" or value != default_config_dict[key]:
serializable_config_dict[key] = value

self.dict_torch_dtype_to_str(serializable_config_dict)
return serializable_config_dict

def to_dict(self) -> Dict[str, Any]:
Expand All @@ -582,6 +607,7 @@ def to_dict(self) -> Dict[str, Any]:
# Transformers version when serializing this file
output["transformers_version"] = __version__

self.dict_torch_dtype_to_str(output)
return output

def to_json_string(self, use_diff: bool = True) -> str:
Expand Down Expand Up @@ -630,7 +656,8 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig"
[`GenerationConfig`]: The configuration object instantiated from those parameters.
"""
config_dict = model_config.to_dict()
config = cls.from_dict(config_dict, return_unused_kwargs=False)
config_dict.pop("_from_model_config", None)
config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True)

# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
# generation config.
Expand All @@ -642,7 +669,6 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig"
if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr):
setattr(config, attr, decoder_config[attr])

config._from_model_config = True
return config

def update(self, **kwargs):
Expand Down
14 changes: 14 additions & 0 deletions tests/generation/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,20 @@ def test_update(self):
# `.update()` returns a dictionary of unused kwargs
self.assertEqual(unused_kwargs, {"foo": "bar"})

def test_initialize_new_kwargs(self):
generation_config = GenerationConfig()
generation_config.foo = "bar"

with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir:
generation_config.save_pretrained(tmp_dir)

new_config = GenerationConfig.from_pretrained(tmp_dir)
# update_kwargs was used to update the config on valid attributes
self.assertEqual(new_config.foo, "bar")

generation_config = GenerationConfig.from_model_config(new_config)
assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config


@is_staging_test
class ConfigPushToHubTester(unittest.TestCase):
Expand Down