@@ -168,3 +168,76 @@ def test_positions_array(self):
168168 got = layer (x , positions = positions )
169169
170170 np .testing .assert_allclose (expected , ops .convert_to_numpy (got ))
171+
172+ def test_rope_scaling (self ):
173+ # Reference values computed from Huggingface llama implementation
174+ # With `scaling_factor` = 2.0
175+ # from transformers.models.llama.modeling_llama import (
176+ # LlamaLinearScalingRotaryEmbedding,apply_rotary_pos_emb
177+ # )
178+ # import torch
179+ # torch.set_printoptions(precision=9)
180+ # rotary_emb = LlamaLinearScalingRotaryEmbedding(
181+ # dim=4, max_position_embeddings=3, scaling_factor=2.0
182+ # )
183+ # query = torch.ones((1, 2, 3, 4)) # [bsz, num_heads, seq_len, head_dim]
184+ # cos, sin = rotary_emb(
185+ # query, torch.unsqueeze(torch.arange(3, dtype=torch.int32), 0)
186+ # )
187+ # query, _ = apply_rotary_pos_emb(query, query, cos, sin)
188+ # print(query.transpose(1, 2))
189+ expected = [
190+ [
191+ [
192+ [1.000000000 , 1.000000000 , 1.000000000 , 1.000000000 ],
193+ [1.000000000 , 1.000000000 , 1.000000000 , 1.000000000 ],
194+ ],
195+ [
196+ [0.398157001 , 0.994987488 , 1.357008100 , 1.004987478 ],
197+ [0.398157001 , 0.994987488 , 1.357008100 , 1.004987478 ],
198+ ],
199+ [
200+ [- 0.301168621 , 0.989950180 , 1.381773233 , 1.009949803 ],
201+ [- 0.301168621 , 0.989950180 , 1.381773233 , 1.009949803 ],
202+ ],
203+ ]
204+ ]
205+
206+ layer = RotaryEmbedding (scaling_factor = 2.0 )
207+ self .assertAllClose (
208+ layer (ops .ones ((1 , 3 , 2 , 4 ))),
209+ ops .convert_to_tensor (expected ),
210+ )
211+
212+ def test_rope_scaling_with_kv_cache (self ):
213+ # Reference values computed from Huggingface llama implementation
214+ # With `scaling_factor` = 5.0
215+ # from transformers.models.llama.modeling_llama import (
216+ # LlamaLinearScalingRotaryEmbedding,apply_rotary_pos_emb
217+ # )
218+ # import torch
219+ # torch.set_printoptions(precision=9)
220+ # rotary_emb = LlamaLinearScalingRotaryEmbedding(
221+ # dim=4, max_position_embeddings=3, scaling_factor=5.0
222+ # )
223+
224+ # query = torch.ones((1, 2, 1, 4)) # [bsz, num_heads, seq_len, head_dim]
225+ # cos, sin = rotary_emb(
226+ # query, torch.unsqueeze(torch.arange(12, 13, dtype=torch.int32), 0)
227+ # )
228+ # query, _ = apply_rotary_pos_emb(query, query, cos, sin)
229+ # query.transpose(1, 2)
230+ expected = [
231+ [
232+ [
233+ [- 1.412856817 , 0.975714266 , - 0.061930716 , 1.023709655 ],
234+ [- 1.412856817 , 0.975714266 , - 0.061930716 , 1.023709655 ],
235+ ]
236+ ]
237+ ]
238+
239+ layer = RotaryEmbedding (scaling_factor = 5.0 )
240+ self .assertAllClose (
241+ layer (ops .ones ((1 , 1 , 2 , 4 )), start_index = 12 ),
242+ ops .convert_to_tensor (expected ),
243+ )
0 commit comments