Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,8 @@
title: MGP-STR
- local: model_doc/mistral3
title: Mistral3
- local: model_doc/mistral4
title: Mistral4
- local: model_doc/mllama
title: mllama
- local: model_doc/mm-grounding-dino
Expand Down
116 changes: 116 additions & 0 deletions docs/source/en/model_doc/mistral4.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
<!--Copyright 2026 Mistral AI and the HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer.

-->
*This model was released on 2026-03-16 and added to Hugging Face Transformers on 2026-03-16.*

# Mistral4

## Overview

Mistral 4 is a powerful hybrid model with the capability of acting as both a general instruction model and a reasoning model. It unifies the capabilities of three different model families - Instruct, Reasoning ( previous called Magistral ), and Devstral - into a single, unified model.

[Mistral-Small-4](https://huggingface.co/mistralai/Mistral-Small-4-119B-2603) consists of the following architectural choices:

- MoE: 128 experts and 4 active.
- 119B with 6.5B activated parameters per token.
- 256k Context Length.
- Multimodal Input: Accepts both text and image input, with text output.
- Instruct and Reasoning functionalities with Function Calls
- Reasoning Effort configurable by request.

Mistral 4 offers the following capabilities:

- **Reasoning Mode**: Switch between a fast instant reply mode, and a reasoning thinking mode, boosting performance with test time compute when requested.
- **Vision**: Enables the model to analyze images and provide insights based on visual content, in addition to text.
- **Multilingual**: Supports dozens of languages, including English, French, Spanish, German, Italian, Portuguese, Dutch, Chinese, Japanese, Korean, Arabic.
- **System Prompt**: Maintains strong adherence and support for system prompts.
- **Agentic**: Offers best-in-class agentic capabilities with native function calling and JSON outputting.
- **Speed-Optimized**: Delivers best-in-class performance and speed.
- **Apache 2.0 License**: Open-source license allowing usage and modification for both commercial and non-commercial purposes.
- **Large Context Window**: Supports a 256k context window.

## Usage examples

```py
import torch
from transformers import AutoProcessor, Mistral3ForConditionalGeneration


model_id = "mistralai/Mistral-Small-4-119B-2603"

processor = AutoProcessor.from_pretrained(model_id)
model = Mistral3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto"
)

image_url = "https://static.wikia.nocookie.net/essentialsdocs/images/7/70/Battle.png/revision/latest?cb=20220523172438"

messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What action do you think I should take in this situation? List all the possible actions and explain why you think they are good or bad.",
},
{"type": "image_url", "image_url": {"url": image_url}},
],
},
]

inputs = processor.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, reasoning_effort="high")
inputs = inputs.to(model.device)

output = model.generate(
**inputs,
max_new_tokens=512,
)[0]

# Setting `skip_special_tokens=False` to visualize reasoning trace between [THINK] [/THINK] tags.
decoded_output = processor.decode(output[len(inputs["input_ids"][0]):], skip_special_tokens=False)
print(decoded_output)
```

## Mistral4Config

[[autodoc]] Mistral4Config

## Mistral4PreTrainedModel

[[autodoc]] Mistral4PreTrainedModel
- forward

## Mistral4Model

[[autodoc]] Mistral4Model
- forward

## Mistral4ForCausalLM

[[autodoc]] Mistral4ForCausalLM

## Mistral4ForSequenceClassification

[[autodoc]] Mistral4ForSequenceClassification

## Mistral4ForTokenClassification

[[autodoc]] Mistral4ForTokenClassification

## Mistral4ForQuestionAnswering

[[autodoc]] Mistral4ForQuestionAnswering
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@
from .ministral3 import *
from .mistral import *
from .mistral3 import *
from .mistral4 import *
from .mixtral import *
from .mlcd import *
from .mllama import *
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@
("ministral3", "Ministral3Config"),
("mistral", "MistralConfig"),
("mistral3", "Mistral3Config"),
("mistral4", "Mistral4Config"),
("mixtral", "MixtralConfig"),
("mlcd", "MLCDVisionConfig"), # Keep this to make some original hub repositories (from `DeepGlint-AI`) works
("mlcd_vision_model", "MLCDVisionConfig"),
Expand Down Expand Up @@ -797,6 +798,7 @@
("ministral3", "Ministral3"),
("mistral", "Mistral"),
("mistral3", "Mistral3"),
("mistral4", "Mistral4"),
("mixtral", "Mixtral"),
("mlcd", "MLCD"), # Keep this to make some original hub repositories (from `DeepGlint-AI`) works
("mlcd_vision_model", "MLCD"),
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("ministral3", "Ministral3Model"),
("mistral", "MistralModel"),
("mistral3", "Mistral3Model"),
("mistral4", "Mistral4Model"),
("mixtral", "MixtralModel"),
("mlcd", "MLCDVisionModel"), # Keep this to make some original hub repositories (from `DeepGlint-AI`) works
("mlcd_vision_model", "MLCDVisionModel"),
Expand Down Expand Up @@ -541,6 +542,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("mamba2", "Mamba2ForCausalLM"),
("megatron-bert", "MegatronBertForPreTraining"),
("mistral3", "Mistral3ForConditionalGeneration"),
("mistral4", "Mistral4ForCausalLM"),
("mllama", "MllamaForConditionalGeneration"),
("mobilebert", "MobileBertForPreTraining"),
("mpnet", "MPNetForMaskedLM"),
Expand Down Expand Up @@ -981,6 +983,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
("mistral3", "Mistral3ForConditionalGeneration"),
("mistral4", "Mistral4ForCausalLM"),
("mllama", "MllamaForConditionalGeneration"),
("ovis2", "Ovis2ForConditionalGeneration"),
("paddleocr_vl", "PaddleOCRVLForConditionalGeneration"),
Expand Down Expand Up @@ -1243,6 +1246,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("ministral", "MinistralForSequenceClassification"),
("ministral3", "Ministral3ForSequenceClassification"),
("mistral", "MistralForSequenceClassification"),
("mistral4", "Mistral4ForSequenceClassification"),
("mixtral", "MixtralForSequenceClassification"),
("mobilebert", "MobileBertForSequenceClassification"),
("modernbert", "ModernBertForSequenceClassification"),
Expand Down Expand Up @@ -1456,6 +1460,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("ministral", "MinistralForTokenClassification"),
("ministral3", "Ministral3ForTokenClassification"),
("mistral", "MistralForTokenClassification"),
("mistral4", "Mistral4ForTokenClassification"),
("mixtral", "MixtralForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"),
("modernbert", "ModernBertForTokenClassification"),
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/ministral3/modeling_ministral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def eager_attention_forward(
return attn_output, attn_weights


def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor:
def get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor:
scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
return scaling.unsqueeze(-1)

Expand Down Expand Up @@ -144,7 +144,7 @@ def forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(query_states.shape[2], device=query_states.device) + past_seen_tokens
query_states = query_states * _get_llama_4_attn_scale(
query_states = query_states * get_llama_4_attn_scale(
cache_position,
self.config.rope_parameters.get("llama_4_scaling_beta"),
self.config.rope_parameters.get("original_max_position_embeddings"),
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/ministral3/modular_ministral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
logger = logging.get_logger(__name__)


def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor:
def get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor:
scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings))
return scaling.unsqueeze(-1)

Expand All @@ -51,7 +51,7 @@ def forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(query_states.shape[2], device=query_states.device) + past_seen_tokens
query_states = query_states * _get_llama_4_attn_scale(
query_states = query_states * get_llama_4_attn_scale(
cache_position,
self.config.rope_parameters.get("llama_4_scaling_beta"),
self.config.rope_parameters.get("original_max_position_embeddings"),
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/models/mistral4/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2026 Mistral AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_mistral4 import *
from .modeling_mistral4 import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Loading
Loading