Skip to content

Commit 880733a

Browse files
committed
[KVCache] Support mode "None" for Rotary Embebdding
This PR supports a "None" Rotary Embedding mode in PagedKVCache. When the mode is None, the rotary embedding will not be applied to when computing attention.
1 parent daa37e7 commit 880733a

File tree

3 files changed

+69
-25
lines changed

3 files changed

+69
-25
lines changed

src/runtime/relax_vm/paged_kv_cache.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,14 @@ struct Sequence {
153153
/*!
154154
* \brief The rotary embedding mode adopted by the paged KV cache
155155
* when computing attention.
156+
* "None" means RoPE is never applied to q and k.
156157
* "Normal" means RoPE is computed in a standalone kernel.
157158
* "Inline" means RoPE is computed on-the-fly in attention kernels.
158159
*/
159160
enum class RoPEMode : int {
160-
kNormal = 0,
161-
kInline = 1,
161+
kNone = 0,
162+
kNormal = 1,
163+
kInline = 2,
162164
};
163165

164166
/*!

tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import enum
1718
from typing import Dict, List, Tuple, Union
1819

1920
import numpy as np
@@ -322,7 +323,19 @@ def create_kv_cache(rope_mode):
322323
return cache
323324

324325

325-
@pytest.fixture(params=[0, 1])
326+
class RopeMode(enum.IntEnum):
327+
"""The RoPE mode of the Paged KV cache.
328+
If it is none, the KV cache will not apply RoPE to q and k.
329+
If it is normal, RoPE will be applied to k before adding k to cache.
330+
Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly.
331+
"""
332+
333+
NONE = 0
334+
NORMAL = 1
335+
INLINE = 2
336+
337+
338+
@pytest.fixture(params=[RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE])
326339
def kv_cache_and_rope_mode(request):
327340
set_global_func()
328341
return create_kv_cache(request.param), request.param
@@ -361,7 +374,7 @@ def f_apply_rotary(x, offset, scale, theta):
361374

362375
def apply_attention(
363376
kv_cache,
364-
rope_mode: int,
377+
rope_mode: RopeMode,
365378
batch: List[Tuple[Union[int, Tuple[int, int]], int]],
366379
cached_k: Dict[int, np.ndarray],
367380
cached_v: Dict[int, np.ndarray],
@@ -406,10 +419,12 @@ def apply_attention(
406419
cached_k[seq_id],
407420
np.stack(
408421
[
409-
new_k[l]
410-
if rope_mode == 1
411-
else f_apply_rotary(
412-
new_k[l], cached_k[seq_id].shape[1], rope_scale, rope_theta
422+
(
423+
new_k[l]
424+
if rope_mode != RopeMode.NORMAL
425+
else f_apply_rotary(
426+
new_k[l], cached_k[seq_id].shape[1], rope_scale, rope_theta
427+
)
413428
)
414429
for l in range(num_layers)
415430
],
@@ -445,15 +460,19 @@ def apply_attention(
445460
assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >= append_length
446461

447462
rope_offset = cached_k[seq_id].shape[1] - append_length
448-
q_seq = f_apply_rotary(
449-
q_array[i][layer_id],
450-
rope_offset,
451-
rope_scale,
452-
rope_theta,
463+
q_seq = (
464+
q_array[i][layer_id]
465+
if rope_mode == RopeMode.NONE
466+
else f_apply_rotary(
467+
q_array[i][layer_id],
468+
rope_offset,
469+
rope_scale,
470+
rope_theta,
471+
)
453472
).transpose(1, 0, 2)
454473
k_seq = (
455474
cached_k[seq_id][layer_id]
456-
if rope_mode == 0
475+
if rope_mode != RopeMode.INLINE
457476
else f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale, rope_theta)
458477
).transpose(1, 2, 0)
459478
v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2)
@@ -586,7 +605,7 @@ def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode, fuse_qkv):
586605

587606
if __name__ == "__main__":
588607
set_global_func()
589-
for rope_mode in [0, 1]:
608+
for rope_mode in [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]:
590609
cache = create_kv_cache(rope_mode)
591610
for fuse_qkv in [False, True]:
592611
test_paged_attention_kv_cache_prefill_and_decode((cache, rope_mode), fuse_qkv)

tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import enum
1718
import itertools
1819
import math
1920
from typing import Dict, List, Tuple, Union
@@ -140,7 +141,25 @@ def create_kv_cache(head_dim, dtype, rope_mode):
140141
return cache
141142

142143

143-
@pytest.fixture(params=itertools.product([64, 128], ["float16", "float32"], [0, 1]))
144+
class RopeMode(enum.IntEnum):
145+
"""The RoPE mode of the Paged KV cache.
146+
If it is none, the KV cache will not apply RoPE to q and k.
147+
If it is normal, RoPE will be applied to k before adding k to cache.
148+
Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly.
149+
"""
150+
151+
NONE = 0
152+
NORMAL = 1
153+
INLINE = 2
154+
155+
156+
@pytest.fixture(
157+
params=itertools.product(
158+
[64, 128],
159+
["float16", "float32"],
160+
[RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE],
161+
)
162+
)
144163
def kv_cache_and_rope_mode(request):
145164
global head_dim, dtype
146165
head_dim, dtype, rope_mode = request.param
@@ -181,7 +200,7 @@ def f_apply_rotary(x, offset, scale, theta):
181200

182201
def apply_attention(
183202
kv_cache,
184-
rope_mode: int,
203+
rope_mode: RopeMode,
185204
batch: List[Tuple[Union[int, Tuple[int, int]], int]],
186205
cached_k: Dict[int, np.ndarray],
187206
cached_v: Dict[int, np.ndarray],
@@ -228,7 +247,7 @@ def apply_attention(
228247
[
229248
(
230249
new_k[l]
231-
if rope_mode == 1
250+
if rope_mode != RopeMode.NORMAL
232251
else f_apply_rotary(
233252
new_k[l], cached_k[seq_id].shape[1], rope_scale, rope_theta
234253
)
@@ -267,15 +286,19 @@ def apply_attention(
267286
assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >= append_length
268287

269288
rope_offset = cached_k[seq_id].shape[1] - append_length
270-
q_seq = f_apply_rotary(
271-
q_array[i][layer_id],
272-
rope_offset,
273-
rope_scale,
274-
rope_theta,
289+
q_seq = (
290+
q_array[i][layer_id]
291+
if rope_mode == RopeMode.NONE
292+
else f_apply_rotary(
293+
q_array[i][layer_id],
294+
rope_offset,
295+
rope_scale,
296+
rope_theta,
297+
)
275298
).transpose(1, 0, 2)
276299
k_seq = (
277300
cached_k[seq_id][layer_id]
278-
if rope_mode == 0
301+
if rope_mode != RopeMode.INLINE
279302
else f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale, rope_theta)
280303
).transpose(1, 2, 0)
281304
v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2)
@@ -1639,7 +1662,7 @@ def merge_state_inplace(
16391662
if __name__ == "__main__":
16401663
for head_dim in [64, 128]:
16411664
for dtype in ["float16", "float32"]:
1642-
for rope_mode in [0, 1]:
1665+
for rope_mode in [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]:
16431666
set_global_func(head_dim, dtype)
16441667
cache = create_kv_cache(head_dim, dtype, rope_mode)
16451668
for fuse_qkv in [False, True]:

0 commit comments

Comments
 (0)