diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 7d6118892c55..5b5a7d1794dd 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -282,6 +282,16 @@ def __init__(self, **kwargs): self._commit_hash = kwargs.pop("_commit_hash", None) self.transformers_version = kwargs.pop("transformers_version", __version__) + # Additional attributes without default values + if not self._from_model_config: + # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a model's default configuration file + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + def __eq__(self, other): self_dict = self.__dict__.copy() other_dict = other.__dict__.copy() @@ -537,7 +547,9 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig": if "_commit_hash" in kwargs and "_commit_hash" in config_dict: kwargs["_commit_hash"] = config_dict["_commit_hash"] - config = cls(**config_dict) + # remove all the arguments that are in the config_dict + + config = cls(**config_dict, **kwargs) unused_kwargs = config.update(**kwargs) logger.info(f"Generate config {config}") @@ -546,6 +558,18 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig": else: return config + def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None: + """ + Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None, + converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"* + string, which can then be stored in the json format. + """ + if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str): + d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] + for value in d.values(): + if isinstance(value, dict): + self.dict_torch_dtype_to_str(value) + def to_diff_dict(self) -> Dict[str, Any]: """ Removes all attributes from config which correspond to the default config attributes for better readability and @@ -566,6 +590,7 @@ def to_diff_dict(self) -> Dict[str, Any]: if key not in default_config_dict or key == "transformers_version" or value != default_config_dict[key]: serializable_config_dict[key] = value + self.dict_torch_dtype_to_str(serializable_config_dict) return serializable_config_dict def to_dict(self) -> Dict[str, Any]: @@ -582,6 +607,7 @@ def to_dict(self) -> Dict[str, Any]: # Transformers version when serializing this file output["transformers_version"] = __version__ + self.dict_torch_dtype_to_str(output) return output def to_json_string(self, use_diff: bool = True) -> str: @@ -630,7 +656,8 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig" [`GenerationConfig`]: The configuration object instantiated from those parameters. """ config_dict = model_config.to_dict() - config = cls.from_dict(config_dict, return_unused_kwargs=False) + config_dict.pop("_from_model_config", None) + config = cls.from_dict(config_dict, return_unused_kwargs=False, _from_model_config=True) # Special case: some models have generation attributes set in the decoder. Use them if still unset in the # generation config. @@ -642,7 +669,6 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig" if attr in decoder_config and getattr(config, attr) == getattr(default_generation_config, attr): setattr(config, attr, decoder_config[attr]) - config._from_model_config = True return config def update(self, **kwargs): diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index fcc481f2093e..2c655254f2e5 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -78,6 +78,20 @@ def test_update(self): # `.update()` returns a dictionary of unused kwargs self.assertEqual(unused_kwargs, {"foo": "bar"}) + def test_initialize_new_kwargs(self): + generation_config = GenerationConfig() + generation_config.foo = "bar" + + with tempfile.TemporaryDirectory("test-generation-config") as tmp_dir: + generation_config.save_pretrained(tmp_dir) + + new_config = GenerationConfig.from_pretrained(tmp_dir) + # update_kwargs was used to update the config on valid attributes + self.assertEqual(new_config.foo, "bar") + + generation_config = GenerationConfig.from_model_config(new_config) + assert not hasattr(generation_config, "foo") # no new kwargs should be initialized if from config + @is_staging_test class ConfigPushToHubTester(unittest.TestCase):