-
-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[Llama.py -> mistral.py] Extract mistral-only relevant code into separate file #32780
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
96725f0
e4a900b
a2c20f2
8ff01e3
7018238
9913155
3f9ecab
f731007
ede4bcc
3e4e976
4cab862
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -152,24 +152,14 @@ def __init__( | |
| # the KV heads across multiple tensor parallel GPUs. | ||
| assert tp_size % self.total_num_kv_heads == 0 | ||
| self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) | ||
| # MistralConfig has an optional head_dim introduced by Mistral-Nemo | ||
|
|
||
| head_dim = getattr(config, "head_dim", None) | ||
| if head_dim is None: | ||
| head_dim = self.hidden_size // self.total_num_heads | ||
| self.head_dim = head_dim | ||
| self.head_dim = head_dim or self.hidden_size // self.total_num_heads | ||
| self.q_size = self.num_heads * self.head_dim | ||
| self.kv_size = self.num_kv_heads * self.head_dim | ||
| self.scaling = self.head_dim**-0.5 | ||
| self.max_position_embeddings = max_position_embeddings | ||
|
|
||
| llama_4_scaling_config = getattr(config, "llama_4_scaling", None) | ||
| self.do_llama_4_scaling = llama_4_scaling_config is not None | ||
| if self.do_llama_4_scaling: | ||
| self.llama_4_scaling_original_max_position_embeddings = ( | ||
| llama_4_scaling_config["original_max_position_embeddings"] | ||
| ) | ||
| self.llama_4_scaling_beta = llama_4_scaling_config["beta"] | ||
|
|
||
| self.qkv_proj = QKVParallelLinear( | ||
| hidden_size=hidden_size, | ||
| head_size=self.head_dim, | ||
|
|
@@ -229,17 +219,6 @@ def __init__( | |
| prefix=f"{prefix}.attn", | ||
| ) | ||
|
|
||
| def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. only mistral makes use of llama_4 scaling |
||
| # Llama4 scaling | ||
| scaling = 1 + self.llama_4_scaling_beta * torch.log( | ||
| 1 | ||
| + torch.floor( | ||
| positions / self.llama_4_scaling_original_max_position_embeddings | ||
| ) | ||
| ) | ||
| # Broadcast over head_dim | ||
| return scaling.unsqueeze(-1) | ||
|
|
||
| def forward( | ||
| self, | ||
| positions: torch.Tensor, | ||
|
|
@@ -248,9 +227,6 @@ def forward( | |
| qkv, _ = self.qkv_proj(hidden_states) | ||
| q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) | ||
| q, k = self.rotary_emb(positions, q, k) | ||
| if self.do_llama_4_scaling: | ||
| attn_scale = self._get_llama_4_attn_scale(positions) | ||
| q = (q * attn_scale).to(q.dtype) | ||
| attn_output = self.attn(q, k, v) | ||
| output, _ = self.o_proj(attn_output) | ||
| return output | ||
|
|
@@ -279,6 +255,7 @@ def __init__( | |
| vllm_config: VllmConfig, | ||
| prefix: str = "", | ||
| config: LlamaConfig | None = None, | ||
| attn_layer_type: type[nn.Module] = LlamaAttention, | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
|
|
@@ -307,7 +284,7 @@ def __init__( | |
| else: | ||
| attn_type = AttentionType.ENCODER_ONLY | ||
|
|
||
| self.self_attn = LlamaAttention( | ||
| self.self_attn = attn_layer_type( | ||
| config=config, | ||
| hidden_size=self.hidden_size, | ||
| num_heads=config.num_attention_heads, | ||
|
|
@@ -537,32 +514,6 @@ class LlamaForCausalLM( | |
| "lm_head": "output_embeddings", | ||
| } | ||
|
|
||
| # Mistral/Llama models can also be loaded with --load-format mistral | ||
| # from consolidated.safetensors checkpoints | ||
| mistral_mapping = { | ||
| "layers": "model.layers", | ||
| "attention": "self_attn", | ||
| "qscale_act": "input_scale", | ||
| "qscale_weight": "weight_scale", | ||
| "kv_fake_quantizer.qscale_act": "kv_scale", | ||
| "q_fake_quantizer.qscale_act": "attn.q_scale", | ||
| "k_fake_quantizer.qscale_act": "k_scale", | ||
| "v_fake_quantizer.qscale_act": "v_scale", | ||
| "wq": "q_proj", | ||
| "wk": "k_proj", | ||
| "wv": "v_proj", | ||
| "wo": "o_proj", | ||
| "attention_norm": "input_layernorm", | ||
| "feed_forward": "mlp", | ||
| "w1": "gate_proj", | ||
| "w2": "down_proj", | ||
| "w3": "up_proj", | ||
| "ffn_norm": "post_attention_layernorm", | ||
| "tok_embeddings": "model.embed_tokens", | ||
| "output": "lm_head", | ||
| "norm": "model.norm", | ||
| } | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
|
|
@@ -649,67 +600,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | |
| self, | ||
| skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), | ||
| ) | ||
| return loader.load_weights( | ||
| self.maybe_remap_mistral(name, loaded_weight) | ||
| for name, loaded_weight in weights | ||
| ) | ||
|
|
||
| # This function is used to remap the mistral format as | ||
| # used by Mistral and Llama <=2 | ||
| def maybe_remap_mistral( | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this code comes from a very old PR: #8168 and I'm quite convinced that it's only mistral checkpoints that actually make use of this function so moving it out |
||
| self, | ||
| name: str, | ||
| loaded_weight: torch.Tensor, | ||
| ) -> tuple[str, torch.Tensor]: | ||
| def permute(w: torch.Tensor, n_heads: int, attn_out: int): | ||
| attn_in = self.config.head_dim * n_heads | ||
|
|
||
| return ( | ||
| w.view(n_heads, attn_in // n_heads // 2, 2, attn_out) | ||
| .transpose(1, 2) | ||
| .reshape(attn_in, attn_out) | ||
| ) | ||
|
|
||
| mapping = self.mistral_mapping | ||
| modules = name.split(".") | ||
|
|
||
| # rotary embeds should be sliced | ||
| # If using quantized model in mistral format, | ||
| # quantization scales (qscale_weight) also need to be sliced | ||
| if "wk" in modules and modules[-1] == "weight": | ||
| loaded_weight = permute( | ||
| loaded_weight, self.config.num_key_value_heads, self.config.hidden_size | ||
| ) | ||
| elif ( | ||
| "wk" in modules | ||
| and modules[-1] == "qscale_weight" | ||
| and loaded_weight.numel() > 1 | ||
| ): | ||
| loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1) | ||
| elif "wq" in modules and modules[-1] == "weight": | ||
| loaded_weight = permute( | ||
| loaded_weight, self.config.num_attention_heads, self.config.hidden_size | ||
| ) | ||
| elif ( | ||
| "wq" in modules | ||
| and modules[-1] == "qscale_weight" | ||
| and loaded_weight.numel() > 1 | ||
| ): | ||
| loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1) | ||
|
|
||
| num_modules = len(modules) | ||
| for i in range(num_modules): | ||
| item = modules[i] | ||
| next_item = modules[i + 1] if i < num_modules - 1 else None | ||
|
|
||
| combined_item = f"{item}.{next_item}" if next_item is not None else None | ||
|
|
||
| if combined_item in mapping: | ||
| name = name.replace(combined_item, mapping[combined_item]) | ||
| elif item in mapping and mapping[item] not in name: | ||
| name = name.replace(item, mapping[item]) | ||
|
|
||
| return name, loaded_weight | ||
| return loader.load_weights(weights) | ||
|
|
||
|
|
||
| class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
afaik only mistral-nemo every used this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't seem to be the case 😅