@@ -128,3 +128,45 @@ def test_remove_einsum_from_complex_attn_linear():
128
128
129
129
# Check if the results are the same
130
130
assert torch .allclose (result_new , result_old , atol = 1e-4 )
131
+
132
+
133
+ @pytest .mark .skipif (
134
+ not torch .backends .mps .is_available () and torch .__version__ != "2.8.0" ,
135
+ reason = "Issue with F.linear issue exclusive to mps and PyTorch 2.8"
136
+ "https://github.com/pytorch/pytorch/issues/161640" ,
137
+ )
138
+ def test_cpu_mps_outputs_match ():
139
+ torch .manual_seed (0 )
140
+
141
+ cfg = {
142
+ "n_layers" : 1 ,
143
+ "d_model" : 48 ,
144
+ "n_ctx" : 256 ,
145
+ "d_head" : 16 ,
146
+ "n_heads" : 3 ,
147
+ "load_in_4bit" : False ,
148
+ "dtype" : torch .float32 ,
149
+ "act_fn" : "relu" ,
150
+ }
151
+
152
+ def init_weights (attn_layer : nn .Module ):
153
+ nn .init .normal_ (attn_layer .W_Q , mean = 0.0 , std = 0.02 )
154
+ nn .init .normal_ (attn_layer .W_K , mean = 0.0 , std = 0.02 )
155
+ nn .init .normal_ (attn_layer .W_V , mean = 0.0 , std = 0.02 )
156
+ nn .init .normal_ (attn_layer .W_O , mean = 0.0 , std = 0.02 )
157
+ return attn_layer
158
+
159
+ attn_cpu = Attention (cfg )
160
+ attn_cpu = init_weights (attn_cpu )
161
+
162
+ attn_mps = Attention (cfg ).to ("mps" )
163
+ attn_mps .load_state_dict (attn_cpu .state_dict (), strict = True )
164
+
165
+ batch = 1
166
+ input_cpu = torch .randn (batch , cfg ["n_ctx" ], cfg ["d_model" ])
167
+ input_mps = input_cpu .to ("mps" )
168
+
169
+ cpu_output = attn_cpu (input_cpu , input_cpu , input_cpu )
170
+ mps_output = attn_mps (input_mps , input_mps , input_mps )
171
+
172
+ assert torch .allclose (cpu_output , mps_output .cpu ())
0 commit comments