Skip to content

Commit 037c6fb

Browse files
authored
Attempted Processing match (#1063)
* created individual processing functions * extracted state dict and inserted back into instance after processing * created weight processing shared class * added test coverage for new functions * updated hooked transformer to use new shared functions * created test * moved over weight processing * replaced keys * used the correct function * created test for making sure path translation works correctly * fixed weight processing * added additional tests * formatted tests a bit * cleaned up * fixed unit test * fixed indentation * fixed doc string * fixed unit test * fixed type * fixed some tests * fixed test * fixed setup of tests * cleaned up test * started working through individual matches * added test coverage * tested function a bit * integrated weight conversion into weight proccessing * simplified functions * identified individual problem lines * identified divergences more clearly * brought back error lines
1 parent f88db75 commit 037c6fb

File tree

10 files changed

+2180
-415
lines changed

10 files changed

+2180
-415
lines changed

tests/integration/test_fold_layer_integration.py

Lines changed: 566 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
"""Integration tests for tensor extraction and math function consistency."""
2+
3+
import pytest
4+
import torch
5+
6+
from transformer_lens import HookedTransformer
7+
from transformer_lens.model_bridge import TransformerBridge
8+
from transformer_lens.weight_processing import ProcessWeights
9+
10+
11+
@pytest.fixture(scope="class")
12+
def test_models():
13+
"""Set up test models for consistency testing."""
14+
device = "cpu"
15+
model_name = "gpt2"
16+
17+
# Load HookedTransformer (no processing)
18+
hooked_model = HookedTransformer.from_pretrained(
19+
model_name,
20+
device=device,
21+
fold_ln=False,
22+
center_writing_weights=False,
23+
center_unembed=False
24+
)
25+
26+
# Load TransformerBridge (no processing)
27+
bridge_model = TransformerBridge.boot_transformers(model_name, device=device)
28+
29+
return {
30+
"hooked_model": hooked_model,
31+
"bridge_model": bridge_model,
32+
"hooked_state_dict": hooked_model.state_dict(),
33+
"bridge_state_dict": bridge_model.original_model.state_dict(),
34+
}
35+
36+
37+
class TestTensorExtractionConsistency:
38+
"""Test that tensor extraction returns consistent results between models."""
39+
40+
def test_extract_attention_tensors_shapes_match(self, test_models):
41+
"""Test that extracted tensors have matching shapes."""
42+
layer = 0
43+
44+
hooked_tensors = ProcessWeights.extract_attention_tensors_for_folding(
45+
test_models["hooked_state_dict"],
46+
test_models["hooked_model"].cfg,
47+
layer,
48+
adapter=None
49+
)
50+
51+
bridge_tensors = ProcessWeights.extract_attention_tensors_for_folding(
52+
test_models["bridge_state_dict"],
53+
test_models["bridge_model"].cfg,
54+
layer,
55+
adapter=test_models["bridge_model"].adapter
56+
)
57+
58+
tensor_names = ['wq', 'wk', 'wv', 'bq', 'bk', 'bv', 'ln1_b', 'ln1_w']
59+
60+
for tensor_name in tensor_names:
61+
hooked_tensor = hooked_tensors[tensor_name]
62+
bridge_tensor = bridge_tensors[tensor_name]
63+
64+
if hooked_tensor is None and bridge_tensor is None:
65+
continue
66+
elif hooked_tensor is None or bridge_tensor is None:
67+
pytest.fail(f"{tensor_name}: One is None, other is not")
68+
69+
assert hooked_tensor.shape == bridge_tensor.shape, \
70+
f"{tensor_name} shape mismatch: {hooked_tensor.shape} vs {bridge_tensor.shape}"
71+
72+
def test_extract_attention_tensors_values_match(self, test_models):
73+
"""Test that extracted tensors have matching values."""
74+
layer = 0
75+
76+
hooked_tensors = ProcessWeights.extract_attention_tensors_for_folding(
77+
test_models["hooked_state_dict"],
78+
test_models["hooked_model"].cfg,
79+
layer,
80+
adapter=None
81+
)
82+
83+
bridge_tensors = ProcessWeights.extract_attention_tensors_for_folding(
84+
test_models["bridge_state_dict"],
85+
test_models["bridge_model"].cfg,
86+
layer,
87+
adapter=test_models["bridge_model"].adapter
88+
)
89+
90+
tensor_names = ['wq', 'wk', 'wv', 'bq', 'bk', 'bv', 'ln1_b', 'ln1_w']
91+
92+
for tensor_name in tensor_names:
93+
hooked_tensor = hooked_tensors[tensor_name]
94+
bridge_tensor = bridge_tensors[tensor_name]
95+
96+
if hooked_tensor is None or bridge_tensor is None:
97+
continue
98+
99+
max_diff = torch.max(torch.abs(hooked_tensor - bridge_tensor)).item()
100+
assert max_diff < 1e-6, \
101+
f"{tensor_name} value mismatch: max_diff={max_diff:.2e}"
102+
103+
@pytest.mark.parametrize("component", ['q', 'k', 'v'])
104+
def test_fold_layer_norm_bias_single_consistency(self, test_models, component):
105+
"""Test fold_layer_norm_bias_single consistency for each component."""
106+
layer = 0
107+
108+
hooked_tensors = ProcessWeights.extract_attention_tensors_for_folding(
109+
test_models["hooked_state_dict"],
110+
test_models["hooked_model"].cfg,
111+
layer,
112+
adapter=None
113+
)
114+
115+
bridge_tensors = ProcessWeights.extract_attention_tensors_for_folding(
116+
test_models["bridge_state_dict"],
117+
test_models["bridge_model"].cfg,
118+
layer,
119+
adapter=test_models["bridge_model"].adapter
120+
)
121+
122+
if hooked_tensors['ln1_b'] is None:
123+
pytest.skip("No LayerNorm bias to test")
124+
125+
# Get tensors for the component
126+
w_key = f'w{component}'
127+
b_key = f'b{component}'
128+
129+
hooked_result = ProcessWeights.fold_layer_norm_bias_single(
130+
hooked_tensors[w_key], hooked_tensors[b_key], hooked_tensors['ln1_b']
131+
)
132+
bridge_result = ProcessWeights.fold_layer_norm_bias_single(
133+
bridge_tensors[w_key], bridge_tensors[b_key], bridge_tensors['ln1_b']
134+
)
135+
136+
max_diff = torch.max(torch.abs(hooked_result - bridge_result)).item()
137+
assert max_diff < 1e-6, \
138+
f"fold_layer_norm_bias_single({component}) mismatch: max_diff={max_diff:.2e}"
139+
140+
@pytest.mark.parametrize("component", ['q', 'k', 'v'])
141+
def test_fold_layer_norm_weight_single_consistency(self, test_models, component):
142+
"""Test fold_layer_norm_weight_single consistency for each component."""
143+
layer = 0
144+
145+
hooked_tensors = ProcessWeights.extract_attention_tensors_for_folding(
146+
test_models["hooked_state_dict"],
147+
test_models["hooked_model"].cfg,
148+
layer,
149+
adapter=None
150+
)
151+
152+
bridge_tensors = ProcessWeights.extract_attention_tensors_for_folding(
153+
test_models["bridge_state_dict"],
154+
test_models["bridge_model"].cfg,
155+
layer,
156+
adapter=test_models["bridge_model"].adapter
157+
)
158+
159+
if hooked_tensors['ln1_w'] is None:
160+
pytest.skip("No LayerNorm weight to test")
161+
162+
# Get tensor for the component
163+
w_key = f'w{component}'
164+
165+
hooked_result = ProcessWeights.fold_layer_norm_weight_single(
166+
hooked_tensors[w_key], hooked_tensors['ln1_w']
167+
)
168+
bridge_result = ProcessWeights.fold_layer_norm_weight_single(
169+
bridge_tensors[w_key], bridge_tensors['ln1_w']
170+
)
171+
172+
max_diff = torch.max(torch.abs(hooked_result - bridge_result)).item()
173+
assert max_diff < 1e-6, \
174+
f"fold_layer_norm_weight_single({component}) mismatch: max_diff={max_diff:.2e}"
175+
176+
@pytest.mark.parametrize("component", ['q', 'k', 'v'])
177+
def test_center_weight_single_consistency(self, test_models, component):
178+
"""Test center_weight_single consistency for each component."""
179+
layer = 0
180+
181+
hooked_tensors = ProcessWeights.extract_attention_tensors_for_folding(
182+
test_models["hooked_state_dict"],
183+
test_models["hooked_model"].cfg,
184+
layer,
185+
adapter=None
186+
)
187+
188+
bridge_tensors = ProcessWeights.extract_attention_tensors_for_folding(
189+
test_models["bridge_state_dict"],
190+
test_models["bridge_model"].cfg,
191+
layer,
192+
adapter=test_models["bridge_model"].adapter
193+
)
194+
195+
# Get tensor for the component
196+
w_key = f'w{component}'
197+
198+
hooked_result = ProcessWeights.center_weight_single(hooked_tensors[w_key])
199+
bridge_result = ProcessWeights.center_weight_single(bridge_tensors[w_key])
200+
201+
max_diff = torch.max(torch.abs(hooked_result - bridge_result)).item()
202+
assert max_diff < 1e-6, \
203+
f"center_weight_single({component}) mismatch: max_diff={max_diff:.2e}"
204+
205+
def test_full_processing_pipeline_consistency(self, test_models):
206+
"""Test that the full processing pipeline produces consistent results."""
207+
layer = 0
208+
209+
hooked_tensors = ProcessWeights.extract_attention_tensors_for_folding(
210+
test_models["hooked_state_dict"],
211+
test_models["hooked_model"].cfg,
212+
layer,
213+
adapter=None
214+
)
215+
216+
bridge_tensors = ProcessWeights.extract_attention_tensors_for_folding(
217+
test_models["bridge_state_dict"],
218+
test_models["bridge_model"].cfg,
219+
layer,
220+
adapter=test_models["bridge_model"].adapter
221+
)
222+
223+
if hooked_tensors['ln1_b'] is None or hooked_tensors['ln1_w'] is None:
224+
pytest.skip("No LayerNorm parameters to test full pipeline")
225+
226+
# Apply full processing pipeline
227+
def process_tensors(tensors):
228+
wq, wk, wv = tensors['wq'], tensors['wk'], tensors['wv']
229+
bq, bk, bv = tensors['bq'], tensors['bk'], tensors['bv']
230+
ln1_b, ln1_w = tensors['ln1_b'], tensors['ln1_w']
231+
232+
# Step 1: Fold biases
233+
bq = ProcessWeights.fold_layer_norm_bias_single(wq, bq, ln1_b)
234+
bk = ProcessWeights.fold_layer_norm_bias_single(wk, bk, ln1_b)
235+
bv = ProcessWeights.fold_layer_norm_bias_single(wv, bv, ln1_b)
236+
237+
# Step 2: Fold weights
238+
wq = ProcessWeights.fold_layer_norm_weight_single(wq, ln1_w)
239+
wk = ProcessWeights.fold_layer_norm_weight_single(wk, ln1_w)
240+
wv = ProcessWeights.fold_layer_norm_weight_single(wv, ln1_w)
241+
242+
# Step 3: Center weights
243+
wq = ProcessWeights.center_weight_single(wq)
244+
wk = ProcessWeights.center_weight_single(wk)
245+
wv = ProcessWeights.center_weight_single(wv)
246+
247+
return wq, wk, wv, bq, bk, bv
248+
249+
hooked_final = process_tensors(hooked_tensors)
250+
bridge_final = process_tensors(bridge_tensors)
251+
252+
# Compare final results
253+
components = ['wq', 'wk', 'wv', 'bq', 'bk', 'bv']
254+
255+
for comp, hooked_result, bridge_result in zip(components, hooked_final, bridge_final):
256+
max_diff = torch.max(torch.abs(hooked_result - bridge_result)).item()
257+
assert max_diff < 1e-6, \
258+
f"Full pipeline mismatch for {comp}: max_diff={max_diff:.2e}"
259+
260+
@pytest.mark.parametrize("layer", [0, 1, 2])
261+
def test_multiple_layers_consistency(self, test_models, layer):
262+
"""Test consistency across multiple layers."""
263+
if layer >= test_models["hooked_model"].cfg.n_layers:
264+
pytest.skip(f"Layer {layer} doesn't exist in model")
265+
266+
hooked_tensors = ProcessWeights.extract_attention_tensors_for_folding(
267+
test_models["hooked_state_dict"],
268+
test_models["hooked_model"].cfg,
269+
layer,
270+
adapter=None
271+
)
272+
273+
bridge_tensors = ProcessWeights.extract_attention_tensors_for_folding(
274+
test_models["bridge_state_dict"],
275+
test_models["bridge_model"].cfg,
276+
layer,
277+
adapter=test_models["bridge_model"].adapter
278+
)
279+
280+
# Test that tensors match
281+
tensor_names = ['wq', 'wk', 'wv', 'bq', 'bk', 'bv']
282+
283+
for tensor_name in tensor_names:
284+
hooked_tensor = hooked_tensors[tensor_name]
285+
bridge_tensor = bridge_tensors[tensor_name]
286+
287+
max_diff = torch.max(torch.abs(hooked_tensor - bridge_tensor)).item()
288+
assert max_diff < 1e-6, \
289+
f"Layer {layer}, {tensor_name} mismatch: max_diff={max_diff:.2e}"

0 commit comments

Comments
 (0)