Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,12 +623,18 @@ def set_compatibility_mode(component: Any) -> None:

if not no_processing:
self.cfg.layer_norm_folding = True

if self.cfg.normalization_type == "RMS":
self.cfg.normalization_type = "RMSPre"
elif self.cfg.normalization_type == "LN":
self.cfg.normalization_type = "LNPre"

def setup_normalization_for_folding(component: Any) -> None:
if hasattr(component, "config") and component.config is not None:
component.config.normalization_type = self.cfg.normalization_type
component.config.layer_norm_folding = True

apply_fn_to_all_components(self, setup_normalization_for_folding)

# Apply weight processing using the centralized ProcessWeights class
self.process_weights(
fold_ln=True,
Expand Down Expand Up @@ -673,12 +679,7 @@ def process_weights(
adapter=self.adapter,
)

# # Step 3: Replace LayerNorm components with LayerNormPre-like operations if fold_ln is True
# This is equivalent to what HookedTransformer does when it replaces LayerNorm with LayerNormPre
if fold_ln:
self._replace_layer_norm_with_identity(self.original_model)

# # Step 4: Load processed weights into the original model using the bridge's load_state_dict method
# # Step 3: Load processed weights into the original model using the bridge's load_state_dict method
# This handles the key mapping between clean keys and _original_component keys
# Use strict=False because weight processing may remove some keys (e.g., individual Q,K,V -> combined QKV)
self.load_state_dict(processed_hf_state_dict, strict=False, assign=True)
Expand Down
Loading