From 836dde9825278b0d24bc00f2d9a8695c78bcbd2b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 11 Feb 2025 11:20:38 -0500 Subject: [PATCH] Revert "use is_composition for pixtral" This reverts commit a53d5f9fc5149c84419b0e9e03db6d99362add53. --- docs/source/en/model_doc/pixtral.md | 4 + src/transformers/__init__.py | 2 +- .../models/auto/configuration_auto.py | 3 + src/transformers/models/auto/modeling_auto.py | 1 + .../models/pixtral/configuration_pixtral.py | 115 +++++++++++++++++- .../models/llava/test_configuration_llava.py | 71 +++++++++++ utils/check_table.py | 1 + 7 files changed, 195 insertions(+), 2 deletions(-) create mode 100644 tests/models/llava/test_configuration_llava.py diff --git a/docs/source/en/model_doc/pixtral.md b/docs/source/en/model_doc/pixtral.md index 62bdc004c517..6439acc0d6ed 100644 --- a/docs/source/en/model_doc/pixtral.md +++ b/docs/source/en/model_doc/pixtral.md @@ -78,6 +78,10 @@ output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up [[autodoc]] PixtralVisionConfig +## PixtralTextConfig + +[[autodoc]] PixtralTextConfig + ## PixtralVisionModel [[autodoc]] PixtralVisionModel diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 80438981047d..c00f40d074f1 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -700,7 +700,7 @@ "Pix2StructTextConfig", "Pix2StructVisionConfig", ], - "models.pixtral": ["PixtralProcessor", "PixtralVisionConfig"], + "models.pixtral": ["PixtralProcessor", "PixtralVisionConfig", "PixtralTextConfig"], "models.plbart": ["PLBartConfig"], "models.poolformer": ["PoolFormerConfig"], "models.pop2piano": ["Pop2PianoConfig"], diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index f52fc2b12ff7..3ddc56cecb02 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -232,6 +232,7 @@ ("phimoe", "PhimoeConfig"), ("pix2struct", "Pix2StructConfig"), ("pixtral", "PixtralVisionConfig"), + ("pixtral_text", "PixtralTextConfig"), ("plbart", "PLBartConfig"), ("poolformer", "PoolFormerConfig"), ("pop2piano", "Pop2PianoConfig"), @@ -574,6 +575,7 @@ ("phobert", "PhoBERT"), ("pix2struct", "Pix2Struct"), ("pixtral", "Pixtral"), + ("pixtral_text", "PixtralMistral"), ("plbart", "PLBart"), ("poolformer", "PoolFormer"), ("pop2piano", "Pop2Piano"), @@ -740,6 +742,7 @@ ("chinese_clip_vision_model", "chinese_clip"), ("rt_detr_resnet", "rt_detr"), ("granitevision", "llava_next"), + ("pixtral_text", "pixtral"), ] ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 686c9c930f60..69d08e9891f1 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -555,6 +555,7 @@ ("phi", "PhiForCausalLM"), ("phi3", "Phi3ForCausalLM"), ("phimoe", "PhimoeForCausalLM"), + ("pixtral_text", "MistralForCausalLM"), ("plbart", "PLBartForCausalLM"), ("prophetnet", "ProphetNetForCausalLM"), ("qdqbert", "QDQBertLMHeadModel"), diff --git a/src/transformers/models/pixtral/configuration_pixtral.py b/src/transformers/models/pixtral/configuration_pixtral.py index d4710e00e421..17ce92abb423 100644 --- a/src/transformers/models/pixtral/configuration_pixtral.py +++ b/src/transformers/models/pixtral/configuration_pixtral.py @@ -14,6 +14,7 @@ """Pixtral model configuration""" from ...configuration_utils import PretrainedConfig +from ...models.mistral.configuration_mistral import MistralConfig from ...utils import logging @@ -103,4 +104,116 @@ def __init__( self.initializer_range = initializer_range -__all__ = ["PixtralVisionConfig"] +class PixtralTextConfig(MistralConfig): + r""" + TODO + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MistralModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Mistral's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> TODO + ```""" + + model_type = "pixtral_text" + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + head_dim=None, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + sliding_window=4096, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.head_dim = head_dim # as opposed to MistralConfig, do not auto-populate + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + PretrainedConfig.__init__( + self, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["PixtralVisionConfig", "PixtralTextConfig"] diff --git a/tests/models/llava/test_configuration_llava.py b/tests/models/llava/test_configuration_llava.py new file mode 100644 index 000000000000..ea39cf95418a --- /dev/null +++ b/tests/models/llava/test_configuration_llava.py @@ -0,0 +1,71 @@ +import tempfile +import unittest + +from transformers import LlavaConfig + + +class LlavaConfigTest(unittest.TestCase): + def test_llava_reload(self): + """ + Simple test for reloading default llava configs + """ + with tempfile.TemporaryDirectory() as tmp_dir: + config = LlavaConfig() + config.save_pretrained(tmp_dir) + + reloaded = LlavaConfig.from_pretrained(tmp_dir) + assert config.to_dict() == reloaded.to_dict() + + def test_pixtral_reload(self): + """ + Simple test for reloading pixtral configs + """ + vision_config = { + "model_type": "pixtral", + "head_dim": 64, + "hidden_act": "silu", + "image_size": 1024, + "is_composition": True, + "patch_size": 16, + "rope_theta": 10000.0, + "tie_word_embeddings": False, + } + + text_config = { + # "model_type": "mistral", + "model_type": "pixtral_text", + "hidden_size": 5120, + "head_dim": 128, + "num_attention_heads": 32, + "intermediate_size": 14336, + "is_composition": True, + "max_position_embeddings": 1024000, + "num_hidden_layers": 40, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000000.0, + "sliding_window": None, + "vocab_size": 131072, + } + + with tempfile.TemporaryDirectory() as tmp_dir: + config = LlavaConfig(vision_config=vision_config, text_config=text_config) + config.save_pretrained(tmp_dir) + + reloaded = LlavaConfig.from_pretrained(tmp_dir) + assert config.to_dict() == reloaded.to_dict() + + def test_arbitrary_reload(self): + """ + Simple test for reloading arbirarily composed subconfigs + """ + default_values = LlavaConfig().to_dict() + default_values["vision_config"]["model_type"] = "qwen2_vl" + default_values["text_config"]["model_type"] = "opt" + + with tempfile.TemporaryDirectory() as tmp_dir: + config = LlavaConfig(**default_values) + config.save_pretrained(tmp_dir) + + reloaded = LlavaConfig.from_pretrained(tmp_dir) + assert config.to_dict() == reloaded.to_dict() diff --git a/utils/check_table.py b/utils/check_table.py index 957bfd5af6af..e37c91e29d68 100644 --- a/utils/check_table.py +++ b/utils/check_table.py @@ -180,6 +180,7 @@ def _center_text(text: str, width: int) -> str: "CLIPVisionModel", "Qwen2AudioEncoder", "SiglipVisionModel", + "PixtralMistral", # not a real model ]