Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
Expand Down Expand Up @@ -217,14 +217,21 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(
# GGUF stores RMSNorm weights with +1 baked in (llama.cpp convention).
# GemmaRMSNorm adds 1 in its forward pass, so use plain RMSNorm for GGUF.
norm_cls = (
RMSNorm
if (quant_config is not None and quant_config.get_name() == "gguf")
else GemmaRMSNorm
)
self.input_layernorm = norm_cls(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = norm_cls(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_feedforward_layernorm = GemmaRMSNorm(
self.pre_feedforward_layernorm = norm_cls(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_feedforward_layernorm = GemmaRMSNorm(
self.post_feedforward_layernorm = norm_cls(
config.hidden_size, eps=config.rms_norm_eps
)
Comment on lines +222 to 236
Copy link
Copy Markdown
Member

@hmellor hmellor Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use this pattern as it's more compact? (and use it in the other places too)

Suggested change
norm_cls = (
RMSNorm
if (quant_config is not None and quant_config.get_name() == "gguf")
else GemmaRMSNorm
)
self.input_layernorm = norm_cls(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = norm_cls(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_feedforward_layernorm = GemmaRMSNorm(
self.pre_feedforward_layernorm = norm_cls(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_feedforward_layernorm = GemmaRMSNorm(
self.post_feedforward_layernorm = norm_cls(
config.hidden_size, eps=config.rms_norm_eps
)
# GGUF stores RMSNorm weights with +1 baked in (llama.cpp convention).
# GemmaRMSNorm adds 1 in its forward pass, so use plain RMSNorm for GGUF.
quant_name = quant_config.get_name() if quant_config else None
rms_norm_cls = RMSNorm if quant_name == "gguf" else GemmaRMSNorm
rms_norm_kwargs = dict(hidden_size=config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = rms_norm_cls(**rms_norm_kwargs)
self.post_attention_layernorm = rms_norm_cls(**rms_norm_kwargs)
self.pre_feedforward_layernorm = rms_norm_cls(**rms_norm_kwargs)
self.post_feedforward_layernorm = rms_norm_cls(**rms_norm_kwargs)


Expand Down Expand Up @@ -274,7 +281,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
),
prefix=f"{prefix}.layers",
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
norm_cls = (
RMSNorm
if (quant_config is not None and quant_config.get_name() == "gguf")
else GemmaRMSNorm
)
self.norm = norm_cls(config.hidden_size, eps=config.rms_norm_eps)

# Normalize the embedding by sqrt(hidden_size)
# The normalizer's data type should be downcasted to the model's
Expand Down
42 changes: 25 additions & 17 deletions vllm/model_executor/models/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
Attention,
EncoderOnlyAttention,
)
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
Expand Down Expand Up @@ -156,8 +156,15 @@ def __init__(
prefix=f"{prefix}.o_proj",
)

self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
# GGUF stores RMSNorm weights with +1 baked in (llama.cpp convention).
# GemmaRMSNorm adds 1 in its forward pass, so use plain RMSNorm for GGUF.
norm_cls = (
RMSNorm
if (quant_config is not None and quant_config.get_name() == "gguf")
else GemmaRMSNorm
)
self.q_norm = norm_cls(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = norm_cls(self.head_dim, eps=config.rms_norm_eps)

layer_idx = extract_layer_index(prefix)
layer_type = config.layer_types[layer_idx]
Expand Down Expand Up @@ -261,14 +268,19 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(
norm_cls = (
RMSNorm
if (quant_config is not None and quant_config.get_name() == "gguf")
else GemmaRMSNorm
)
self.input_layernorm = norm_cls(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = norm_cls(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_feedforward_layernorm = GemmaRMSNorm(
self.pre_feedforward_layernorm = norm_cls(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_feedforward_layernorm = GemmaRMSNorm(
self.post_feedforward_layernorm = norm_cls(
config.hidden_size, eps=config.rms_norm_eps
)

Expand Down Expand Up @@ -322,7 +334,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
),
prefix=f"{prefix}.layers",
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
norm_cls = (
RMSNorm
if (quant_config is not None and quant_config.get_name() == "gguf")
else GemmaRMSNorm
)
self.norm = norm_cls(config.hidden_size, eps=config.rms_norm_eps)

# Normalize the embedding by sqrt(hidden_size)
# The normalizer's data type should be downcasted to the model's
Expand Down Expand Up @@ -383,15 +400,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
# Revert +1 during llama.cpp conversion
# see: https://github.com/ggml-org/llama.cpp/blob/be7c3034108473beda214fd1d7c98fd6a7a3bdf5/convert_hf_to_gguf.py#L3397-L3400
if (
self.quant_config
and self.quant_config.get_name() == "gguf"
and name.endswith("norm.weight")
):
loaded_weight -= 1

if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
Expand Down
Loading