From c1a867ad8800af4e0b06aec45eabee9e3996fbb5 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Mon, 29 Sep 2025 22:45:08 +0200 Subject: [PATCH 1/5] moved final layer norm --- transformer_lens/model_bridge/bridge.py | 9 --- .../generalized_components/__init__.py | 4 + .../final_normalization.py | 78 +++++++++++++++++++ .../supported_architectures/gpt2.py | 1 + 4 files changed, 83 insertions(+), 9 deletions(-) create mode 100644 transformer_lens/model_bridge/generalized_components/final_normalization.py diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index ae1e3bec0..c9784494a 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -882,15 +882,6 @@ def _port_unembed_component(self): b_U = processed_weights.get("unembed.b_U") self.unembed.set_processed_weight(W_U, b_U) - # Also port final layer norm if it exists - if hasattr(self, "ln_final"): - - def ln_final_forward(x): - # When layer norm is folded, just return input unchanged - return x - - self.ln_final.forward = ln_final_forward - print(f" ✅ Final LayerNorm ported (identity)") def _ported_forward_pass( self, diff --git a/transformer_lens/model_bridge/generalized_components/__init__.py b/transformer_lens/model_bridge/generalized_components/__init__.py index 252095f15..57e72f522 100644 --- a/transformer_lens/model_bridge/generalized_components/__init__.py +++ b/transformer_lens/model_bridge/generalized_components/__init__.py @@ -15,6 +15,9 @@ from transformer_lens.model_bridge.generalized_components.normalization import ( NormalizationBridge, ) +from transformer_lens.model_bridge.generalized_components.final_normalization import ( + FinalNormalizationBridge, +) from transformer_lens.model_bridge.generalized_components.linear import ( LinearBridge, @@ -37,6 +40,7 @@ "EmbeddingBridge", "PosEmbedBridge", "NormalizationBridge", + "FinalNormalizationBridge", "JointQKVAttentionBridge", "JointGateUpMLPBridge", "LinearBridge", diff --git a/transformer_lens/model_bridge/generalized_components/final_normalization.py b/transformer_lens/model_bridge/generalized_components/final_normalization.py new file mode 100644 index 000000000..1576feac1 --- /dev/null +++ b/transformer_lens/model_bridge/generalized_components/final_normalization.py @@ -0,0 +1,78 @@ +"""Final normalization bridge component implementation.""" + +from typing import Any, Dict, Optional + +import torch + +from transformer_lens.model_bridge.generalized_components.base import ( + GeneralizedComponent, +) +from transformer_lens.model_bridge.generalized_components.normalization import ( + NormalizationBridge, +) + + +class FinalNormalizationBridge(NormalizationBridge): + """Final layer normalization bridge that behaves as identity when weights are folded. + + This component extends NormalizationBridge and overrides the forward method to return + identity (input unchanged) when layer norm folding is enabled, otherwise falls back + to the standard normalization functionality. + """ + + def forward( + self, + hidden_states: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + """Forward pass through the final normalization bridge. + + Args: + hidden_states: Input hidden states + **kwargs: Additional arguments to pass to the original component + + Returns: + Normalized output or identity if folded + """ + if self.original_component is None: + raise RuntimeError( + f"Original component not set for {self.name}. Call set_original_component() first." + ) + + # keep mypy happy + assert self.config is not None + + # Check if layer norm folding is enabled - if so, behave as identity + if hasattr(self.config, "layer_norm_folding") and self.config.layer_norm_folding: + # Final layer norm becomes identity when folding is enabled + # (weights are absorbed into other components during processing) + # Simply return the input unchanged (identity function) + return hidden_states + else: + # Fall back to standard normalization behavior + return super().forward(hidden_states, **kwargs) + + @classmethod + def create_final_normalization_bridge( + cls, + name: str, + config: Any, + original_component: Any, + ) -> "FinalNormalizationBridge": + """Create a final normalization bridge for final layer norm components. + + Args: + name: The name of this component + config: Configuration object + original_component: The original layer norm component + + Returns: + FinalNormalizationBridge that behaves as identity when folding is enabled + """ + # Create the bridge + bridge = cls(name=name, config=config) + + # Set the original component + bridge.set_original_component(original_component) + + return bridge \ No newline at end of file diff --git a/transformer_lens/model_bridge/supported_architectures/gpt2.py b/transformer_lens/model_bridge/supported_architectures/gpt2.py index 6dbd8b530..5917bf037 100644 --- a/transformer_lens/model_bridge/supported_architectures/gpt2.py +++ b/transformer_lens/model_bridge/supported_architectures/gpt2.py @@ -18,6 +18,7 @@ LinearBridge, MLPBridge, NormalizationBridge, + FinalNormalizationBridge, UnembeddingBridge, ) From 8ffd2ae83e756739d300c81f378fd134cb033d34 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Mon, 29 Sep 2025 23:08:40 +0200 Subject: [PATCH 2/5] moved layer norm forward --- transformer_lens/model_bridge/bridge.py | 27 ++----- .../generalized_components/__init__.py | 4 - .../final_normalization.py | 78 ------------------- .../supported_architectures/gpt2.py | 1 - 4 files changed, 7 insertions(+), 103 deletions(-) delete mode 100644 transformer_lens/model_bridge/generalized_components/final_normalization.py diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index c9784494a..7f17a12b9 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -743,15 +743,6 @@ def _port_transformer_blocks(self): if hasattr(block, "mlp"): self._port_mlp_component(block.mlp, layer_idx, processed_weights) - # Port layer norms (should be identity if folded) - if hasattr(block, "ln1"): - self._port_layernorm_component( - block.ln1, f"blocks.{layer_idx}.ln1", processed_weights - ) - if hasattr(block, "ln2"): - self._port_layernorm_component( - block.ln2, f"blocks.{layer_idx}.ln2", processed_weights - ) def _port_attention_component(self, attn_component, layer_idx, processed_weights): """Port attention component using reference model's exact computation.""" @@ -861,17 +852,6 @@ def mlp_forward(x): mlp_component.forward = mlp_forward print(f" ✅ MLP ported (fallback) for layer {layer_idx}") - def _port_layernorm_component(self, ln_component, ln_name, processed_weights): - """Port layer norm component (usually identity when folded).""" - - # For folded layer norms, these should be identity operations - def layernorm_forward(x): - # When layer norm is folded, just return input unchanged - return x - - ln_component.forward = layernorm_forward - print(f" ✅ LayerNorm {ln_name} ported (identity)") - def _port_unembed_component(self): """Port unembedding component from processed weights.""" processed_weights = self._processed_tl_weights @@ -882,6 +862,13 @@ def _port_unembed_component(self): b_U = processed_weights.get("unembed.b_U") self.unembed.set_processed_weight(W_U, b_U) + # Also port final layer norm if it exists + if hasattr(self, "ln_final"): + def ln_final_forward(x): + # When layer norm is folded, just return input unchanged + return x + self.ln_final.forward = ln_final_forward + print(f" ✅ Final LayerNorm ported (identity)") def _ported_forward_pass( self, diff --git a/transformer_lens/model_bridge/generalized_components/__init__.py b/transformer_lens/model_bridge/generalized_components/__init__.py index 57e72f522..252095f15 100644 --- a/transformer_lens/model_bridge/generalized_components/__init__.py +++ b/transformer_lens/model_bridge/generalized_components/__init__.py @@ -15,9 +15,6 @@ from transformer_lens.model_bridge.generalized_components.normalization import ( NormalizationBridge, ) -from transformer_lens.model_bridge.generalized_components.final_normalization import ( - FinalNormalizationBridge, -) from transformer_lens.model_bridge.generalized_components.linear import ( LinearBridge, @@ -40,7 +37,6 @@ "EmbeddingBridge", "PosEmbedBridge", "NormalizationBridge", - "FinalNormalizationBridge", "JointQKVAttentionBridge", "JointGateUpMLPBridge", "LinearBridge", diff --git a/transformer_lens/model_bridge/generalized_components/final_normalization.py b/transformer_lens/model_bridge/generalized_components/final_normalization.py deleted file mode 100644 index 1576feac1..000000000 --- a/transformer_lens/model_bridge/generalized_components/final_normalization.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Final normalization bridge component implementation.""" - -from typing import Any, Dict, Optional - -import torch - -from transformer_lens.model_bridge.generalized_components.base import ( - GeneralizedComponent, -) -from transformer_lens.model_bridge.generalized_components.normalization import ( - NormalizationBridge, -) - - -class FinalNormalizationBridge(NormalizationBridge): - """Final layer normalization bridge that behaves as identity when weights are folded. - - This component extends NormalizationBridge and overrides the forward method to return - identity (input unchanged) when layer norm folding is enabled, otherwise falls back - to the standard normalization functionality. - """ - - def forward( - self, - hidden_states: torch.Tensor, - **kwargs: Any, - ) -> torch.Tensor: - """Forward pass through the final normalization bridge. - - Args: - hidden_states: Input hidden states - **kwargs: Additional arguments to pass to the original component - - Returns: - Normalized output or identity if folded - """ - if self.original_component is None: - raise RuntimeError( - f"Original component not set for {self.name}. Call set_original_component() first." - ) - - # keep mypy happy - assert self.config is not None - - # Check if layer norm folding is enabled - if so, behave as identity - if hasattr(self.config, "layer_norm_folding") and self.config.layer_norm_folding: - # Final layer norm becomes identity when folding is enabled - # (weights are absorbed into other components during processing) - # Simply return the input unchanged (identity function) - return hidden_states - else: - # Fall back to standard normalization behavior - return super().forward(hidden_states, **kwargs) - - @classmethod - def create_final_normalization_bridge( - cls, - name: str, - config: Any, - original_component: Any, - ) -> "FinalNormalizationBridge": - """Create a final normalization bridge for final layer norm components. - - Args: - name: The name of this component - config: Configuration object - original_component: The original layer norm component - - Returns: - FinalNormalizationBridge that behaves as identity when folding is enabled - """ - # Create the bridge - bridge = cls(name=name, config=config) - - # Set the original component - bridge.set_original_component(original_component) - - return bridge \ No newline at end of file diff --git a/transformer_lens/model_bridge/supported_architectures/gpt2.py b/transformer_lens/model_bridge/supported_architectures/gpt2.py index 5917bf037..6dbd8b530 100644 --- a/transformer_lens/model_bridge/supported_architectures/gpt2.py +++ b/transformer_lens/model_bridge/supported_architectures/gpt2.py @@ -18,7 +18,6 @@ LinearBridge, MLPBridge, NormalizationBridge, - FinalNormalizationBridge, UnembeddingBridge, ) From c0b392bf352c393b21c17dd6d8c0de205bfb97c2 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Mon, 29 Sep 2025 23:10:17 +0200 Subject: [PATCH 3/5] cleaned up more things --- transformer_lens/model_bridge/bridge.py | 30 ++----------------------- 1 file changed, 2 insertions(+), 28 deletions(-) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 7f17a12b9..d0de874e6 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -831,26 +831,8 @@ def _port_mlp_component(self, mlp_component, layer_idx, processed_weights): if W_in is None or W_out is None: print(f" ⚠️ Missing MLP weights for layer {layer_idx}, skipping port") return - - # Use the new set_processed_weights method if available (integrated into MLPBridge) - if hasattr(mlp_component, 'set_processed_weights'): - mlp_component.set_processed_weights(W_in, W_out, b_in, b_out) - print(f" ✅ MLP set in MLPBridge for layer {layer_idx}") - else: - # Fallback: Replace the bridge MLP component with direct tensor operations - def mlp_forward(x): - """Port of reference model's MLP forward pass.""" - # Input projection using TransformerLens format - hidden = torch.nn.functional.linear(x, W_in.T, b_in) - # Apply activation (GELU for GPT-2) - hidden = torch.nn.functional.gelu(hidden) - # Output projection using TransformerLens format - result = torch.nn.functional.linear(hidden, W_out.T, b_out) - return result - - # Replace the MLP component's forward method - mlp_component.forward = mlp_forward - print(f" ✅ MLP ported (fallback) for layer {layer_idx}") + mlp_component.set_processed_weights(W_in, W_out, b_in, b_out) + print(f" ✅ MLP set in MLPBridge for layer {layer_idx}") def _port_unembed_component(self): """Port unembedding component from processed weights.""" @@ -862,14 +844,6 @@ def _port_unembed_component(self): b_U = processed_weights.get("unembed.b_U") self.unembed.set_processed_weight(W_U, b_U) - # Also port final layer norm if it exists - if hasattr(self, "ln_final"): - def ln_final_forward(x): - # When layer norm is folded, just return input unchanged - return x - self.ln_final.forward = ln_final_forward - print(f" ✅ Final LayerNorm ported (identity)") - def _ported_forward_pass( self, input: Union[str, List[str], torch.Tensor], From 4536b3a8b9a84612030ede2b150401e32dfc02cb Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Mon, 29 Sep 2025 23:17:52 +0200 Subject: [PATCH 4/5] updated attention weight loading --- transformer_lens/model_bridge/bridge.py | 50 +---------- .../generalized_components/attention.py | 88 ++++++++++++++++++- 2 files changed, 88 insertions(+), 50 deletions(-) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index d0de874e6..f70b8c886 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -766,55 +766,7 @@ def _port_attention_component(self, attn_component, layer_idx, processed_weights b_V = processed_weights.get(b_V_key) b_O = processed_weights.get(b_O_key) - if W_Q is None or W_K is None or W_V is None or W_O is None: - print(f" ⚠️ Missing attention weights for layer {layer_idx}, skipping port") - return - - def attention_forward(x): - """Direct implementation of reference model's attention computation with hooks.""" - batch_size, seq_len, d_model = x.shape - - # Compute Q, K, V using TransformerLens format weights - # W_Q shape: [n_heads, d_model, d_head], b_Q shape: [n_heads, d_head] - # x shape: [batch, seq, d_model] - q = torch.einsum("bsd,hdc->bshc", x, W_Q) + b_Q.unsqueeze(0).unsqueeze(0) - k = torch.einsum("bsd,hdc->bshc", x, W_K) + b_K.unsqueeze(0).unsqueeze(0) - v = torch.einsum("bsd,hdc->bshc", x, W_V) + b_V.unsqueeze(0).unsqueeze(0) - - # Apply hook for V if it exists (this is what gets ablated in the comparison script) - if hasattr(attn_component, "hook_v"): - v = attn_component.hook_v(v) - - # Transpose to [batch, n_heads, seq, d_head] for attention computation - q = q.transpose(1, 2) # [batch, n_heads, seq, d_head] - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - # Compute attention scores - attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.cfg.d_head**0.5) - - # Apply causal mask - causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)) - attn_scores = attn_scores.masked_fill(causal_mask == 0, float("-inf")) - - # Apply softmax - attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1) - - # Apply attention to values - attn_out = torch.matmul(attn_weights, v) # [batch, n_heads, seq, d_head] - - # Transpose back to [batch, seq, n_heads, d_head] for output projection - attn_out = attn_out.transpose(1, 2) - - # Apply output projection using TransformerLens format - # attn_out: [batch, seq, n_heads, d_head], W_O: [n_heads, d_head, d_model] - result = torch.einsum("bshc,hcd->bsd", attn_out, W_O) + b_O.unsqueeze(0).unsqueeze(0) - - return result - - # Replace the attention component's forward method - attn_component.forward = attention_forward - print(f" ✅ Attention ported for layer {layer_idx}") + attn_component.set_processed_weights(W_Q, W_K, W_V, W_O, b_Q, b_K, b_V, b_O) def _port_mlp_component(self, mlp_component, layer_idx, processed_weights): """Port MLP component using reference model's exact computation.""" diff --git a/transformer_lens/model_bridge/generalized_components/attention.py b/transformer_lens/model_bridge/generalized_components/attention.py index 7a54217cd..063a7e3c8 100644 --- a/transformer_lens/model_bridge/generalized_components/attention.py +++ b/transformer_lens/model_bridge/generalized_components/attention.py @@ -383,7 +383,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: """Forward pass through the attention layer. This method forwards all arguments to the original component and applies hooks - to the output. + to the output, or uses processed weights if available. Args: *args: Input arguments to pass to the original component @@ -392,6 +392,10 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: Returns: The output from the original component, with hooks applied """ + # Check if we're using processed weights from a reference model (layer norm folding case) + if hasattr(self, '_use_processed_weights') and self._use_processed_weights: + return self._forward_with_processed_weights(*args, **kwargs) + if self.original_component is None: raise RuntimeError( f"Original component not set for {self.name}. Call set_original_component() first." @@ -413,6 +417,88 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: return output + def set_processed_weights(self, W_Q: torch.Tensor, W_K: torch.Tensor, W_V: torch.Tensor, W_O: torch.Tensor, + b_Q: Optional[torch.Tensor] = None, b_K: Optional[torch.Tensor] = None, + b_V: Optional[torch.Tensor] = None, b_O: Optional[torch.Tensor] = None) -> None: + """Set the processed weights to use when layer norm is folded. + + Args: + W_Q: Query weight tensor [n_heads, d_model, d_head] + W_K: Key weight tensor [n_heads, d_model, d_head] + W_V: Value weight tensor [n_heads, d_model, d_head] + W_O: Output projection weight tensor [n_heads, d_head, d_model] + b_Q: Query bias tensor [n_heads, d_head] (optional) + b_K: Key bias tensor [n_heads, d_head] (optional) + b_V: Value bias tensor [n_heads, d_head] (optional) + b_O: Output bias tensor [d_model] (optional) + """ + self._processed_W_Q = W_Q + self._processed_W_K = W_K + self._processed_W_V = W_V + self._processed_W_O = W_O + self._processed_b_Q = b_Q + self._processed_b_K = b_K + self._processed_b_V = b_V + self._processed_b_O = b_O + self._use_processed_weights = True + + def _forward_with_processed_weights(self, *args: Any, **kwargs: Any) -> torch.Tensor: + """Direct implementation of reference model's attention computation with hooks.""" + # Extract input from args/kwargs + if len(args) > 0 and isinstance(args[0], torch.Tensor): + x = args[0] + elif "hidden_states" in kwargs: + x = kwargs["hidden_states"] + else: + raise ValueError("No valid input tensor found in args or kwargs") + + # Apply input hook + x = self.hook_in(x) + + batch_size, seq_len, d_model = x.shape + + # Compute Q, K, V using TransformerLens format weights + # W_Q shape: [n_heads, d_model, d_head], b_Q shape: [n_heads, d_head] + # x shape: [batch, seq, d_model] + q = torch.einsum("bsd,hdc->bshc", x, self._processed_W_Q) + self._processed_b_Q.unsqueeze(0).unsqueeze(0) + k = torch.einsum("bsd,hdc->bshc", x, self._processed_W_K) + self._processed_b_K.unsqueeze(0).unsqueeze(0) + v = torch.einsum("bsd,hdc->bshc", x, self._processed_W_V) + self._processed_b_V.unsqueeze(0).unsqueeze(0) + + # Apply hook for V if it exists (this is what gets ablated in the comparison script) + if hasattr(self, "hook_v"): + v = self.hook_v(v) + + # Transpose to [batch, n_heads, seq, d_head] for attention computation + q = q.transpose(1, 2) # [batch, n_heads, seq, d_head] + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # Compute attention scores + d_head = self._processed_W_Q.shape[-1] # Get d_head from weight shape + attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (d_head**0.5) + + # Apply causal mask + causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)) + attn_scores = attn_scores.masked_fill(causal_mask == 0, float("-inf")) + + # Apply softmax + attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1) + + # Apply attention to values + attn_out = torch.matmul(attn_weights, v) # [batch, n_heads, seq, d_head] + + # Transpose back to [batch, seq, n_heads, d_head] for output projection + attn_out = attn_out.transpose(1, 2) + + # Apply output projection using TransformerLens format + # attn_out: [batch, seq, n_heads, d_head], W_O: [n_heads, d_head, d_model] + result = torch.einsum("bshc,hcd->bsd", attn_out, self._processed_W_O) + self._processed_b_O.unsqueeze(0).unsqueeze(0) + + # Apply output hook + result = self.hook_out(result) + + return result + def get_attention_weights(self) -> Optional[torch.Tensor]: """Get cached attention weights if available. From 204ef51bc3374a96d6062897a8db03bdb24da90f Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Tue, 30 Sep 2025 08:47:58 +0200 Subject: [PATCH 5/5] fixed function names --- transformer_lens/model_bridge/bridge.py | 61 +++++++++++++------------ 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index f70b8c886..3c2c1a83c 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -698,34 +698,36 @@ def _configure_components_for_processing(self): def _load_all_processed_weights(self): """Load processed weights into all components (Phase 2).""" - print("Porting reference model embedding components...") - self._port_embedding_components() + print("Loading embedding weights...") + self._load_embedding_weights() - print("Porting reference model transformer block components...") - self._port_transformer_blocks() + print("Loading transformer block weights...") + self._load_transformer_block_weights() - print("Porting reference model unembed component...") - self._port_unembed_component() + print("Loading unembedding weights...") + self._load_unembed_weights() - print("✅ All reference model components ported successfully") + print("✅ All weights loaded successfully") - def _port_embedding_components(self): - """Port embedding and positional embedding from processed weights.""" + def _load_embedding_weights(self): + """Load embedding and positional embedding weights into components.""" processed_weights = self._processed_tl_weights - # Port token embedding (embed.W_E) - now handled by EmbeddingBridge.set_processed_weight() + # Load token embedding (embed.W_E) into EmbeddingBridge if hasattr(self, "embed") and "embed.W_E" in processed_weights: embed_weight = processed_weights["embed.W_E"] self.embed.set_processed_weight(embed_weight) + print(f" ✅ Token embedding loaded: {embed_weight.shape}") - # Port positional embedding (pos_embed.W_pos) + # Load positional embedding (pos_embed.W_pos) into PosEmbedBridge if hasattr(self, "pos_embed") and "pos_embed.W_pos" in processed_weights: pos_embed_weight = processed_weights["pos_embed.W_pos"] self.pos_embed.set_processed_weight(pos_embed_weight) + print(f" ✅ Positional embedding loaded: {pos_embed_weight.shape}") - def _port_transformer_blocks(self): - """Port transformer block functionality from processed weights.""" + def _load_transformer_block_weights(self): + """Load transformer block weights into attention and MLP components.""" processed_weights = self._processed_tl_weights for layer_idx in range(self.cfg.n_layers): @@ -733,19 +735,19 @@ def _port_transformer_blocks(self): continue block = self.blocks[layer_idx] - print(f" Porting layer {layer_idx}...") + print(f" Loading layer {layer_idx} weights...") - # Port attention component + # Load attention weights if hasattr(block, "attn"): - self._port_attention_component(block.attn, layer_idx, processed_weights) + self._load_attention_weights(block.attn, layer_idx, processed_weights) - # Port MLP component + # Load MLP weights if hasattr(block, "mlp"): - self._port_mlp_component(block.mlp, layer_idx, processed_weights) + self._load_mlp_weights(block.mlp, layer_idx, processed_weights) - def _port_attention_component(self, attn_component, layer_idx, processed_weights): - """Port attention component using reference model's exact computation.""" + def _load_attention_weights(self, attn_component, layer_idx, processed_weights): + """Load attention weights into the AttentionBridge component.""" # Get the processed attention weights in TransformerLens format W_Q_key = f"blocks.{layer_idx}.attn.W_Q" W_K_key = f"blocks.{layer_idx}.attn.W_K" @@ -768,8 +770,8 @@ def _port_attention_component(self, attn_component, layer_idx, processed_weights attn_component.set_processed_weights(W_Q, W_K, W_V, W_O, b_Q, b_K, b_V, b_O) - def _port_mlp_component(self, mlp_component, layer_idx, processed_weights): - """Port MLP component using reference model's exact computation.""" + def _load_mlp_weights(self, mlp_component, layer_idx, processed_weights): + """Load MLP weights into the MLPBridge component.""" W_in_key = f"blocks.{layer_idx}.mlp.W_in" W_out_key = f"blocks.{layer_idx}.mlp.W_out" b_in_key = f"blocks.{layer_idx}.mlp.b_in" @@ -786,15 +788,16 @@ def _port_mlp_component(self, mlp_component, layer_idx, processed_weights): mlp_component.set_processed_weights(W_in, W_out, b_in, b_out) print(f" ✅ MLP set in MLPBridge for layer {layer_idx}") - def _port_unembed_component(self): - """Port unembedding component from processed weights.""" + def _load_unembed_weights(self): + """Load unembedding weights into the UnembeddingBridge component.""" processed_weights = self._processed_tl_weights - # Port unembedding (unembed.W_U) - now handled by UnembeddingBridge.set_processed_weight() + # Load unembedding (unembed.W_U) into UnembeddingBridge if hasattr(self, "unembed") and "unembed.W_U" in processed_weights: W_U = processed_weights["unembed.W_U"] b_U = processed_weights.get("unembed.b_U") self.unembed.set_processed_weight(W_U, b_U) + print(f" ✅ Unembedding weights loaded: {W_U.shape}") def _ported_forward_pass( self, @@ -2853,16 +2856,16 @@ def _load_tl_weights_into_bridge_components(self, tl_state_dict): block = self.blocks[layer_idx] # Load attention weights - self._load_attention_weights(block.attn, layer_idx, tl_state_dict) + self._load_attention_weights_from_tl_dict(block.attn, layer_idx, tl_state_dict) # Load MLP weights - self._load_mlp_weights(block.mlp, layer_idx, tl_state_dict) + self._load_mlp_weights_from_tl_dict(block.mlp, layer_idx, tl_state_dict) # Layer norms should already be handled by LayerNormPre behavior print("Finished loading TL weights into bridge components") - def _load_attention_weights(self, attn_component, layer_idx, tl_state_dict): + def _load_attention_weights_from_tl_dict(self, attn_component, layer_idx, tl_state_dict): """Load attention weights from TL format into bridge attention component.""" prefix = f"blocks.{layer_idx}.attn" @@ -2930,7 +2933,7 @@ def _load_attention_weights(self, attn_component, layer_idx, tl_state_dict): bias_data = bias_data.flatten() component.bias.data = bias_data - def _load_mlp_weights(self, mlp_component, layer_idx, tl_state_dict): + def _load_mlp_weights_from_tl_dict(self, mlp_component, layer_idx, tl_state_dict): """Load MLP weights from TL format into bridge MLP component.""" prefix = f"blocks.{layer_idx}.mlp"