Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/model_doc/pixtral.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up

[[autodoc]] PixtralVisionConfig

## PixtralTextConfig

[[autodoc]] PixtralTextConfig

## PixtralVisionModel

[[autodoc]] PixtralVisionModel
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@
"Pix2StructTextConfig",
"Pix2StructVisionConfig",
],
"models.pixtral": ["PixtralProcessor", "PixtralVisionConfig"],
"models.pixtral": ["PixtralProcessor", "PixtralVisionConfig", "PixtralTextConfig"],
"models.plbart": ["PLBartConfig"],
"models.poolformer": ["PoolFormerConfig"],
"models.pop2piano": ["Pop2PianoConfig"],
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@
("phimoe", "PhimoeConfig"),
("pix2struct", "Pix2StructConfig"),
("pixtral", "PixtralVisionConfig"),
("pixtral_text", "PixtralTextConfig"),
("plbart", "PLBartConfig"),
("poolformer", "PoolFormerConfig"),
("pop2piano", "Pop2PianoConfig"),
Expand Down Expand Up @@ -574,6 +575,7 @@
("phobert", "PhoBERT"),
("pix2struct", "Pix2Struct"),
("pixtral", "Pixtral"),
("pixtral_text", "PixtralMistral"),
("plbart", "PLBart"),
("poolformer", "PoolFormer"),
("pop2piano", "Pop2Piano"),
Expand Down Expand Up @@ -740,6 +742,7 @@
("chinese_clip_vision_model", "chinese_clip"),
("rt_detr_resnet", "rt_detr"),
("granitevision", "llava_next"),
("pixtral_text", "pixtral"),
]
)

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@
("phi", "PhiForCausalLM"),
("phi3", "Phi3ForCausalLM"),
("phimoe", "PhimoeForCausalLM"),
("pixtral_text", "MistralForCausalLM"),
("plbart", "PLBartForCausalLM"),
("prophetnet", "ProphetNetForCausalLM"),
("qdqbert", "QDQBertLMHeadModel"),
Expand Down
115 changes: 114 additions & 1 deletion src/transformers/models/pixtral/configuration_pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Pixtral model configuration"""

from ...configuration_utils import PretrainedConfig
from ...models.mistral.configuration_mistral import MistralConfig
from ...utils import logging


Expand Down Expand Up @@ -103,4 +104,116 @@ def __init__(
self.initializer_range = initializer_range


__all__ = ["PixtralVisionConfig"]
class PixtralTextConfig(MistralConfig):
r"""
TODO
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MistralModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 14336):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
The attention head dimension.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
The maximum sequence length that this model might ever be used with. Mistral's sliding window attention
allows sequence of up to 4096*32 tokens.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention window size. If not specified, will default to `4096`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> TODO
```"""

model_type = "pixtral_text"

def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
head_dim=None,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
sliding_window=4096,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.head_dim = head_dim # as opposed to MistralConfig, do not auto-populate

# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout

PretrainedConfig.__init__(
self,
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)


__all__ = ["PixtralVisionConfig", "PixtralTextConfig"]
71 changes: 71 additions & 0 deletions tests/models/llava/test_configuration_llava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import tempfile
import unittest

from transformers import LlavaConfig


class LlavaConfigTest(unittest.TestCase):
def test_llava_reload(self):
"""
Simple test for reloading default llava configs
"""
with tempfile.TemporaryDirectory() as tmp_dir:
config = LlavaConfig()
config.save_pretrained(tmp_dir)

reloaded = LlavaConfig.from_pretrained(tmp_dir)
assert config.to_dict() == reloaded.to_dict()

def test_pixtral_reload(self):
"""
Simple test for reloading pixtral configs
"""
vision_config = {
"model_type": "pixtral",
"head_dim": 64,
"hidden_act": "silu",
"image_size": 1024,
"is_composition": True,
"patch_size": 16,
"rope_theta": 10000.0,
"tie_word_embeddings": False,
}

text_config = {
# "model_type": "mistral",
"model_type": "pixtral_text",
"hidden_size": 5120,
"head_dim": 128,
"num_attention_heads": 32,
"intermediate_size": 14336,
"is_composition": True,
"max_position_embeddings": 1024000,
"num_hidden_layers": 40,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-05,
"rope_theta": 1000000000.0,
"sliding_window": None,
"vocab_size": 131072,
}

with tempfile.TemporaryDirectory() as tmp_dir:
config = LlavaConfig(vision_config=vision_config, text_config=text_config)
config.save_pretrained(tmp_dir)

reloaded = LlavaConfig.from_pretrained(tmp_dir)
assert config.to_dict() == reloaded.to_dict()

def test_arbitrary_reload(self):
"""
Simple test for reloading arbirarily composed subconfigs
"""
default_values = LlavaConfig().to_dict()
default_values["vision_config"]["model_type"] = "qwen2_vl"
default_values["text_config"]["model_type"] = "opt"

with tempfile.TemporaryDirectory() as tmp_dir:
config = LlavaConfig(**default_values)
config.save_pretrained(tmp_dir)

reloaded = LlavaConfig.from_pretrained(tmp_dir)
assert config.to_dict() == reloaded.to_dict()
1 change: 1 addition & 0 deletions utils/check_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def _center_text(text: str, width: int) -> str:
"CLIPVisionModel",
"Qwen2AudioEncoder",
"SiglipVisionModel",
"PixtralMistral", # not a real model
]


Expand Down