From ed38719595e8c67ec726260388ae892e6ee71df6 Mon Sep 17 00:00:00 2001 From: degenfabian Date: Sun, 21 Sep 2025 18:21:23 +0100 Subject: [PATCH] Properly set up normalization_type and layer_norm_folding attributes for already initialized components --- transformer_lens/model_bridge/bridge.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 52f1d6cfa..197173538 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -623,12 +623,18 @@ def set_compatibility_mode(component: Any) -> None: if not no_processing: self.cfg.layer_norm_folding = True - if self.cfg.normalization_type == "RMS": self.cfg.normalization_type = "RMSPre" elif self.cfg.normalization_type == "LN": self.cfg.normalization_type = "LNPre" + def setup_normalization_for_folding(component: Any) -> None: + if hasattr(component, "config") and component.config is not None: + component.config.normalization_type = self.cfg.normalization_type + component.config.layer_norm_folding = True + + apply_fn_to_all_components(self, setup_normalization_for_folding) + # Apply weight processing using the centralized ProcessWeights class self.process_weights( fold_ln=True, @@ -673,12 +679,7 @@ def process_weights( adapter=self.adapter, ) - # # Step 3: Replace LayerNorm components with LayerNormPre-like operations if fold_ln is True - # This is equivalent to what HookedTransformer does when it replaces LayerNorm with LayerNormPre - if fold_ln: - self._replace_layer_norm_with_identity(self.original_model) - - # # Step 4: Load processed weights into the original model using the bridge's load_state_dict method + # # Step 3: Load processed weights into the original model using the bridge's load_state_dict method # This handles the key mapping between clean keys and _original_component keys # Use strict=False because weight processing may remove some keys (e.g., individual Q,K,V -> combined QKV) self.load_state_dict(processed_hf_state_dict, strict=False, assign=True)