Skip to content

Commit 778ccd7

Browse files
authored
Fix rope scaling factor (#1605)
* Fix rope scaling factor * Fix format * Add tests * Fix format
1 parent 026c6ed commit 778ccd7

File tree

2 files changed

+80
-6
lines changed

2 files changed

+80
-6
lines changed

keras_nlp/src/layers/modeling/rotary_embedding.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ class RotaryEmbedding(keras.layers.Layer):
3535
Args:
3636
max_wavelength: int. The maximum angular wavelength of the sine/cosine
3737
curves.
38-
scaling_factor: float. The scaling factor used to scale frequency range.
38+
scaling_factor: float. The scaling factor used to scale positions of
39+
the tokens.
3940
sequence_axis: int. Sequence axis in the input tensor.
4041
feature_axis: int. Feature axis in the input tensor.
4142
**kwargs: other keyword arguments passed to `keras.layers.Layer`,
@@ -125,6 +126,7 @@ def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None):
125126
else:
126127
positions = ops.cast(positions, "float32")
127128

129+
positions = positions / ops.cast(self.scaling_factor, "float32")
128130
freq = ops.einsum("i,j->ij", positions, inverse_freq)
129131
embedding = ops.stack((freq, freq), axis=-2)
130132
embedding = ops.reshape(
@@ -143,12 +145,11 @@ def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None):
143145
return cos_emb, sin_emb
144146

145147
def _get_inverse_freq(self, rotary_dim):
146-
freq_range = ops.arange(0, rotary_dim, 2, dtype="float32")
147-
freq_range = freq_range / ops.cast(self.scaling_factor, "float32")
148-
inverse_freq = 1.0 / (
149-
self.max_wavelength
150-
** (freq_range / ops.cast(rotary_dim, "float32"))
148+
freq_range = ops.divide(
149+
ops.arange(0, rotary_dim, 2, dtype="float32"),
150+
ops.cast(rotary_dim, "float32"),
151151
)
152+
inverse_freq = 1.0 / (self.max_wavelength**freq_range)
152153
return inverse_freq
153154

154155
def get_config(self):

keras_nlp/src/layers/modeling/rotary_embedding_test.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,76 @@ def test_positions_array(self):
168168
got = layer(x, positions=positions)
169169

170170
np.testing.assert_allclose(expected, ops.convert_to_numpy(got))
171+
172+
def test_rope_scaling(self):
173+
# Reference values computed from Huggingface llama implementation
174+
# With `scaling_factor` = 2.0
175+
# from transformers.models.llama.modeling_llama import (
176+
# LlamaLinearScalingRotaryEmbedding,apply_rotary_pos_emb
177+
# )
178+
# import torch
179+
# torch.set_printoptions(precision=9)
180+
# rotary_emb = LlamaLinearScalingRotaryEmbedding(
181+
# dim=4, max_position_embeddings=3, scaling_factor=2.0
182+
# )
183+
# query = torch.ones((1, 2, 3, 4)) # [bsz, num_heads, seq_len, head_dim]
184+
# cos, sin = rotary_emb(
185+
# query, torch.unsqueeze(torch.arange(3, dtype=torch.int32), 0)
186+
# )
187+
# query, _ = apply_rotary_pos_emb(query, query, cos, sin)
188+
# print(query.transpose(1, 2))
189+
expected = [
190+
[
191+
[
192+
[1.000000000, 1.000000000, 1.000000000, 1.000000000],
193+
[1.000000000, 1.000000000, 1.000000000, 1.000000000],
194+
],
195+
[
196+
[0.398157001, 0.994987488, 1.357008100, 1.004987478],
197+
[0.398157001, 0.994987488, 1.357008100, 1.004987478],
198+
],
199+
[
200+
[-0.301168621, 0.989950180, 1.381773233, 1.009949803],
201+
[-0.301168621, 0.989950180, 1.381773233, 1.009949803],
202+
],
203+
]
204+
]
205+
206+
layer = RotaryEmbedding(scaling_factor=2.0)
207+
self.assertAllClose(
208+
layer(ops.ones((1, 3, 2, 4))),
209+
ops.convert_to_tensor(expected),
210+
)
211+
212+
def test_rope_scaling_with_kv_cache(self):
213+
# Reference values computed from Huggingface llama implementation
214+
# With `scaling_factor` = 5.0
215+
# from transformers.models.llama.modeling_llama import (
216+
# LlamaLinearScalingRotaryEmbedding,apply_rotary_pos_emb
217+
# )
218+
# import torch
219+
# torch.set_printoptions(precision=9)
220+
# rotary_emb = LlamaLinearScalingRotaryEmbedding(
221+
# dim=4, max_position_embeddings=3, scaling_factor=5.0
222+
# )
223+
224+
# query = torch.ones((1, 2, 1, 4)) # [bsz, num_heads, seq_len, head_dim]
225+
# cos, sin = rotary_emb(
226+
# query, torch.unsqueeze(torch.arange(12, 13, dtype=torch.int32), 0)
227+
# )
228+
# query, _ = apply_rotary_pos_emb(query, query, cos, sin)
229+
# query.transpose(1, 2)
230+
expected = [
231+
[
232+
[
233+
[-1.412856817, 0.975714266, -0.061930716, 1.023709655],
234+
[-1.412856817, 0.975714266, -0.061930716, 1.023709655],
235+
]
236+
]
237+
]
238+
239+
layer = RotaryEmbedding(scaling_factor=5.0)
240+
self.assertAllClose(
241+
layer(ops.ones((1, 1, 2, 4)), start_index=12),
242+
ops.convert_to_tensor(expected),
243+
)

0 commit comments

Comments
 (0)