Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions src/transformers/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,23 @@ def __init__(
f"`attn_type` has to be either `multihead_attention` or `multiquery_attention`. Received: {attn_type}"
)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)

config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)

if config_dict.get("model_type") == "mpt":
config_dict = config_dict["attn_config"]

if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

return cls.from_dict(config_dict, **kwargs)


class MptConfig(PretrainedConfig):
"""
Expand Down Expand Up @@ -180,6 +197,7 @@ class MptConfig(PretrainedConfig):
"hidden_size": "d_model",
"num_hidden_layers": "n_layers",
}
is_composition = True

def __init__(
self,
Expand All @@ -204,6 +222,7 @@ def __init__(
initializer_range=0.02,
**kwargs,
):
self.attn_config = attn_config
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
Expand All @@ -222,20 +241,25 @@ def __init__(
self.layer_norm_epsilon = layer_norm_epsilon
self.use_cache = use_cache
self.initializer_range = initializer_range
super().__init__(**kwargs)

@property
def attn_config(self):
return self._attn_config

@attn_config.setter
def attn_config(self, attn_config):
if attn_config is None:
self.attn_config = MptAttentionConfig()
self._attn_config = MptAttentionConfig()
elif isinstance(attn_config, dict):
self.attn_config = MptAttentionConfig(**attn_config)
self._attn_config = MptAttentionConfig(**attn_config)
elif isinstance(attn_config, MptAttentionConfig):
self.attn_config = attn_config
self._attn_config = attn_config
else:
raise ValueError(
f"`attn_config` has to be either a `MptAttentionConfig` or a dictionary. Received: {type(attn_config)}"
)

super().__init__(**kwargs)

def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Expand All @@ -245,7 +269,8 @@ def to_dict(self):
"""
output = copy.deepcopy(self.__dict__)
output["attn_config"] = (
self.attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config
self._attn_config.to_dict() if not isinstance(self.attn_config, dict) else self.attn_config
)
del output["_attn_config"]
output["model_type"] = self.__class__.model_type
return output
16 changes: 15 additions & 1 deletion tests/models/mpt/test_modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,20 @@ def prepare_config_and_inputs_for_common(self):
return config, inputs_dict


class MptConfigTester(ConfigTester):
def __init__(self, parent, config_class=None, has_text_modality=True, common_properties=None, **kwargs):
super().__init__(parent, config_class, has_text_modality, common_properties, **kwargs)

def test_attn_config_as_dict(self):
config = self.config_class(**self.inputs_dict, attn_config={"attn_impl": "flash", "softmax_scale": None})
self.parent.assertTrue(config.attn_config.attn_impl == "flash")
self.parent.assertTrue(config.attn_config.softmax_scale is None)

def run_common_tests(self):
self.test_attn_config_as_dict()
return super().run_common_tests()


@require_torch
class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
Expand All @@ -353,7 +367,7 @@ class MptModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,

def setUp(self):
self.model_tester = MptModelTester(self)
self.config_tester = ConfigTester(self, config_class=MptConfig, n_embd=37)
self.config_tester = MptConfigTester(self, config_class=MptConfig, n_embd=37)

def test_config(self):
self.config_tester.run_common_tests()
Expand Down