@@ -277,15 +277,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
277277 PackedFunc f_transpose_append_;
278278 PackedFunc f_attention_prefill_;
279279 PackedFunc f_attention_decode_;
280- PackedFunc f_attention_prefill_ragged_;
281- PackedFunc f_attention_prefill_ragged_begin_forward_;
282- PackedFunc f_attention_prefill_ragged_end_forward_;
283- PackedFunc f_attention_prefill_begin_forward_;
284- PackedFunc f_attention_prefill_end_forward_;
285- PackedFunc f_attention_decode_begin_forward_;
286- PackedFunc f_attention_decode_end_forward_;
280+ Optional< PackedFunc> f_attention_prefill_ragged_;
281+ Optional< PackedFunc> f_attention_prefill_ragged_begin_forward_;
282+ Optional< PackedFunc> f_attention_prefill_ragged_end_forward_;
283+ Optional< PackedFunc> f_attention_prefill_begin_forward_;
284+ Optional< PackedFunc> f_attention_prefill_end_forward_;
285+ Optional< PackedFunc> f_attention_decode_begin_forward_;
286+ Optional< PackedFunc> f_attention_decode_end_forward_;
287287 PackedFunc f_rotary_;
288- PackedFunc f_merge_inplace_;
288+ Optional< PackedFunc> f_merge_inplace_;
289289 Optional<PackedFunc> f_debug_get_kv_;
290290
291291 /* ! \brief Number of fork depth in the current round of forward. */
@@ -297,19 +297,23 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
297297
298298 public:
299299 /* ! \brief Constructor. Take the cache configuration and initialize the NDArrays. */
300- explicit PagedAttentionKVCacheObj (
301- int64_t page_size, //
302- int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads,
303- int64_t head_dim, //
304- int64_t reserved_num_seqs, int64_t num_total_pages, //
305- double rotary_scale, double rotary_theta, //
306- DLDataType dtype, DLDevice device, PackedFunc f_transpose_append,
307- PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
308- PackedFunc f_attention_prefill_ragged, PackedFunc f_attention_prefill_ragged_begin_forward,
309- PackedFunc f_attention_prefill_ragged_end_forward,
310- PackedFunc f_attention_prefill_begin_forward, PackedFunc f_attention_prefill_end_forward,
311- PackedFunc f_attention_decode_begin_forward, PackedFunc f_attention_decode_end_forward,
312- PackedFunc f_rotary, PackedFunc f_merge_inplace, Optional<PackedFunc> f_debug_get_kv)
300+ explicit PagedAttentionKVCacheObj (int64_t page_size, //
301+ int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads,
302+ int64_t head_dim, //
303+ int64_t reserved_num_seqs, int64_t num_total_pages, //
304+ double rotary_scale, double rotary_theta, //
305+ DLDataType dtype, DLDevice device,
306+ PackedFunc f_transpose_append, PackedFunc f_attention_prefill,
307+ PackedFunc f_attention_decode,
308+ Optional<PackedFunc> f_attention_prefill_ragged,
309+ Optional<PackedFunc> f_attention_prefill_ragged_begin_forward,
310+ Optional<PackedFunc> f_attention_prefill_ragged_end_forward,
311+ Optional<PackedFunc> f_attention_prefill_begin_forward,
312+ Optional<PackedFunc> f_attention_prefill_end_forward,
313+ Optional<PackedFunc> f_attention_decode_begin_forward,
314+ Optional<PackedFunc> f_attention_decode_end_forward,
315+ PackedFunc f_rotary, Optional<PackedFunc> f_merge_inplace,
316+ Optional<PackedFunc> f_debug_get_kv)
313317 : page_size_(page_size),
314318 num_layers_(num_layers),
315319 num_qo_heads_(num_qo_heads),
@@ -418,6 +422,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
418422 << " The parent sequence \" " << parent_seq_id << " \" cannot be found in KV cache." ;
419423 CHECK (seq_map_.find (child_seq_id) == seq_map_.end ())
420424 << " The child sequence \" " << child_seq_id << " \" is already in the KV cache." ;
425+ CHECK (f_merge_inplace_.defined () && f_attention_prefill_ragged_.defined ())
426+ << " Attention merge-score function not available. ForkSequence is thereby not supported." ;
421427
422428 int32_t parent_block_idx = parent_it->second .last_block_idx ;
423429 // Create a child block with the parent block pointer.
@@ -558,13 +564,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
558564 }
559565
560566 void EndForward () final {
567+ if (!f_attention_prefill_end_forward_.defined () || !f_attention_decode_end_forward_.defined () ||
568+ !f_attention_prefill_ragged_end_forward_.defined ()) {
569+ return ;
570+ }
561571 // Mark the dirty flag as true, so that BeginForward is required
562572 // to be invoked before the next round of model forward.
563573 dirty_aux_data_device_ = true ;
564- f_attention_prefill_ragged_end_forward_ ();
574+ f_attention_prefill_ragged_end_forward_. value () ();
565575 for (int d = 0 ; d < num_depths_; ++d) {
566- f_attention_prefill_end_forward_ (d);
567- f_attention_decode_end_forward_ (d);
576+ f_attention_prefill_end_forward_. value () (d);
577+ f_attention_decode_end_forward_. value () (d);
568578 }
569579 }
570580
@@ -845,30 +855,36 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
845855
846856 /* ! \brief Invoke the "begin forward" functions of underlying kernels. */
847857 void KernelBeginForward () {
858+ if (!f_attention_prefill_begin_forward_.defined () ||
859+ !f_attention_decode_begin_forward_.defined () ||
860+ !f_attention_prefill_ragged_begin_forward_.defined ()) {
861+ return ;
862+ }
863+
848864 if (num_depths_ == 1 ) {
849865 if (use_decode_kernel_[0 ]) {
850- f_attention_decode_begin_forward_ (
866+ f_attention_decode_begin_forward_. value () (
851867 /* depth=*/ 0 , page_indptr_on_depths_view_[0 ], last_page_len_on_depths_view_[0 ],
852868 num_qo_heads_, num_kv_heads_, head_dim_, page_size_, /* rotary_mode=*/ true );
853869 } else {
854- f_attention_prefill_begin_forward_ (/* depth=*/ 0 , qo_indptr_on_depths_view_[0 ],
855- cur_batch_size_, num_qo_heads_, num_kv_heads_);
870+ f_attention_prefill_begin_forward_. value () (/* depth=*/ 0 , qo_indptr_on_depths_view_[0 ],
871+ cur_batch_size_, num_qo_heads_, num_kv_heads_);
856872 }
857873 } else {
858- f_attention_prefill_ragged_begin_forward_ (cur_append_length_indptr_view_, cur_batch_size_,
859- num_qo_heads_, num_kv_heads_);
874+ f_attention_prefill_ragged_begin_forward_. value ()(
875+ cur_append_length_indptr_view_, cur_batch_size_, num_qo_heads_, num_kv_heads_);
860876 for (int d = 0 ; d < num_depths_; ++d) {
861877 if (page_indices_on_depths_view_[d]->shape [0 ] == 0 ) {
862878 continue ;
863879 }
864880 if (use_decode_kernel_[d]) {
865- f_attention_decode_begin_forward_ (
881+ f_attention_decode_begin_forward_. value () (
866882 d, page_indptr_on_depths_view_[d], last_page_len_on_depths_view_[d], num_qo_heads_,
867883 num_kv_heads_, head_dim_, page_size_, /* rotary_mode=*/ false );
868884 } else {
869- f_attention_prefill_begin_forward_ (/* depth=*/ d, qo_indptr_on_depths_view_[d],
870- last_page_len_on_depths_view_[d]->shape [0 ],
871- num_qo_heads_, num_kv_heads_);
885+ f_attention_prefill_begin_forward_. value () (/* depth=*/ d, qo_indptr_on_depths_view_[d],
886+ last_page_len_on_depths_view_[d]->shape [0 ],
887+ num_qo_heads_, num_kv_heads_);
872888 }
873889 }
874890 }
@@ -896,10 +912,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
896912 }
897913 } else {
898914 // Compute appended text self-attention
899- f_attention_prefill_ragged_ (q_data, cur_append_length_indptr_view_, k_data, v_data,
900- cur_append_length_indptr_view_, output, merged_attn_scores_view_,
901- /* causal=*/ 1 ,
902- /* rotary_mode=*/ 0 , rotary_scale_, rotary_theta_);
915+ f_attention_prefill_ragged_.value ()(q_data, cur_append_length_indptr_view_, k_data, v_data,
916+ cur_append_length_indptr_view_, output,
917+ merged_attn_scores_view_,
918+ /* causal=*/ 1 ,
919+ /* rotary_mode=*/ 0 , rotary_scale_, rotary_theta_);
903920
904921 for (int d = 0 ; d < num_depths_; ++d) {
905922 if (page_indices_on_depths_view_[d]->shape [0 ] == 0 ) {
@@ -920,8 +937,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
920937 /* causal=*/ 0 ,
921938 /* rotary_mode=*/ 0 , rotary_scale_, rotary_theta_);
922939 }
923- f_merge_inplace_ (output, merged_attn_scores_view_, temp_attn_output_view_,
924- temp_attn_scores_view_);
940+ f_merge_inplace_. value () (output, merged_attn_scores_view_, temp_attn_output_view_,
941+ temp_attn_scores_view_);
925942 }
926943 }
927944 }
@@ -1068,6 +1085,27 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
10681085 return PagedAttentionKVCache (std::move (n));
10691086 });
10701087
1088+ TVM_REGISTER_GLOBAL (" vm.builtin.paged_attention_kv_cache_create_reduced" )
1089+ .set_body_typed([](ShapeTuple cache_config, int64_t num_layers, int64_t num_qo_heads,
1090+ int64_t num_kv_heads, int64_t head_dim, double rotary_scale,
1091+ double rotary_theta, NDArray init, PackedFunc f_transpose_append,
1092+ PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
1093+ PackedFunc f_rotary, Optional<PackedFunc> f_debug_get_kv) {
1094+ CHECK_EQ (cache_config.size (), 3 );
1095+ int64_t reserved_num_seqs = cache_config[0 ];
1096+ int64_t total_token_capacity = cache_config[1 ];
1097+ int64_t page_size = cache_config[2 ];
1098+ int64_t num_total_pages = (total_token_capacity + page_size - 1 ) / page_size;
1099+ ObjectPtr<PagedAttentionKVCacheObj> n = make_object<PagedAttentionKVCacheObj>(
1100+ page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs,
1101+ num_total_pages, rotary_scale, rotary_theta, init->dtype , init->device ,
1102+ std::move (f_transpose_append), std::move (f_attention_prefill),
1103+ std::move (f_attention_decode), //
1104+ NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, //
1105+ std::move (f_rotary), NullOpt, std::move (f_debug_get_kv));
1106+ return PagedAttentionKVCache (std::move (n));
1107+ });
1108+
10711109TVM_REGISTER_GLOBAL (" vm.builtin.paged_attention_kv_cache_clear" )
10721110 .set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::Clear);
10731111TVM_REGISTER_GLOBAL (" vm.builtin.paged_attention_kv_cache_add_sequence" )
0 commit comments