@@ -169,6 +169,7 @@ def __init__( # pylint: disable=too-many-locals
169169 rope_scaling : Dict [str , Any ],
170170 rope_ext_factors : rx .Expr ,
171171 rotary_dim : int ,
172+ enable_disaggregation : bool ,
172173 dtype : str ,
173174 target : Target ,
174175 name : str = "paged_kv_cache" ,
@@ -214,6 +215,8 @@ def __init__( # pylint: disable=too-many-locals
214215 The RoPE extension factors when "longrope" mode RoPE scaling is enabled.
215216 rotary_dim : int
216217 The number of dimensions in the embedding that RoPE is applied to.
218+ enable_disaggregation : bool
219+ Whether to enable disaggregation in the KV cache.
217220 """
218221 if rope_mode == RopeMode .INLINE :
219222 assert rotary_dim == head_dim , "FlashInfer RoPE does not support partial rotary dim."
@@ -259,6 +262,7 @@ def __init__( # pylint: disable=too-many-locals
259262 bb .add_func (tree_attn (num_key_value_heads , num_attention_heads , head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask" ),
260263 bb .add_func (tree_attn_with_paged_kv_cache (num_key_value_heads , num_attention_heads , head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache" ),
261264 rope_ext_factors ,
265+ rx .PrimValue (enable_disaggregation ),
262266 # fmt: on
263267 # pylint: enable=line-too-long
264268 ]
@@ -293,6 +297,7 @@ def __init__( # pylint: disable=too-many-locals
293297 rope_scaling : Dict [str , Any ],
294298 rope_ext_factors : rx .Expr ,
295299 rotary_dim : int ,
300+ enable_disaggregation : bool ,
296301 dtype : str ,
297302 target : Target ,
298303 name : str = "paged_kv_cache" ,
@@ -338,6 +343,8 @@ def __init__( # pylint: disable=too-many-locals
338343 The RoPE extension factors when "longrope" mode RoPE scaling is enabled.
339344 rotary_dim : int
340345 The number of dimensions in the embedding that RoPE is applied to.
346+ enable_disaggregation : bool
347+ Whether to enable disaggregation in the KV cache.
341348 target : Target
342349 The target to build the model to.
343350 """
@@ -377,6 +384,7 @@ def __init__( # pylint: disable=too-many-locals
377384 bb .add_func (tree_attn (num_key_value_heads , num_attention_heads , head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask" ),
378385 bb .add_func (tree_attn_with_paged_kv_cache (num_key_value_heads , num_attention_heads , head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache" ),
379386 rope_ext_factors ,
387+ rx .PrimValue (enable_disaggregation ),
380388 # fmt: on
381389 # pylint: enable=line-too-long
382390 ]
@@ -409,8 +417,9 @@ def tir_kv_cache_transpose_append(
409417 T .func_attr ({"tir.noalias" : T .bool (True )})
410418 ntoken = T .SizeVar ("num_tokens_excluding_cache" , "int64" )
411419 num_pages = T .int64 ()
420+ pages_elem_offset = T .int64 ()
412421 position_map_elem_offset = T .int32 ()
413- pages = T .match_buffer (var_pages , (num_pages , 2 , num_key_value_heads , 16 , head_dim ), dtype )
422+ pages = T .match_buffer (var_pages , (num_pages , 2 , num_key_value_heads , 16 , head_dim ), dtype , elem_offset = pages_elem_offset )
414423 k_data = T .match_buffer (var_k_data , (ntoken , num_key_value_heads , head_dim ), dtype )
415424 v_data = T .match_buffer (var_v_data , (ntoken , num_key_value_heads , head_dim ), dtype )
416425 position_map = T .match_buffer (
@@ -453,8 +462,9 @@ def tir_kv_cache_debug_get_kv(
453462 seqlen = T .SizeVar ("num_tokens_including_cache" , "int64" )
454463 page_size = T .SizeVar ("page_size" , "int64" )
455464 num_pages = T .int64 ()
465+ pages_elem_offset = T .int64 ()
456466 position_map_elem_offset = T .int64 ()
457- pages = T .match_buffer (var_pages , (num_pages , 2 , num_key_value_heads , page_size , head_dim ), dtype )
467+ pages = T .match_buffer (var_pages , (num_pages , 2 , num_key_value_heads , page_size , head_dim ), dtype , elem_offset = pages_elem_offset )
458468 position_map = T .match_buffer (
459469 var_position_map , (seqlen ,), "int32" , elem_offset = position_map_elem_offset
460470 )
@@ -594,6 +604,7 @@ def batch_prefill_paged_kv(
594604 total_len = T .int32 (is_size_var = True )
595605 nnz_pages = T .int32 (is_size_var = True )
596606 max_num_pages = T .int32 (is_size_var = True )
607+ pages_elem_offset = T .int64 (is_size_var = True )
597608 q_indptr_elem_offset = T .int32 (is_size_var = True )
598609 page_indptr_elem_offset = T .int32 (is_size_var = True )
599610 page_values_elem_offset = T .int32 (is_size_var = True )
@@ -603,7 +614,7 @@ def batch_prefill_paged_kv(
603614
604615 q = T .match_buffer (var_q , (total_len , h_q , d ), dtype )
605616 q_indptr = T .match_buffer (var_q_indptr , (batch_size + 1 ,), "int32" , elem_offset = q_indptr_elem_offset )
606- pages = T .match_buffer (var_pages , (max_num_pages , 2 , h_kv , 16 , d ), dtype )
617+ pages = T .match_buffer (var_pages , (max_num_pages , 2 , h_kv , 16 , d ), dtype , elem_offset = pages_elem_offset )
607618 page_indptr = T .match_buffer (var_page_indptr , (batch_size + 1 ,), "int32" , elem_offset = page_indptr_elem_offset )
608619 page_values = T .match_buffer (var_page_values , (nnz_pages ,), "int32" , elem_offset = page_values_elem_offset )
609620 k_rope_pos_offset = T .match_buffer (var_k_rope_pos_offset , (batch_size ,), "int32" , elem_offset = k_rope_pos_offset_elem_offset )
@@ -975,6 +986,7 @@ def batch_decode_paged_kv(
975986 B = T .int32 (is_size_var = True )
976987 nnz_pages = T .int32 (is_size_var = True )
977988 max_num_pages = T .int32 (is_size_var = True )
989+ pages_elem_offset = T .int64 (is_size_var = True )
978990 page_indptr_elem_offset = T .int32 (is_size_var = True )
979991 page_values_elem_offset = T .int32 (is_size_var = True )
980992 k_rope_pos_offset_elem_offset = T .int32 (is_size_var = True )
@@ -983,7 +995,7 @@ def batch_decode_paged_kv(
983995
984996 Q = T .match_buffer (Q_handle , (B , H_qo , D ), qkv_dtype )
985997 pages = T .match_buffer (
986- pages_handle , (max_num_pages , 2 , H_kv , 16 , D ), qkv_dtype
998+ pages_handle , (max_num_pages , 2 , H_kv , 16 , D ), qkv_dtype , elem_offset = pages_elem_offset
987999 )
9881000 page_table_indptr = T .match_buffer (page_table_indptr_handle , (B + 1 ,), "int32" , elem_offset = page_indptr_elem_offset )
9891001 page_table_values = T .match_buffer (page_table_values_handle , (nnz_pages ,), "int32" , elem_offset = page_values_elem_offset )
@@ -1949,7 +1961,13 @@ def copy_single_page(
19491961 ):
19501962 T .func_attr ({"tir.is_scheduled" : 1 })
19511963 num_pages = T .int32 ()
1952- pages = T .match_buffer (var_pages , (num_pages , 2 , num_heads , page_size , head_dim ), dtype )
1964+ pages_elem_offset = T .int64 ()
1965+ pages = T .match_buffer (
1966+ var_pages ,
1967+ (num_pages , 2 , num_heads , page_size , head_dim ),
1968+ dtype ,
1969+ elem_offset = pages_elem_offset ,
1970+ )
19531971
19541972 for b in T .thread_binding (
19551973 (copy_length * num_heads * head_dim + tx - 1 ) // tx , thread = "blockIdx.x"
@@ -1993,7 +2011,10 @@ def compact_kv_copy(
19932011 total_copy_length = T .int32 ()
19942012 copy_length_indptr_elem_offset = T .int32 ()
19952013 copy_src_dst_pos_elem_offset = T .int32 ()
1996- pages = T .match_buffer (var_pages , (num_pages , 2 , num_heads , 16 , head_dim ), dtype )
2014+ pages_elem_offset = T .int64 ()
2015+ pages = T .match_buffer (
2016+ var_pages , (num_pages , 2 , num_heads , 16 , head_dim ), dtype , elem_offset = pages_elem_offset
2017+ )
19972018 copy_length_indptr = T .match_buffer (
19982019 var_copy_length_indptr ,
19992020 (batch_size + 1 ,),
0 commit comments