Skip to content

Commit e959530

Browse files
Add Mistral3 (#36790)
* initial start * style and dummies * Create convert_mistral3_weights_to_hf.py * update * typo * typo * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * up * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * update * update * Update image_processing_mistral3.py * Update convert_mistral3_weights_to_hf.py * fix patch merger * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * up * update modular to fit * style * Update convert_mistral3_weights_to_hf.py * typo * Update modular_mistral3.py * simplify a lot all shape shenanigans * simplify * add working test processor * Add partially working common modeling tests * All tests working and remove mistral3 image processors * add docs and fixup * fix inference with image size >1540 * 🚨fix test image proc pixtral * Remove vision_feature_select_strategy * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * Update convert_mistral3_weights_to_hf.py * clean * fix test checkpoints * Update test_modeling_mistral3.py * Update test_modeling_mistral3.py * style * Use Pixtral processor * up * finish cleaning processor to use pixtral directly * Update __init__.py * Update processing_pixtral.py * doc * Update __init__.py * Update mistral3.md * Update _toctree.yml --------- Co-authored-by: yonigozlan <[email protected]> Co-authored-by: yonigozlan <[email protected]>
1 parent bd92073 commit e959530

21 files changed

+2303
-6
lines changed

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,8 @@
529529
title: MegatronGPT2
530530
- local: model_doc/mistral
531531
title: Mistral
532+
- local: model_doc/mistral3
533+
title: Mistral3
532534
- local: model_doc/mixtral
533535
title: Mixtral
534536
- local: model_doc/mluke

docs/source/en/model_doc/mistral3.md

+234
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Mistral3
18+
19+
## Overview
20+
21+
Building upon Mistral Small 3 (2501), Mistral Small 3.1 (2503) adds state-of-the-art vision understanding and enhances long context capabilities up to 128k tokens without compromising text performance. With 24 billion parameters, this model achieves top-tier capabilities in both text and vision tasks.
22+
23+
It is ideal for:
24+
- Fast-response conversational agents.
25+
- Low-latency function calling.
26+
- Subject matter experts via fine-tuning.
27+
- Local inference for hobbyists and organizations handling sensitive data.
28+
- Programming and math reasoning.
29+
- Long document understanding.
30+
- Visual understanding.
31+
32+
This model was contributed by [cyrilvallez](https://huggingface.co/cyrilvallez) and [yonigozlan](https://huggingface.co/yonigozlan).
33+
34+
The original code can be found [here](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/pixtral.py) and [here](https://github.com/mistralai/mistral-common).
35+
36+
## Usage example
37+
38+
### Inference with Pipeline
39+
40+
Here is how you can use the `image-text-to-text` pipeline to perform inference with the `Mistral3` models in just a few lines of code:
41+
```python
42+
>>> from transformers import pipeline
43+
44+
>>> messages = [
45+
... {
46+
... "role": "user",
47+
... "content": [
48+
... {
49+
... "type": "image",
50+
... "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
51+
... },
52+
... {"type": "text", "text": "Describe this image."},
53+
... ],
54+
... },
55+
... ]
56+
57+
>>> pipe = pipeline("image-text-to-text", model="../mistral3_weights", torch_dtype=torch.bfloat16)
58+
>>> outputs = pipe(text=messages, max_new_tokens=50, return_full_text=False)
59+
>>> outputs[0]["generated_text"]
60+
'The image depicts a vibrant and lush garden scene featuring a variety of wildflowers and plants. The central focus is on a large, pinkish-purple flower, likely a Greater Celandine (Chelidonium majus), with a'
61+
```
62+
### Inference on a single image
63+
64+
This example demonstrates how to perform inference on a single image with the Mistral3 models using chat templates.
65+
66+
```python
67+
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
68+
>>> import torch
69+
70+
>>> torch_device = "cuda"
71+
>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
72+
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
73+
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
74+
75+
>>> messages = [
76+
... {
77+
... "role": "user",
78+
... "content": [
79+
... {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
80+
... {"type": "text", "text": "Describe this image"},
81+
... ],
82+
... }
83+
... ]
84+
85+
>>> inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
86+
87+
>>> generate_ids = model.generate(**inputs, max_new_tokens=20)
88+
>>> decoded_output = processor.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
89+
90+
>>> decoded_output
91+
"The image depicts two cats lying on a pink blanket. The larger cat, which appears to be an"...
92+
```
93+
94+
### Text-only generation
95+
This example shows how to generate text using the Mistral3 model without providing any image input.
96+
97+
98+
````python
99+
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
100+
>>> import torch
101+
102+
>>> torch_device = "cuda"
103+
>>> model_checkpoint = ".mistralai/Mistral-Small-3.1-24B-Instruct-2503"
104+
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
105+
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
106+
107+
>>> SYSTEM_PROMPT = "You are a conversational agent that always answers straight to the point, always end your accurate response with an ASCII drawing of a cat."
108+
>>> user_prompt = "Give me 5 non-formal ways to say 'See you later' in French."
109+
110+
>>> messages = [
111+
... {"role": "system", "content": SYSTEM_PROMPT},
112+
... {"role": "user", "content": user_prompt},
113+
... ]
114+
115+
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
116+
>>> inputs = processor(text=text, return_tensors="pt").to(0, dtype=torch.float16)
117+
>>> generate_ids = model.generate(**inputs, max_new_tokens=50, do_sample=False)
118+
>>> decoded_output = processor.batch_decode(generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True)[0]
119+
120+
>>> print(decoded_output)
121+
"1. À plus tard!
122+
2. Salut, à plus!
123+
3. À toute!
124+
4. À la prochaine!
125+
5. Je me casse, à plus!
126+
127+
```
128+
/\_/\
129+
( o.o )
130+
> ^ <
131+
```"
132+
````
133+
134+
### Batched image and text inputs
135+
Mistral3 models also support batched image and text inputs.
136+
137+
```python
138+
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
139+
>>> import torch
140+
141+
>>> torch_device = "cuda"
142+
>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
143+
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
144+
>>> model = AutoModelForImageTextToText.from_pretrained(model_checkpoint, device_map=torch_device, torch_dtype=torch.bfloat16)
145+
146+
>>> messages = [
147+
... [
148+
... {
149+
... "role": "user",
150+
... "content": [
151+
... {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
152+
... {"type": "text", "text": "Write a haiku for this image"},
153+
... ],
154+
... },
155+
... ],
156+
... [
157+
... {
158+
... "role": "user",
159+
... "content": [
160+
... {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
161+
... {"type": "text", "text": "Describe this image"},
162+
... ],
163+
... },
164+
... ],
165+
... ]
166+
167+
168+
>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
169+
170+
>>> output = model.generate(**inputs, max_new_tokens=25)
171+
172+
>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
173+
>>> decoded_outputs
174+
["Write a haiku for this imageCalm waters reflect\nWhispers of the forest's breath\nPeace on wooden path"
175+
, "Describe this imageThe image depicts a vibrant street scene in what appears to be a Chinatown district. The focal point is a traditional Chinese"]
176+
```
177+
178+
### Batched multi-image input and quantization with BitsAndBytes
179+
This implementation of the Mistral3 models supports batched text-images inputs with different number of images for each text.
180+
This example also how to use `BitsAndBytes` to load the model in 4bit quantization.
181+
182+
```python
183+
>>> from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
184+
>>> import torch
185+
186+
>>> torch_device = "cuda"
187+
>>> model_checkpoint = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
188+
>>> processor = AutoProcessor.from_pretrained(model_checkpoint)
189+
>>> quantization_config = BitsAndBytesConfig(load_in_4bit=True)
190+
>>> model = AutoModelForImageTextToText.from_pretrained(
191+
... model_checkpoint, quantization_config=quantization_config
192+
... )
193+
194+
>>> messages = [
195+
...     [
196+
...         {
197+
...             "role": "user",
198+
...             "content": [
199+
...                 {"type": "image", "url": "https://llava-vl.github.io/static/images/view.jpg"},
200+
...                 {"type": "text", "text": "Write a haiku for this image"},
201+
...             ],
202+
...         },
203+
...     ],
204+
...     [
205+
...         {
206+
...             "role": "user",
207+
...             "content": [
208+
...                 {"type": "image", "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"},
209+
...                 {"type": "image", "url": "https://thumbs.dreamstime.com/b/golden-gate-bridge-san-francisco-purple-flowers-california-echium-candicans-36805947.jpg"},
210+
...                 {"type": "text", "text": "These images depict two different landmarks. Can you identify them?"},
211+
...             ],
212+
...         },
213+
...     ],
214+
>>> ]
215+
216+
>>> inputs = processor.apply_chat_template(messages, padding=True, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
217+
218+
>>> output = model.generate(**inputs, max_new_tokens=25)
219+
220+
>>> decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
221+
>>> decoded_outputs
222+
["Write a haiku for this imageSure, here is a haiku inspired by the image:\n\nCalm lake's wooden path\nSilent forest stands guard\n", "These images depict two different landmarks. Can you identify them? Certainly! The images depict two iconic landmarks:\n\n1. The first image shows the Statue of Liberty in New York City."]
223+
```
224+
225+
226+
## Mistral3Config
227+
228+
[[autodoc]] Mistral3Config
229+
230+
231+
## Mistral3ForConditionalGeneration
232+
233+
[[autodoc]] Mistral3ForConditionalGeneration
234+
- forward

src/transformers/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,7 @@
613613
],
614614
"models.mimi": ["MimiConfig"],
615615
"models.mistral": ["MistralConfig"],
616+
"models.mistral3": ["Mistral3Config"],
616617
"models.mixtral": ["MixtralConfig"],
617618
"models.mllama": [
618619
"MllamaConfig",
@@ -2940,6 +2941,12 @@
29402941
"MistralPreTrainedModel",
29412942
]
29422943
)
2944+
_import_structure["models.mistral3"].extend(
2945+
[
2946+
"Mistral3ForConditionalGeneration",
2947+
"Mistral3PreTrainedModel",
2948+
]
2949+
)
29432950
_import_structure["models.mixtral"].extend(
29442951
[
29452952
"MixtralForCausalLM",
@@ -5788,6 +5795,7 @@
57885795
MimiConfig,
57895796
)
57905797
from .models.mistral import MistralConfig
5798+
from .models.mistral3 import Mistral3Config
57915799
from .models.mixtral import MixtralConfig
57925800
from .models.mllama import (
57935801
MllamaConfig,
@@ -7844,6 +7852,10 @@
78447852
MistralModel,
78457853
MistralPreTrainedModel,
78467854
)
7855+
from .models.mistral3 import (
7856+
Mistral3ForConditionalGeneration,
7857+
Mistral3PreTrainedModel,
7858+
)
78477859
from .models.mixtral import (
78487860
MixtralForCausalLM,
78497861
MixtralForQuestionAnswering,

src/transformers/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@
169169
mgp_str,
170170
mimi,
171171
mistral,
172+
mistral3,
172173
mixtral,
173174
mllama,
174175
mluke,

src/transformers/models/auto/configuration_auto.py

+2
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@
192192
("mgp-str", "MgpstrConfig"),
193193
("mimi", "MimiConfig"),
194194
("mistral", "MistralConfig"),
195+
("mistral3", "Mistral3Config"),
195196
("mixtral", "MixtralConfig"),
196197
("mllama", "MllamaConfig"),
197198
("mobilebert", "MobileBertConfig"),
@@ -537,6 +538,7 @@
537538
("mgp-str", "MGP-STR"),
538539
("mimi", "Mimi"),
539540
("mistral", "Mistral"),
541+
("mistral3", "Mistral3"),
540542
("mixtral", "Mixtral"),
541543
("mllama", "Mllama"),
542544
("mluke", "mLUKE"),

src/transformers/models/auto/image_processing_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
("mask2former", ("Mask2FormerImageProcessor",)),
112112
("maskformer", ("MaskFormerImageProcessor",)),
113113
("mgp-str", ("ViTImageProcessor", "ViTImageProcessorFast")),
114+
("mistral3", ("PixtralImageProcessor", "PixtralImageProcessorFast")),
114115
("mllama", ("MllamaImageProcessor",)),
115116
("mobilenet_v1", ("MobileNetV1ImageProcessor",)),
116117
("mobilenet_v2", ("MobileNetV2ImageProcessor",)),

src/transformers/models/auto/modeling_auto.py

+3
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@
361361
("mamba2", "Mamba2ForCausalLM"),
362362
("mega", "MegaForMaskedLM"),
363363
("megatron-bert", "MegatronBertForPreTraining"),
364+
("mistral3", "Mistral3ForConditionalGeneration"),
364365
("mllama", "MllamaForConditionalGeneration"),
365366
("mobilebert", "MobileBertForPreTraining"),
366367
("mpnet", "MPNetForMaskedLM"),
@@ -802,6 +803,7 @@
802803
("llava_next", "LlavaNextForConditionalGeneration"),
803804
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
804805
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
806+
("mistral3", "Mistral3ForConditionalGeneration"),
805807
("mllama", "MllamaForConditionalGeneration"),
806808
("paligemma", "PaliGemmaForConditionalGeneration"),
807809
("pix2struct", "Pix2StructForConditionalGeneration"),
@@ -839,6 +841,7 @@
839841
("llava", "LlavaForConditionalGeneration"),
840842
("llava_next", "LlavaNextForConditionalGeneration"),
841843
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
844+
("mistral3", "Mistral3ForConditionalGeneration"),
842845
("mllama", "MllamaForConditionalGeneration"),
843846
("paligemma", "PaliGemmaForConditionalGeneration"),
844847
("pix2struct", "Pix2StructForConditionalGeneration"),

src/transformers/models/auto/processing_auto.py

+1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
("markuplm", "MarkupLMProcessor"),
8585
("mctct", "MCTCTProcessor"),
8686
("mgp-str", "MgpstrProcessor"),
87+
("mistral3", "PixtralProcessor"),
8788
("mllama", "MllamaProcessor"),
8889
("moonshine", "Wav2Vec2Processor"),
8990
("oneformer", "OneFormerProcessor"),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import TYPE_CHECKING
15+
16+
from ...utils import _LazyModule
17+
from ...utils.import_utils import define_import_structure
18+
19+
20+
if TYPE_CHECKING:
21+
from .configuration_mistral3 import *
22+
from .modeling_mistral3 import *
23+
from .processing_mistral3 import *
24+
else:
25+
import sys
26+
27+
_file = globals()["__file__"]
28+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

0 commit comments

Comments
 (0)