Skip to content

Commit f903629

Browse files
Test: Add unit test to show mps vs cpu diff
1 parent d4872f0 commit f903629

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

tests/unit/components/test_attention.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,45 @@ def test_remove_einsum_from_complex_attn_linear():
128128

129129
# Check if the results are the same
130130
assert torch.allclose(result_new, result_old, atol=1e-4)
131+
132+
133+
@pytest.mark.skipif(
134+
not torch.backends.mps.is_available() and torch.__version__ != "2.8.0",
135+
reason="Issue with F.linear issue exclusive to mps and PyTorch 2.8"
136+
"https://github.com/pytorch/pytorch/issues/161640",
137+
)
138+
def test_cpu_mps_outputs_match():
139+
torch.manual_seed(0)
140+
141+
cfg = {
142+
"n_layers": 1,
143+
"d_model": 48,
144+
"n_ctx": 256,
145+
"d_head": 16,
146+
"n_heads": 3,
147+
"load_in_4bit": False,
148+
"dtype": torch.float32,
149+
"act_fn": "relu",
150+
}
151+
152+
def init_weights(attn_layer: nn.Module):
153+
nn.init.normal_(attn_layer.W_Q, mean=0.0, std=0.02)
154+
nn.init.normal_(attn_layer.W_K, mean=0.0, std=0.02)
155+
nn.init.normal_(attn_layer.W_V, mean=0.0, std=0.02)
156+
nn.init.normal_(attn_layer.W_O, mean=0.0, std=0.02)
157+
return attn_layer
158+
159+
attn_cpu = Attention(cfg)
160+
attn_cpu = init_weights(attn_cpu)
161+
162+
attn_mps = Attention(cfg).to("mps")
163+
attn_mps.load_state_dict(attn_cpu.state_dict(), strict=True)
164+
165+
batch = 1
166+
input_cpu = torch.randn(batch, cfg["n_ctx"], cfg["d_model"])
167+
input_mps = input_cpu.to("mps")
168+
169+
cpu_output = attn_cpu(input_cpu, input_cpu, input_cpu)
170+
mps_output = attn_mps(input_mps, input_mps, input_mps)
171+
172+
assert torch.allclose(cpu_output, mps_output.cpu())

0 commit comments

Comments
 (0)