Skip to content

Commit 619fefd

Browse files
committed
[KVCache] Add tree attention with paged cache support
1 parent 6ca0bea commit 619fefd

File tree

4 files changed

+829
-172
lines changed

4 files changed

+829
-172
lines changed

python/tvm/relax/frontend/nn/llm/kv_cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from tvm.target import Target
3131

3232
from .position_embedding import llama_rope_with_position_map, switch_rope_freq_func
33-
from .tree_attn import tree_attn
33+
from .tree_attn import tree_attn, tree_attn_with_paged_kv_cache
3434

3535

3636
def get_max_num_threads_per_block(target: Target) -> int:
@@ -257,6 +257,7 @@ def __init__( # pylint: disable=too-many-locals
257257
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"),
258258
bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"),
259259
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"),
260+
bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"),
260261
rope_ext_factors,
261262
# fmt: on
262263
# pylint: enable=line-too-long

0 commit comments

Comments
 (0)