Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 35 additions & 128 deletions transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,63 +698,56 @@ def _configure_components_for_processing(self):

def _load_all_processed_weights(self):
"""Load processed weights into all components (Phase 2)."""
print("Porting reference model embedding components...")
self._port_embedding_components()
print("Loading embedding weights...")
self._load_embedding_weights()

print("Porting reference model transformer block components...")
self._port_transformer_blocks()
print("Loading transformer block weights...")
self._load_transformer_block_weights()

print("Porting reference model unembed component...")
self._port_unembed_component()
print("Loading unembedding weights...")
self._load_unembed_weights()

print("✅ All reference model components ported successfully")
print("✅ All weights loaded successfully")

def _port_embedding_components(self):
"""Port embedding and positional embedding from processed weights."""
def _load_embedding_weights(self):
"""Load embedding and positional embedding weights into components."""
processed_weights = self._processed_tl_weights

# Port token embedding (embed.W_E) - now handled by EmbeddingBridge.set_processed_weight()
# Load token embedding (embed.W_E) into EmbeddingBridge
if hasattr(self, "embed") and "embed.W_E" in processed_weights:
embed_weight = processed_weights["embed.W_E"]
self.embed.set_processed_weight(embed_weight)
print(f" ✅ Token embedding loaded: {embed_weight.shape}")

# Port positional embedding (pos_embed.W_pos)
# Load positional embedding (pos_embed.W_pos) into PosEmbedBridge
if hasattr(self, "pos_embed") and "pos_embed.W_pos" in processed_weights:
pos_embed_weight = processed_weights["pos_embed.W_pos"]
self.pos_embed.set_processed_weight(pos_embed_weight)
print(f" ✅ Positional embedding loaded: {pos_embed_weight.shape}")


def _port_transformer_blocks(self):
"""Port transformer block functionality from processed weights."""
def _load_transformer_block_weights(self):
"""Load transformer block weights into attention and MLP components."""
processed_weights = self._processed_tl_weights

for layer_idx in range(self.cfg.n_layers):
if not hasattr(self, "blocks") or layer_idx >= len(self.blocks):
continue

block = self.blocks[layer_idx]
print(f" Porting layer {layer_idx}...")
print(f" Loading layer {layer_idx} weights...")

# Port attention component
# Load attention weights
if hasattr(block, "attn"):
self._port_attention_component(block.attn, layer_idx, processed_weights)
self._load_attention_weights(block.attn, layer_idx, processed_weights)

# Port MLP component
# Load MLP weights
if hasattr(block, "mlp"):
self._port_mlp_component(block.mlp, layer_idx, processed_weights)
self._load_mlp_weights(block.mlp, layer_idx, processed_weights)

# Port layer norms (should be identity if folded)
if hasattr(block, "ln1"):
self._port_layernorm_component(
block.ln1, f"blocks.{layer_idx}.ln1", processed_weights
)
if hasattr(block, "ln2"):
self._port_layernorm_component(
block.ln2, f"blocks.{layer_idx}.ln2", processed_weights
)

def _port_attention_component(self, attn_component, layer_idx, processed_weights):
"""Port attention component using reference model's exact computation."""
def _load_attention_weights(self, attn_component, layer_idx, processed_weights):
"""Load attention weights into the AttentionBridge component."""
# Get the processed attention weights in TransformerLens format
W_Q_key = f"blocks.{layer_idx}.attn.W_Q"
W_K_key = f"blocks.{layer_idx}.attn.W_K"
Expand All @@ -775,58 +768,10 @@ def _port_attention_component(self, attn_component, layer_idx, processed_weights
b_V = processed_weights.get(b_V_key)
b_O = processed_weights.get(b_O_key)

if W_Q is None or W_K is None or W_V is None or W_O is None:
print(f" ⚠️ Missing attention weights for layer {layer_idx}, skipping port")
return

def attention_forward(x):
"""Direct implementation of reference model's attention computation with hooks."""
batch_size, seq_len, d_model = x.shape

# Compute Q, K, V using TransformerLens format weights
# W_Q shape: [n_heads, d_model, d_head], b_Q shape: [n_heads, d_head]
# x shape: [batch, seq, d_model]
q = torch.einsum("bsd,hdc->bshc", x, W_Q) + b_Q.unsqueeze(0).unsqueeze(0)
k = torch.einsum("bsd,hdc->bshc", x, W_K) + b_K.unsqueeze(0).unsqueeze(0)
v = torch.einsum("bsd,hdc->bshc", x, W_V) + b_V.unsqueeze(0).unsqueeze(0)

# Apply hook for V if it exists (this is what gets ablated in the comparison script)
if hasattr(attn_component, "hook_v"):
v = attn_component.hook_v(v)

# Transpose to [batch, n_heads, seq, d_head] for attention computation
q = q.transpose(1, 2) # [batch, n_heads, seq, d_head]
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# Compute attention scores
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.cfg.d_head**0.5)

