Skip to content

[Bugfix] Apply RMSNorm weight correction for Gemma2 GGUF models#31464

Closed
kitaekatt wants to merge 2 commits into
vllm-project:mainfrom
kitaekatt:fix/gemma2-gguf-rmsnorm
Closed

[Bugfix] Apply RMSNorm weight correction for Gemma2 GGUF models#31464
kitaekatt wants to merge 2 commits into
vllm-project:mainfrom
kitaekatt:fix/gemma2-gguf-rmsnorm

Conversation

@kitaekatt
Copy link
Copy Markdown
Contributor

Summary

llama.cpp adds 1 to RMSNorm weights during GGUF conversion (see convert_hf_to_gguf.py#L3397-L3400), but vLLM expects original values. Without this correction, Gemma2 GGUF models produce gibberish output.

This fix applies the same correction that was added for Gemma3 in PR #26189 - subtracting 1 from norm weights during GGUF loading.

Root Cause

When loading Gemma2 GGUF models, the RMSNorm weights are +1 higher than expected because llama.cpp's GGUF conversion adds 1 to these weights. The Gemma3 model already handles this (added in #26189), but Gemma2 was missing the same correction.

Changes

  • Add _process_weights generator to Gemma2ForCausalLM.load_weights() that subtracts 1 from norm weights when loading GGUF models
  • Pattern matches existing Gemma3 implementation

Testing

Tested with bartowski/gemma-2-2b-it-GGUF on RTX 5090:

Before fix: Gibberish output (excessive script mixing: ARABIC, CJK, CYRILLIC, GREEK, LATIN)
After fix: Coherent output, 40% accuracy on MMLU benchmark, 344 tok/s throughput

Single-Thread Baseline Results: 40.00% accuracy (4/10 correct)
Throughput: 343.70 tok/s

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a bugfix for Gemma2 GGUF models, which were producing incorrect outputs due to a discrepancy in how RMSNorm weights are handled. The issue stems from llama.cpp's GGUF conversion process adding 1 to these weights, which, when combined with vLLM's GemmaRMSNorm implementation that also adds 1, resulted in an incorrect normalization factor. The fix correctly subtracts 1 from the norm weights during the loading process for Gemma2 GGUF models. The implementation is clean, localized within the Gemma2ForCausalLM.load_weights method, and uses a generator for efficient weight processing. The logic appears sound and effectively resolves the described problem. I have no further comments as the change is correct and well-implemented.

@mergify mergify Bot added the bug Something isn't working label Jan 14, 2026
llama.cpp adds 1 to RMSNorm weights during GGUF conversion (see
convert_hf_to_gguf.py#L3397-L3400), but vLLM expects original values.
Without this correction, Gemma2 GGUF models produce gibberish output.

This fix applies the same correction that was added for Gemma3 in
PR vllm-project#26189 - subtracting 1 from norm weights during GGUF loading.

Fixes gibberish output from bartowski/gemma-2-2b-it-GGUF and similar
Gemma2 GGUF models on Blackwell GPUs (RTX 5090).

Signed-off-by: Christina <kitaekatt@gmail.com>
Signed-off-by: Christina <truffle@gmail.com>
@kitaekatt
Copy link
Copy Markdown
Contributor Author

Validation Results

vLLM transformers Cherry-picked PRs HumanEval IFEval
HEAD 5.x #30410, #30411, #30412, #30413, #30424, #30434, #30699, #30702, #31464, #33846 gem2-2b-gguf (42.1%), gemma3-1b (26.8%) gem2-2b-gguf (65.6%)
HEAD 4.x #30410, #30411, #30412, #30413, #30424, #30434, #30699, #30702, #31464, #33846 q3-moe-gguf (83.5%) q3-moe-gguf (85.4%)

Tested on RTX 5090 (Blackwell, SM 120) with all listed PRs cherry-picked together; models listed under each benchmark passed that benchmark in the given environment, while the same models crash or fail without these PRs applied.

Converting from draft to open. GGUF stores RMSNorm weights centered at 1.0 but GemmaRMSNorm expects 0.0 — this fix applies the -1 correction for all 105 norm.weight parameters. Without it, logits are flat and output is all <pad> tokens. Validated on RTX 5090 (Blackwell, SM 120).

Copy link
Copy Markdown
Member

@hmellor hmellor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of modifying the weights can we just say that if self.quant_config.get_name() == "gguf" we use RMSNorm instead of GemmaRMSNorm?

@hmellor
Copy link
Copy Markdown
Member

hmellor commented Mar 10, 2026

You could make the same change in Gemma3 too

GGUF files store RMSNorm weights with +1 already baked in (llama.cpp
convention). GemmaRMSNorm adds 1 in its forward pass, causing double
addition for GGUF models.

Instead of subtracting 1 during weight loading, select the norm class
at construction time: use plain RMSNorm (no +1) for GGUF, GemmaRMSNorm
for non-GGUF. Applied to all norm layers in Gemma2 and Gemma3 including
decoder layer norms and Gemma3's q_norm/k_norm in attention.

Removes the weight correction workaround from load_weights().

Addresses reviewer feedback from hmellor on PR vllm-project#31464.

Signed-off-by: Christina <truffle@gmail.com>
@kitaekatt
Copy link
Copy Markdown
Contributor Author

Thanks @hmellor — addressed both pieces of feedback:

  • Gemma2: switched from _process_weights weight subtraction to using RMSNorm (instead of GemmaRMSNorm) at construction time when quant_config.get_name() == "gguf". Applied to all 5 norm layers (input_layernorm, post_attention_layernorm, pre_feedforward_layernorm, post_feedforward_layernorm, Gemma2Model.norm).
  • Gemma3: same change applied to all 5 decoder layer norms + q_norm/k_norm in Gemma3Attention. Removed the loaded_weight -= 1 block from Gemma3Model.load_weights.

Running validation benchmarks (HumanEval + IFEval on gem2-2b-gguf and gemma3-1b) now; will update when complete.

@kitaekatt
Copy link
Copy Markdown
Contributor Author

Validation results (RTX 5090, Blackwell SM 120, v13-t5-pr with all PRs cherry-picked):

Model Benchmark Score vs baseline
gem2-2b-gguf HumanEval 42.1% (69/164) ✅ matches
gem2-2b-gguf IFEval 65.6% (355/541) ✅ matches
gemma3-1b HumanEval 26.8% (44/164) ✅ matches

No regression. Both approaches (_process_weights subtract-1 + GemmaRMSNorm and the new RMSNorm with weights as-loaded) are mathematically equivalent since (w_gguf - 1) + 1 == w_gguf.

@kitaekatt kitaekatt requested a review from hmellor March 11, 2026 23:08
Comment on lines +222 to 236
norm_cls = (
RMSNorm
if (quant_config is not None and quant_config.get_name() == "gguf")
else GemmaRMSNorm
)
self.input_layernorm = norm_cls(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = norm_cls(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_feedforward_layernorm = GemmaRMSNorm(
self.pre_feedforward_layernorm = norm_cls(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_feedforward_layernorm = GemmaRMSNorm(
self.post_feedforward_layernorm = norm_cls(
config.hidden_size, eps=config.rms_norm_eps
)
Copy link
Copy Markdown
Member

@hmellor hmellor Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use this pattern as it's more compact? (and use it in the other places too)

Suggested change
norm_cls = (
RMSNorm
if (quant_config is not None and quant_config.get_name() == "gguf")
else GemmaRMSNorm
)
self.input_layernorm = norm_cls(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = norm_cls(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_feedforward_layernorm = GemmaRMSNorm(
self.pre_feedforward_layernorm = norm_cls(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_feedforward_layernorm = GemmaRMSNorm(
self.post_feedforward_layernorm = norm_cls(
config.hidden_size, eps=config.rms_norm_eps
)
# GGUF stores RMSNorm weights with +1 baked in (llama.cpp convention).
# GemmaRMSNorm adds 1 in its forward pass, so use plain RMSNorm for GGUF.
quant_name = quant_config.get_name() if quant_config else None
rms_norm_cls = RMSNorm if quant_name == "gguf" else GemmaRMSNorm
rms_norm_kwargs = dict(hidden_size=config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = rms_norm_cls(**rms_norm_kwargs)
self.post_attention_layernorm = rms_norm_cls(**rms_norm_kwargs)
self.pre_feedforward_layernorm = rms_norm_cls(**rms_norm_kwargs)
self.post_feedforward_layernorm = rms_norm_cls(**rms_norm_kwargs)

kitaekatt added a commit to kitaekatt/vllm that referenced this pull request Mar 16, 2026
This PR consolidates four related GGUF bug fixes for Gemma2 and Gemma3
models, plus a style improvement from reviewer feedback.

**1. Add quant_config to embedding layer (PR vllm-project#30424)**
Pass quant_config to VocabParallelEmbedding in Gemma2Model so that
GGUFEmbeddingMethod is selected instead of UnquantizedEmbeddingMethod.
Without this, quantized bytes are read as raw floats producing gibberish.

**2. Fix EOS token extraction for HF blob paths (PR vllm-project#30434)**
GGUF files served from HuggingFace Hub use blob paths that don't match
the expected filename pattern. Extract EOS token ID directly from GGUF
metadata as a reliable fallback.

**3. Skip missing parameters during GGUF weight loading (PR vllm-project#30699)**
Gemma2 GGUF files may lack certain weight keys (e.g. embed_tokens.qweight_type).
Skip missing parameters gracefully instead of raising KeyError.

**4. Use RMSNorm instead of GemmaRMSNorm for GGUF (PR vllm-project#31464)**
GGUF files store RMSNorm weights with +1 baked in (llama.cpp convention).
GemmaRMSNorm adds 1 in its forward pass, causing double addition.
Select plain RMSNorm at construction time for GGUF models. Applied to
all norm layers in Gemma2 and Gemma3 (including q_norm/k_norm).

**Style: compact rms_norm_kwargs pattern (reviewer feedback)**
Use rms_norm_kwargs dict to avoid repeating constructor arguments,
per hmellor's review on PR vllm-project#31464.

Tested on RTX 5090 (Blackwell, SM 120) with gem2-2b-gguf and gemma3-1b.
Supersedes PRs vllm-project#30424, vllm-project#30434, vllm-project#30699, vllm-project#31464.

Signed-off-by: Christina <truffle@gmail.com>
@kitaekatt
Copy link
Copy Markdown
Contributor Author

Closing in favor of consolidated PR #37220, as requested by @Isotr0py in #30434. All fixes from this PR are included in the consolidated version.

@kitaekatt kitaekatt closed this Mar 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants