diff --git a/src/transformers/models/mpt/configuration_mpt.py b/src/transformers/models/mpt/configuration_mpt.py index 9601fe5b3e06..fced045f71e1 100644 --- a/src/transformers/models/mpt/configuration_mpt.py +++ b/src/transformers/models/mpt/configuration_mpt.py @@ -101,6 +101,23 @@ def __init__( f"`attn_type` has to be either `multihead_attention` or `multiquery_attention`. Received: {attn_type}" ) + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + if config_dict.get("model_type") == "mpt": + config_dict = config_dict["attn_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + class MptConfig(PretrainedConfig): """ @@ -180,6 +197,7 @@ class MptConfig(PretrainedConfig): "hidden_size": "d_model", "num_hidden_layers": "n_layers", } + is_composition = True def __init__( self, @@ -204,6 +222,7 @@ def __init__( initializer_range=0.02, **kwargs, ): + self.attn_config = attn_config self.d_model = d_model self.n_heads = n_heads self.n_layers = n_layers @@ -222,20 +241,25 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon self.use_cache = use_cache self.initializer_range = initializer_range + super().__init__(**kwargs) + + @property + def attn_config(self): + return self._attn_config + @attn_config.setter + def attn_config(self, attn_config): if attn_config is None: - self.attn_config = MptAttentionConfig() + self._attn_config = MptAttentionConfig() elif isinstance(attn_config, dict): - self.attn_config = MptAttentionConfig(**attn_config) + self._attn_config = MptAttentionConfig(**attn_config) elif isinstance(attn_config, MptAttentionConfig): - self.attn_config = attn_config + self._attn_config = attn_config else: raise ValueError( f"`attn_config` has to be either a `MptAttentionConfig` or a dictionary. Received: {type(attn_config)}" ) - super().__init__(**kwargs) - def to_dict(self): """ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. @@ -245,7 +269,8 @@ def to_dict(self): """ output = copy.deepcopy(self.__dict__) output["attn_config"] = ( - self.attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config + self._attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config ) + del output["_attn_config"] output["model_type"] = self.__class__.model_type return output diff --git a/tests/models/mpt/test_modeling_mpt.py b/tests/models/mpt/test_modeling_mpt.py index 82c157a4bcfe..f3fc6d35951c 100644 --- a/tests/models/mpt/test_modeling_mpt.py +++ b/tests/models/mpt/test_modeling_mpt.py @@ -327,6 +327,20 @@ def prepare_config_and_inputs_for_common(self): return config, inputs_dict +class MptConfigTester(ConfigTester): + def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs): + super().__init__(parent, config_class, has_text_modality, common_properties, **kwargs) + + def test_attn_config_as_dict(self): + config = self.config_class(**self.inputs_dict, attn_config={"attn_impl": "flash", "softmax_scale": None}) + self.parent.assertTrue(config.attn_config.attn_impl == "flash") + self.parent.assertTrue(config.attn_config.softmax_scale is None) + + def run_common_tests(self): + self.test_attn_config_as_dict() + return super().run_common_tests() + + @require_torch class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( @@ -353,7 +367,7 @@ class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, def setUp(self): self.model_tester = MptModelTester(self) - self.config_tester = ConfigTester(self, config_class=MptConfig, n_embd=37) + self.config_tester = MptConfigTester(self, config_class=MptConfig, n_embd=37) def test_config(self): self.config_tester.run_common_tests()