Skip to content

Commit e468426

Browse files
authored
[Fix][Relax] Add the missing tree-attn func arg for KV cache creation (#17345)
This PR fixes the TIRPagedKVCache construction issue, which is caused by missing the tree-attention with paged KV cache kernel.
1 parent 995524a commit e468426

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

python/tvm/relax/frontend/nn/llm/kv_cache.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def __init__( # pylint: disable=too-many-locals
375375
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"),
376376
bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"),
377377
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"),
378+
bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"),
378379
rope_ext_factors,
379380
# fmt: on
380381
# pylint: enable=line-too-long

0 commit comments

Comments
 (0)