diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index c42c2c7d870c..07aad5be5b57 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1232,6 +1232,8 @@ title: MGP-STR - local: model_doc/mistral3 title: Mistral3 + - local: model_doc/mistral4 + title: Mistral4 - local: model_doc/mllama title: mllama - local: model_doc/mm-grounding-dino diff --git a/docs/source/en/model_doc/mistral4.md b/docs/source/en/model_doc/mistral4.md new file mode 100644 index 000000000000..6b636e3c05f1 --- /dev/null +++ b/docs/source/en/model_doc/mistral4.md @@ -0,0 +1,116 @@ + +*This model was released on 2026-03-16 and added to Hugging Face Transformers on 2026-03-16.* + +# Mistral4 + +## Overview + +Mistral 4 is a powerful hybrid model with the capability of acting as both a general instruction model and a reasoning model. It unifies the capabilities of three different model families - Instruct, Reasoning ( previous called Magistral ), and Devstral - into a single, unified model. + +[Mistral-Small-4](https://huggingface.co/mistralai/Mistral-Small-4-119B-2603) consists of the following architectural choices: + +- MoE: 128 experts and 4 active. +- 119B with 6.5B activated parameters per token. +- 256k Context Length. +- Multimodal Input: Accepts both text and image input, with text output. +- Instruct and Reasoning functionalities with Function Calls + - Reasoning Effort configurable by request. + +Mistral 4 offers the following capabilities: + +- **Reasoning Mode**: Switch between a fast instant reply mode, and a reasoning thinking mode, boosting performance with test time compute when requested. +- **Vision**: Enables the model to analyze images and provide insights based on visual content, in addition to text. +- **Multilingual**: Supports dozens of languages, including English, French, Spanish, German, Italian, Portuguese, Dutch, Chinese, Japanese, Korean, Arabic. +- **System Prompt**: Maintains strong adherence and support for system prompts. +- **Agentic**: Offers best-in-class agentic capabilities with native function calling and JSON outputting. +- **Speed-Optimized**: Delivers best-in-class performance and speed. +- **Apache 2.0 License**: Open-source license allowing usage and modification for both commercial and non-commercial purposes. +- **Large Context Window**: Supports a 256k context window. + +## Usage examples + +```py +import torch +from transformers import AutoProcessor, Mistral3ForConditionalGeneration + + +model_id = "mistralai/Mistral-Small-4-119B-2603" + +processor = AutoProcessor.from_pretrained(model_id) +model = Mistral3ForConditionalGeneration.from_pretrained( + model_id, device_map="auto" +) + +image_url = "https://static.wikia.nocookie.net/essentialsdocs/images/7/70/Battle.png/revision/latest?cb=20220523172438" + +messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What action do you think I should take in this situation? List all the possible actions and explain why you think they are good or bad.", + }, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + }, +] + +inputs = processor.apply_chat_template(messages, return_tensors="pt", tokenize=True, return_dict=True, reasoning_effort="high") +inputs = inputs.to(model.device) + +output = model.generate( + **inputs, + max_new_tokens=512, +)[0] + +# Setting `skip_special_tokens=False` to visualize reasoning trace between [THINK] [/THINK] tags. +decoded_output = processor.decode(output[len(inputs["input_ids"][0]):], skip_special_tokens=False) +print(decoded_output) +``` + +## Mistral4Config + +[[autodoc]] Mistral4Config + +## Mistral4PreTrainedModel + +[[autodoc]] Mistral4PreTrainedModel + - forward + +## Mistral4Model + +[[autodoc]] Mistral4Model + - forward + +## Mistral4ForCausalLM + +[[autodoc]] Mistral4ForCausalLM + +## Mistral4ForSequenceClassification + +[[autodoc]] Mistral4ForSequenceClassification + +## Mistral4ForTokenClassification + +[[autodoc]] Mistral4ForTokenClassification + +## Mistral4ForQuestionAnswering + +[[autodoc]] Mistral4ForQuestionAnswering diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 7caa68039500..5f45081ac4a0 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -250,6 +250,7 @@ from .ministral3 import * from .mistral import * from .mistral3 import * + from .mistral4 import * from .mixtral import * from .mlcd import * from .mllama import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 3c71c30f120d..476b5362343f 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -287,6 +287,7 @@ ("ministral3", "Ministral3Config"), ("mistral", "MistralConfig"), ("mistral3", "Mistral3Config"), + ("mistral4", "Mistral4Config"), ("mixtral", "MixtralConfig"), ("mlcd", "MLCDVisionConfig"), # Keep this to make some original hub repositories (from `DeepGlint-AI`) works ("mlcd_vision_model", "MLCDVisionConfig"), @@ -797,6 +798,7 @@ ("ministral3", "Ministral3"), ("mistral", "Mistral"), ("mistral3", "Mistral3"), + ("mistral4", "Mistral4"), ("mixtral", "Mixtral"), ("mlcd", "MLCD"), # Keep this to make some original hub repositories (from `DeepGlint-AI`) works ("mlcd_vision_model", "MLCD"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 8a94da531eb8..764d3b770e86 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -282,6 +282,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ministral3", "Ministral3Model"), ("mistral", "MistralModel"), ("mistral3", "Mistral3Model"), + ("mistral4", "Mistral4Model"), ("mixtral", "MixtralModel"), ("mlcd", "MLCDVisionModel"), # Keep this to make some original hub repositories (from `DeepGlint-AI`) works ("mlcd_vision_model", "MLCDVisionModel"), @@ -541,6 +542,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("mamba2", "Mamba2ForCausalLM"), ("megatron-bert", "MegatronBertForPreTraining"), ("mistral3", "Mistral3ForConditionalGeneration"), + ("mistral4", "Mistral4ForCausalLM"), ("mllama", "MllamaForConditionalGeneration"), ("mobilebert", "MobileBertForPreTraining"), ("mpnet", "MPNetForMaskedLM"), @@ -981,6 +983,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), ("mistral3", "Mistral3ForConditionalGeneration"), + ("mistral4", "Mistral4ForCausalLM"), ("mllama", "MllamaForConditionalGeneration"), ("ovis2", "Ovis2ForConditionalGeneration"), ("paddleocr_vl", "PaddleOCRVLForConditionalGeneration"), @@ -1243,6 +1246,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ministral", "MinistralForSequenceClassification"), ("ministral3", "Ministral3ForSequenceClassification"), ("mistral", "MistralForSequenceClassification"), + ("mistral4", "Mistral4ForSequenceClassification"), ("mixtral", "MixtralForSequenceClassification"), ("mobilebert", "MobileBertForSequenceClassification"), ("modernbert", "ModernBertForSequenceClassification"), @@ -1456,6 +1460,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("ministral", "MinistralForTokenClassification"), ("ministral3", "Ministral3ForTokenClassification"), ("mistral", "MistralForTokenClassification"), + ("mistral4", "Mistral4ForTokenClassification"), ("mixtral", "MixtralForTokenClassification"), ("mobilebert", "MobileBertForTokenClassification"), ("modernbert", "ModernBertForTokenClassification"), diff --git a/src/transformers/models/ministral3/modeling_ministral3.py b/src/transformers/models/ministral3/modeling_ministral3.py index be0659788d71..3d2d15fa1a84 100644 --- a/src/transformers/models/ministral3/modeling_ministral3.py +++ b/src/transformers/models/ministral3/modeling_ministral3.py @@ -102,7 +102,7 @@ def eager_attention_forward( return attn_output, attn_weights -def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor: +def get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor: scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings)) return scaling.unsqueeze(-1) @@ -144,7 +144,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange(query_states.shape[2], device=query_states.device) + past_seen_tokens - query_states = query_states * _get_llama_4_attn_scale( + query_states = query_states * get_llama_4_attn_scale( cache_position, self.config.rope_parameters.get("llama_4_scaling_beta"), self.config.rope_parameters.get("original_max_position_embeddings"), diff --git a/src/transformers/models/ministral3/modular_ministral3.py b/src/transformers/models/ministral3/modular_ministral3.py index bde7829d769a..9dcac11dd058 100644 --- a/src/transformers/models/ministral3/modular_ministral3.py +++ b/src/transformers/models/ministral3/modular_ministral3.py @@ -26,7 +26,7 @@ logger = logging.get_logger(__name__) -def _get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor: +def get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor: scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings)) return scaling.unsqueeze(-1) @@ -51,7 +51,7 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange(query_states.shape[2], device=query_states.device) + past_seen_tokens - query_states = query_states * _get_llama_4_attn_scale( + query_states = query_states * get_llama_4_attn_scale( cache_position, self.config.rope_parameters.get("llama_4_scaling_beta"), self.config.rope_parameters.get("original_max_position_embeddings"), diff --git a/src/transformers/models/mistral4/__init__.py b/src/transformers/models/mistral4/__init__.py new file mode 100644 index 000000000000..2e484504111d --- /dev/null +++ b/src/transformers/models/mistral4/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2026 Mistral AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_mistral4 import * + from .modeling_mistral4 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/mistral4/configuration_mistral4.py b/src/transformers/models/mistral4/configuration_mistral4.py new file mode 100644 index 000000000000..ceb252929f80 --- /dev/null +++ b/src/transformers/models/mistral4/configuration_mistral4.py @@ -0,0 +1,149 @@ +# Copyright 2026 Mistral AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Mistral4 model configuration""" + +from huggingface_hub.dataclasses import strict + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters +from ...utils import auto_docstring + + +@auto_docstring(checkpoint="mistralai/Mistral-Small-4-119B-2603") +@strict(accept_kwargs=True) +class Mistral4Config(PreTrainedConfig): + r""" + n_group (`int`, *optional*, defaults to 1): + Number of groups for routed experts. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + rope_interleave (`bool`, *optional*, defaults to `True`): + Whether to interleave the rotary position embeddings. + + Example: + + ```python + >>> from transformers import Mistral4Model, Mistral4Config + + >>> # Initializing a Mistral4 style configuration + >>> configuration = Mistral4Config() + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mistral4" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.mlp.experts.gate_up_proj": "packed_colwise", + "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.experts": "moe_tp_experts", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + attribute_map = { + "num_local_experts": "n_routed_experts", + } + + vocab_size: int = 131072 + hidden_size: int = 4096 + intermediate_size: int = 12288 + moe_intermediate_size: int = 2048 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int | None = 32 + n_shared_experts: int = 1 + n_routed_experts: int = 128 + routed_scaling_factor: float = 1.0 + kv_lora_rank: int = 256 + q_lora_rank: int = 1024 + qk_rope_head_dim: int = 64 + v_head_dim: int | None = 128 + qk_nope_head_dim: int = 64 + n_group: int | None = 1 + topk_group: int | None = 1 + num_experts_per_tok: int | None = 4 + first_k_dense_replace: int | None = 0 + norm_topk_prob: bool | None = True + hidden_act: str = "silu" + max_position_embeddings: int = 1048576 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + use_cache: bool = True + pad_token_id: int | None = 11 + bos_token_id: int | None = 1 + eos_token_id: int | None = 2 + pretraining_tp: int | None = 1 + tie_word_embeddings: bool = False + rope_parameters: RopeParameters | dict | None = None + rope_interleave: bool | None = True + attention_bias: bool = False + attention_dropout: float | int | None = 0.0 + + def __post_init__(self, **kwargs): + if self.rope_parameters is None: + self.rope_parameters = { + "type": "yarn", + "rope_theta": 10000.0, + "factor": 128.0, + "original_max_position_embeddings": 8192, + "max_position_embeddings": self.max_position_embeddings, + "beta_fast": 32.0, + "beta_slow": 1.0, + "mscale_all_dim": 1.0, + "mscale": 1.0, + "llama_4_scaling_beta": 0.1, + "partial_rotary_factor": self.qk_rope_head_dim / (self.qk_nope_head_dim + self.qk_rope_head_dim), + } + + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + self.head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + self.rope_parameters.setdefault("partial_rotary_factor", self.qk_rope_head_dim / self.head_dim) + super().__post_init__( + ignore_keys_at_rope_validation={"llama_4_scaling_beta", "max_position_embeddings"}, **kwargs + ) + + def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: set | None = None, **kwargs): + rope_scaling = kwargs.pop("rope_scaling", None) + self.rope_parameters = rope_scaling or self.rope_parameters + self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else {} + + # Standardize and validate the correctness of rotary position embeddings parameters + self.rope_parameters.setdefault("rope_theta", kwargs.pop("rope_theta", self.default_theta)) + self.standardize_rope_params() + if ignore_keys_at_rope_validation is not None: + self.ignore_keys_at_rope_validation = self.ignore_keys_at_rope_validation | ignore_keys_at_rope_validation + self.validate_rope() + + # Convert to float because RoPE fn expect a float. Models on the hub were saved as int + for key in ["beta_fast", "beta_slow", "factor"]: + if key in self.rope_parameters: + self.rope_parameters[key] = float(self.rope_parameters[key]) + return kwargs + + +__all__ = ["Mistral4Config"] diff --git a/src/transformers/models/mistral4/convert_mistral4_weight_to_hf.py b/src/transformers/models/mistral4/convert_mistral4_weight_to_hf.py new file mode 100644 index 000000000000..b3c3b7714fca --- /dev/null +++ b/src/transformers/models/mistral4/convert_mistral4_weight_to_hf.py @@ -0,0 +1,608 @@ +# Copyright 2026 Mistral AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import re +from collections import defaultdict +from pathlib import Path +from typing import Any + +import torch +from safetensors.torch import load_file + +from transformers import ( + GenerationConfig, + Mistral3Config, + Mistral3ForConditionalGeneration, + Mistral4Config, + PixtralImageProcessorFast, + PixtralProcessor, + PixtralVisionConfig, +) +from transformers.core_model_loading import ( + Concatenate, + ConversionOps, + MergeModulelist, + WeightRenaming, +) +from transformers.integrations.finegrained_fp8 import replace_with_fp8_linear +from transformers.integrations.mistral import convert_tekken_tokenizer +from transformers.models.mistral4.modeling_mistral4 import Mistral4ForCausalLM +from transformers.quantizers.auto import AutoQuantizationConfig + + +_FP8_DTYPE = torch.float8_e4m3fn +_FP8_MIN = torch.finfo(_FP8_DTYPE).min +_FP8_MAX = torch.finfo(_FP8_DTYPE).max + +EXPERT_KEY_PATTERN = re.compile(r"^layers\.(\d+)\.experts\.(\d+)\.(w[123])\.(weight|qscale_act|qscale_weight)$") + + +class FP8RescaleMergeAndConcatenate(ConversionOps): + r"""FP8-aware gate+up expert fusion with per-expert scale rescaling. + + Takes per-expert gate (w1) and up (w3) weight tensors together with their + FP8 `weight_scale_inv` values, rescales both to a common scale per expert + (the max of the two), concatenates gate+up along `dim=0`, and stacks + across experts along a new leading dimension. + """ + + @torch.no_grad() + def convert( + self, + input_dict: dict[str, list[torch.Tensor]], + source_patterns: list[str], + target_patterns: list[str], + **kwargs: Any, + ) -> dict[str, torch.Tensor]: + w1_weights = input_dict["w1.weight"] + w3_weights = input_dict["w3.weight"] + w1_scales = input_dict["w1.qscale_weight"] + w3_scales = input_dict["w3.qscale_weight"] + + gate_up_list: list[torch.Tensor] = [] + scale_inv_list: list[torch.Tensor] = [] + + for e in range(len(w1_weights)): + fused_scale_inv = torch.max(w1_scales[e], w3_scales[e]) + gate = _rescale_fp8(w1_weights[e], w1_scales[e], fused_scale_inv) + up = _rescale_fp8(w3_weights[e], w3_scales[e], fused_scale_inv) + gate_up_list.append(torch.cat([gate, up], dim=0)) + scale_inv_list.append(fused_scale_inv) + + gate_up_proj_scale_inv = torch.stack(scale_inv_list) + while gate_up_proj_scale_inv.ndim < 3: + gate_up_proj_scale_inv = gate_up_proj_scale_inv.unsqueeze(-1) + + return { + "gate_up_proj": torch.stack(gate_up_list, dim=0), + "gate_up_proj_scale_inv": gate_up_proj_scale_inv, + } + + +def _get_text_renamings(prefix: str) -> list[WeightRenaming]: + r"""Build `WeightRenaming` list for text-model keys.""" + return [ + WeightRenaming("^output", "lm_head"), + WeightRenaming("^norm", f"{prefix}.norm"), + WeightRenaming("^tok_embeddings", f"{prefix}.embed_tokens"), + WeightRenaming("^layers", f"{prefix}.layers"), + WeightRenaming("attention_norm", "input_layernorm"), + WeightRenaming("ffn_norm", "post_attention_layernorm"), + WeightRenaming(r"attention\.wkv_a_with_mqa", "self_attn.kv_a_proj_with_mqa"), + WeightRenaming(r"attention\.wkv_b", "self_attn.kv_b_proj"), + WeightRenaming(r"attention\.wq_a", "self_attn.q_a_proj"), + WeightRenaming(r"attention\.wq_b", "self_attn.q_b_proj"), + WeightRenaming(r"attention\.wo", "self_attn.o_proj"), + WeightRenaming(r"attention\.q_a_norm", "self_attn.q_a_layernorm"), + WeightRenaming(r"attention\.kv_a_norm", "self_attn.kv_a_layernorm"), + WeightRenaming(r"\.gate\.weight", ".mlp.gate.weight"), + WeightRenaming(r"shared_experts\.w1", "mlp.shared_experts.gate_proj"), + WeightRenaming(r"shared_experts\.w2", "mlp.shared_experts.down_proj"), + WeightRenaming(r"shared_experts\.w3", "mlp.shared_experts.up_proj"), + WeightRenaming(r"\.qscale_act", ".activation_scale"), + WeightRenaming(r"\.qscale_weight", ".weight_scale_inv"), + ] + + +def _get_vision_renamings() -> list[WeightRenaming]: + r"""Build `WeightRenaming` list for vision-model keys.""" + return [ + WeightRenaming("^vision_encoder", "model.vision_tower"), + WeightRenaming(r"^vision_language_adapter\.w_in", "model.multi_modal_projector.linear_1"), + WeightRenaming(r"^vision_language_adapter\.w_out", "model.multi_modal_projector.linear_2"), + WeightRenaming("^patch_merger", "model.multi_modal_projector.patch_merger"), + WeightRenaming("^pre_mm_projector_norm", "model.multi_modal_projector.norm"), + WeightRenaming(r"attention\.wq\.", "attention.q_proj."), + WeightRenaming(r"attention\.wk\.", "attention.k_proj."), + WeightRenaming(r"attention\.wv\.", "attention.v_proj."), + WeightRenaming(r"attention\.wo\.", "attention.o_proj."), + WeightRenaming(r"feed_forward\.w1", "feed_forward.gate_proj"), + WeightRenaming(r"feed_forward\.w2", "feed_forward.down_proj"), + WeightRenaming(r"feed_forward\.w3", "feed_forward.up_proj"), + ] + + +_VISION_KEY_PREFIXES = ("vision_encoder.", "vision_language_adapter.", "patch_merger.", "pre_mm_projector_norm.") + + +def _is_vision_key(key: str) -> bool: + r"""Return whether *key* belongs to the vision / projector components.""" + return key.startswith(_VISION_KEY_PREFIXES) + + +def _rename_key( + key: str, + text_renamings: list[WeightRenaming], + vision_renamings: list[WeightRenaming], +) -> str: + r"""Apply the appropriate `WeightRenaming` chain to *key*.""" + renamings = vision_renamings if _is_vision_key(key) else text_renamings + for renaming in renamings: + key, _ = renaming.rename_source_key(key) + return key + + +def _rescale_fp8( + tensor: torch.Tensor, + original_scale_inv: torch.Tensor, + target_scale_inv: torch.Tensor, +) -> torch.Tensor: + r"""Rescale an FP8 tensor from *original_scale_inv* to *target_scale_inv*.""" + ratio = original_scale_inv / target_scale_inv + return (tensor.to(torch.bfloat16) * ratio).clamp(min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) + + +def _descale_fp8_to_bf16(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor: + r"""Descale an FP8 tensor back to BF16 using its `weight_scale_inv`.""" + return (tensor.to(torch.bfloat16) * scale_inv.to(torch.bfloat16)).to(torch.bfloat16) + + +def _permute_for_rope(tensor: torch.Tensor, n_heads: int, dim1: int, dim2: int) -> torch.Tensor: + r"""Permute Q/K weight matrices from Mistral's interleaved layout to HF's contiguous-half layout.""" + tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2) + tensor = tensor.transpose(1, 2) + tensor = tensor.reshape(dim1, dim2) + return tensor + + +def _maybe_permute_vision_rope( + new_key: str, + tensor: torch.Tensor, + vision_config: PixtralVisionConfig, +) -> torch.Tensor: + r"""Apply RoPE permutation to vision Q/K weight matrices if applicable.""" + num_attention_heads = vision_config.num_attention_heads + hidden_size = vision_config.hidden_size + head_dim = vision_config.head_dim + attn_dim = head_dim * num_attention_heads + + if "q_proj" in new_key and new_key.endswith("weight"): + tensor = _permute_for_rope(tensor, num_attention_heads, attn_dim, hidden_size) + elif "k_proj" in new_key and new_key.endswith("weight"): + tensor = _permute_for_rope(tensor, num_attention_heads, attn_dim, hidden_size) + + return tensor + + +def _fuse_experts_for_layer( + grouped: dict[tuple, dict[int, torch.Tensor]], + layer_idx: int, + n_experts: int, + base: str, + output_fp8: bool, +) -> dict[str, torch.Tensor]: + r"""Fuse per-expert weights for a single layer""" + merge_op = MergeModulelist(dim=0) + + w1 = grouped[(layer_idx, "w1", "weight")] + w2 = grouped[(layer_idx, "w2", "weight")] + w3 = grouped[(layer_idx, "w3", "weight")] + + w1_scales = grouped.get((layer_idx, "w1", "qscale_weight")) + w2_scales = grouped.get((layer_idx, "w2", "qscale_weight")) + w3_scales = grouped.get((layer_idx, "w3", "qscale_weight")) + + result: dict[str, torch.Tensor] = {} + + if output_fp8: + fp8_fuse_op = FP8RescaleMergeAndConcatenate() + gate_up_result = fp8_fuse_op.convert( + input_dict={ + "w1.weight": [w1[e] for e in range(n_experts)], + "w3.weight": [w3[e] for e in range(n_experts)], + "w1.qscale_weight": [w1_scales[e] for e in range(n_experts)], + "w3.qscale_weight": [w3_scales[e] for e in range(n_experts)], + }, + source_patterns=["w1.weight", "w3.weight", "w1.qscale_weight", "w3.qscale_weight"], + target_patterns=["gate_up_proj", "gate_up_proj_scale_inv"], + ) + result[f"{base}.gate_up_proj"] = gate_up_result["gate_up_proj"] + result[f"{base}.gate_up_proj_scale_inv"] = gate_up_result["gate_up_proj_scale_inv"] + + down_result = merge_op.convert( + input_dict={"w2": [w2[e] for e in range(n_experts)]}, + source_patterns=["w2"], + target_patterns=["down_proj"], + ) + result[f"{base}.down_proj"] = down_result["down_proj"] + down_proj_scale_inv = torch.stack([w2_scales[e] for e in range(n_experts)]) + while down_proj_scale_inv.ndim < 3: + down_proj_scale_inv = down_proj_scale_inv.unsqueeze(-1) + result[f"{base}.down_proj_scale_inv"] = down_proj_scale_inv + + w1_act = grouped.get((layer_idx, "w1", "qscale_act")) + if w1_act is not None: + w2_act = grouped[(layer_idx, "w2", "qscale_act")] + w3_act = grouped[(layer_idx, "w3", "qscale_act")] + result[f"{base}.gate_up_proj_activation_scale"] = torch.stack( + [torch.max(w1_act[e], w3_act[e]) for e in range(n_experts)] + ) + result[f"{base}.down_proj_activation_scale"] = torch.stack([w2_act[e] for e in range(n_experts)]) + else: + concat_op = Concatenate(dim=1) + + w1_list = [_descale_fp8_to_bf16(w1[e], w1_scales[e]) if w1_scales else w1[e] for e in range(n_experts)] + w3_list = [_descale_fp8_to_bf16(w3[e], w3_scales[e]) if w3_scales else w3[e] for e in range(n_experts)] + w2_list = [_descale_fp8_to_bf16(w2[e], w2_scales[e]) if w2_scales else w2[e] for e in range(n_experts)] + + step1 = merge_op.convert( + input_dict={"w1": w1_list, "w3": w3_list}, + source_patterns=["w1", "w3"], + target_patterns=["gate_up_proj"], + ) + gate_up = concat_op.convert(step1, source_patterns=["w1", "w3"], target_patterns=["gate_up_proj"]) + result[f"{base}.gate_up_proj"] = gate_up["gate_up_proj"] + + down = merge_op.convert( + input_dict={"w2": w2_list}, + source_patterns=["w2"], + target_patterns=["down_proj"], + ) + result[f"{base}.down_proj"] = down["down_proj"] + + return result + + +def fuse_experts( + expert_weights: dict[tuple, torch.Tensor], + n_experts: int, + has_vision: bool, + output_fp8: bool, +) -> dict[str, torch.Tensor]: + r"""Fuse per-expert weights across all layers.""" + prefix = "model.language_model" if has_vision else "model" + + grouped: dict[tuple, dict[int, torch.Tensor]] = defaultdict(dict) + for (layer_idx, expert_idx, param_type, suffix), tensor in expert_weights.items(): + grouped[(layer_idx, param_type, suffix)][int(expert_idx)] = tensor + + consumed_keys: set[tuple] = set() + result: dict[str, torch.Tensor] = {} + layers = sorted({layer_idx for (layer_idx, _, _) in grouped}) + + for layer_idx in layers: + base = f"{prefix}.layers.{layer_idx}.mlp.experts" + + w1_weight_key = (layer_idx, "w1", "weight") + assert w1_weight_key in grouped, f"Layer {layer_idx}: missing w1 weights" + assert len(grouped[w1_weight_key]) == n_experts, ( + f"Layer {layer_idx}: expected {n_experts} w1 experts, got {len(grouped[w1_weight_key])}" + ) + + for param in ("w1", "w2", "w3"): + for suffix in ("weight", "qscale_weight", "qscale_act"): + key = (layer_idx, param, suffix) + if key in grouped: + consumed_keys.add(key) + + layer_result = _fuse_experts_for_layer(grouped, layer_idx, n_experts, base, output_fp8) + + result.update(layer_result) + + unconsumed = set(grouped.keys()) - consumed_keys + assert not unconsumed, f"Unconsumed expert groups: {unconsumed}" + + return result + + +def convert_state_dict( + original_state_dict: dict[str, torch.Tensor], + text_renamings: list[WeightRenaming], + vision_renamings: list[WeightRenaming], + total_keys_seen: set[str], + vision_config: PixtralVisionConfig | None = None, + is_fp8_source: bool = False, + output_bf16: bool = False, +) -> tuple[dict[str, torch.Tensor], dict[tuple, torch.Tensor]]: + r"""Rename and optionally descale one shard of the original state dict.""" + new_dict: dict[str, torch.Tensor] = {} + expert_weights: dict[tuple, torch.Tensor] = {} + + for old_key, tensor in original_state_dict.items(): + assert old_key not in total_keys_seen, f"Duplicate key across shards: {old_key}" + total_keys_seen.add(old_key) + + match = EXPERT_KEY_PATTERN.match(old_key) + if match: + layer_idx, expert_idx, param_type, suffix = int(match[1]), int(match[2]), match[3], match[4] + expert_weights[(layer_idx, expert_idx, param_type, suffix)] = tensor + continue + + if output_bf16 and is_fp8_source: + if old_key.endswith((".qscale_act", ".qscale_weight")): + continue + if old_key.endswith(".weight"): + scale_key = old_key.rsplit(".weight", 1)[0] + ".qscale_weight" + if scale_key in original_state_dict: + tensor = _descale_fp8_to_bf16(tensor, original_state_dict[scale_key]) + + new_key = _rename_key(old_key, text_renamings, vision_renamings) + + if vision_config is not None and "vision_tower" in new_key: + tensor = _maybe_permute_vision_rope(new_key, tensor, vision_config) + + new_dict[new_key] = tensor + + return new_dict, expert_weights + + +def _read_json(path: Path) -> dict: + with open(path) as f: + return json.load(f) + + +def convert_config( + original_config: dict, + max_position_embeddings: int = 1_048_576, + is_vision: bool = True, + output_fp8: bool = True, +) -> Mistral3Config | Mistral4Config: + r"""Convert original Mistral `params.json` to a HF config object.""" + original_vision_config = original_config.pop("vision_encoder", None) + assert is_vision == (original_vision_config is not None) + + text_kwargs: dict[str, Any] = { + "hidden_size": original_config["dim"], + "num_hidden_layers": original_config["n_layers"], + "intermediate_size": original_config["hidden_dim"], + "num_attention_heads": original_config["n_heads"], + "num_key_value_heads": original_config["n_kv_heads"], + "rms_norm_eps": original_config["norm_eps"], + "vocab_size": original_config["vocab_size"], + "tie_word_embeddings": original_config.get("tied_embeddings", False), + "sliding_window": int(original_config["sliding_window"]) + if original_config.get("sliding_window") is not None + else None, + "max_position_embeddings": original_config.get( + "max_position_embeddings", original_config.get("max_seq_len", max_position_embeddings) + ), + "bos_token_id": 1, + "eos_token_id": 2, + "pad_token_id": 11, + } + + for key in ["q_lora_rank", "qk_rope_head_dim", "qk_nope_head_dim", "kv_lora_rank", "v_head_dim"]: + if key in original_config: + text_kwargs[key] = original_config[key] + + moe = original_config.get("moe") + assert moe is not None + text_kwargs.update( + { + "n_routed_experts": moe.get("num_experts", 128), + "num_experts_per_tok": moe.get("num_experts_per_tok", 4), + "first_k_dense_replace": moe.get("first_k_dense_replace", 0), + "n_shared_experts": moe.get("num_shared_experts", 1), + "moe_intermediate_size": moe.get("expert_hidden_dim", 2048), + "routed_scaling_factor": moe.get("routed_scale", 1.0), + "n_group": moe.get("num_expert_groups", 1), + "topk_group": moe.get("num_expert_groups_per_tok", 1), + "norm_topk_prob": True, + } + ) + + qk_rope_head_dim = text_kwargs.get("qk_rope_head_dim", 64) + qk_nope_head_dim = text_kwargs.get("qk_nope_head_dim", 64) + + text_kwargs["rope_parameters"] = { + "type": "yarn", + "rope_theta": original_config.get("rope_theta", 10_000.0), + "factor": float(original_config["yarn"]["factor"]), + "original_max_position_embeddings": original_config["yarn"]["original_max_position_embeddings"], + "beta_fast": float(original_config["yarn"]["beta"]), + "beta_slow": float(original_config["yarn"]["alpha"]), + "mscale_all_dim": 1.0, + "mscale": 1.0, + "llama_4_scaling_beta": original_config.get("llama_4_scaling", {}).get("beta", 0.1), + "partial_rotary_factor": qk_rope_head_dim / (qk_nope_head_dim + qk_rope_head_dim), + } + + quant_kwargs: dict[str, Any] = {} + quant = original_config.get("quantization", {}) + if output_fp8 and quant.get("qformat_weight") == "fp8_e4m3": + assert quant["qscheme_act"] == "TENSOR" + quant_kwargs["quantization_config"] = AutoQuantizationConfig.from_dict( + { + "activation_scheme": "static", + "modules_to_not_convert": ["model.vision_tower", "model.multi_modal_projector", "lm_head"], + "quant_method": "fp8", + "weight_block_size": None, + } + ) + + if not is_vision: + return Mistral4Config(**text_kwargs, **quant_kwargs) + + text_config = Mistral4Config(**text_kwargs) + adapter_bias = original_vision_config.pop("adapter_bias", False) + spatial_merge_size = original_vision_config.pop("spatial_merge_size") + image_token_id = original_vision_config.pop("image_token_id", 10) + for drop_key in [ + "mm_projector_id", + "add_pre_mm_projector_layer_norm", + "image_break_token_id", + "image_end_token_id", + "max_image_size", + ]: + original_vision_config.pop(drop_key, None) + vision_config = PixtralVisionConfig(hidden_act="silu", **original_vision_config) + + return Mistral3Config( + vision_config=vision_config, + text_config=text_config, + multimodal_projector_bias=adapter_bias, + image_token_id=image_token_id, + spatial_merge_size=spatial_merge_size, + vision_feature_layer=-1, + tie_word_embeddings=text_kwargs["tie_word_embeddings"], + **quant_kwargs, + ) + + +def convert_and_write_model( + input_dir: Path, + output_dir: Path, + max_position_embeddings: int, + output_format: str, +) -> Mistral3Config | Mistral4Config: + r"""Convert weights and write the HF model to output_dir.""" + params = _read_json(input_dir / "params.json") + is_vision = params.get("vision_encoder") is not None + is_fp8_source = params.get("quantization", {}).get("qformat_weight") == "fp8_e4m3" + output_fp8 = output_format == "fp8" and is_fp8_source + output_bf16 = not output_fp8 + + config = convert_config(params, max_position_embeddings, is_vision, output_fp8) + + text_config = config.text_config if isinstance(config, Mistral3Config) else config + n_experts = text_config.n_routed_experts + vision_config = config.vision_config if isinstance(config, Mistral3Config) else None + + model_prefix = "model.language_model" if is_vision else "model" + text_renamings = _get_text_renamings(model_prefix) + vision_renamings = _get_vision_renamings() if is_vision else [] + + full_state_dict: dict[str, torch.Tensor] = {} + all_expert_weights: dict[tuple, torch.Tensor] = {} + total_keys_seen: set[str] = set() + shards = sorted(p for p in input_dir.iterdir() if p.suffix == ".safetensors") + assert shards, f"No .safetensors files found in {input_dir}" + + for shard_path in shards: + print(f"Processing shard: {shard_path.name}") + original = load_file(str(shard_path)) + new_dict, expert_weights = convert_state_dict( + original, + text_renamings, + vision_renamings, + total_keys_seen, + vision_config, + is_fp8_source, + output_bf16, + ) + del original + full_state_dict.update(new_dict) + del new_dict + all_expert_weights.update(expert_weights) + del expert_weights + + print(f"Fusing {len(all_expert_weights)} expert weight entries...") + fused = fuse_experts(all_expert_weights, n_experts, is_vision, output_fp8) + del all_expert_weights + full_state_dict.update(fused) + del fused + + if text_config.tie_word_embeddings: + full_state_dict["lm_head.weight"] = full_state_dict[f"{model_prefix}.embed_tokens.weight"] + + with torch.device("meta"): + if isinstance(config, Mistral3Config): + model = Mistral3ForConditionalGeneration(config) + else: + model = Mistral4ForCausalLM(config) + + if output_fp8 and hasattr(model.config, "quantization_config"): + qconfig = model.config.quantization_config + model = replace_with_fp8_linear(model, qconfig.modules_to_not_convert, qconfig) + + model.load_state_dict(full_state_dict, strict=True, assign=True) + model.save_pretrained(str(output_dir)) + return config + + +def convert_and_write_processor_and_tokenizer( + input_dir: Path, + output_dir: Path, + model_config: Mistral3Config | Mistral4Config, +) -> None: + r"""Convert and write tokenizer (and processor for VLMs) to *output_dir*.""" + tokenizer = convert_tekken_tokenizer(str(input_dir / "tekken.json")) + + if isinstance(model_config, Mistral4Config): + tokenizer.save_pretrained(str(output_dir)) + return + + params = _read_json(input_dir / "params.json") + ve = params["vision_encoder"] + + processor = PixtralProcessor( + tokenizer=tokenizer, + image_processor=PixtralImageProcessorFast( + patch_size=ve["patch_size"], size={"longest_edge": ve["max_image_size"]} + ), + image_token="[IMG]", + patch_size=ve["patch_size"], + spatial_merge_size=ve["spatial_merge_size"], + ) + processor.save_pretrained(str(output_dir)) + + text_config = model_config.text_config if hasattr(model_config, "text_config") else model_config + GenerationConfig( + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + max_length=text_config.max_position_embeddings, + ).save_pretrained(str(output_dir)) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Convert Mistral4 weights to HuggingFace format.") + parser.add_argument( + "input_dir", + type=Path, + help="Directory containing Mistral weights (params.json, tekken.json, *.safetensors)", + ) + parser.add_argument("output_dir", type=Path, help="Output directory for HF model") + parser.add_argument( + "--max_position_embeddings", + type=int, + default=1_048_576, + help="max_position_embeddings (used when not specified in params.json)", + ) + parser.add_argument( + "--output_format", + choices=["fp8", "bf16"], + default="fp8", + help="Output format: 'fp8' keeps FP8 quantization (default), 'bf16' descales to BF16", + ) + args = parser.parse_args() + + config = convert_and_write_model(args.input_dir, args.output_dir, args.max_position_embeddings, args.output_format) + convert_and_write_processor_and_tokenizer(args.input_dir, args.output_dir, config) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/mistral4/modeling_mistral4.py b/src/transformers/models/mistral4/modeling_mistral4.py new file mode 100644 index 000000000000..df836e52f2dd --- /dev/null +++ b/src/transformers/models/mistral4/modeling_mistral4.py @@ -0,0 +1,730 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/mistral4/modular_mistral4.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_mistral4.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2026 Mistral AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernel_func_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import ( + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from .configuration_mistral4 import Mistral4Config + + +@use_kernel_forward_from_hub("RMSNorm") +class Mistral4RMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + Mistral4RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Mistral4RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Mistral4Config, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: Mistral4Config | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Mistral4MLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Mistral4TopkRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.n_routed_experts = config.n_routed_experts + + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states, self.weight) + return router_logits + + +@use_experts_implementation +class Mistral4NaiveMoe(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config): + super().__init__() + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + +class Mistral4MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.experts = Mistral4NaiveMoe(config) + self.gate = Mistral4TopkRouter(config) + self.shared_experts = Mistral4MLP( + config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) + self.n_routed_experts = config.n_routed_experts + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + self.top_k = config.num_experts_per_tok + + def route_tokens_to_experts(self, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + router_logits = router_logits.softmax(-1) + group_scores = ( + router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = router_logits.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + router_logits = self.gate(hidden_states) + topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + r""" + TODO let's just use the original freqcis computation to not have the view + transpose + reshape! This is not optimized! + Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def get_llama_4_attn_scale(positions_ids: torch.Tensor, beta: float, max_position_embeddings: int) -> torch.Tensor: + scaling = 1 + beta * torch.log(1 + torch.floor(positions_ids / max_position_embeddings)) + return scaling.unsqueeze(-1) + + +class Mistral4Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Mistral4Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_attention_heads + + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_head_dim = config.qk_head_dim + + self.is_causal = True + if self.q_lora_rank is None: + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = Mistral4RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) + + self.kv_a_proj_with_mqa = nn.Linear( + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = Mistral4RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + self.scaling = self.qk_head_dim ** (-0.5) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + + if self.q_lora_rank is None: + q_states = self.q_proj(hidden_states) + else: + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q_states = q_states.view(query_shape).transpose(1, 2) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + + cos, sin = position_embeddings + if self.config.rope_interleave: # support using interleaved weights for efficiency + q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) + else: + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange(query_states.shape[2], device=query_states.device) + past_seen_tokens + + query_states = query_states * get_llama_4_attn_scale( + cache_position, + self.config.rope_parameters.get("llama_4_scaling_beta"), + self.config.rope_parameters.get("original_max_position_embeddings"), + ).to(query_states.dtype) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Mistral4DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Mistral4Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Mistral4Attention(config=config, layer_idx=layer_idx) + + if layer_idx >= config.first_k_dense_replace: + self.mlp = Mistral4MoE(config) + else: + self.mlp = Mistral4MLP(config) + + self.input_layernorm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Mistral4PreTrainedModel(PreTrainedModel): + config: Mistral4Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Mistral4DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Mistral4DecoderLayer, + "attentions": Mistral4Attention, + } + _keep_in_fp32_modules_strict = [] + _keys_to_ignore_on_load_unexpected = [] + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Mistral4TopkRouter): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, Mistral4NaiveMoe): + init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) + init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) + + +@auto_docstring +class Mistral4Model(Mistral4PreTrainedModel): + def __init__(self, config: Mistral4Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Mistral4DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Mistral4RotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if position_ids is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + position_ids = position_ids.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_embeddings=position_embeddings, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class Mistral4ForCausalLM(Mistral4PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Mistral4Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, Mistral4ForCausalLM + + >>> model = Mistral4ForCausalLM.from_pretrained("meta-mistral4/Mistral4-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral4/Mistral4-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Mistral4ForSequenceClassification(GenericForSequenceClassification, Mistral4PreTrainedModel): + pass + + +class Mistral4ForTokenClassification(GenericForTokenClassification, Mistral4PreTrainedModel): + pass + + +__all__ = [ + "Mistral4PreTrainedModel", + "Mistral4Model", + "Mistral4ForCausalLM", + "Mistral4ForSequenceClassification", + "Mistral4ForTokenClassification", +] diff --git a/src/transformers/models/mistral4/modular_mistral4.py b/src/transformers/models/mistral4/modular_mistral4.py new file mode 100644 index 000000000000..d9c73a3c19cc --- /dev/null +++ b/src/transformers/models/mistral4/modular_mistral4.py @@ -0,0 +1,291 @@ +# Copyright 2026 Mistral AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...cache_utils import Cache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GenericForSequenceClassification, GenericForTokenClassification +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import logging +from ...utils.generic import is_flash_attention_requested +from ..deepseek_v3.modeling_deepseek_v3 import ( + DeepseekV3Attention, + DeepseekV3DecoderLayer, + DeepseekV3MoE, + DeepseekV3NaiveMoe, + apply_rotary_pos_emb_interleave, +) +from ..llama.modeling_llama import ( + LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) +from ..ministral3.modeling_ministral3 import get_llama_4_attn_scale +from ..qwen2_moe.modeling_qwen2_moe import Qwen2MoeMLP +from .configuration_mistral4 import Mistral4Config + + +logger = logging.get_logger(__name__) + + +class Mistral4RMSNorm(LlamaRMSNorm): + pass + + +class Mistral4RotaryEmbedding(LlamaRotaryEmbedding): + pass + + +class Mistral4MLP(Qwen2MoeMLP): + pass + + +class Mistral4TopkRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.n_routed_experts = config.n_routed_experts + + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states, self.weight) + return router_logits + + +class Mistral4NaiveMoe(DeepseekV3NaiveMoe): + pass + + +class Mistral4MoE(DeepseekV3MoE): + def route_tokens_to_experts(self, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + router_logits = router_logits.softmax(-1) + group_scores = ( + router_logits.view(-1, self.n_group, self.n_routed_experts // self.n_group).topk(2, dim=-1)[0].sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + scores_for_choice = router_logits.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = router_logits.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + +class Mistral4Attention(DeepseekV3Attention): + def __init__(self, config: Mistral4Config, layer_idx: int): + nn.Module.__init__(self) + self.config = config + self.layer_idx = layer_idx + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_attention_heads + + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_head_dim = config.qk_head_dim + + self.is_causal = True + if self.q_lora_rank is None: + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = Mistral4RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) + + self.kv_a_proj_with_mqa = nn.Linear( + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = Mistral4RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + self.scaling = self.qk_head_dim ** (-0.5) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + + if self.q_lora_rank is None: + q_states = self.q_proj(hidden_states) + else: + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q_states = q_states.view(query_shape).transpose(1, 2) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + + cos, sin = position_embeddings + if self.config.rope_interleave: # support using interleaved weights for efficiency + q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) + else: + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange(query_states.shape[2], device=query_states.device) + past_seen_tokens + + query_states = query_states * get_llama_4_attn_scale( + cache_position, + self.config.rope_parameters.get("llama_4_scaling_beta"), + self.config.rope_parameters.get("original_max_position_embeddings"), + ).to(query_states.dtype) + + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Mistral4DecoderLayer(DeepseekV3DecoderLayer): + def __init__(self, config: Mistral4Config, layer_idx: int): + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + + self.self_attn = Mistral4Attention(config=config, layer_idx=layer_idx) + + if layer_idx >= config.first_k_dense_replace: + self.mlp = Mistral4MoE(config) + else: + self.mlp = Mistral4MLP(config) + + self.input_layernorm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Mistral4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class Mistral4PreTrainedModel(PreTrainedModel): + config: Mistral4Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Mistral4DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Mistral4DecoderLayer, + "attentions": Mistral4Attention, + } + _keep_in_fp32_modules_strict = [] + _keys_to_ignore_on_load_unexpected = [] + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, Mistral4TopkRouter): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, Mistral4NaiveMoe): + init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) + init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) + + +class Mistral4Model(LlamaModel): + pass + + +class Mistral4ForCausalLM(LlamaForCausalLM): + pass + + +class Mistral4ForSequenceClassification(GenericForSequenceClassification, Mistral4PreTrainedModel): + pass + + +class Mistral4ForTokenClassification(GenericForTokenClassification, Mistral4PreTrainedModel): + pass + + +__all__ = [ + "Mistral4PreTrainedModel", + "Mistral4Model", + "Mistral4ForCausalLM", + "Mistral4ForSequenceClassification", + "Mistral4ForTokenClassification", +] diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 32045d94f7ca..24cefad6a222 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -500,6 +500,10 @@ class TokenizerChatTemplateKwargs(TypedDict, total=False): Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant, the mask will contain 1. For user and system tokens, the mask will contain 0. This functionality is only available for chat templates that support it via the `{% generation %}` keyword. + reasoning_effort (`str`, *optional*): + The reasoning effort level to use for the model's response. Supported values depend on the model + (e.g. `"none"`, "low"`, `"medium"`, `"high"`). If the template does not support reasoning effort, + this argument will have no effect. """ tools: list[dict] | None = None @@ -507,6 +511,7 @@ class TokenizerChatTemplateKwargs(TypedDict, total=False): add_generation_prompt: bool | None = False continue_final_message: bool | None = False return_assistant_tokens_mask: bool | None = False + reasoning_effort: str | None = None class ProcessorChatTemplateKwargs(TokenizerChatTemplateKwargs, total=False): diff --git a/tests/models/mistral4/__init__.py b/tests/models/mistral4/__init__.py new file mode 100644 index 000000000000..b1e6bad8a4a1 --- /dev/null +++ b/tests/models/mistral4/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mistral4 model tests.""" + +from .test_modeling_mistral4 import Mistral4IntegrationTest, Mistral4ModelTest diff --git a/tests/models/mistral4/test_modeling_mistral4.py b/tests/models/mistral4/test_modeling_mistral4.py new file mode 100644 index 000000000000..1703cb63e49e --- /dev/null +++ b/tests/models/mistral4/test_modeling_mistral4.py @@ -0,0 +1,133 @@ +# Copyright 2026 the HuggingFace and MistralAI Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch Mistral4 model.""" + +import gc +import unittest + +import pytest + +from transformers import AutoTokenizer, Mistral3ForConditionalGeneration, is_torch_available +from transformers.testing_utils import ( + Expectations, + backend_empty_cache, + cleanup, + require_deterministic_for_xpu, + require_flash_attn, + require_torch, + require_torch_accelerator, + slow, + torch_device, +) + + +if is_torch_available(): + import torch + + from transformers import ( + Mistral4Model, + ) + + +from ...causal_lm_tester import CausalLMModelTest, CausalLMModelTester + + +class Mistral4ModelTester(CausalLMModelTester): + if is_torch_available(): + base_model_class = Mistral4Model + + +@require_torch +class Mistral4ModelTest(CausalLMModelTest, unittest.TestCase): + _is_stateful = True + model_split_percents = [0.5, 0.6] + model_tester_class = Mistral4ModelTester + + # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 + def is_pipeline_test_to_skip( + self, + pipeline_test_case_name, + config_class, + model_architecture, + tokenizer_name, + image_processor_name, + feature_extractor_name, + processor_name, + ): + return True + + @require_flash_attn + @require_torch_accelerator + @pytest.mark.flash_attn_test + @slow + def test_flash_attn_2_inference_equivalence_right_padding(self): + self.skipTest(reason="Mistral4 flash attention does not support right padding") + + +@require_torch +class Mistral4IntegrationTest(unittest.TestCase): + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @slow + def test_mistral_small_4_logits(self): + input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338] + model = Mistral3ForConditionalGeneration.from_pretrained( + "mistralai/Mistral-Small-4-119B-2603", device_map="auto" + ) + input_ids = torch.tensor([input_ids]).to(model.device) + with torch.no_grad(): + out = model(input_ids).logits.float().cpu() + # Expected mean on dim = -1 + # fmt: off + EXPECTED_MEANS = Expectations( + { + ("cuda", None): torch.tensor([[0.1793, -1.0928, -3.9925, -2.8699, -0.1250, -1.6851, -2.5565, -1.2263]]), + } + ) + # fmt: on + EXPECTED_MEAN = EXPECTED_MEANS.get_expectation() + torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2) + + del model + backend_empty_cache(torch_device) + gc.collect() + + @slow + @require_deterministic_for_xpu + def test_mistral_small_4_generation(self): + # fmt: off + EXPECTED_TEXTS = Expectations( + { + ("cuda", None): "My favourite condiment is 1000 island dressing. I love it on burgers and hot dogs. I also like", + # ("xpu", None): "My favourite condiment is iced tea. I love the way it makes me feel. It’s like a little bubble bath for", + } + ) + # fmt: on + EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation() + prompt = "My favourite condiment is " + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-Small-4-119B-2603") + model = Mistral3ForConditionalGeneration.from_pretrained( + "mistralai/Mistral-Small-4-119B-2603", device_map="auto" + ) + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) + + # greedy generation outputs + generated_ids = model.generate(input_ids, max_new_tokens=20, temperature=0) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(text, EXPECTED_TEXT) + + del model + backend_empty_cache(torch_device) + gc.collect()