Skip to content
73 changes: 56 additions & 17 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
)


METADATA_FIELDS = ["_original_object_hash", "_commit_hash", "transformers_version"]

logger = logging.get_logger(__name__)

_re_configuration_file = re.compile(r"config\.(.*)\.json")
Expand Down Expand Up @@ -313,6 +315,13 @@ def __init__(self, **kwargs):
logger.error(f"Can't set {key} with value {value} for {self}")
raise err

# If we load the object from an external source, we need to store the original object hash. (The hash can't
# be set here -- some classes overload __init__ and modify the instance after calling super().__init__)
self._original_object_hash = None

def __hash__(self):
return hash(self.to_json_string(ignore_metadata=True))

@property
def name_or_path(self) -> str:
return getattr(self, "_name_or_path", None)
Expand Down Expand Up @@ -380,12 +389,17 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:

non_default_generation_parameters = self._get_non_default_generation_parameters()
if len(non_default_generation_parameters) > 0:
raise ValueError(
error_message = (
"Some non-default generation parameters are set in the model config. These should go into either a) "
"`model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file "
"(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) "
f"\nNon-default generation parameters: {str(non_default_generation_parameters)}"
)
# If the user was resposible for setting these, raise an exception. Otherwise, don't crash (warn).
if hash(self) != self._original_object_hash:
raise ValueError(error_message)
else:
warnings.warn(error_message, UserWarning)

os.makedirs(save_directory, exist_ok=True)

Expand Down Expand Up @@ -542,7 +556,8 @@ def from_pretrained(
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

return cls.from_dict(config_dict, **kwargs)
config = cls.from_dict(config_dict, **kwargs)
return config

@classmethod
def get_config_dict(
Expand Down Expand Up @@ -736,6 +751,7 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
for key in to_remove:
kwargs.pop(key, None)

config._original_object_hash = hash(config) # config object loaded from external source -> store hash
logger.info(f"Model config {config}")
if return_unused_kwargs:
return config, kwargs
Expand All @@ -756,7 +772,9 @@ def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig

"""
config_dict = cls._dict_from_json_file(json_file)
return cls(**config_dict)
config = cls(**config_dict)
config._original_object_hash = hash(config) # config object loaded from external source -> store hash
return config

@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
Expand All @@ -765,7 +783,12 @@ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
return json.loads(text)

def __eq__(self, other):
return isinstance(other, PretrainedConfig) and (self.__dict__ == other.__dict__)
if not isinstance(other, PretrainedConfig):
return False

self_without_metadata = self.to_json_string(use_diff=False, ignore_metadata=True)
other_without_metadata = other.to_json_string(use_diff=False, ignore_metadata=True)
return self_without_metadata == other_without_metadata

def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
Expand Down Expand Up @@ -804,7 +827,7 @@ def to_diff_dict(self) -> Dict[str, Any]:
serializable_config_dict[key] = diff
elif (
key not in default_config_dict
or key == "transformers_version"
or key in METADATA_FIELDS
or value != default_config_dict[key]
or (key in class_config_dict and value != class_config_dict[key])
):
Expand Down Expand Up @@ -834,24 +857,27 @@ def to_dict(self) -> Dict[str, Any]:
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
self_dict = copy.deepcopy(self.__dict__)
if hasattr(self.__class__, "model_type"):
output["model_type"] = self.__class__.model_type
if "_auto_class" in output:
del output["_auto_class"]
if "_commit_hash" in output:
del output["_commit_hash"]
if "_attn_implementation_internal" in output:
del output["_attn_implementation_internal"]
self_dict["model_type"] = self.__class__.model_type
if "_auto_class" in self_dict:
del self_dict["_auto_class"]
if "_attn_implementation_internal" in self_dict:
del self_dict["_attn_implementation_internal"]

for key in METADATA_FIELDS:
self_dict.pop(key, None)

# Transformers version when serializing the model
output["transformers_version"] = __version__
self_dict["transformers_version"] = __version__

for key, value in output.items():
output = {}
for key, value in self_dict.items():
# Deal with nested configs like CLIP
if isinstance(value, PretrainedConfig):
value = value.to_dict()
del value["transformers_version"]
for key in METADATA_FIELDS:
value.pop(key, None)

output[key] = value

Expand All @@ -869,14 +895,16 @@ def to_dict(self) -> Dict[str, Any]:

return output

def to_json_string(self, use_diff: bool = True) -> str:
def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) -> str:
"""
Serializes this instance to a JSON string.

Args:
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
is serialized to JSON string.
ignore_metadata (`bool`, *optional*, defaults to `False`):
Whether to ignore the metadata fields present in the instance

Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
Expand All @@ -885,6 +913,17 @@ def to_json_string(self, use_diff: bool = True) -> str:
config_dict = self.to_diff_dict()
else:
config_dict = self.to_dict()

if ignore_metadata:
# top level metadata
for metadata_field in METADATA_FIELDS:
config_dict.pop(metadata_field, None)
# nested metadata
for value in config_dict.values():
if isinstance(value, dict):
for metadata_field in METADATA_FIELDS:
value.pop(metadata_field, None)

return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
Expand Down
17 changes: 10 additions & 7 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,10 @@ def __init__(self, **kwargs):
# Validate the values of the attributes
self.validate(is_init=True)

# If we load the object from an external source, we need to store the original object hash. (The hash can't
# be set here -- some classes overload __init__ and modify the instance after calling super().__init__)
self._original_object_hash = None

def __hash__(self):
return hash(self.to_json_string(ignore_metadata=True))

Expand Down Expand Up @@ -1055,11 +1059,11 @@ def from_pretrained(

if kwargs.get("return_unused_kwargs") is True:
config, unused_kwargs = cls.from_dict(config_dict, **kwargs)
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
config._original_object_hash = hash(config) # config object loaded from external source -> store hash
return config, unused_kwargs
else:
config = cls.from_dict(config_dict, **kwargs)
config._original_object_hash = hash(config) # Hash to detect whether the instance was modified
config._original_object_hash = hash(config) # config object loaded from external source -> store hash
return config

@classmethod
Expand Down Expand Up @@ -1096,6 +1100,7 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig":
config = cls(**{**config_dict, **kwargs})
unused_kwargs = config.update(**kwargs)

config._original_object_hash = hash(config) # config object loaded from external source -> store hash
logger.info(f"Generate config {config}")
if return_unused_kwargs:
return config, unused_kwargs
Expand Down Expand Up @@ -1147,10 +1152,8 @@ def to_dict(self) -> Dict[str, Any]:
output = copy.deepcopy(self.__dict__)

# Fields to ignore at serialization time
if "_commit_hash" in output:
del output["_commit_hash"]
if "_original_object_hash" in output:
del output["_original_object_hash"]
for key in METADATA_FIELDS:
output.pop(key, None)

# Transformers version when serializing this file
output["transformers_version"] = __version__
Expand Down Expand Up @@ -1256,7 +1259,7 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig"
):
generation_config.return_dict_in_generate = True

# Hash to detect whether the instance was modified
# `from_model_config` is a valid initializer and has post __init__ changes
generation_config._original_object_hash = hash(generation_config)
return generation_config

Expand Down
9 changes: 5 additions & 4 deletions tests/generation/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WatermarkLogitsProcessor,
)
from transformers.generation.configuration_utils import METADATA_FIELDS
from transformers.testing_utils import TOKEN, USER, is_staging_test, torch_device


Expand Down Expand Up @@ -700,7 +701,7 @@ def test_push_to_hub(self):

new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
if k not in METADATA_FIELDS:
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
Expand All @@ -720,7 +721,7 @@ def test_push_to_hub_via_save_pretrained(self):

new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
if k not in METADATA_FIELDS:
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
Expand All @@ -739,7 +740,7 @@ def test_push_to_hub_in_organization(self):

new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
if k not in METADATA_FIELDS:
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
Expand All @@ -759,7 +760,7 @@ def test_push_to_hub_in_organization_via_save_pretrained(self):

new_config = GenerationConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
if k not in METADATA_FIELDS:
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
Expand Down
6 changes: 3 additions & 3 deletions tests/test_configuration_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def create_and_test_config_to_json_file(self):
config_first.to_json_file(json_file_path)
config_second = self.config_class.from_json_file(json_file_path)

self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
self.parent.assertEqual(config_second, config_first)

def create_and_test_config_from_and_save_pretrained(self):
config_first = self.config_class(**self.inputs_dict)
Expand All @@ -94,7 +94,7 @@ def create_and_test_config_from_and_save_pretrained(self):
config_first.save_pretrained(tmpdirname)
config_second = self.config_class.from_pretrained(tmpdirname)

self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
self.parent.assertEqual(config_second, config_first)

with self.parent.assertRaises(OSError):
self.config_class.from_pretrained(f".{tmpdirname}")
Expand All @@ -108,7 +108,7 @@ def create_and_test_config_from_and_save_pretrained_subfolder(self):
config_first.save_pretrained(sub_tmpdirname)
config_second = self.config_class.from_pretrained(tmpdirname, subfolder=subfolder)

self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
self.parent.assertEqual(config_second, config_first)

def create_and_test_config_with_num_labels(self):
config = self.config_class(**self.inputs_dict, num_labels=5)
Expand Down
42 changes: 31 additions & 11 deletions tests/utils/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from requests.exceptions import HTTPError

from transformers import AutoConfig, BertConfig, GPT2Config
from transformers.configuration_utils import PretrainedConfig
from transformers.configuration_utils import METADATA_FIELDS, PretrainedConfig
from transformers.testing_utils import TOKEN, USER, is_staging_test


Expand Down Expand Up @@ -118,7 +118,7 @@ def test_push_to_hub(self):

new_config = BertConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
if k not in METADATA_FIELDS:
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
Expand All @@ -137,7 +137,7 @@ def test_push_to_hub_via_save_pretrained(self):

new_config = BertConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
if k not in METADATA_FIELDS:
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
Expand All @@ -154,7 +154,7 @@ def test_push_to_hub_in_organization(self):

new_config = BertConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
if k not in METADATA_FIELDS:
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
Expand All @@ -172,7 +172,7 @@ def test_push_to_hub_in_organization_via_save_pretrained(self):

new_config = BertConfig.from_pretrained(tmp_repo)
for k, v in config.to_dict().items():
if k != "transformers_version":
if k not in METADATA_FIELDS:
self.assertEqual(v, getattr(new_config, k))
finally:
# Always (try to) delete the repo.
Expand Down Expand Up @@ -219,17 +219,16 @@ def test_config_from_string(self):

def test_config_common_kwargs_is_complete(self):
base_config = PretrainedConfig()
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
missing_keys = {key for key in base_config.__dict__ if key not in config_common_kwargs}
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
self.assertListEqual(
self.assertSetEqual(
missing_keys,
[
{
"is_encoder_decoder",
"_name_or_path",
"_commit_hash",
"_attn_implementation_internal",
"transformers_version",
],
}
| set(METADATA_FIELDS),
)
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
if len(keys_with_defaults) > 0:
Expand Down Expand Up @@ -333,3 +332,24 @@ def test_loading_config_do_not_raise_future_warnings(self):
with warnings.catch_warnings():
warnings.simplefilter("error")
PretrainedConfig.from_pretrained("bert-base-uncased")

def test_saving_untouched_config_with_generation_parameters(self):
"""
We don't want to save generation parameters in the model config. However, if a pretrained config has generation
paremeters, we don't want to throw exceptions -- the user has done nothing incorrect, so lower them to
warnings. Tests that this behavior persists.
"""
# Saving a model config with a user-defined generation config will raise an exception.
config = BertConfig(min_length=3)
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(ValueError):
config.save_pretrained(tmp_dir)

# However, if the user loads a pretrained config with generation parameters, we should not raise an exception
# at save time
config = AutoConfig.from_pretrained("openai/whisper-small")
self.assertTrue(len(config._get_non_default_generation_parameters()) > 0) # sanity check: has gen params
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertWarns(UserWarning) as cm:
config.save_pretrained(tmp_dir)
self.assertIn("non-default generation parameters are set in the model config", str(cm.warning))