From cd4213832e1258689937ed08f071f300114ad2a9 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 26 Jul 2023 14:11:15 +0000 Subject: [PATCH 1/5] support from pretrained args --- .../models/mpt/configuration_mpt.py | 39 ++++++++++++++++--- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/mpt/configuration_mpt.py b/src/transformers/models/mpt/configuration_mpt.py index 9601fe5b3e06..b87d64ea81da 100644 --- a/src/transformers/models/mpt/configuration_mpt.py +++ b/src/transformers/models/mpt/configuration_mpt.py @@ -101,6 +101,24 @@ def __init__( f"`attn_type` has to be either `multihead_attention` or `multiquery_attention`. Received: {attn_type}" ) + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from CLIPConfig + if config_dict.get("model_type") == "mpt": + config_dict = config_dict["attn_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + class MptConfig(PretrainedConfig): """ @@ -180,6 +198,7 @@ class MptConfig(PretrainedConfig): "hidden_size": "d_model", "num_hidden_layers": "n_layers", } + is_composition = True def __init__( self, @@ -204,6 +223,7 @@ def __init__( initializer_range=0.02, **kwargs, ): + self.attn_config = attn_config self.d_model = d_model self.n_heads = n_heads self.n_layers = n_layers @@ -222,20 +242,26 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon self.use_cache = use_cache self.initializer_range = initializer_range + super().__init__(**kwargs) + + @property + def attn_config(self): + return self._attn_config + + @attn_config.setter + def attn_config(self, attn_config): if attn_config is None: - self.attn_config = MptAttentionConfig() + self._attn_config = MptAttentionConfig() elif isinstance(attn_config, dict): - self.attn_config = MptAttentionConfig(**attn_config) + self._attn_config = MptAttentionConfig(**attn_config) elif isinstance(attn_config, MptAttentionConfig): - self.attn_config = attn_config + self._attn_config = attn_config else: raise ValueError( f"`attn_config` has to be either a `MptAttentionConfig` or a dictionary. Received: {type(attn_config)}" ) - super().__init__(**kwargs) - def to_dict(self): """ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. @@ -245,7 +271,8 @@ def to_dict(self): """ output = copy.deepcopy(self.__dict__) output["attn_config"] = ( - self.attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config + self._attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config ) + del output["_attn_config"] output["model_type"] = self.__class__.model_type return output From 20ca46c26b18e9fabd0c77eb0e02867e1fad8cad Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 26 Jul 2023 14:23:14 +0000 Subject: [PATCH 2/5] draft addition of tests --- src/transformers/models/mpt/configuration_mpt.py | 3 +-- tests/models/mpt/test_modeling_mpt.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/mpt/configuration_mpt.py b/src/transformers/models/mpt/configuration_mpt.py index b87d64ea81da..ecc7ad648b38 100644 --- a/src/transformers/models/mpt/configuration_mpt.py +++ b/src/transformers/models/mpt/configuration_mpt.py @@ -118,7 +118,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "Pretrained ) return cls.from_dict(config_dict, **kwargs) - + class MptConfig(PretrainedConfig): """ @@ -248,7 +248,6 @@ def __init__( def attn_config(self): return self._attn_config - @attn_config.setter def attn_config(self, attn_config): if attn_config is None: diff --git a/tests/models/mpt/test_modeling_mpt.py b/tests/models/mpt/test_modeling_mpt.py index 82c157a4bcfe..e7e1143ee755 100644 --- a/tests/models/mpt/test_modeling_mpt.py +++ b/tests/models/mpt/test_modeling_mpt.py @@ -326,6 +326,17 @@ def prepare_config_and_inputs_for_common(self): return config, inputs_dict +class MptConfigTester(ConfigTester): + + def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs): + super().__init__(parent, config_class, has_text_modality, common_properties, **kwargs) + + def test_attn_config_as_dict(self): + pass + + def run_common_tests(self): + self.test_attn_config_as_dict() + return super().run_common_tests() @require_torch class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): @@ -353,7 +364,7 @@ class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, def setUp(self): self.model_tester = MptModelTester(self) - self.config_tester = ConfigTester(self, config_class=MptConfig, n_embd=37) + self.config_tester = MptConfigTester(self, config_class=MptConfig, n_embd=37) def test_config(self): self.config_tester.run_common_tests() From 2f76e21be00e6975f3b8b6bc57654cfdcea83dca Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 26 Jul 2023 15:42:12 +0000 Subject: [PATCH 3/5] update test --- tests/models/mpt/test_modeling_mpt.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/models/mpt/test_modeling_mpt.py b/tests/models/mpt/test_modeling_mpt.py index e7e1143ee755..40b85bab65fa 100644 --- a/tests/models/mpt/test_modeling_mpt.py +++ b/tests/models/mpt/test_modeling_mpt.py @@ -332,8 +332,10 @@ def __init__(self, parent, config_class=None, has_text_modality=True, common_pro super().__init__(parent, config_class, has_text_modality, common_properties, **kwargs) def test_attn_config_as_dict(self): - pass - + config = self.config_class(**self.inputs_dict, attn_config = dict(attn_impl="flash", softmax_scale=None)) + assert config.attn_config.attn_impl == "flash" + assert config.attn_config.softmax_scale is None + def run_common_tests(self): self.test_attn_config_as_dict() return super().run_common_tests() From 0d168903ee06c61be58e1d5c595d176d8ac470f5 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 26 Jul 2023 15:43:28 +0000 Subject: [PATCH 4/5] use parrent assert true --- tests/models/mpt/test_modeling_mpt.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/models/mpt/test_modeling_mpt.py b/tests/models/mpt/test_modeling_mpt.py index 40b85bab65fa..f3fc6d35951c 100644 --- a/tests/models/mpt/test_modeling_mpt.py +++ b/tests/models/mpt/test_modeling_mpt.py @@ -326,20 +326,21 @@ def prepare_config_and_inputs_for_common(self): return config, inputs_dict + class MptConfigTester(ConfigTester): - def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs): super().__init__(parent, config_class, has_text_modality, common_properties, **kwargs) - + def test_attn_config_as_dict(self): - config = self.config_class(**self.inputs_dict, attn_config = dict(attn_impl="flash", softmax_scale=None)) - assert config.attn_config.attn_impl == "flash" - assert config.attn_config.softmax_scale is None + config = self.config_class(**self.inputs_dict, attn_config={"attn_impl": "flash", "softmax_scale": None}) + self.parent.assertTrue(config.attn_config.attn_impl == "flash") + self.parent.assertTrue(config.attn_config.softmax_scale is None) def run_common_tests(self): self.test_attn_config_as_dict() return super().run_common_tests() + @require_torch class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( From 3eaf812ef8a50fca49beaef180f8b578b45940f9 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 27 Jul 2023 15:34:30 +0200 Subject: [PATCH 5/5] Update src/transformers/models/mpt/configuration_mpt.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- src/transformers/models/mpt/configuration_mpt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/mpt/configuration_mpt.py b/src/transformers/models/mpt/configuration_mpt.py index ecc7ad648b38..fced045f71e1 100644 --- a/src/transformers/models/mpt/configuration_mpt.py +++ b/src/transformers/models/mpt/configuration_mpt.py @@ -107,7 +107,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "Pretrained config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) - # get the text config dict if we are loading from CLIPConfig if config_dict.get("model_type") == "mpt": config_dict = config_dict["attn_config"]