From bbec75ea9067e95d7cf7bf5d406f63c3201e846b Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Tue, 8 Feb 2022 14:41:25 -0500 Subject: [PATCH 1/2] Make sure custom configs work with Transformers --- src/transformers/configuration_utils.py | 2 +- src/transformers/modeling_utils.py | 4 ++-- tests/test_modeling_common.py | 13 +++++++++++-- utils/test_module/custom_configuration.py | 7 +++++++ utils/test_module/custom_modeling.py | 17 ++++++++++++++++- 5 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 19da2d24ebd6..ddd884693481 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -368,7 +368,7 @@ def __init__(self, **kwargs): @property def name_or_path(self) -> str: - return self._name_or_path + return getattr(self, "_name_or_path", "") @name_or_path.setter def name_or_path(self, value): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4845ff074306..3b7d9b582379 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -621,10 +621,10 @@ def tie_weights(self): weights instead. """ output_embeddings = self.get_output_embeddings() - if output_embeddings is not None and self.config.tie_word_embeddings: + if output_embeddings is not None and getattr(self.config, "tie_word_embeddings", True): self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) - if self.config.is_encoder_decoder and self.config.tie_encoder_decoder: + if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False): if hasattr(self, self.base_model_prefix): self = getattr(self, self.base_model_prefix) self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 40885e1d22e8..93fcfd0bb7a8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -59,14 +59,14 @@ sys.path.append(str(Path(__file__).parent.parent / "utils")) -from test_module.custom_configuration import CustomConfig # noqa E402 +from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402 if is_torch_available(): import torch from torch import nn - from test_module.custom_modeling import CustomModel + from test_module.custom_modeling import CustomModel, NoSuperInitModel from transformers import ( BERT_PRETRAINED_MODEL_ARCHIVE_LIST, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, @@ -2091,6 +2091,15 @@ def test_model_from_pretrained_torch_dtype(self): model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16) self.assertEqual(model.dtype, torch.float16) + def test_no_super_init_config_and_model(self): + config = NoSuperInitConfig(attribute=32) + model = NoSuperInitModel(config) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + + model = NoSuperInitModel.from_pretrained(tmp_dir) + @require_torch @is_staging_test diff --git a/utils/test_module/custom_configuration.py b/utils/test_module/custom_configuration.py index 4bb0fe6a15dc..676486fc5171 100644 --- a/utils/test_module/custom_configuration.py +++ b/utils/test_module/custom_configuration.py @@ -7,3 +7,10 @@ class CustomConfig(PretrainedConfig): def __init__(self, attribute=1, **kwargs): self.attribute = attribute super().__init__(**kwargs) + + +class NoSuperInitConfig(PretrainedConfig): + model_type = "custom" + + def __init__(self, attribute=1, **kwargs): + self.attribute = attribute diff --git a/utils/test_module/custom_modeling.py b/utils/test_module/custom_modeling.py index 0e0defd6eb56..07c078494d67 100644 --- a/utils/test_module/custom_modeling.py +++ b/utils/test_module/custom_modeling.py @@ -2,7 +2,7 @@ from transformers import PreTrainedModel -from .custom_configuration import CustomConfig +from .custom_configuration import CustomConfig, NoSuperInitConfig class CustomModel(PreTrainedModel): @@ -18,3 +18,18 @@ def forward(self, x): def _init_weights(self, module): pass + + +class NoSuperInitModel(PreTrainedModel): + config_class = NoSuperInitConfig + base_model_prefix = "custom" + + def __init__(self, config): + super().__init__(config) + self.linear = torch.nn.Linear(config.attribute, config.attribute) + + def forward(self, x): + return self.linear(x) + + def _init_weights(self, module): + pass From fe5c26036df199440239755d1f184cd8eeaf7647 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Wed, 9 Feb 2022 09:32:54 -0500 Subject: [PATCH 2/2] Apply code review suggestions --- src/transformers/configuration_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index ddd884693481..c8af1ca4e3fa 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -368,7 +368,7 @@ def __init__(self, **kwargs): @property def name_or_path(self) -> str: - return getattr(self, "_name_or_path", "") + return getattr(self, "_name_or_path", None) @name_or_path.setter def name_or_path(self, value):