diff --git a/src/transformers/configuration_encoder_decoder.py b/src/transformers/configuration_encoder_decoder.py index eff92bf245cd..e357d15a067b 100644 --- a/src/transformers/configuration_encoder_decoder.py +++ b/src/transformers/configuration_encoder_decoder.py @@ -70,6 +70,7 @@ class EncoderDecoderConfig(PretrainedConfig): >>> model = EncoderDecoderModel.from_pretrained('my-model', config=encoder_decoder_config) """ model_type = "encoder_decoder" + is_composition = True def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/transformers/configuration_fsmt.py b/src/transformers/configuration_fsmt.py index 747f47dd5290..b20328bc438b 100644 --- a/src/transformers/configuration_fsmt.py +++ b/src/transformers/configuration_fsmt.py @@ -126,9 +126,9 @@ class FSMTConfig(PretrainedConfig): # update the defaults from config file def __init__( self, - langs, - src_vocab_size, - tgt_vocab_size, + langs=["en", "de"], + src_vocab_size=42024, + tgt_vocab_size=42024, activation_function="relu", d_model=1024, max_length=200, diff --git a/src/transformers/configuration_rag.py b/src/transformers/configuration_rag.py index 30baca04c527..c18e1980b4e9 100644 --- a/src/transformers/configuration_rag.py +++ b/src/transformers/configuration_rag.py @@ -77,6 +77,7 @@ @add_start_docstrings(RAG_CONFIG_DOC) class RagConfig(PretrainedConfig): model_type = "rag" + is_composition = True def __init__( self, diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index c4044aece519..57f635bfbca4 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -41,6 +41,10 @@ class PretrainedConfig(object): Class attributes (overridden by derived classes) - **model_type** (:obj:`str`): An identifier for the model type, serialized into the JSON file, and used to recreate the correct object in :class:`~transformers.AutoConfig`. + - **is_composition** (:obj:`bool`): Whether the config class is composed of multiple + sub-configs. In this case the config has to be initialized from two or more configs of + type :class:`~transformers.PretrainedConfig` like: :class:`~transformers.EncoderDecoderConfig` or + :class:`~RagConfig`. Args: name_or_path (:obj:`str`, `optional`, defaults to :obj:`""`): @@ -145,6 +149,7 @@ class PretrainedConfig(object): use BFloat16 scalars (only used by some TensorFlow models). """ model_type: str = "" + is_composition: bool = False def __init__(self, **kwargs): # Attributes with defaults @@ -476,11 +481,18 @@ def to_diff_dict(self) -> Dict[str, Any]: # get the default config dict default_config_dict = PretrainedConfig().to_dict() + # get class specific config dict + class_config_dict = self.__class__().to_dict() if not self.is_composition else {} + serializable_config_dict = {} # only serialize values that differ from the default config for key, value in config_dict.items(): - if key not in default_config_dict or value != default_config_dict[key]: + if ( + key not in default_config_dict + or value != default_config_dict[key] + or (key in class_config_dict and value != class_config_dict[key]) + ): serializable_config_dict[key] = value return serializable_config_dict diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 7498ae6caf7e..53dbc9eeb913 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -66,9 +66,16 @@ def create_and_test_config_with_num_labels(self): self.parent.assertEqual(len(config.id2label), 3) self.parent.assertEqual(len(config.label2id), 3) + def check_config_can_be_init_without_params(self): + if self.config_class.is_composition: + return + config = self.config_class() + self.parent.assertIsNotNone(config) + def run_common_tests(self): self.create_and_test_config_common_properties() self.create_and_test_config_to_json_string() self.create_and_test_config_to_json_file() self.create_and_test_config_from_and_save_pretrained() self.create_and_test_config_with_num_labels() + self.check_config_can_be_init_without_params() diff --git a/tests/test_modeling_prophetnet.py b/tests/test_modeling_prophetnet.py index 90ca042db89c..55336c5d2fe6 100644 --- a/tests/test_modeling_prophetnet.py +++ b/tests/test_modeling_prophetnet.py @@ -901,6 +901,15 @@ def test_attn_mask_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_model_with_attn_mask(*config_and_inputs) + def test_config_save(self): + config = self.model_tester.prepare_config_and_inputs()[0] + config.add_cross_attention = False + with tempfile.TemporaryDirectory() as tmp_dirname: + config.save_pretrained(tmp_dirname) + config = ProphetNetConfig.from_pretrained(tmp_dirname) + + self.assertFalse(config.add_cross_attention) + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") def test_fp16_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs()