@@ -54,11 +54,11 @@ namespace relax_vm {
5454 * \brief The maximum allowed block depth (a.k.a. number of common
5555 * prefixes) in paged KV cache.
5656 */
57- constexpr const int kPagedKVCacheMaxBlockDepth = 5 ;
57+ constexpr const int kPagedKVCacheMaxBlockDepth = 1 ;
5858/* ! \brief The maximum tree size of a single sequence in tree attention. */
5959constexpr const int kTreeAttnMaxTreeSize = 256 ;
6060/* ! \brief The 8MB workspace size for attention auxiliary data. */
61- constexpr const int kAttnWorkspaceByte = 8 * 1024 * 1024 ;
61+ constexpr const int kAttnWorkspaceByte = 128 * 1024 * 1024 ;
6262/* ! \brief The id of the temporary logical page, which is useful for sliding window. */
6363constexpr const int kPagedKVCacheTempPageId = -1 ;
6464
@@ -119,6 +119,9 @@ struct Block {
119119 void Reset () {
120120 page_ids.clear ();
121121 seq_length = 0 ;
122+ start_pos = 0 ;
123+ sink_length = 0 ;
124+ sliding_window_offset = 0 ;
122125 parent_idx = -1 ;
123126 external_ref_cnt = 0 ;
124127 }
@@ -169,11 +172,9 @@ struct Sequence {
169172 this ->last_block_idx = last_block_idx;
170173 int32_t block_ptr = last_block_idx;
171174 // Go through each block in the sequence, sum up the length.
172- int depth = 0 ;
173175 while (true ) {
174176 const Block& block = global_block_pool->at (block_ptr);
175177 this ->seq_length += block.seq_length ;
176- ++depth;
177178 if (block.parent_idx == -1 ) {
178179 break ;
179180 }
@@ -1078,17 +1079,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
10781079 dtype_aux_, preferred_host_device);
10791080
10801081 for (int d = 0 ; d < kPagedKVCacheMaxBlockDepth ; ++d) {
1081- temp_attn_workspace_.push_back (
1082- NDArray::Empty ({kAttnWorkspaceByte / 4 }, DataType::Float (32 ), device));
1082+ if (NeedKernelBeginForward ()) {
1083+ temp_attn_workspace_.push_back (
1084+ NDArray::Empty ({kAttnWorkspaceByte / 4 }, DataType::Float (32 ), device));
1085+ }
10831086 qo_indptr_on_depths_view_.push_back (NDArray ());
10841087 page_indptr_on_depths_view_.push_back (NDArray ());
10851088 page_indices_on_depths_view_.push_back (NDArray ());
10861089 length_info_on_depths_view_.push_back (NDArray ());
10871090 k_rope_pos_offset_view_.push_back (NDArray ());
10881091 }
10891092 // Additional workspace for the "prefill with ragged kv" kernel.
1090- temp_attn_workspace_.push_back (
1091- NDArray::Empty ({kAttnWorkspaceByte / 4 }, DataType::Float (32 ), device));
1093+ if (NeedKernelBeginForward ()) {
1094+ temp_attn_workspace_.push_back (
1095+ NDArray::Empty ({kAttnWorkspaceByte / 4 }, DataType::Float (32 ), device));
1096+ }
10921097
10931098 temp_attn_q_device_ =
10941099 NDArray::Empty ({prefill_chunk_size_, num_qo_heads, head_dim}, dtype, device);
@@ -1531,6 +1536,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
15311536 }
15321537
15331538 append_before_attn_ = !support_sliding_window_ && num_depths_ == 1 && use_decode_kernel_[0 ];
1539+ if (NeedKernelBeginForward () && num_qo_heads_ / num_kv_heads_ >= 4 ) {
1540+ // When GQA group size is at least 4 and FlashInfer is enabled,
1541+ // we always use prefill kernel for better performance.
1542+ std::fill (use_decode_kernel_.begin (), use_decode_kernel_.end (), /* value=*/ false );
1543+ }
1544+
15341545 if (append_before_attn_) {
15351546 // Right now we use different kernels when depth is 1 or not 1.
15361547 // For the case where maximum depth is 1, we create the auxiliary
@@ -2196,11 +2207,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
21962207 use_decode_kernel};
21972208 }
21982209
2210+ /* ! \brief Check whether BeginForward for kernels is needed. */
2211+ bool NeedKernelBeginForward () {
2212+ return f_attention_prefill_begin_forward_.defined () &&
2213+ f_attention_decode_begin_forward_.defined () &&
2214+ f_attention_prefill_ragged_begin_forward_.defined ();
2215+ }
2216+
21992217 /* ! \brief Invoke the "begin forward" functions of underlying kernels. */
22002218 void KernelBeginForward () {
2201- if (!f_attention_prefill_begin_forward_.defined () ||
2202- !f_attention_decode_begin_forward_.defined () ||
2203- !f_attention_prefill_ragged_begin_forward_.defined ()) {
2219+ if (!NeedKernelBeginForward ()) {
22042220 return ;
22052221 }
22062222
@@ -2214,8 +2230,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
22142230 }
22152231 } else {
22162232 f_attention_prefill_ragged_begin_forward_.value ()(
2217- temp_attn_workspace_[0 ], cur_append_lengths_indptr_host_.as_ndarray (), cur_batch_size_,
2218- num_qo_heads_, num_kv_heads_, head_dim_, copy_stream_);
2233+ temp_attn_workspace_[0 ], cur_append_lengths_indptr_host_.as_ndarray (),
2234+ cur_append_lengths_indptr_host_.as_ndarray (), cur_batch_size_, num_qo_heads_,
2235+ num_kv_heads_, head_dim_, copy_stream_);
22192236 if (support_sliding_window_) {
22202237 return ;
22212238 }
@@ -2232,8 +2249,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
22322249 } else {
22332250 f_attention_prefill_begin_forward_.value ()(
22342251 /* depth=*/ d, temp_attn_workspace_[d + 1 ], qo_indptr_on_depths_host_[d].as_ndarray (),
2235- length_info_on_depths_view_[d]->shape [0 ], num_qo_heads_, num_kv_heads_, head_dim_,
2236- copy_stream_);
2252+ page_indptr_on_depths_host_[d].as_ndarray (),
2253+ static_cast <int >(page_indptr_on_depths_host_[d].size ()) - 1 , num_qo_heads_,
2254+ num_kv_heads_, head_dim_, page_size_, copy_stream_);
22372255 }
22382256 }
22392257 }
0 commit comments