diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 92e5425e95e7..436de7c021f2 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -42,6 +42,8 @@ ) +METADATA_FIELDS = ["_original_object_hash", "_commit_hash", "transformers_version"] + logger = logging.get_logger(__name__) _re_configuration_file = re.compile(r"config\.(.*)\.json") @@ -313,6 +315,13 @@ def __init__(self, **kwargs): logger.error(f"Can't set {key} with value {value} for {self}") raise err + # If we load the object from an external source, we need to store the original object hash. (The hash can't + # be set here -- some classes overload __init__ and modify the instance after calling super().__init__) + self._original_object_hash = None + + def __hash__(self): + return hash(self.to_json_string(ignore_metadata=True)) + @property def name_or_path(self) -> str: return getattr(self, "_name_or_path", None) @@ -380,12 +389,17 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: non_default_generation_parameters = self._get_non_default_generation_parameters() if len(non_default_generation_parameters) > 0: - raise ValueError( + error_message = ( "Some non-default generation parameters are set in the model config. These should go into either a) " "`model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file " "(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model) " f"\nNon-default generation parameters: {str(non_default_generation_parameters)}" ) + # If the user was resposible for setting these, raise an exception. Otherwise, don't crash (warn). + if hash(self) != self._original_object_hash: + raise ValueError(error_message) + else: + warnings.warn(error_message, UserWarning) os.makedirs(save_directory, exist_ok=True) @@ -542,7 +556,8 @@ def from_pretrained( f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." ) - return cls.from_dict(config_dict, **kwargs) + config = cls.from_dict(config_dict, **kwargs) + return config @classmethod def get_config_dict( @@ -736,6 +751,7 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig": for key in to_remove: kwargs.pop(key, None) + config._original_object_hash = hash(config) # config object loaded from external source -> store hash logger.info(f"Model config {config}") if return_unused_kwargs: return config, kwargs @@ -756,7 +772,9 @@ def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig """ config_dict = cls._dict_from_json_file(json_file) - return cls(**config_dict) + config = cls(**config_dict) + config._original_object_hash = hash(config) # config object loaded from external source -> store hash + return config @classmethod def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): @@ -765,7 +783,12 @@ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): return json.loads(text) def __eq__(self, other): - return isinstance(other, PretrainedConfig) and (self.__dict__ == other.__dict__) + if not isinstance(other, PretrainedConfig): + return False + + self_without_metadata = self.to_json_string(use_diff=False, ignore_metadata=True) + other_without_metadata = other.to_json_string(use_diff=False, ignore_metadata=True) + return self_without_metadata == other_without_metadata def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" @@ -804,7 +827,7 @@ def to_diff_dict(self) -> Dict[str, Any]: serializable_config_dict[key] = diff elif ( key not in default_config_dict - or key == "transformers_version" + or key in METADATA_FIELDS or value != default_config_dict[key] or (key in class_config_dict and value != class_config_dict[key]) ): @@ -834,24 +857,27 @@ def to_dict(self) -> Dict[str, Any]: Returns: `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. """ - output = copy.deepcopy(self.__dict__) + self_dict = copy.deepcopy(self.__dict__) if hasattr(self.__class__, "model_type"): - output["model_type"] = self.__class__.model_type - if "_auto_class" in output: - del output["_auto_class"] - if "_commit_hash" in output: - del output["_commit_hash"] - if "_attn_implementation_internal" in output: - del output["_attn_implementation_internal"] + self_dict["model_type"] = self.__class__.model_type + if "_auto_class" in self_dict: + del self_dict["_auto_class"] + if "_attn_implementation_internal" in self_dict: + del self_dict["_attn_implementation_internal"] + + for key in METADATA_FIELDS: + self_dict.pop(key, None) # Transformers version when serializing the model - output["transformers_version"] = __version__ + self_dict["transformers_version"] = __version__ - for key, value in output.items(): + output = {} + for key, value in self_dict.items(): # Deal with nested configs like CLIP if isinstance(value, PretrainedConfig): value = value.to_dict() - del value["transformers_version"] + for key in METADATA_FIELDS: + value.pop(key, None) output[key] = value @@ -869,7 +895,7 @@ def to_dict(self) -> Dict[str, Any]: return output - def to_json_string(self, use_diff: bool = True) -> str: + def to_json_string(self, use_diff: bool = True, ignore_metadata: bool = False) -> str: """ Serializes this instance to a JSON string. @@ -877,6 +903,8 @@ def to_json_string(self, use_diff: bool = True) -> str: use_diff (`bool`, *optional*, defaults to `True`): If set to `True`, only the difference between the config instance and the default `PretrainedConfig()` is serialized to JSON string. + ignore_metadata (`bool`, *optional*, defaults to `False`): + Whether to ignore the metadata fields present in the instance Returns: `str`: String containing all the attributes that make up this configuration instance in JSON format. @@ -885,6 +913,17 @@ def to_json_string(self, use_diff: bool = True) -> str: config_dict = self.to_diff_dict() else: config_dict = self.to_dict() + + if ignore_metadata: + # top level metadata + for metadata_field in METADATA_FIELDS: + config_dict.pop(metadata_field, None) + # nested metadata + for value in config_dict.values(): + if isinstance(value, dict): + for metadata_field in METADATA_FIELDS: + value.pop(metadata_field, None) + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True): diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 1acd40641132..d37937c18f3e 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -485,6 +485,10 @@ def __init__(self, **kwargs): # Validate the values of the attributes self.validate(is_init=True) + # If we load the object from an external source, we need to store the original object hash. (The hash can't + # be set here -- some classes overload __init__ and modify the instance after calling super().__init__) + self._original_object_hash = None + def __hash__(self): return hash(self.to_json_string(ignore_metadata=True)) @@ -1055,11 +1059,11 @@ def from_pretrained( if kwargs.get("return_unused_kwargs") is True: config, unused_kwargs = cls.from_dict(config_dict, **kwargs) - config._original_object_hash = hash(config) # Hash to detect whether the instance was modified + config._original_object_hash = hash(config) # config object loaded from external source -> store hash return config, unused_kwargs else: config = cls.from_dict(config_dict, **kwargs) - config._original_object_hash = hash(config) # Hash to detect whether the instance was modified + config._original_object_hash = hash(config) # config object loaded from external source -> store hash return config @classmethod @@ -1096,6 +1100,7 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig": config = cls(**{**config_dict, **kwargs}) unused_kwargs = config.update(**kwargs) + config._original_object_hash = hash(config) # config object loaded from external source -> store hash logger.info(f"Generate config {config}") if return_unused_kwargs: return config, unused_kwargs @@ -1147,10 +1152,8 @@ def to_dict(self) -> Dict[str, Any]: output = copy.deepcopy(self.__dict__) # Fields to ignore at serialization time - if "_commit_hash" in output: - del output["_commit_hash"] - if "_original_object_hash" in output: - del output["_original_object_hash"] + for key in METADATA_FIELDS: + output.pop(key, None) # Transformers version when serializing this file output["transformers_version"] = __version__ @@ -1256,7 +1259,7 @@ def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig" ): generation_config.return_dict_in_generate = True - # Hash to detect whether the instance was modified + # `from_model_config` is a valid initializer and has post __init__ changes generation_config._original_object_hash = hash(generation_config) return generation_config diff --git a/tests/generation/test_configuration_utils.py b/tests/generation/test_configuration_utils.py index 9c7f4db3c923..78cfbbeba7c9 100644 --- a/tests/generation/test_configuration_utils.py +++ b/tests/generation/test_configuration_utils.py @@ -57,6 +57,7 @@ UnbatchedClassifierFreeGuidanceLogitsProcessor, WatermarkLogitsProcessor, ) +from transformers.generation.configuration_utils import METADATA_FIELDS from transformers.testing_utils import TOKEN, USER, is_staging_test, torch_device @@ -700,7 +701,7 @@ def test_push_to_hub(self): new_config = GenerationConfig.from_pretrained(tmp_repo) for k, v in config.to_dict().items(): - if k != "transformers_version": + if k not in METADATA_FIELDS: self.assertEqual(v, getattr(new_config, k)) finally: # Always (try to) delete the repo. @@ -720,7 +721,7 @@ def test_push_to_hub_via_save_pretrained(self): new_config = GenerationConfig.from_pretrained(tmp_repo) for k, v in config.to_dict().items(): - if k != "transformers_version": + if k not in METADATA_FIELDS: self.assertEqual(v, getattr(new_config, k)) finally: # Always (try to) delete the repo. @@ -739,7 +740,7 @@ def test_push_to_hub_in_organization(self): new_config = GenerationConfig.from_pretrained(tmp_repo) for k, v in config.to_dict().items(): - if k != "transformers_version": + if k not in METADATA_FIELDS: self.assertEqual(v, getattr(new_config, k)) finally: # Always (try to) delete the repo. @@ -759,7 +760,7 @@ def test_push_to_hub_in_organization_via_save_pretrained(self): new_config = GenerationConfig.from_pretrained(tmp_repo) for k, v in config.to_dict().items(): - if k != "transformers_version": + if k not in METADATA_FIELDS: self.assertEqual(v, getattr(new_config, k)) finally: # Always (try to) delete the repo. diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 81c6a008b133..d3b3d5d12a74 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -85,7 +85,7 @@ def create_and_test_config_to_json_file(self): config_first.to_json_file(json_file_path) config_second = self.config_class.from_json_file(json_file_path) - self.parent.assertEqual(config_second.to_dict(), config_first.to_dict()) + self.parent.assertEqual(config_second, config_first) def create_and_test_config_from_and_save_pretrained(self): config_first = self.config_class(**self.inputs_dict) @@ -94,7 +94,7 @@ def create_and_test_config_from_and_save_pretrained(self): config_first.save_pretrained(tmpdirname) config_second = self.config_class.from_pretrained(tmpdirname) - self.parent.assertEqual(config_second.to_dict(), config_first.to_dict()) + self.parent.assertEqual(config_second, config_first) with self.parent.assertRaises(OSError): self.config_class.from_pretrained(f".{tmpdirname}") @@ -108,7 +108,7 @@ def create_and_test_config_from_and_save_pretrained_subfolder(self): config_first.save_pretrained(sub_tmpdirname) config_second = self.config_class.from_pretrained(tmpdirname, subfolder=subfolder) - self.parent.assertEqual(config_second.to_dict(), config_first.to_dict()) + self.parent.assertEqual(config_second, config_first) def create_and_test_config_with_num_labels(self): config = self.config_class(**self.inputs_dict, num_labels=5) diff --git a/tests/utils/test_configuration_utils.py b/tests/utils/test_configuration_utils.py index 76394daf9ced..f984c307d789 100644 --- a/tests/utils/test_configuration_utils.py +++ b/tests/utils/test_configuration_utils.py @@ -26,7 +26,7 @@ from requests.exceptions import HTTPError from transformers import AutoConfig, BertConfig, GPT2Config -from transformers.configuration_utils import PretrainedConfig +from transformers.configuration_utils import METADATA_FIELDS, PretrainedConfig from transformers.testing_utils import TOKEN, USER, is_staging_test @@ -118,7 +118,7 @@ def test_push_to_hub(self): new_config = BertConfig.from_pretrained(tmp_repo) for k, v in config.to_dict().items(): - if k != "transformers_version": + if k not in METADATA_FIELDS: self.assertEqual(v, getattr(new_config, k)) finally: # Always (try to) delete the repo. @@ -137,7 +137,7 @@ def test_push_to_hub_via_save_pretrained(self): new_config = BertConfig.from_pretrained(tmp_repo) for k, v in config.to_dict().items(): - if k != "transformers_version": + if k not in METADATA_FIELDS: self.assertEqual(v, getattr(new_config, k)) finally: # Always (try to) delete the repo. @@ -154,7 +154,7 @@ def test_push_to_hub_in_organization(self): new_config = BertConfig.from_pretrained(tmp_repo) for k, v in config.to_dict().items(): - if k != "transformers_version": + if k not in METADATA_FIELDS: self.assertEqual(v, getattr(new_config, k)) finally: # Always (try to) delete the repo. @@ -172,7 +172,7 @@ def test_push_to_hub_in_organization_via_save_pretrained(self): new_config = BertConfig.from_pretrained(tmp_repo) for k, v in config.to_dict().items(): - if k != "transformers_version": + if k not in METADATA_FIELDS: self.assertEqual(v, getattr(new_config, k)) finally: # Always (try to) delete the repo. @@ -219,17 +219,16 @@ def test_config_from_string(self): def test_config_common_kwargs_is_complete(self): base_config = PretrainedConfig() - missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs] + missing_keys = {key for key in base_config.__dict__ if key not in config_common_kwargs} # If this part of the test fails, you have arguments to addin config_common_kwargs above. - self.assertListEqual( + self.assertSetEqual( missing_keys, - [ + { "is_encoder_decoder", "_name_or_path", - "_commit_hash", "_attn_implementation_internal", - "transformers_version", - ], + } + | set(METADATA_FIELDS), ) keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)] if len(keys_with_defaults) > 0: @@ -333,3 +332,24 @@ def test_loading_config_do_not_raise_future_warnings(self): with warnings.catch_warnings(): warnings.simplefilter("error") PretrainedConfig.from_pretrained("bert-base-uncased") + + def test_saving_untouched_config_with_generation_parameters(self): + """ + We don't want to save generation parameters in the model config. However, if a pretrained config has generation + paremeters, we don't want to throw exceptions -- the user has done nothing incorrect, so lower them to + warnings. Tests that this behavior persists. + """ + # Saving a model config with a user-defined generation config will raise an exception. + config = BertConfig(min_length=3) + with tempfile.TemporaryDirectory() as tmp_dir: + with self.assertRaises(ValueError): + config.save_pretrained(tmp_dir) + + # However, if the user loads a pretrained config with generation parameters, we should not raise an exception + # at save time + config = AutoConfig.from_pretrained("openai/whisper-small") + self.assertTrue(len(config._get_non_default_generation_parameters()) > 0) # sanity check: has gen params + with tempfile.TemporaryDirectory() as tmp_dir: + with self.assertWarns(UserWarning) as cm: + config.save_pretrained(tmp_dir) + self.assertIn("non-default generation parameters are set in the model config", str(cm.warning))