From 62f8cdc35de66249d5d4a516fec605e95e7d6e27 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 26 Aug 2025 07:41:10 +0000 Subject: [PATCH 1/5] Fix Gemma RMSNorm weight init Fix #40224 --- src/transformers/modeling_utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7dc4f9324928..f1203a5ad987 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3085,6 +3085,14 @@ def _init_weights(self, module): module._reset_parameters() # We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names # between modelings (because they are prefixed with the model name) + # Some norms use the weight additively in `...*(1+weight)`, + # so those should init `weight` to 0 (https://github.com/huggingface/transformers/issues/40224). + elif ( + "RMSNorm" in module.__class__.__name__ + and "Gemma" in module.__class__.__name__ + and "Gemma3n" not in module.__class__.__name__ + ): + module.weight.data.zero_() elif ( isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) or "LayerNorm" in module.__class__.__name__ From d3516da1bba7c6aedbb8762018b3621e278c2f9b Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 30 Sep 2025 11:53:39 +0000 Subject: [PATCH 2/5] revert --- src/transformers/modeling_utils.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f1203a5ad987..7dc4f9324928 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3085,14 +3085,6 @@ def _init_weights(self, module): module._reset_parameters() # We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names # between modelings (because they are prefixed with the model name) - # Some norms use the weight additively in `...*(1+weight)`, - # so those should init `weight` to 0 (https://github.com/huggingface/transformers/issues/40224). - elif ( - "RMSNorm" in module.__class__.__name__ - and "Gemma" in module.__class__.__name__ - and "Gemma3n" not in module.__class__.__name__ - ): - module.weight.data.zero_() elif ( isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) or "LayerNorm" in module.__class__.__name__ From fba4f350bab7b1ef08960caf9346559bcff3def2 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 30 Sep 2025 11:54:00 +0000 Subject: [PATCH 3/5] GemmaModel _init_weights --- src/transformers/models/gemma/modular_gemma.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 67aedbd55115..d61018b790ce 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -367,6 +367,14 @@ def __init__(self, config): class GemmaModel(LlamaModel): + def _init_weights(self, module): + if isinstance(module, GemmaRMSNorm): + # The norm uses the weight additively as in `...*(1+weight)`, + # so those should init `weight` to 0 (https://github.com/huggingface/transformers/issues/40224). + module.weight.zero_() + else: + super()._init_weights(module) + def forward( self, input_ids: Optional[torch.LongTensor] = None, From 984a8fc11431a4a080ab82b3461ab38fcdadf5ac Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 30 Sep 2025 12:01:46 +0000 Subject: [PATCH 4/5] fix recgemma param init --- .../models/recurrent_gemma/modeling_recurrent_gemma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index daef714ab883..a7b45379a14c 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -558,7 +558,7 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() elif isinstance(module, RecurrentGemmaRMSNorm): - module.weight.data.fill_(1.0) + module.weight.data.zero_() def _setup_cache(self, config, batch, device, dtype): layers = getattr(self, "model", self).layers From c96a2b3fe8f0ace342e8ee0a2654a695dcd281d8 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 30 Sep 2025 12:12:03 +0000 Subject: [PATCH 5/5] revert --- src/transformers/models/gemma/modular_gemma.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index aac78d897676..e17545822da8 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -408,14 +408,6 @@ def _init_weights(self, module): class GemmaModel(LlamaModel): - def _init_weights(self, module): - if isinstance(module, GemmaRMSNorm): - # The norm uses the weight additively as in `...*(1+weight)`, - # so those should init `weight` to 0 (https://github.com/huggingface/transformers/issues/40224). - module.weight.zero_() - else: - super()._init_weights(module) - def forward( self, input_ids: Optional[torch.LongTensor] = None,