Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 5 additions & 114 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,24 +152,14 @@ def __init__(
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo

head_dim = getattr(config, "head_dim", None)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaik only mistral-nemo every used this

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't seem to be the case 😅

if head_dim is None:
head_dim = self.hidden_size // self.total_num_heads
self.head_dim = head_dim
self.head_dim = head_dim or self.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings

llama_4_scaling_config = getattr(config, "llama_4_scaling", None)
self.do_llama_4_scaling = llama_4_scaling_config is not None
if self.do_llama_4_scaling:
self.llama_4_scaling_original_max_position_embeddings = (
llama_4_scaling_config["original_max_position_embeddings"]
)
self.llama_4_scaling_beta = llama_4_scaling_config["beta"]

self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
Expand Down Expand Up @@ -229,17 +219,6 @@ def __init__(
prefix=f"{prefix}.attn",
)

def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only mistral makes use of llama_4 scaling

# Llama4 scaling
scaling = 1 + self.llama_4_scaling_beta * torch.log(
1
+ torch.floor(
positions / self.llama_4_scaling_original_max_position_embeddings
)
)
# Broadcast over head_dim
return scaling.unsqueeze(-1)

def forward(
self,
positions: torch.Tensor,
Expand All @@ -248,9 +227,6 @@ def forward(
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
if self.do_llama_4_scaling:
attn_scale = self._get_llama_4_attn_scale(positions)
q = (q * attn_scale).to(q.dtype)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
Expand Down Expand Up @@ -279,6 +255,7 @@ def __init__(
vllm_config: VllmConfig,
prefix: str = "",
config: LlamaConfig | None = None,
attn_layer_type: type[nn.Module] = LlamaAttention,
) -> None:
super().__init__()

Expand Down Expand Up @@ -307,7 +284,7 @@ def __init__(
else:
attn_type = AttentionType.ENCODER_ONLY

self.self_attn = LlamaAttention(
self.self_attn = attn_layer_type(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
Expand Down Expand Up @@ -537,32 +514,6 @@ class LlamaForCausalLM(
"lm_head": "output_embeddings",
}

# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"qscale_act": "input_scale",
"qscale_weight": "weight_scale",
"kv_fake_quantizer.qscale_act": "kv_scale",
"q_fake_quantizer.qscale_act": "attn.q_scale",
"k_fake_quantizer.qscale_act": "k_scale",
"v_fake_quantizer.qscale_act": "v_scale",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
"wo": "o_proj",
"attention_norm": "input_layernorm",
"feed_forward": "mlp",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm",
}

def __init__(
self,
*,
Expand Down Expand Up @@ -649,67 +600,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(
self.maybe_remap_mistral(name, loaded_weight)
for name, loaded_weight in weights
)

# This function is used to remap the mistral format as
# used by Mistral and Llama <=2
def maybe_remap_mistral(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code comes from a very old PR: #8168 and I'm quite convinced that it's only mistral checkpoints that actually make use of this function so moving it out

self,
name: str,
loaded_weight: torch.Tensor,
) -> tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int, attn_out: int):
attn_in = self.config.head_dim * n_heads

return (
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
.transpose(1, 2)
.reshape(attn_in, attn_out)
)

mapping = self.mistral_mapping
modules = name.split(".")

# rotary embeds should be sliced
# If using quantized model in mistral format,
# quantization scales (qscale_weight) also need to be sliced
if "wk" in modules and modules[-1] == "weight":
loaded_weight = permute(
loaded_weight, self.config.num_key_value_heads, self.config.hidden_size
)
elif (
"wk" in modules
and modules[-1] == "qscale_weight"
and loaded_weight.numel() > 1
):
loaded_weight = permute(loaded_weight, self.config.num_key_value_heads, 1)
elif "wq" in modules and modules[-1] == "weight":
loaded_weight = permute(
loaded_weight, self.config.num_attention_heads, self.config.hidden_size
)
elif (
"wq" in modules
and modules[-1] == "qscale_weight"
and loaded_weight.numel() > 1
):
loaded_weight = permute(loaded_weight, self.config.num_attention_heads, 1)

num_modules = len(modules)
for i in range(num_modules):
item = modules[i]
next_item = modules[i + 1] if i < num_modules - 1 else None

combined_item = f"{item}.{next_item}" if next_item is not None else None

if combined_item in mapping:
name = name.replace(combined_item, mapping[combined_item])
elif item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item])

return name, loaded_weight
return loader.load_weights(weights)


class LlamaBidirectionalForSequenceClassification(as_seq_cls_model(LlamaForCausalLM)):
Expand Down
Loading