# Apply causal mask
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device))
attn_scores = attn_scores.masked_fill(causal_mask == 0, float("-inf"))
attn_component.set_processed_weights(W_Q, W_K, W_V, W_O, b_Q, b_K, b_V, b_O)

# Apply softmax
attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1)

# Apply attention to values
attn_out = torch.matmul(attn_weights, v) # [batch, n_heads, seq, d_head]

# Transpose back to [batch, seq, n_heads, d_head] for output projection
attn_out = attn_out.transpose(1, 2)

# Apply output projection using TransformerLens format
# attn_out: [batch, seq, n_heads, d_head], W_O: [n_heads, d_head, d_model]
result = torch.einsum("bshc,hcd->bsd", attn_out, W_O) + b_O.unsqueeze(0).unsqueeze(0)

return result

# Replace the attention component's forward method
attn_component.forward = attention_forward
print(f" ✅ Attention ported for layer {layer_idx}")

def _port_mlp_component(self, mlp_component, layer_idx, processed_weights):
"""Port MLP component using reference model's exact computation."""
def _load_mlp_weights(self, mlp_component, layer_idx, processed_weights):
"""Load MLP weights into the MLPBridge component."""
W_in_key = f"blocks.{layer_idx}.mlp.W_in"
W_out_key = f"blocks.{layer_idx}.mlp.W_out"
b_in_key = f"blocks.{layer_idx}.mlp.b_in"
Expand All @@ -840,57 +785,19 @@ def _port_mlp_component(self, mlp_component, layer_idx, processed_weights):
if W_in is None or W_out is None:
print(f" ⚠️ Missing MLP weights for layer {layer_idx}, skipping port")
return
mlp_component.set_processed_weights(W_in, W_out, b_in, b_out)
print(f" ✅ MLP set in MLPBridge for layer {layer_idx}")

# Use the new set_processed_weights method if available (integrated into MLPBridge)
if hasattr(mlp_component, 'set_processed_weights'):
mlp_component.set_processed_weights(W_in, W_out, b_in, b_out)
print(f" ✅ MLP set in MLPBridge for layer {layer_idx}")
else:
# Fallback: Replace the bridge MLP component with direct tensor operations
def mlp_forward(x):
"""Port of reference model's MLP forward pass."""
# Input projection using TransformerLens format
hidden = torch.nn.functional.linear(x, W_in.T, b_in)
# Apply activation (GELU for GPT-2)
hidden = torch.nn.functional.gelu(hidden)
# Output projection using TransformerLens format
result = torch.nn.functional.linear(hidden, W_out.T, b_out)
return result

# Replace the MLP component's forward method
mlp_component.forward = mlp_forward
print(f" ✅ MLP ported (fallback) for layer {layer_idx}")

def _port_layernorm_component(self, ln_component, ln_name, processed_weights):
"""Port layer norm component (usually identity when folded)."""

# For folded layer norms, these should be identity operations
def layernorm_forward(x):
# When layer norm is folded, just return input unchanged
return x

ln_component.forward = layernorm_forward
print(f" ✅ LayerNorm {ln_name} ported (identity)")

def _port_unembed_component(self):
"""Port unembedding component from processed weights."""
def _load_unembed_weights(self):
"""Load unembedding weights into the UnembeddingBridge component."""
processed_weights = self._processed_tl_weights

# Port unembedding (unembed.W_U) - now handled by UnembeddingBridge.set_processed_weight()
# Load unembedding (unembed.W_U) into UnembeddingBridge
if hasattr(self, "unembed") and "unembed.W_U" in processed_weights:
W_U = processed_weights["unembed.W_U"]
b_U = processed_weights.get("unembed.b_U")
self.unembed.set_processed_weight(W_U, b_U)

# Also port final layer norm if it exists
if hasattr(self, "ln_final"):

def ln_final_forward(x):
# When layer norm is folded, just return input unchanged
return x

self.ln_final.forward = ln_final_forward
print(f" ✅ Final LayerNorm ported (identity)")
print(f" ✅ Unembedding weights loaded: {W_U.shape}")

