diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index b6427b866aa3..e265308937d4 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -152,24 +152,14 @@ def __init__( # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - # MistralConfig has an optional head_dim introduced by Mistral-Nemo + head_dim = getattr(config, "head_dim", None) - if head_dim is None: - head_dim = self.hidden_size // self.total_num_heads - self.head_dim = head_dim + self.head_dim = head_dim or self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.max_position_embeddings = max_position_embeddings - llama_4_scaling_config = getattr(config, "llama_4_scaling", None) - self.do_llama_4_scaling = llama_4_scaling_config is not None - if self.do_llama_4_scaling: - self.llama_4_scaling_original_max_position_embeddings = ( - llama_4_scaling_config["original_max_position_embeddings"] - ) - self.llama_4_scaling_beta = llama_4_scaling_config["beta"] - self.qkv_proj = QKVParallelLinear( hidden_size=hidden_size, head_size=self.head_dim, @@ -229,17 +219,6 @@ def __init__( prefix=f"{prefix}.attn", ) - def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: - # Llama4 scaling - scaling = 1 + self.llama_4_scaling_beta * torch.log( - 1 - + torch.floor( - positions / self.llama_4_scaling_original_max_position_embeddings - ) - ) - # Broadcast over head_dim - return scaling.unsqueeze(-1) - def forward( self, positions: torch.Tensor, @@ -248,9 +227,6 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - if self.do_llama_4_scaling: - attn_scale = self._get_llama_4_attn_scale(positions) - q = (q * attn_scale).to(q.dtype) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output @@ -279,6 +255,7 @@ def __init__( vllm_config: VllmConfig, prefix: str = "", config: LlamaConfig | None = None, + attn_layer_type: type[nn.Module] = LlamaAttention, ) -> None: super().__init__() @@ -307,7 +284,7 @@ def __init__( else: attn_type = AttentionType.ENCODER_ONLY - self.self_attn = LlamaAttention( + self.self_attn = attn_layer_type( config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -537,32 +514,6 @@ class LlamaForCausalLM( "lm_head": "output_embeddings", } - # Mistral/Llama models can also be loaded with --load-format mistral - # from consolidated.safetensors checkpoints - mistral_mapping = { - "layers": "model.layers", - "attention": "self_attn", - "qscale_act": "input_scale", - "qscale_weight": "weight_scale", - "kv_fake_quantizer.qscale_act": "kv_scale", - "q_fake_quantizer.qscale_act": "attn.q_scale", - "k_fake_quantizer.qscale_act": "k_scale", - "v_fake_quantizer.qscale_act": "v_scale", - "wq": "q_proj", - "wk": "k_proj", - "wv": "v_proj", - "wo": "o_proj", - "attention_norm": "input_layernorm", - "feed_forward": "mlp", - "w1": "gate_proj", - "w2": "down_proj", - "w3": "up_proj", - "ffn_norm": "post_attention_layernorm", - "tok_embeddings": "model.embed_tokens", - "output": "lm_head", - "norm": "model.norm", - } - def __init__( self, *, @@ -649,67 +600,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights( - self.maybe_remap_mistral(name, loaded_weight) - for name, loaded_weight in weights - ) - - # This function is used to remap the mistral format as - # used by Mistral and Llama <=2 - def maybe_remap_mistral( - self, - name: str, - loaded_weight: torch.Tensor, - ) -> tuple[str, torch.Tensor]: - def permute(w: torch.Tensor, n_heads: int, attn_out: int): - attn_in = self.config.head_dim * n_heads - - return ( - w.view(n_heads, attn_in // n_heads // 2, 2, attn_out) - .transpose(1, 2) - .reshape(attn_in, attn_out) - ) - - mapping = self.mistral_mapping - modules = name.split(".") - - # rotary embeds should be sliced - # If using quantized model in mistral format, - # quantization scales (qscale_weight) also need to be sliced - if "wk" in modules and modules[-1] == "weight": - loaded_weight = permute( - loaded_weight, self.config.num_key_value_heads, self.config.hidden_size - ) - elif ( - "wk" in modules - and modules[-1] == "qscale_weight" - and loaded_weight.numel() > 1 - ): - loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1) - elif "wq" in modules and modules[-1] == "weight": - loaded_weight = permute( - loaded_weight, self.config.num_attention_heads, self.config.hidden_size - ) - elif ( - "wq" in modules - and modules[-1] == "qscale_weight" - and loaded_weight.numel() > 1 - ): - loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1) - - num_modules = len(modules) - for i in range(num_modules): - item = modules[i] - next_item = modules[i + 1] if i < num_modules - 1 else None - - combined_item = f"{item}.{next_item}" if next_item is not None else None - - if combined_item in mapping: - name = name.replace(combined_item, mapping[combined_item]) - elif item in mapping and mapping[item] not in name: - name = name.replace(item, mapping[item]) - - return name, loaded_weight + return loader.load_weights(weights) class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)): diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py new file mode 100644 index 000000000000..a1d88d088ddb --- /dev/null +++ b/vllm/model_executor/models/mistral.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Mistral adaptation of the LLaMA architecture.""" + +from collections.abc import Iterable + +import torch +from torch import nn +from transformers import LlamaConfig + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, +) +from vllm.v1.attention.backend import AttentionType + +from .utils import AutoWeightsLoader + + +class MistralAttention(LlamaAttention): + def __init__( + self, + config: LlamaConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position_embeddings: int = 8192, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + bias_o_proj: bool = False, + cache_config: CacheConfig | None = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ) -> None: + super().__init__( + config=config, + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=bias, + bias_o_proj=bias_o_proj, + cache_config=cache_config, + prefix=prefix, + attn_type=attn_type, + ) + + llama_4_scaling_config: dict[str, int | float | str] | None = getattr( + config, "llama_4_scaling", None + ) + self.do_llama_4_scaling = llama_4_scaling_config is not None + if self.do_llama_4_scaling: + assert llama_4_scaling_config is not None + self.llama_4_scaling_original_max_position_embeddings = ( + llama_4_scaling_config["original_max_position_embeddings"] + ) + self.llama_4_scaling_beta = llama_4_scaling_config["beta"] + + def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: + # Llama4 scaling + scaling = 1 + self.llama_4_scaling_beta * torch.log( + 1 + + torch.floor( + positions / self.llama_4_scaling_original_max_position_embeddings + ) + ) + # Broadcast over head_dim + return scaling.unsqueeze(-1) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + if self.do_llama_4_scaling: + attn_scale = self._get_llama_4_attn_scale(positions) + q = (q * attn_scale).to(q.dtype) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class MistralDecoderLayer(LlamaDecoderLayer): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + config: LlamaConfig | None = None, + ) -> None: + super().__init__( + vllm_config=vllm_config, + prefix=prefix, + config=config, + attn_layer_type=MistralAttention, + ) + + self.layer_idx = int(prefix.split(sep=".")[-1]) + quant_config = self.get_quant_config(vllm_config) + config = config or vllm_config.model_config.hf_config + + do_fusion = getattr( + quant_config, "enable_quantization_scaling_fusion", False + ) and vllm_config.cache_config.cache_dtype.startswith("fp8") + if do_fusion: + self.input_layernorm.quant_scaling_from = self.self_attn.qkv_proj + self.post_attention_layernorm.quant_scaling_from = self.mlp.gate_up_proj + + +@support_torch_compile +class MistralModel(LlamaModel): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = MistralDecoderLayer, + ): + super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) + + +class MistralForCausalLM(LlamaForCausalLM): + # Mistral: We don't support LoRA on the embedding layers. + embedding_modules: dict[str, str] = {} + + # Mistral/Llama models can also be loaded with --load-format mistral + # from consolidated.safetensors checkpoints + mistral_mapping = { + "layers": "model.layers", + "attention": "self_attn", + "qscale_act": "input_scale", + "qscale_weight": "weight_scale", + "kv_fake_quantizer.qscale_act": "kv_scale", + "q_fake_quantizer.qscale_act": "attn.q_scale", + "k_fake_quantizer.qscale_act": "k_scale", + "v_fake_quantizer.qscale_act": "v_scale", + "wq": "q_proj", + "wk": "k_proj", + "wv": "v_proj", + "wo": "o_proj", + "attention_norm": "input_layernorm", + "feed_forward": "mlp", + "w1": "gate_proj", + "w2": "down_proj", + "w3": "up_proj", + "ffn_norm": "post_attention_layernorm", + "tok_embeddings": "model.embed_tokens", + "output": "lm_head", + "norm": "model.norm", + } + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = MistralDecoderLayer, + ): + super().__init__(vllm_config=vllm_config, prefix=prefix, layer_type=layer_type) + + def _init_model( + self, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: type[nn.Module] = MistralDecoderLayer, + ): + return MistralModel( + vllm_config=vllm_config, prefix=prefix, layer_type=layer_type + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), + ) + return loader.load_weights( + self.maybe_remap_mistral(name, loaded_weight) + for name, loaded_weight in weights + ) + + def maybe_remap_mistral( + self, + name: str, + loaded_weight: torch.Tensor, + ) -> tuple[str, torch.Tensor]: + def permute(w: torch.Tensor, n_heads: int, attn_out: int): + attn_in = self.config.head_dim * n_heads + + return ( + w.view(n_heads, attn_in // n_heads // 2, 2, attn_out) + .transpose(1, 2) + .reshape(attn_in, attn_out) + ) + + mapping = self.mistral_mapping + modules = name.split(".") + + # rotary embeds should be sliced + # If using quantized model in mistral format, + # quantization scales (qscale_weight) also need to be sliced + if "wk" in modules and modules[-1] == "weight": + loaded_weight = permute( + loaded_weight, self.config.num_key_value_heads, self.config.hidden_size + ) + elif ( + "wk" in modules + and modules[-1] == "qscale_weight" + and loaded_weight.numel() > 1 + ): + loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1) + elif "wq" in modules and modules[-1] == "weight": + loaded_weight = permute( + loaded_weight, self.config.num_attention_heads, self.config.hidden_size + ) + elif ( + "wq" in modules + and modules[-1] == "qscale_weight" + and loaded_weight.numel() > 1 + ): + loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1) + + num_modules = len(modules) + for i in range(num_modules): + item = modules[i] + next_item = modules[i + 1] if i < num_modules - 1 else None + + combined_item = f"{item}.{next_item}" if next_item is not None else None + + if combined_item in mapping: + name = name.replace(combined_item, mapping[combined_item]) + elif item in mapping and mapping[item] not in name: + name = name.replace(item, mapping[item]) + + return name, loaded_weight diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 6a76c93883e2..5b1881dba8ad 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -153,7 +153,7 @@ "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), "MiniMaxM2ForCausalLM": ("minimax_m2", "MiniMaxM2ForCausalLM"), - "MistralForCausalLM": ("llama", "LlamaForCausalLM"), + "MistralForCausalLM": ("mistral", "MistralForCausalLM"), "MistralLarge3ForCausalLM": ("mistral_large_3", "MistralLarge3ForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), # transformers's mpt class has lower case