From e03c7fcdfe22761d10850c1b914ddd41ddb5f57a Mon Sep 17 00:00:00 2001 From: juliendenize Date: Wed, 11 Mar 2026 15:33:59 +0000 Subject: [PATCH 1/9] Add Mistral Small 4 --- docs/source/en/model_doc/mistral4.md | 116 +++ src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 3 + .../models/ministral3/modeling_ministral3.py | 4 +- .../models/ministral3/modular_ministral3.py | 4 +- src/transformers/models/mistral4/__init__.py | 37 + .../models/mistral4/configuration_mistral4.py | 261 +++++++ .../mistral4/convert_mistral4_weight_to_hf.py | 608 +++++++++++++++ .../models/mistral4/modeling_mistral4.py | 730 ++++++++++++++++++ .../models/mistral4/modular_mistral4.py | 291 +++++++ src/transformers/processing_utils.py | 5 + tests/models/mistral4/__init__.py | 17 + .../models/mistral4/test_modeling_mistral4.py | 133 ++++ 14 files changed, 2208 insertions(+), 4 deletions(-) create mode 100644 docs/source/en/model_doc/mistral4.md create mode 100644 src/transformers/models/mistral4/__init__.py create mode 100644 src/transformers/models/mistral4/configuration_mistral4.py create mode 100644 src/transformers/models/mistral4/convert_mistral4_weight_to_hf.py create mode 100644 src/transformers/models/mistral4/modeling_mistral4.py create mode 100644 src/transformers/models/mistral4/modular_mistral4.py create mode 100644 tests/models/mistral4/__init__.py create mode 100644 tests/models/mistral4/test_modeling_mistral4.py diff --git a/docs/source/en/model_doc/mistral4.md b/docs/source/en/model_doc/mistral4.md new file mode 100644 index 000000000000..610883d85914 --- /dev/null +++ b/docs/source/en/model_doc/mistral4.md @@ -0,0 +1,116 @@ + +*This model was released on 2026-16-02 and added to Hugging Face Transformers on 2026-16-02.* + +# Mistral4 + +## Overview + +Mistral 4 is a powerful hybrid model with the capability of acting as both a general instruction model and a reasoning model. It unifies the capabilities of three different model families - Instruct, Reasoning ( previous called Magistral ), and Devstral - into a single, unified model. + +[Mistral-Small-4](https://huggingface.co/mistralai/Mistral-Small-4-119B-2603) consists of the following architectural choices: + +- MoE: 128 experts and 4 active. +- 119B with 6.5B activated parameters per token. +- 256k Context Length. +- Multimodal Input: Accepts both text and image input, with text output. +- Instruct and Reasoning functionalities with Function Calls + - Reasoning Effort configurable by request. + +Mistral 4 offers the following capabilities: + +- **Reasoning Mode**: Switch between a fast instant reply mode, and a reasoning thinking mode, boosting performance with test time compute when requested. +- **Vision**: Enables the model to analyze images and provide insights based on visual content, in addition to text. +- **Multilingual**: Supports dozens of languages, including English, French, Spanish, German, Italian, Portuguese, Dutch, Chinese, Japanese, Korean, Arabic. +- **System Prompt**: Maintains strong adherence and support for system prompts. +- **Agentic**: Offers best-in-class agentic capabilities with native function calling and JSON outputting. +- **Speed-Optimized**: Delivers best-in-class performance and speed. +- **Apache 2.0 License**: Open-source license allowing usage and modification for both commercial and non-commercial purposes. +- **Large Context Window**: Supports a 256k context window. + +## Usage examples + +```py +import torch +from transformers import AutoProcessor, Mistral3ForConditionalGeneration + + +model_id = "mistralai/Mistral-Small-4-119B-2603" + +processor = AutoProcessor.from_pretrained(model_id) +model = Mistral3ForConditionalGeneration.from_pretrained( + model_id, device_map="auto" +) + +image_url = "https://static.wikia.nocookie.net/essentialsdocs/images/7/70/Battle.png/revision/latest?cb=20220523172438" + +messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What action do you think I should take in this situation? List all the possible actions and explain why you think they are good or bad.", + }, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + }, +] + +inputs = processor.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, reasoning_effort="high") +inputs = inputs.to(model.device) + +output = model.generate( + **inputs, + max_new_tokens=512, +)[0] + +# Setting `skip_special_tokens=False` to visualize reasoning trace between [THINK] [/THINK] tags. +decoded_output = processor.decode(output[len(inputs["input_ids"][0]):], skip_special_tokens=False) +print(decoded_output) +``` + +## Mistral4Config + +[[autodoc]] Mistral4Config + +## Mistral4PreTrainedModel + +[[autodoc]] Mistral4PreTrainedModel + - forward + +## Mistral4Model + +[[autodoc]] Mistral4Model + - forward + +## Mistral4ForCausalLM + +[[autodoc]] Mistral4ForCausalLM + +## Mistral4ForSequenceClassification + +[[autodoc]] Mistral4ForSequenceClassification + +## Mistral4ForTokenClassification + +[[autodoc]] Mistral4ForTokenClassification + +## Mistral4ForQuestionAnswering + +[[autodoc]] Mistral4ForQuestionAnswering diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 7caa68039500..5f45081ac4a0 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -250,6 +250,7 @@ from .ministral3 import * from .mistral import * from .mistral3 import * + from .mistral4 import * from .mixtral import * from .mlcd import * from .mllama import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 3c71c30f120d..476b5362343f 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -287,6 +287,7 @@ ("ministral3", "Ministral3Config"), ("mistral", "MistralConfig"), ("mistral3", "Mistral3Config"), + ("mistral4", "Mistral4Config"), ("mixtral", "MixtralConfig"), ("mlcd", "MLCDVisionConfig"), # Keep this to make some original hub repositories (from `DeepGlint-AI`) works ("mlcd_vision_model", "MLCDVisionConfig"), @@ -797,6 +798,7 @@ ("ministral3", "Ministral3"), ("mistral", "Mistral"), ("mistral3", "Mistral3"), + ("mistral4", "Mistral4"), ("mixtral", "Mixtral"), ("mlcd", "MLCD"), # Keep this to make some original hub repositories (from `DeepGlint-AI`) works ("mlcd_vision_model", "MLCD"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 8a94da531eb8..316e394aad87 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -282,6 +282,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ministral3", "Ministral3Model"), ("mistral", "MistralModel"), ("mistral3", "Mistral3Model"), + ("mistral4", "Mistral4Model"), ("mixtral", "MixtralModel"), ("mlcd", "MLCDVisionModel"), # Keep this to make some original hub repositories (from `DeepGlint-AI`) works ("mlcd_vision_model", "MLCDVisionModel"), @@ -541,6 +542,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("mamba2", "Mamba2ForCausalLM"), ("megatron-bert", "MegatronBertForPreTraining"), ("mistral3", "Mistral3ForConditionalGeneration"), + ("mistral4", "Mistral4ForCausalLM"), ("mllama", "MllamaForConditionalGeneration"), ("mobilebert", "MobileBertForPreTraining"), ("mpnet", "MPNetForMaskedLM"), @@ -981,6 +983,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), ("mistral3", "Mistral3ForConditionalGeneration"), + ("mistral4", "Mistral4ForCausalLM"), ("mllama", "MllamaForConditionalGeneration"), ("ovis2", "Ovis2ForConditionalGeneration"), ("paddleocr_vl", "PaddleOCRVLForConditionalGeneration"), diff --git a/src/transformers/models/ministral3/modeling_ministral3.py b/src/transformers/models/ministral3/modeling_ministral3.py index be0659788d71..3d2d15fa1a84 100644 --- a/src/transformers/models/ministral3/modeling_ministral3.py +++ b/src/transformers/models/ministral3/modeling_ministral3.py @@ -102,7 +102,7 @@ def eager_attention_forward( return attn_output, attn_weights -def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor: +def get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor: scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings)) return scaling.unsqueeze(-1) @@ -144,7 +144,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange(query_states.shape[2], device=query_states.device) + past_seen_tokens - query_states = query_states * _get_llama_4_attn_scale( + query_states = query_states * get_llama_4_attn_scale( cache_position, self.config.rope_parameters.get("llama_4_scaling_beta"), self.config.rope_parameters.get("original_max_position_embeddings"), diff --git a/src/transformers/models/ministral3/modular_ministral3.py b/src/transformers/models/ministral3/modular_ministral3.py index bde7829d769a..9dcac11dd058 100644 --- a/src/transformers/models/ministral3/modular_ministral3.py +++ b/src/transformers/models/ministral3/modular_ministral3.py @@ -26,7 +26,7 @@ logger = logging.get_logger(__name__) -def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor: +def get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor: scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings)) return scaling.unsqueeze(-1) @@ -51,7 +51,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange(query_states.shape[2], device=query_states.device) + past_seen_tokens - query_states = query_states * _get_llama_4_attn_scale( + query_states = query_states * get_llama_4_attn_scale( cache_position, self.config.rope_parameters.get("llama_4_scaling_beta"), self.config.rope_parameters.get("original_max_position_embeddings"), diff --git a/src/transformers/models/mistral4/__init__.py b/src/transformers/models/mistral4/__init__.py new file mode 100644 index 000000000000..fdde4abe6b16 --- /dev/null +++ b/src/transformers/models/mistral4/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2026 Mistral AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mistral4 import * + from .modeling_mistral4 import * +else: + import sys + + _file = globals()["__file__"] + # Explicitly define the import structure to include modeling classes + import_structure = define_import_structure(_file) + # Add the modeling classes explicitly + import_structure[frozenset({})]["modeling_mistral4"] = { + "Mistral4PreTrainedModel", + "Mistral4Model", + "Mistral4ForCausalLM", + "Mistral4ForSequenceClassification", + "Mistral4ForTokenClassification", + } + sys.modules[__name__] = _LazyModule(__name__, _file, import_structure, module_spec=__spec__) diff --git a/src/transformers/models/mistral4/configuration_mistral4.py b/src/transformers/models/mistral4/configuration_mistral4.py new file mode 100644 index 000000000000..e9ea9981e7cf --- /dev/null +++ b/src/transformers/models/mistral4/configuration_mistral4.py @@ -0,0 +1,261 @@ +# Copyright 2026 Mistral AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters + + +class Mistral4Config(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Mistral4Model`]. It is used to instantiate a Mistral4 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Mistral-Small-4 model. + e.g. [mistralai/Mistral-Small-4-119B-2603](https://huggingface.co/mistralai/Mistral-Small-4-119B-2603) + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 131072): + Vocabulary size of the Mistral4 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Mistral4Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 12288): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 2048): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 36): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 32): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + n_shared_experts (`int`, *optional*, defaults to 1): + Number of shared experts. + n_routed_experts (`int`, *optional*, defaults to 128): + Number of routed experts. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + kv_lora_rank (`int`, *optional*, defaults to 256): + Rank of the LoRA matrices for key and value projections. + q_lora_rank (`int`, *optional*, defaults to 1024): + Rank of the LoRA matrices for query projections. + qk_rope_head_dim (`int`, *optional*, defaults to 64): + Dimension of the query/key heads that use rotary position embeddings. + v_head_dim (`int`, *optional*, defaults to 128): + Dimension of the value heads. + qk_nope_head_dim (`int`, *optional*, defaults to 64): + Dimension of the query/key heads that don't use rotary position embeddings. + n_group (`int`, *optional*, defaults to 1): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to 1): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to 4): + Number of selected experts, None means dense model. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the weights of the routed experts. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 1048576): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 11): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain + a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE + with longer `max_position_embeddings`. + rope_interleave (`bool`, *optional*, defaults to `True`): + Whether to interleave the rotary position embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import Mistral4Model, Mistral4Config + + >>> # Initializing a Mistral4 style configuration + >>> configuration = Mistral4Config() + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mistral4" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.mlp.experts.gate_up_proj": "rowwise", + "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + attribute_map = { + "num_local_experts": "n_routed_experts", + } + + def __init__( + self, + vocab_size: int | None = 131072, + hidden_size: int | None = 4096, + intermediate_size: int | None = 12288, + moe_intermediate_size: int | None = 2048, + num_hidden_layers: int | None = 36, + num_attention_heads: int | None = 32, + num_key_value_heads: int | None = 32, + n_shared_experts: int | None = 1, + n_routed_experts: int | None = 128, + routed_scaling_factor: float | None = 1.0, + kv_lora_rank: int | None = 256, + q_lora_rank: int | None = 1024, + qk_rope_head_dim: int | None = 64, + v_head_dim: int | None = 128, + qk_nope_head_dim: int | None = 64, + n_group: int | None = 1, + topk_group: int | None = 1, + num_experts_per_tok: int | None = 4, + first_k_dense_replace: int | None = 0, + norm_topk_prob: bool | None = True, + hidden_act: str | None = "silu", + max_position_embeddings: int | None = 1048576, + initializer_range: float | None = 0.02, + rms_norm_eps: float | None = 1e-6, + use_cache: bool | None = True, + pad_token_id: int | None = 11, + bos_token_id: int | None = 1, + eos_token_id: int | None = 2, + pretraining_tp: int | None = 1, + tie_word_embeddings: bool | None = False, + rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None, + rope_interleave: bool | None = True, + attention_bias: bool | None = False, + attention_dropout: float | None = 0.0, + mlp_bias: bool | None = False, + **kwargs, + ): + if rope_parameters is None: + rope_parameters = { + "type": "yarn", + "rope_theta": 10000.0, + "factor": 128.0, + "original_max_position_embeddings": 8192, + "max_position_embeddings": max_position_embeddings, + "beta_fast": 32.0, + "beta_slow": 1.0, + "mscale_all_dim": 1.0, + "mscale": 1.0, + "llama_4_scaling_beta": 0.1, + "partial_rotary_factor": qk_rope_head_dim / (qk_nope_head_dim + qk_rope_head_dim), + } + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.head_dim = qk_nope_head_dim + qk_rope_head_dim + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.rope_interleave = rope_interleave + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.rope_parameters = rope_parameters + self.rope_parameters.setdefault("partial_rotary_factor", self.qk_rope_head_dim / self.head_dim) + + self.tie_word_embeddings = tie_word_embeddings + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.mlp_bias = mlp_bias + super().__init__(ignore_keys_at_rope_validation={"llama_4_scaling_beta", "max_position_embeddings"}, **kwargs) + + def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: set | None = None, **kwargs): + rope_scaling = kwargs.pop("rope_scaling", None) + self.rope_parameters = rope_scaling or self.rope_parameters + self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else {} + + # Standardize and validate the correctness of rotary position embeddings parameters + self.rope_parameters.setdefault("rope_theta", kwargs.pop("rope_theta", self.default_theta)) + self.standardize_rope_params() + if ignore_keys_at_rope_validation is not None: + self.ignore_keys_at_rope_validation = self.ignore_keys_at_rope_validation | ignore_keys_at_rope_validation + self.validate_rope() + + # Convert to float because RoPE fn expect a float. Models on the hub were saved as int + for key in ["beta_fast", "beta_slow", "factor"]: + if key in self.rope_parameters: + self.rope_parameters[key] = float(self.rope_parameters[key]) + return kwargs + + +__all__ = ["Mistral4Config"] diff --git a/src/transformers/models/mistral4/convert_mistral4_weight_to_hf.py b/src/transformers/models/mistral4/convert_mistral4_weight_to_hf.py new file mode 100644 index 000000000000..b3c3b7714fca --- /dev/null +++ b/src/transformers/models/mistral4/convert_mistral4_weight_to_hf.py @@ -0,0 +1,608 @@ +# Copyright 2026 Mistral AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import re +from collections import defaultdict +from pathlib import Path +from typing import Any + +import torch +from safetensors.torch import load_file + +from transformers import ( + GenerationConfig, + Mistral3Config, + Mistral3ForConditionalGeneration, + Mistral4Config, + PixtralImageProcessorFast, + PixtralProcessor, + PixtralVisionConfig, +) +from transformers.core_model_loading import ( + Concatenate, + ConversionOps, + MergeModulelist, + WeightRenaming, +) +from transformers.integrations.finegrained_fp8 import replace_with_fp8_linear +from transformers.integrations.mistral import convert_tekken_tokenizer +from transformers.models.mistral4.modeling_mistral4 import Mistral4ForCausalLM +from transformers.quantizers.auto import AutoQuantizationConfig + + +_FP8_DTYPE = torch.float8_e4m3fn +_FP8_MIN = torch.finfo(_FP8_DTYPE).min +_FP8_MAX = torch.finfo(_FP8_DTYPE).max + +EXPERT_KEY_PATTERN = re.compile(r"^layers\.(\d+)\.experts\.(\d+)\.(w[123])\.(weight|qscale_act|qscale_weight)$") + + +class FP8RescaleMergeAndConcatenate(ConversionOps): + r"""FP8-aware gate+up expert fusion with per-expert scale rescaling. + + Takes per-expert gate (w1) and up (w3) weight tensors together with their + FP8 `weight_scale_inv` values, rescales both to a common scale per expert + (the max of the two), concatenates gate+up along `dim=0`, and stacks + across experts along a new leading dimension. + """ + + @torch.no_grad() + def convert( + self, + input_dict: dict[str, list[torch.Tensor]], + source_patterns: list[str], + target_patterns: list[str], + **kwargs: Any, + ) -> dict[str, torch.Tensor]: + w1_weights = input_dict["w1.weight"] + w3_weights = input_dict["w3.weight"] + w1_scales = input_dict["w1.qscale_weight"] + w3_scales = input_dict["w3.qscale_weight"] + + gate_up_list: list[torch.Tensor] = [] + scale_inv_list: list[torch.Tensor] = [] + + for e in range(len(w1_weights)): + fused_scale_inv = torch.max(w1_scales[e], w3_scales[e]) + gate = _rescale_fp8(w1_weights[e], w1_scales[e], fused_scale_inv) + up = _rescale_fp8(w3_weights[e], w3_scales[e], fused_scale_inv) + gate_up_list.append(torch.cat([gate, up], dim=0)) + scale_inv_list.append(fused_scale_inv) + + gate_up_proj_scale_inv = torch.stack(scale_inv_list) + while gate_up_proj_scale_inv.ndim < 3: + gate_up_proj_scale_inv = gate_up_proj_scale_inv.unsqueeze(-1) + + return { + "gate_up_proj": torch.stack(gate_up_list, dim=0), + "gate_up_proj_scale_inv": gate_up_proj_scale_inv, + } + + +def _get_text_renamings(prefix: str) -> list[WeightRenaming]: + r"""Build `WeightRenaming` list for text-model keys.""" + return [ + WeightRenaming("^output", "lm_head"), + WeightRenaming("^norm", f"{prefix}.norm"), + WeightRenaming("^tok_embeddings", f"{prefix}.embed_tokens"), + WeightRenaming("^layers", f"{prefix}.layers"), + WeightRenaming("attention_norm", "input_layernorm"), + WeightRenaming("ffn_norm", "post_attention_layernorm"), + WeightRenaming(r"attention\.wkv_a_with_mqa", "self_attn.kv_a_proj_with_mqa"), + WeightRenaming(r"attention\.wkv_b", "self_attn.kv_b_proj"), + WeightRenaming(r"attention\.wq_a", "self_attn.q_a_proj"), + WeightRenaming(r"attention\.wq_b", "self_attn.q_b_proj"), + WeightRenaming(r"attention\.wo", "self_attn.o_proj"), + WeightRenaming(r"attention\.q_a_norm", "self_attn.q_a_layernorm"), + WeightRenaming(r"attention\.kv_a_norm", "self_attn.kv_a_layernorm"), + WeightRenaming(r"\.gate\.weight", ".mlp.gate.weight"), + WeightRenaming(r"shared_experts\.w1", "mlp.shared_experts.gate_proj"), + WeightRenaming(r"shared_experts\.w2", "mlp.shared_experts.down_proj"), + WeightRenaming(r"shared_experts\.w3", "mlp.shared_experts.up_proj"), + WeightRenaming(r"\.qscale_act", ".activation_scale"), + WeightRenaming(r"\.qscale_weight", ".weight_scale_inv"), + ] + + +def _get_vision_renamings() -> list[WeightRenaming]: + r"""Build `WeightRenaming` list for vision-model keys.""" + return [ + WeightRenaming("^vision_encoder", "model.vision_tower"), + WeightRenaming(r"^vision_language_adapter\.w_in", "model.multi_modal_projector.linear_1"), + WeightRenaming(r"^vision_language_adapter\.w_out", "model.multi_modal_projector.linear_2"), + WeightRenaming("^patch_merger", "model.multi_modal_projector.patch_merger"), + WeightRenaming("^pre_mm_projector_norm", "model.multi_modal_projector.norm"), + WeightRenaming(r"attention\.wq\.", "attention.q_proj."), + WeightRenaming(r"attention\.wk\.", "attention.k_proj."), + WeightRenaming(r"attention\.wv\.", "attention.v_proj."), + WeightRenaming(r"attention\.wo\.", "attention.o_proj."), + WeightRenaming(r"feed_forward\.w1", "feed_forward.gate_proj"), + WeightRenaming(r"feed_forward\.w2", "feed_forward.down_proj"), + WeightRenaming(r"feed_forward\.w3", "feed_forward.up_proj"), + ] + + +_VISION_KEY_PREFIXES = ("vision_encoder.", "vision_language_adapter.", "patch_merger.", "pre_mm_projector_norm.") + + +def _is_vision_key(key: str) -> bool: + r"""Return whether *key* belongs to the vision / projector components.""" + return key.startswith(_VISION_KEY_PREFIXES) + + +def _rename_key( + key: str, + text_renamings: list[WeightRenaming], + vision_renamings: list[WeightRenaming], +) -> str: + r"""Apply the appropriate `WeightRenaming` chain to *key*.""" + renamings = vision_renamings if _is_vision_key(key) else text_renamings + for renaming in renamings: + key, _ = renaming.rename_source_key(key) + return key + + +def _rescale_fp8( + tensor: torch.Tensor, + original_scale_inv: torch.Tensor, + target_scale_inv: torch.Tensor, +) -> torch.Tensor: + r"""Rescale an FP8 tensor from *original_scale_inv* to *target_scale_inv*.""" + ratio = original_scale_inv / target_scale_inv + return (tensor.to(torch.bfloat16) * ratio).clamp(min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) + + +def _descale_fp8_to_bf16(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor: + r"""Descale an FP8 tensor back to BF16 using its `weight_scale_inv`.""" + return (tensor.to(torch.bfloat16) * scale_inv.to(torch.bfloat16)).to(torch.bfloat16) + + +def _permute_for_rope(tensor: torch.Tensor, n_heads: int, dim1: int, dim2: int) -> torch.Tensor: + r"""Permute Q/K weight matrices from Mistral's interleaved layout to HF's contiguous-half layout.""" + tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2) + tensor = tensor.transpose(1, 2) + tensor = tensor.reshape(dim1, dim2) + return tensor + + +def _maybe_permute_vision_rope( + new_key: str, + tensor: torch.Tensor, + vision_config: PixtralVisionConfig, +) -> torch.Tensor: + r"""Apply RoPE permutation to vision Q/K weight matrices if applicable.""" + num_attention_heads = vision_config.num_attention_heads + hidden_size = vision_config.hidden_size + head_dim = vision_config.head_dim + attn_dim = head_dim * num_attention_heads + + if "q_proj" in new_key and new_key.endswith("weight"): + tensor = _permute_for_rope(tensor, num_attention_heads, attn_dim, hidden_size) + elif "k_proj" in new_key and new_key.endswith("weight"): + tensor = _permute_for_rope(tensor, num_attention_heads, attn_dim, hidden_size) + + return tensor + + +def _fuse_experts_for_layer( + grouped: dict[tuple, dict[int, torch.Tensor]], + layer_idx: int, + n_experts: int, + base: str, + output_fp8: bool, +) -> dict[str, torch.Tensor]: + r"""Fuse per-expert weights for a single layer""" + merge_op = MergeModulelist(dim=0) + + w1 = grouped[(layer_idx, "w1", "weight")] + w2 = grouped[(layer_idx, "w2", "weight")] + w3 = grouped[(layer_idx, "w3", "weight")] + + w1_scales = grouped.get((layer_idx, "w1", "qscale_weight")) + w2_scales = grouped.get((layer_idx, "w2", "qscale_weight")) + w3_scales = grouped.get((layer_idx, "w3", "qscale_weight")) + + result: dict[str, torch.Tensor] = {} + + if output_fp8: + fp8_fuse_op = FP8RescaleMergeAndConcatenate() + gate_up_result = fp8_fuse_op.convert( + input_dict={ + "w1.weight": [w1[e] for e in range(n_experts)], + "w3.weight": [w3[e] for e in range(n_experts)], + "w1.qscale_weight": [w1_scales[e] for e in range(n_experts)], + "w3.qscale_weight": [w3_scales[e] for e in range(n_experts)], + }, + source_patterns=["w1.weight", "w3.weight", "w1.qscale_weight", "w3.qscale_weight"], + target_patterns=["gate_up_proj", "gate_up_proj_scale_inv"], + ) + result[f"{base}.gate_up_proj"] = gate_up_result["gate_up_proj"] + result[f"{base}.gate_up_proj_scale_inv"] = gate_up_result["gate_up_proj_scale_inv"] + + down_result = merge_op.convert( + input_dict={"w2": [w2[e] for e in range(n_experts)]}, + source_patterns=["w2"], + target_patterns=["down_proj"], + ) + result[f"{base}.down_proj"] = down_result["down_proj"] + down_proj_scale_inv = torch.stack([w2_scales[e] for e in range(n_experts)]) + while down_proj_scale_inv.ndim < 3: + down_proj_scale_inv = down_proj_scale_inv.unsqueeze(-1) + result[f"{base}.down_proj_scale_inv"] = down_proj_scale_inv + + w1_act = grouped.get((layer_idx, "w1", "qscale_act")) + if w1_act is not None: + w2_act = grouped[(layer_idx, "w2", "qscale_act")] + w3_act = grouped[(layer_idx, "w3", "qscale_act")] + result[f"{base}.gate_up_proj_activation_scale"] = torch.stack( + [torch.max(w1_act[e], w3_act[e]) for e in range(n_experts)] + ) + result[f"{base}.down_proj_activation_scale"] = torch.stack([w2_act[e] for e in range(n_experts)]) + else: + concat_op = Concatenate(dim=1) + + w1_list = [_descale_fp8_to_bf16(w1[e], w1_scales[e]) if w1_scales else w1[e] for e in range(n_experts)] + w3_list = [_descale_fp8_to_bf16(w3[e], w3_scales[e]) if w3_scales else w3[e] for e in range(n_experts)] + w2_list = [_descale_fp8_to_bf16(w2[e], w2_scales[e]) if w2_scales else w2[e] for e in range(n_experts)] + + step1 = merge_op.convert( + input_dict={"w1": w1_list, "w3": w3_list}, + source_patterns=["w1", "w3"], + target_patterns=["gate_up_proj"], + ) + gate_up = concat_op.convert(step1, source_patterns=["w1", "w3"], target_patterns=["gate_up_proj"]) + result[f"{base}.gate_up_proj"] = gate_up["gate_up_proj"] + + down = merge_op.convert( + input_dict={"w2": w2_list}, + source_patterns=["w2"], + target_patterns=["down_proj"], + ) + result[f"{base}.down_proj"] = down["down_proj"] + + return result + + +def fuse_experts( + expert_weights: dict[tuple, torch.Tensor], + n_experts: int, + has_vision: bool, + output_fp8: bool, +) -> dict[str, torch.Tensor]: + r"""Fuse per-expert weights across all layers.""" + prefix = "model.language_model" if has_vision else "model" + + grouped: dict[tuple, dict[int, torch.Tensor]] = defaultdict(dict) + for (layer_idx, expert_idx, param_type, suffix), tensor in expert_weights.items(): + grouped[(layer_idx, param_type, suffix)][int(expert_idx)] = tensor + + consumed_keys: set[tuple] = set() + result: dict[str, torch.Tensor] = {} + layers = sorted({layer_idx for (layer_idx, _, _) in grouped}) + + for layer_idx in layers: + base = f"{prefix}.layers.{layer_idx}.mlp.experts" + + w1_weight_key = (layer_idx, "w1", "weight") + assert w1_weight_key in grouped, f"Layer {layer_idx}: missing w1 weights" + assert len(grouped[w1_weight_key]) == n_experts, ( + f"Layer {layer_idx}: expected {n_experts} w1 experts, got {len(grouped[w1_weight_key])}" + ) + + for param in ("w1", "w2", "w3"): + for suffix in ("weight", "qscale_weight", "qscale_act"): + key = (layer_idx, param, suffix) + if key in grouped: + consumed_keys.add(key) + + layer_result = _fuse_experts_for_layer(grouped, layer_idx, n_experts, base, output_fp8) + + result.update(layer_result) + + unconsumed = set(grouped.keys()) - consumed_keys + assert not unconsumed, f"Unconsumed expert groups: {unconsumed}" + + return result + + +def convert_state_dict( + original_state_dict: dict[str, torch.Tensor], + text_renamings: list[WeightRenaming], + vision_renamings: list[WeightRenaming], + total_keys_seen: set[str], + vision_config: PixtralVisionConfig | None = None, + is_fp8_source: bool = False, + output_bf16: bool = False, +) -> tuple[dict[str, torch.Tensor], dict[tuple, torch.Tensor]]: + r"""Rename and optionally descale one shard of the original state dict.""" + new_dict: dict[str, torch.Tensor] = {} + expert_weights: dict[tuple, torch.Tensor] = {} + + for old_key, tensor in original_state_dict.items(): + assert old_key not in total_keys_seen, f"Duplicate key across shards: {old_key}" + total_keys_seen.add(old_key) + + match = EXPERT_KEY_PATTERN.match(old_key) + if match: + layer_idx, expert_idx, param_type, suffix = int(match[1]), int(match[2]), match[3], match[4] + expert_weights[(layer_idx, expert_idx, param_type, suffix)] = tensor + continue + + if output_bf16 and is_fp8_source: + if old_key.endswith((".qscale_act", ".qscale_weight")): + continue + if old_key.endswith(".weight"): + scale_key = old_key.rsplit(".weight", 1)[0] + ".qscale_weight" + if scale_key in original_state_dict: + tensor = _descale_fp8_to_bf16(tensor, original_state_dict[scale_key]) + + new_key = _rename_key(old_key, text_renamings, vision_renamings) + + if vision_config is not None and "vision_tower" in new_key: + tensor = _maybe_permute_vision_rope(new_key, tensor, vision_config) + + new_dict[new_key] = tensor + + return new_dict, expert_weights + + +def _read_json(path: Path) -> dict: + with open(path) as f: + return json.load(f) + + +def convert_config( + original_config: dict, + max_position_embeddings: int = 1_048_576, + is_vision: bool = True, + output_fp8: bool = True, +) -> Mistral3Config | Mistral4Config: + r"""Convert original Mistral `params.json` to a HF config object.""" + original_vision_config = original_config.pop("vision_encoder", None) + assert is_vision == (original_vision_config is not None) + + text_kwargs: dict[str, Any] = { + "hidden_size": original_config["dim"], + "num_hidden_layers": original_config["n_layers"], + "intermediate_size": original_config["hidden_dim"], + "num_attention_heads": original_config["n_heads"], + "num_key_value_heads": original_config["n_kv_heads"], + "rms_norm_eps": original_config["norm_eps"], + "vocab_size": original_config["vocab_size"], + "tie_word_embeddings": original_config.get("tied_embeddings", False), + "sliding_window": int(original_config["sliding_window"]) + if original_config.get("sliding_window") is not None + else None, + "max_position_embeddings": original_config.get( + "max_position_embeddings", original_config.get("max_seq_len", max_position_embeddings) + ), + "bos_token_id": 1, + "eos_token_id": 2, + "pad_token_id": 11, + } + + for key in ["q_lora_rank", "qk_rope_head_dim", "qk_nope_head_dim", "kv_lora_rank", "v_head_dim"]: + if key in original_config: + text_kwargs[key] = original_config[key] + + moe = original_config.get("moe") + assert moe is not None + text_kwargs.update( + { + "n_routed_experts": moe.get("num_experts", 128), + "num_experts_per_tok": moe.get("num_experts_per_tok", 4), + "first_k_dense_replace": moe.get("first_k_dense_replace", 0), + "n_shared_experts": moe.get("num_shared_experts", 1), + "moe_intermediate_size": moe.get("expert_hidden_dim", 2048), + "routed_scaling_factor": moe.get("routed_scale", 1.0), + "n_group": moe.get("num_expert_groups", 1), + "topk_group": moe.get("num_expert_groups_per_tok", 1), + "norm_topk_prob": True, + } + ) + + qk_rope_head_dim = text_kwargs.get("qk_rope_head_dim", 64) + qk_nope_head_dim = text_kwargs.get("qk_nope_head_dim", 64) + + text_kwargs["rope_parameters"] = { + "type": "yarn", + "rope_theta": original_config.get("rope_theta", 10_000.0), + "factor": float(original_config["yarn"]["factor"]), + "original_max_position_embeddings": original_config["yarn"]["original_max_position_embeddings"], + "beta_fast": float(original_config["yarn"]["beta"]), + "beta_slow": float(original_config["yarn"]["alpha"]), + "mscale_all_dim": 1.0, + "mscale": 1.0, + "llama_4_scaling_beta": original_config.get("llama_4_scaling", {}).get("beta", 0.1), + "partial_rotary_factor": qk_rope_head_dim / (qk_nope_head_dim + qk_rope_head_dim), + } + + quant_kwargs: dict[str, Any] = {} + quant = original_config.get("quantization", {}) + if output_fp8 and quant.get("qformat_weight") == "fp8_e4m3": + assert quant["qscheme_act"] == "TENSOR" + quant_kwargs["quantization_config"] = AutoQuantizationConfig.from_dict( + { + "activation_scheme": "static", + "modules_to_not_convert": ["model.vision_tower", "model.multi_modal_projector", "lm_head"], + "quant_method": "fp8", + "weight_block_size": None, + } + ) + + if not is_vision: + return Mistral4Config(**text_kwargs, **quant_kwargs) + + text_config = Mistral4Config(**text_kwargs) + adapter_bias = original_vision_config.pop("adapter_bias", False) + spatial_merge_size = original_vision_config.pop("spatial_merge_size") + image_token_id = original_vision_config.pop("image_token_id", 10) + for drop_key in [ + "mm_projector_id", + "add_pre_mm_projector_layer_norm", + "image_break_token_id", + "image_end_token_id", + "max_image_size", + ]: + original_vision_config.pop(drop_key, None) + vision_config = PixtralVisionConfig(hidden_act="silu", **original_vision_config) + + return Mistral3Config( + vision_config=vision_config, + text_config=text_config, + multimodal_projector_bias=adapter_bias, + image_token_id=image_token_id, + spatial_merge_size=spatial_merge_size, + vision_feature_layer=-1, + tie_word_embeddings=text_kwargs["tie_word_embeddings"], + **quant_kwargs, + ) + + +def convert_and_write_model( + input_dir: Path, + output_dir: Path, + max_position_embeddings: int, + output_format: str, +) -> Mistral3Config | Mistral4Config: + r"""Convert weights and write the HF model to output_dir.""" + params = _read_json(input_dir / "params.json") + is_vision = params.get("vision_encoder") is not None + is_fp8_source = params.get("quantization", {}).get("qformat_weight") == "fp8_e4m3" + output_fp8 = output_format == "fp8" and is_fp8_source + output_bf16 = not output_fp8 + + config = convert_config(params, max_position_embeddings, is_vision, output_fp8) + + text_config = config.text_config if isinstance(config, Mistral3Config) else config + n_experts = text_config.n_routed_experts + vision_config = config.vision_config if isinstance(config, Mistral3Config) else None + + model_prefix = "model.language_model" if is_vision else "model" + text_renamings = _get_text_renamings(model_prefix) + vision_renamings = _get_vision_renamings() if is_vision else [] + + full_state_dict: dict[str, torch.Tensor] = {} + all_expert_weights: dict[tuple, torch.Tensor] = {} + total_keys_seen: set[str] = set() + shards = sorted(p for p in input_dir.iterdir() if p.suffix == ".safetensors") + assert shards, f"No .safetensors files found in {input_dir}" + + for shard_path in shards: + print(f"Processing shard: {shard_path.name}") + original = load_file(str(shard_path)) + new_dict, expert_weights = convert_state_dict( + original, + text_renamings, + vision_renamings, + total_keys_seen, + vision_config, + is_fp8_source, + output_bf16, + ) + del original + full_state_dict.update(new_dict) + del new_dict + all_expert_weights.update(expert_weights) + del expert_weights + + print(f"Fusing {len(all_expert_weights)} expert weight entries...") + fused = fuse_experts(all_expert_weights, n_experts, is_vision, output_fp8) + del all_expert_weights + full_state_dict.update(fused) + del fused + + if text_config.tie_word_embeddings: + full_state_dict["lm_head.weight"] = full_state_dict[f"{model_prefix}.embed_tokens.weight"] + + with torch.device("meta"): + if isinstance(config, Mistral3Config): + model = Mistral3ForConditionalGeneration(config) + else: + model = Mistral4ForCausalLM(config) + + if output_fp8 and hasattr(model.config, "quantization_config"): + qconfig = model.config.quantization_config + model = replace_with_fp8_linear(model, qconfig.modules_to_not_convert, qconfig) + + model.load_state_dict(full_state_dict, strict=True, assign=True) + model.save_pretrained(str(output_dir)) + return config + + +def convert_and_write_processor_and_tokenizer( + input_dir: Path, + output_dir: Path, + model_config: Mistral3Config | Mistral4Config, +) -> None: + r"""Convert and write tokenizer (and processor for VLMs) to *output_dir*.""" + tokenizer = convert_tekken_tokenizer(str(input_dir / "tekken.json")) + + if isinstance(model_config, Mistral4Config): + tokenizer.save_pretrained(str(output_dir)) + return + + params = _read_json(input_dir / "params.json") + ve = params["vision_encoder"] + + processor = PixtralProcessor( + tokenizer=tokenizer, + image_processor=PixtralImageProcessorFast( + patch_size=ve["patch_size"], size={"longest_edge": ve["max_image_size"]} + ), + image_token="[IMG]", + patch_size=ve["patch_size"], + spatial_merge_size=ve["spatial_merge_size"], + ) + processor.save_pretrained(str(output_dir)) + + text_config = model_config.text_config if hasattr(model_config, "text_config") else model_config + GenerationConfig( + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + max_length=text_config.max_position_embeddings, + ).save_pretrained(str(output_dir)) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Convert Mistral4 weights to HuggingFace format.") + parser.add_argument( + "input_dir", + type=Path, + help="Directory containing Mistral weights (params.json, tekken.json, *.safetensors)", + ) + parser.add_argument("output_dir", type=Path, help="Output directory for HF model") + parser.add_argument( + "--max_position_embeddings", + type=int, + default=1_048_576, + help="max_position_embeddings (used when not specified in params.json)", + ) + parser.add_argument( + "--output_format", + choices=["fp8", "bf16"], + default="fp8", + help="Output format: 'fp8' keeps FP8 quantization (default), 'bf16' descales to BF16", + ) + args = parser.parse_args() + + config = convert_and_write_model(args.input_dir, args.output_dir, args.max_position_embeddings, args.output_format) + convert_and_write_processor_and_tokenizer(args.input_dir, args.output_dir, config) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/mistral4/modeling_mistral4.py b/src/transformers/models/mistral4/modeling_mistral4.py new file mode 100644 index 000000000000..df836e52f2dd --- /dev/null +++ b/src/transformers/models/mistral4/modeling_mistral4.py @@ -0,0 +1,730 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/mistral4/modular_mistral4.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_mistral4.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 Mistral AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernel_func_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import ( + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from .configuration_mistral4 import Mistral4Config + + +@use_kernel_forward_from_hub("RMSNorm") +class Mistral4RMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + Mistral4RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Mistral4RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Mistral4Config, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: Mistral4Config | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Mistral4MLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Mistral4TopkRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.n_routed_experts = config.n_routed_experts + + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states, self.weight) + return router_logits + + +@use_experts_implementation +class Mistral4NaiveMoe(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config): + super().__init__() + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + +class Mistral4MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.experts = Mistral4NaiveMoe(config) + self.gate = Mistral4TopkRouter(config) + self.shared_experts = Mistral4MLP( + config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) + self.n_routed_experts = config.n_routed_experts + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + self.top_k = config.num_experts_per_tok + + def route_tokens_to_experts(self, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + router_logits = router_logits.softmax(-1) + group_scores = ( + router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = router_logits.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + router_logits = self.gate(hidden_states) + topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + r""" + TODO let's just use the original freqcis computation to not have the view + transpose + reshape! This is not optimized! + Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor: + scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings)) + return scaling.unsqueeze(-1) + + +class Mistral4Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Mistral4Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_attention_heads + + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_head_dim = config.qk_head_dim + + self.is_causal = True + if self.q_lora_rank is None: + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = Mistral4RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) + + self.kv_a_proj_with_mqa = nn.Linear( + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = Mistral4RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + self.scaling = self.qk_head_dim ** (-0.5) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + + if self.q_lora_rank is None: + q_states = self.q_proj(hidden_states) + else: + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q_states = q_states.view(query_shape).transpose(1, 2) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + + cos, sin = position_embeddings + if self.config.rope_interleave: # support using interleaved weights for efficiency + q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) + else: + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange(query_states.shape[2], device=query_states.device) + past_seen_tokens + + query_states = query_states * get_llama_4_attn_scale( + cache_position, + self.config.rope_parameters.get("llama_4_scaling_beta"), + self.config.rope_parameters.get("original_max_position_embeddings"), + ).to(query_states.dtype) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Mistral4DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Mistral4Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Mistral4Attention(config=config, layer_idx=layer_idx) + + if layer_idx >= config.first_k_dense_replace: + self.mlp = Mistral4MoE(config) + else: + self.mlp = Mistral4MLP(config) + + self.input_layernorm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Mistral4PreTrainedModel(PreTrainedModel): + config: Mistral4Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Mistral4DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Mistral4DecoderLayer, + "attentions": Mistral4Attention, + } + _keep_in_fp32_modules_strict = [] + _keys_to_ignore_on_load_unexpected = [] + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Mistral4TopkRouter): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, Mistral4NaiveMoe): + init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) + init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) + + +@auto_docstring +class Mistral4Model(Mistral4PreTrainedModel): + def __init__(self, config: Mistral4Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Mistral4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Mistral4RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_embeddings=position_embeddings, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class Mistral4ForCausalLM(Mistral4PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Mistral4Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, Mistral4ForCausalLM + + >>> model = Mistral4ForCausalLM.from_pretrained("meta-mistral4/Mistral4-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral4/Mistral4-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Mistral4ForSequenceClassification(GenericForSequenceClassification, Mistral4PreTrainedModel): + pass + + +class Mistral4ForTokenClassification(GenericForTokenClassification, Mistral4PreTrainedModel): + pass + + +__all__ = [ + "Mistral4PreTrainedModel", + "Mistral4Model", + "Mistral4ForCausalLM", + "Mistral4ForSequenceClassification", + "Mistral4ForTokenClassification", +] diff --git a/src/transformers/models/mistral4/modular_mistral4.py b/src/transformers/models/mistral4/modular_mistral4.py new file mode 100644 index 000000000000..d9c73a3c19cc --- /dev/null +++ b/src/transformers/models/mistral4/modular_mistral4.py @@ -0,0 +1,291 @@ +# Copyright 2026 Mistral AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...cache_utils import Cache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GenericForSequenceClassification, GenericForTokenClassification +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import logging +from ...utils.generic import is_flash_attention_requested +from ..deepseek_v3.modeling_deepseek_v3 import ( + DeepseekV3Attention, + DeepseekV3DecoderLayer, + DeepseekV3MoE, + DeepseekV3NaiveMoe, + apply_rotary_pos_emb_interleave, +) +from ..llama.modeling_llama import ( + LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) +from ..ministral3.modeling_ministral3 import get_llama_4_attn_scale +from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeMLP +from .configuration_mistral4 import Mistral4Config + + +logger = logging.get_logger(__name__) + + +class Mistral4RMSNorm(LlamaRMSNorm): + pass + + +class Mistral4RotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class Mistral4MLP(Qwen2MoeMLP): + pass + + +class Mistral4TopkRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.n_routed_experts = config.n_routed_experts + + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states, self.weight) + return router_logits + + +class Mistral4NaiveMoe(DeepseekV3NaiveMoe): + pass + + +class Mistral4MoE(DeepseekV3MoE): + def route_tokens_to_experts(self, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + router_logits = router_logits.softmax(-1) + group_scores = ( + router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = router_logits.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + +class Mistral4Attention(DeepseekV3Attention): + def __init__(self, config: Mistral4Config, layer_idx: int): + nn.Module.__init__(self) + self.config = config + self.layer_idx = layer_idx + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_attention_heads + + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_head_dim = config.qk_head_dim + + self.is_causal = True + if self.q_lora_rank is None: + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = Mistral4RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) + + self.kv_a_proj_with_mqa = nn.Linear( + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = Mistral4RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + self.scaling = self.qk_head_dim ** (-0.5) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + + if self.q_lora_rank is None: + q_states = self.q_proj(hidden_states) + else: + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q_states = q_states.view(query_shape).transpose(1, 2) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + + cos, sin = position_embeddings + if self.config.rope_interleave: # support using interleaved weights for efficiency + q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) + else: + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange(query_states.shape[2], device=query_states.device) + past_seen_tokens + + query_states = query_states * get_llama_4_attn_scale( + cache_position, + self.config.rope_parameters.get("llama_4_scaling_beta"), + self.config.rope_parameters.get("original_max_position_embeddings"), + ).to(query_states.dtype) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Mistral4DecoderLayer(DeepseekV3DecoderLayer): + def __init__(self, config: Mistral4Config, layer_idx: int): + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + + self.self_attn = Mistral4Attention(config=config, layer_idx=layer_idx) + + if layer_idx >= config.first_k_dense_replace: + self.mlp = Mistral4MoE(config) + else: + self.mlp = Mistral4MLP(config) + + self.input_layernorm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class Mistral4PreTrainedModel(PreTrainedModel): + config: Mistral4Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Mistral4DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Mistral4DecoderLayer, + "attentions": Mistral4Attention, + } + _keep_in_fp32_modules_strict = [] + _keys_to_ignore_on_load_unexpected = [] + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Mistral4TopkRouter): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, Mistral4NaiveMoe): + init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) + init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) + + +class Mistral4Model(LlamaModel): + pass + + +class Mistral4ForCausalLM(LlamaForCausalLM): + pass + + +class Mistral4ForSequenceClassification(GenericForSequenceClassification, Mistral4PreTrainedModel): + pass + + +class Mistral4ForTokenClassification(GenericForTokenClassification, Mistral4PreTrainedModel): + pass + + +__all__ = [ + "Mistral4PreTrainedModel", + "Mistral4Model", + "Mistral4ForCausalLM", + "Mistral4ForSequenceClassification", + "Mistral4ForTokenClassification", +] diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 32045d94f7ca..24cefad6a222 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -500,6 +500,10 @@ class TokenizerChatTemplateKwargs(TypedDict, total=False): Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant, the mask will contain 1. For user and system tokens, the mask will contain 0. This functionality is only available for chat templates that support it via the `{% generation %}` keyword. + reasoning_effort (`str`, *optional*): + The reasoning effort level to use for the model's response. Supported values depend on the model + (e.g. `"none"`, "low"`, `"medium"`, `"high"`). If the template does not support reasoning effort, + this argument will have no effect. """ tools: list[dict] | None = None @@ -507,6 +511,7 @@ class TokenizerChatTemplateKwargs(TypedDict, total=False): add_generation_prompt: bool | None = False continue_final_message: bool | None = False return_assistant_tokens_mask: bool | None = False + reasoning_effort: str | None = None class ProcessorChatTemplateKwargs(TokenizerChatTemplateKwargs, total=False): diff --git a/tests/models/mistral4/__init__.py b/tests/models/mistral4/__init__.py new file mode 100644 index 000000000000..d384c72ec156 --- /dev/null +++ b/tests/models/mistral4/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mistral4 model tests.""" + +from .test_modeling_mistral4 import Mistral4ConfigTest, Mistral4ModelTest diff --git a/tests/models/mistral4/test_modeling_mistral4.py b/tests/models/mistral4/test_modeling_mistral4.py new file mode 100644 index 000000000000..1703cb63e49e --- /dev/null +++ b/tests/models/mistral4/test_modeling_mistral4.py @@ -0,0 +1,133 @@ +# Copyright 2026 the HuggingFace and MistralAI Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Mistral4 model.""" + +import gc +import unittest + +import pytest + +from transformers import AutoTokenizer, Mistral3ForConditionalGeneration, is_torch_available +from transformers.testing_utils import ( + Expectations, + backend_empty_cache, + cleanup, + require_deterministic_for_xpu, + require_flash_attn, + require_torch, + require_torch_accelerator, + slow, + torch_device, +) + + +if is_torch_available(): + import torch + + from transformers import ( + Mistral4Model, + ) + + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + + +class Mistral4ModelTester(CausalLMModelTester): + if is_torch_available(): + base_model_class = Mistral4Model + + +@require_torch +class Mistral4ModelTest(CausalLMModelTest, unittest.TestCase): + _is_stateful = True + model_split_percents = [0.5, 0.6] + model_tester_class = Mistral4ModelTester + + # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 + def is_pipeline_test_to_skip( + self, + pipeline_test_case_name, + config_class, + model_architecture, + tokenizer_name, + image_processor_name, + feature_extractor_name, + processor_name, + ): + return True + + @require_flash_attn + @require_torch_accelerator + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence_right_padding(self): + self.skipTest(reason="Mistral4 flash attention does not support right padding") + + +@require_torch +class Mistral4IntegrationTest(unittest.TestCase): + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @slow + def test_mistral_small_4_logits(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + model = Mistral3ForConditionalGeneration.from_pretrained( + "mistralai/Mistral-Small-4-119B-2603", device_map="auto" + ) + input_ids = torch.tensor([input_ids]).to(model.device) + with torch.no_grad(): + out = model(input_ids).logits.float().cpu() + # Expected mean on dim = -1 + # fmt: off + EXPECTED_MEANS = Expectations( + { + ("cuda", None): torch.tensor([[0.1793, -1.0928, -3.9925, -2.8699, -0.1250, -1.6851, -2.5565, -1.2263]]), + } + ) + # fmt: on + EXPECTED_MEAN = EXPECTED_MEANS.get_expectation() + torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2) + + del model + backend_empty_cache(torch_device) + gc.collect() + + @slow + @require_deterministic_for_xpu + def test_mistral_small_4_generation(self): + # fmt: off + EXPECTED_TEXTS = Expectations( + { + ("cuda", None): "My favourite condiment is 1000 island dressing. I love it on burgers and hot dogs. I also like", + # ("xpu", None): "My favourite condiment is iced tea. I love the way it makes me feel. It’s like a little bubble bath for", + } + ) + # fmt: on + EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation() + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-Small-4-119B-2603") + model = Mistral3ForConditionalGeneration.from_pretrained( + "mistralai/Mistral-Small-4-119B-2603", device_map="auto" + ) + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) + + # greedy generation outputs + generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(text, EXPECTED_TEXT) + + del model + backend_empty_cache(torch_device) + gc.collect() From 4d3e92ee00f14e1a161917c4608517a5b3f058e1 Mon Sep 17 00:00:00 2001 From: juliendenize Date: Mon, 16 Mar 2026 16:02:51 +0000 Subject: [PATCH 2/9] add mistral4 to toctree --- docs/source/en/_toctree.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c42c2c7d870c..07aad5be5b57 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1232,6 +1232,8 @@ title: MGP-STR - local: model_doc/mistral3 title: Mistral3 + - local: model_doc/mistral4 + title: Mistral4 - local: model_doc/mllama title: mllama - local: model_doc/mm-grounding-dino From d30030d0f05d7c7a5a6a4dde0041e82c33f0f2ad Mon Sep 17 00:00:00 2001 From: juliendenize Date: Mon, 16 Mar 2026 16:32:21 +0000 Subject: [PATCH 3/9] Quality --- docs/source/en/model_doc/mistral4.md | 2 +- src/transformers/models/mistral4/configuration_mistral4.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/mistral4.md b/docs/source/en/model_doc/mistral4.md index 610883d85914..6b636e3c05f1 100644 --- a/docs/source/en/model_doc/mistral4.md +++ b/docs/source/en/model_doc/mistral4.md @@ -15,7 +15,7 @@ limitations under the License. ⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer. --> -*This model was released on 2026-16-02 and added to Hugging Face Transformers on 2026-16-02.* +*This model was released on 2026-03-16 and added to Hugging Face Transformers on 2026-03-16.* # Mistral4 diff --git a/src/transformers/models/mistral4/configuration_mistral4.py b/src/transformers/models/mistral4/configuration_mistral4.py index e9ea9981e7cf..6b2d382f486b 100644 --- a/src/transformers/models/mistral4/configuration_mistral4.py +++ b/src/transformers/models/mistral4/configuration_mistral4.py @@ -107,6 +107,7 @@ class Mistral4Config(PreTrainedConfig): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. + mlp_bias (`bool | None`, *optional*, defaults to `False`): ```python >>> from transformers import Mistral4Model, Mistral4Config From 8b5b5ef08f99f50f6bceeb16c64eeab649012873 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Mon, 16 Mar 2026 17:33:36 +0100 Subject: [PATCH 4/9] Update src/transformers/models/mistral4/__init__.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/mistral4/__init__.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/mistral4/__init__.py b/src/transformers/models/mistral4/__init__.py index fdde4abe6b16..cb7639a11977 100644 --- a/src/transformers/models/mistral4/__init__.py +++ b/src/transformers/models/mistral4/__init__.py @@ -24,14 +24,5 @@ import sys _file = globals()["__file__"] - # Explicitly define the import structure to include modeling classes - import_structure = define_import_structure(_file) - # Add the modeling classes explicitly - import_structure[frozenset({})]["modeling_mistral4"] = { - "Mistral4PreTrainedModel", - "Mistral4Model", - "Mistral4ForCausalLM", - "Mistral4ForSequenceClassification", - "Mistral4ForTokenClassification", - } - sys.modules[__name__] = _LazyModule(__name__, _file, import_structure, module_spec=__spec__) + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) + From 06e9f10a09deb79b537f222c4a0d012aaac90bba Mon Sep 17 00:00:00 2001 From: juliendenize Date: Mon, 16 Mar 2026 16:37:52 +0000 Subject: [PATCH 5/9] wip --- src/transformers/models/mistral4/__init__.py | 1 - tests/models/mistral4/__init__.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/mistral4/__init__.py b/src/transformers/models/mistral4/__init__.py index cb7639a11977..2e484504111d 100644 --- a/src/transformers/models/mistral4/__init__.py +++ b/src/transformers/models/mistral4/__init__.py @@ -25,4 +25,3 @@ _file = globals()["__file__"] sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) - diff --git a/tests/models/mistral4/__init__.py b/tests/models/mistral4/__init__.py index d384c72ec156..b1e6bad8a4a1 100644 --- a/tests/models/mistral4/__init__.py +++ b/tests/models/mistral4/__init__.py @@ -14,4 +14,4 @@ """Mistral4 model tests.""" -from .test_modeling_mistral4 import Mistral4ConfigTest, Mistral4ModelTest +from .test_modeling_mistral4 import Mistral4IntegrationTest, Mistral4ModelTest From d93b507ca24c58add463bdf28b5cb3bdf0a9bf95 Mon Sep 17 00:00:00 2001 From: juliendenize Date: Mon, 16 Mar 2026 16:49:59 +0000 Subject: [PATCH 6/9] docstring --- src/transformers/models/mistral4/configuration_mistral4.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/mistral4/configuration_mistral4.py b/src/transformers/models/mistral4/configuration_mistral4.py index 6b2d382f486b..a30d2a0b4420 100644 --- a/src/transformers/models/mistral4/configuration_mistral4.py +++ b/src/transformers/models/mistral4/configuration_mistral4.py @@ -107,7 +107,8 @@ class Mistral4Config(PreTrainedConfig): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - mlp_bias (`bool | None`, *optional*, defaults to `False`): + mlp_bias (`bool | None`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. ```python >>> from transformers import Mistral4Model, Mistral4Config From e9b9a1c68e9edbb0cc28db005447569391c46b12 Mon Sep 17 00:00:00 2001 From: juliendenize Date: Mon, 16 Mar 2026 17:36:11 +0000 Subject: [PATCH 7/9] Make config dataclass --- .../models/mistral4/configuration_mistral4.py | 240 +++++------------- 1 file changed, 63 insertions(+), 177 deletions(-) diff --git a/src/transformers/models/mistral4/configuration_mistral4.py b/src/transformers/models/mistral4/configuration_mistral4.py index a30d2a0b4420..2123d5718b25 100644 --- a/src/transformers/models/mistral4/configuration_mistral4.py +++ b/src/transformers/models/mistral4/configuration_mistral4.py @@ -11,104 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Mistral4 model configuration""" + +from huggingface_hub.dataclasses import strict + from ...configuration_utils import PreTrainedConfig from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring +@auto_docstring(checkpoint="mistralai/Mistral-Small-4-119B-2603") +@strict(accept_kwargs=True) class Mistral4Config(PreTrainedConfig): r""" - This is the configuration class to store the configuration of a [`Mistral4Model`]. It is used to instantiate a Mistral4 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the Mistral-Small-4 model. - e.g. [mistralai/Mistral-Small-4-119B-2603](https://huggingface.co/mistralai/Mistral-Small-4-119B-2603) - - Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PreTrainedConfig`] for more information. + n_group (`int`, *optional*, defaults to 1): + Number of groups for routed experts. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + rope_interleave (`bool`, *optional*, defaults to `True`): + Whether to interleave the rotary position embeddings. - Args: - vocab_size (`int`, *optional*, defaults to 131072): - Vocabulary size of the Mistral4 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Mistral4Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 12288): - Dimension of the MLP representations. - moe_intermediate_size (`int`, *optional*, defaults to 2048): - Dimension of the MoE representations. - num_hidden_layers (`int`, *optional*, defaults to 36): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*, defaults to 32): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details, check out [this - paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to - `num_attention_heads`. - n_shared_experts (`int`, *optional*, defaults to 1): - Number of shared experts. - n_routed_experts (`int`, *optional*, defaults to 128): - Number of routed experts. - routed_scaling_factor (`float`, *optional*, defaults to 1.0): - Scaling factor or routed experts. - kv_lora_rank (`int`, *optional*, defaults to 256): - Rank of the LoRA matrices for key and value projections. - q_lora_rank (`int`, *optional*, defaults to 1024): - Rank of the LoRA matrices for query projections. - qk_rope_head_dim (`int`, *optional*, defaults to 64): - Dimension of the query/key heads that use rotary position embeddings. - v_head_dim (`int`, *optional*, defaults to 128): - Dimension of the value heads. - qk_nope_head_dim (`int`, *optional*, defaults to 64): - Dimension of the query/key heads that don't use rotary position embeddings. - n_group (`int`, *optional*, defaults to 1): - Number of groups for routed experts. - topk_group (`int`, *optional*, defaults to 1): - Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). - num_experts_per_tok (`int`, *optional*, defaults to 4): - Number of selected experts, None means dense model. - first_k_dense_replace (`int`, *optional*, defaults to 0): - Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). - \--k dense layers--/ - norm_topk_prob (`bool`, *optional*, defaults to `True`): - Whether to normalize the weights of the routed experts. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 1048576): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*, defaults to 11): - Padding token id. - bos_token_id (`int`, *optional*, defaults to 1): - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 2): - End of stream token id. - pretraining_tp (`int`, *optional*, defaults to 1): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_parameters (`RopeParameters`, *optional*): - Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain - a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE - with longer `max_position_embeddings`. - rope_interleave (`bool`, *optional*, defaults to `True`): - Whether to interleave the rotary position embeddings. - attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - mlp_bias (`bool | None`, *optional*, defaults to `False`): - Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + Example: ```python >>> from transformers import Mistral4Model, Mistral4Config @@ -141,105 +65,67 @@ class Mistral4Config(PreTrainedConfig): "num_local_experts": "n_routed_experts", } - def __init__( - self, - vocab_size: int | None = 131072, - hidden_size: int | None = 4096, - intermediate_size: int | None = 12288, - moe_intermediate_size: int | None = 2048, - num_hidden_layers: int | None = 36, - num_attention_heads: int | None = 32, - num_key_value_heads: int | None = 32, - n_shared_experts: int | None = 1, - n_routed_experts: int | None = 128, - routed_scaling_factor: float | None = 1.0, - kv_lora_rank: int | None = 256, - q_lora_rank: int | None = 1024, - qk_rope_head_dim: int | None = 64, - v_head_dim: int | None = 128, - qk_nope_head_dim: int | None = 64, - n_group: int | None = 1, - topk_group: int | None = 1, - num_experts_per_tok: int | None = 4, - first_k_dense_replace: int | None = 0, - norm_topk_prob: bool | None = True, - hidden_act: str | None = "silu", - max_position_embeddings: int | None = 1048576, - initializer_range: float | None = 0.02, - rms_norm_eps: float | None = 1e-6, - use_cache: bool | None = True, - pad_token_id: int | None = 11, - bos_token_id: int | None = 1, - eos_token_id: int | None = 2, - pretraining_tp: int | None = 1, - tie_word_embeddings: bool | None = False, - rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None, - rope_interleave: bool | None = True, - attention_bias: bool | None = False, - attention_dropout: float | None = 0.0, - mlp_bias: bool | None = False, - **kwargs, - ): - if rope_parameters is None: - rope_parameters = { + vocab_size: int = 131072 + hidden_size: int = 4096 + intermediate_size: int = 12288 + moe_intermediate_size: int = 2048 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int | None = 32 + n_shared_experts: int = 1 + n_routed_experts: int = 128 + routed_scaling_factor: float = 1.0 + kv_lora_rank: int = 256 + q_lora_rank: int = 1024 + qk_rope_head_dim: int = 64 + v_head_dim: int | None = 128 + qk_nope_head_dim: int = 64 + n_group: int | None = 1 + topk_group: int | None = 1 + num_experts_per_tok: int | None = 4 + first_k_dense_replace: int | None = 0 + norm_topk_prob: bool | None = True + hidden_act: str = "silu" + max_position_embeddings: int = 1048576 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + use_cache: bool = True + pad_token_id: int | None = 11 + bos_token_id: int | None = 1 + eos_token_id: int | None = 2 + pretraining_tp: int | None = 1 + tie_word_embeddings: bool = False + rope_parameters: RopeParameters | dict | None = None + rope_interleave: bool | None = True + attention_bias: bool = False + attention_dropout: float | int | None = 0.0 + mlp_bias: bool | None = False + + def __post_init__(self, **kwargs): + if self.rope_parameters is None: + self.rope_parameters = { "type": "yarn", "rope_theta": 10000.0, "factor": 128.0, "original_max_position_embeddings": 8192, - "max_position_embeddings": max_position_embeddings, + "max_position_embeddings": self.max_position_embeddings, "beta_fast": 32.0, "beta_slow": 1.0, "mscale_all_dim": 1.0, "mscale": 1.0, "llama_4_scaling_beta": 0.1, - "partial_rotary_factor": qk_rope_head_dim / (qk_nope_head_dim + qk_rope_head_dim), + "partial_rotary_factor": self.qk_rope_head_dim / (self.qk_nope_head_dim + self.qk_rope_head_dim), } - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.moe_intermediate_size = moe_intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.n_shared_experts = n_shared_experts - self.n_routed_experts = n_routed_experts - self.routed_scaling_factor = routed_scaling_factor - self.kv_lora_rank = kv_lora_rank - self.q_lora_rank = q_lora_rank - self.qk_rope_head_dim = qk_rope_head_dim - self.v_head_dim = v_head_dim - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - self.head_dim = qk_nope_head_dim + qk_rope_head_dim - self.n_group = n_group - self.topk_group = topk_group - self.num_experts_per_tok = num_experts_per_tok - self.first_k_dense_replace = first_k_dense_replace - self.norm_topk_prob = norm_topk_prob - self.rope_interleave = rope_interleave + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.rope_parameters = rope_parameters + self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + self.head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim self.rope_parameters.setdefault("partial_rotary_factor", self.qk_rope_head_dim / self.head_dim) - - self.tie_word_embeddings = tie_word_embeddings - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - self.mlp_bias = mlp_bias - super().__init__(ignore_keys_at_rope_validation={"llama_4_scaling_beta", "max_position_embeddings"}, **kwargs) + super().__post_init__( + ignore_keys_at_rope_validation={"llama_4_scaling_beta", "max_position_embeddings"}, **kwargs + ) def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: set | None = None, **kwargs): rope_scaling = kwargs.pop("rope_scaling", None) From 421a575ba32ca001b927d04a61beab22b26911ad Mon Sep 17 00:00:00 2001 From: juliendenize Date: Mon, 16 Mar 2026 18:05:43 +0000 Subject: [PATCH 8/9] Fixes fixes fixes --- src/transformers/models/auto/modeling_auto.py | 2 ++ src/transformers/models/mistral4/configuration_mistral4.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 316e394aad87..764d3b770e86 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -1246,6 +1246,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ministral", "MinistralForSequenceClassification"), ("ministral3", "Ministral3ForSequenceClassification"), ("mistral", "MistralForSequenceClassification"), + ("mistral4", "Mistral4ForSequenceClassification"), ("mixtral", "MixtralForSequenceClassification"), ("mobilebert", "MobileBertForSequenceClassification"), ("modernbert", "ModernBertForSequenceClassification"), @@ -1459,6 +1460,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ministral", "MinistralForTokenClassification"), ("ministral3", "Ministral3ForTokenClassification"), ("mistral", "MistralForTokenClassification"), + ("mistral4", "Mistral4ForTokenClassification"), ("mixtral", "MixtralForTokenClassification"), ("mobilebert", "MobileBertForTokenClassification"), ("modernbert", "ModernBertForTokenClassification"), diff --git a/src/transformers/models/mistral4/configuration_mistral4.py b/src/transformers/models/mistral4/configuration_mistral4.py index 2123d5718b25..86f63cca23da 100644 --- a/src/transformers/models/mistral4/configuration_mistral4.py +++ b/src/transformers/models/mistral4/configuration_mistral4.py @@ -99,7 +99,6 @@ class Mistral4Config(PreTrainedConfig): rope_interleave: bool | None = True attention_bias: bool = False attention_dropout: float | int | None = 0.0 - mlp_bias: bool | None = False def __post_init__(self, **kwargs): if self.rope_parameters is None: From b98b67657a1da070c071a8450940849f8136e332 Mon Sep 17 00:00:00 2001 From: juliendenize Date: Mon, 16 Mar 2026 18:26:52 +0000 Subject: [PATCH 9/9] update tp plan --- src/transformers/models/mistral4/configuration_mistral4.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/mistral4/configuration_mistral4.py b/src/transformers/models/mistral4/configuration_mistral4.py index 86f63cca23da..ceb252929f80 100644 --- a/src/transformers/models/mistral4/configuration_mistral4.py +++ b/src/transformers/models/mistral4/configuration_mistral4.py @@ -47,8 +47,9 @@ class Mistral4Config(PreTrainedConfig): model_type = "mistral4" keys_to_ignore_at_inference = ["past_key_values"] base_model_tp_plan = { - "layers.*.mlp.experts.gate_up_proj": "rowwise", + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", "layers.*.mlp.shared_experts.gate_proj": "colwise", "layers.*.mlp.shared_experts.up_proj": "colwise", "layers.*.mlp.shared_experts.down_proj": "rowwise",