def _ported_forward_pass(
self,
Expand Down Expand Up @@ -2949,16 +2856,16 @@ def _load_tl_weights_into_bridge_components(self, tl_state_dict):
block = self.blocks[layer_idx]

# Load attention weights
self._load_attention_weights(block.attn, layer_idx, tl_state_dict)
self._load_attention_weights_from_tl_dict(block.attn, layer_idx, tl_state_dict)

# Load MLP weights
self._load_mlp_weights(block.mlp, layer_idx, tl_state_dict)
self._load_mlp_weights_from_tl_dict(block.mlp, layer_idx, tl_state_dict)

# Layer norms should already be handled by LayerNormPre behavior

print("Finished loading TL weights into bridge components")

def _load_attention_weights(self, attn_component, layer_idx, tl_state_dict):
def _load_attention_weights_from_tl_dict(self, attn_component, layer_idx, tl_state_dict):
"""Load attention weights from TL format into bridge attention component."""
prefix = f"blocks.{layer_idx}.attn"

Expand Down Expand Up @@ -3026,7 +2933,7 @@ def _load_attention_weights(self, attn_component, layer_idx, tl_state_dict):
bias_data = bias_data.flatten()
component.bias.data = bias_data

def _load_mlp_weights(self, mlp_component, layer_idx, tl_state_dict):
def _load_mlp_weights_from_tl_dict(self, mlp_component, layer_idx, tl_state_dict):
"""Load MLP weights from TL format into bridge MLP component."""
prefix = f"blocks.{layer_idx}.mlp"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Forward pass through the attention layer.

This method forwards all arguments to the original component and applies hooks
to the output.
to the output, or uses processed weights if available.

Args:
*args: Input arguments to pass to the original component
Expand All @@ -392,6 +392,10 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
Returns:
The output from the original component, with hooks applied
"""
# Check if we're using processed weights from a reference model (layer norm folding case)
if hasattr(self, '_use_processed_weights') and self._use_processed_weights:
return self._forward_with_processed_weights(*args, **kwargs)

if self.original_component is None:
raise RuntimeError(
f"Original component not set for {self.name}. Call set_original_component() first."
Expand All @@ -413,6 +417,88 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:

return output

def set_processed_weights(self, W_Q: torch.Tensor, W_K: torch.Tensor, W_V: torch.Tensor, W_O: torch.Tensor,
b_Q: Optional[torch.Tensor] = None, b_K: Optional[torch.Tensor] = None,
b_V: Optional[torch.Tensor] = None, b_O: Optional[torch.Tensor] = None) -> None:
"""Set the processed weights to use when layer norm is folded.

Args:
W_Q: Query weight tensor [n_heads, d_model, d_head]
W_K: Key weight tensor [n_heads, d_model, d_head]
W_V: Value weight tensor [n_heads, d_model, d_head]
W_O: Output projection weight tensor [n_heads, d_head, d_model]
b_Q: Query bias tensor [n_heads, d_head] (optional)
b_K: Key bias tensor [n_heads, d_head] (optional)
b_V: Value bias tensor [n_heads, d_head] (optional)
b_O: Output bias tensor [d_model] (optional)
"""
self._processed_W_Q = W_Q
self._processed_W_K = W_K
self._processed_W_V = W_V
self._processed_W_O = W_O
self._processed_b_Q = b_Q
self._processed_b_K = b_K
self._processed_b_V = b_V
self._processed_b_O = b_O
self._use_processed_weights = True

def _forward_with_processed_weights(self, *args: Any, **kwargs: Any) -> torch.Tensor:
"""Direct implementation of reference model's attention computation with hooks."""
# Extract input from args/kwargs
if len(args) > 0 and isinstance(args[0], torch.Tensor):
x = args[0]
elif "hidden_states" in kwargs:
x = kwargs["hidden_states"]
else:
raise ValueError("No valid input tensor found in args or kwargs")

# Apply input hook
x = self.hook_in(x)

batch_size, seq_len, d_model = x.shape

# Compute Q, K, V using TransformerLens format weights
# W_Q shape: [n_heads, d_model, d_head], b_Q shape: [n_heads, d_head]
# x shape: [batch, seq, d_model]
q = torch.einsum("bsd,hdc->bshc", x, self._processed_W_Q) + self._processed_b_Q.unsqueeze(0).unsqueeze(0)
k = torch.einsum("bsd,hdc->bshc", x, self._processed_W_K) + self._processed_b_K.unsqueeze(0).unsqueeze(0)
v = torch.einsum("bsd,hdc->bshc", x, self._processed_W_V) + self._processed_b_V.unsqueeze(0).unsqueeze(0)

# Apply hook for V if it exists (this is what gets ablated in the comparison script)
if hasattr(self, "hook_v"):
v = self.hook_v(v)

# Transpose to [batch, n_heads, seq, d_head] for attention computation
q = q.transpose(1, 2) # [batch, n_heads, seq, d_head]
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# Compute attention scores
d_head = self._processed_W_Q.shape[-1] # Get d_head from weight shape
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (d_head**0.5)

# Apply causal mask
causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device))
attn_scores = attn_scores.masked_fill(causal_mask == 0, float("-inf"))

# Apply softmax
attn_weights = torch.nn.functional.softmax(attn_scores, dim=-1)

# Apply attention to values
attn_out = torch.matmul(attn_weights, v) # [batch, n_heads, seq, d_head]

# Transpose back to [batch, seq, n_heads, d_head] for output projection
attn_out = attn_out.transpose(1, 2)

# Apply output projection using TransformerLens format
# attn_out: [batch, seq, n_heads, d_head], W_O: [n_heads, d_head, d_model]
result = torch.einsum("bshc,hcd->bsd", attn_out, self._processed_W_O) + self._processed_b_O.unsqueeze(0).unsqueeze(0)

# Apply output hook
result = self.hook_out(result)

return result

def get_attention_weights(self) -> Optional[torch.Tensor]:
"""Get cached attention weights if available.

Expand Down
Loading