@@ -1535,7 +1535,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
15351535 CHECK_EQ (chunked_block_ids_arr[num_depths_ - 1 ].size (), cur_batch_size_);
15361536 }
15371537
1538- append_before_attn_ = !support_sliding_window_ && num_depths_ == 1 && use_decode_kernel_[ 0 ] ;
1538+ append_before_attn_ = !support_sliding_window_ && use_decode_kernel_. back () ;
15391539 if (NeedKernelBeginForward () && num_qo_heads_ / num_kv_heads_ >= 4 ) {
15401540 // When GQA group size is at least 4 and FlashInfer is enabled,
15411541 // we always use prefill kernel for better performance.
@@ -2220,39 +2220,33 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
22202220 return ;
22212221 }
22222222
2223- if (append_before_attn_) {
2224- if (!support_sliding_window_) {
2223+ if (!append_before_attn_) {
2224+ if (is_chain_) {
2225+ f_attention_prefill_ragged_begin_forward_.value ()(
2226+ temp_attn_workspace_[0 ], cur_append_lengths_indptr_host_.as_ndarray (),
2227+ cur_append_lengths_indptr_host_.as_ndarray (), cur_batch_size_, num_qo_heads_,
2228+ num_kv_heads_, head_dim_, copy_stream_);
2229+ } else {
2230+ LOG (FATAL) << " Kernel BeginForward doesn't support tree attn." ;
2231+ }
2232+ }
2233+ for (int d = 0 ; d < num_depths_; ++d) {
2234+ if (page_indices_on_depths_view_[d]->shape [0 ] == 0 ) {
2235+ continue ;
2236+ }
2237+ CHECK (!support_sliding_window_) << " Kernel BeginForward doesn't support sliding window." ;
2238+ if (use_decode_kernel_[d]) {
22252239 f_attention_decode_begin_forward_.value ()(
2226- /* depth= */ 0 , temp_attn_workspace_[1 ], page_indptr_on_depths_host_[0 ].as_ndarray (),
2227- last_page_len_on_depths_host_[0 ].as_ndarray (), num_qo_heads_, num_kv_heads_, head_dim_,
2240+ d , temp_attn_workspace_[d + 1 ], page_indptr_on_depths_host_[d ].as_ndarray (),
2241+ last_page_len_on_depths_host_[d ].as_ndarray (), num_qo_heads_, num_kv_heads_, head_dim_,
22282242 page_size_,
22292243 /* rotary_mode=*/ rope_mode_ == RoPEMode::kInline , copy_stream_);
2230- }
2231- } else {
2232- f_attention_prefill_ragged_begin_forward_.value ()(
2233- temp_attn_workspace_[0 ], cur_append_lengths_indptr_host_.as_ndarray (),
2234- cur_append_lengths_indptr_host_.as_ndarray (), cur_batch_size_, num_qo_heads_,
2235- num_kv_heads_, head_dim_, copy_stream_);
2236- if (support_sliding_window_) {
2237- return ;
2238- }
2239- for (int d = 0 ; d < num_depths_; ++d) {
2240- if (page_indices_on_depths_view_[d]->shape [0 ] == 0 ) {
2241- continue ;
2242- }
2243- if (use_decode_kernel_[d]) {
2244- f_attention_decode_begin_forward_.value ()(
2245- d, temp_attn_workspace_[d + 1 ], page_indptr_on_depths_host_[d].as_ndarray (),
2246- last_page_len_on_depths_host_[d].as_ndarray (), num_qo_heads_, num_kv_heads_,
2247- head_dim_, page_size_,
2248- /* rotary_mode=*/ rope_mode_ == RoPEMode::kInline , copy_stream_);
2249- } else {
2250- f_attention_prefill_begin_forward_.value ()(
2251- /* depth=*/ d, temp_attn_workspace_[d + 1 ], qo_indptr_on_depths_host_[d].as_ndarray (),
2252- page_indptr_on_depths_host_[d].as_ndarray (),
2253- static_cast <int >(page_indptr_on_depths_host_[d].size ()) - 1 , num_qo_heads_,
2254- num_kv_heads_, head_dim_, page_size_, copy_stream_);
2255- }
2244+ } else {
2245+ f_attention_prefill_begin_forward_.value ()(
2246+ /* depth=*/ d, temp_attn_workspace_[d + 1 ], qo_indptr_on_depths_host_[d].as_ndarray (),
2247+ page_indptr_on_depths_host_[d].as_ndarray (),
2248+ static_cast <int >(page_indptr_on_depths_host_[d].size ()) - 1 , num_qo_heads_,
2249+ num_kv_heads_, head_dim_, page_size_, copy_stream_);
22562250 }
22572251 }
22582252 }
@@ -2271,15 +2265,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
22712265 PackedFunc f_decode =
22722266 !support_sliding_window_ ? f_attention_decode_ : f_attention_decode_sliding_window_;
22732267 CHECK_GE (num_depths_, 1 ) << " The number of effective depths must be greater or equal to 1." ;
2274- if (append_before_attn_) {
2275- f_decode (
2276- /* depth=*/ 0 , q_data, pages_[local_layer_id], page_indptr_on_depths_view_[0 ],
2277- page_indices_on_depths_view_[0 ], length_info_on_depths_view_[0 ],
2278- k_rope_pos_offset_view_[0 ], q_rope_position_map_view_, output, merged_attn_scores_view_,
2279- /* rotary_mode=*/ rope_mode_ == RoPEMode::kInline , rotary_scale_, rotary_theta_,
2280- attn_score_scaling_factor);
2281- } else {
2282- // Compute appended text self-attention
2268+
2269+ bool is_first_kernel = true ;
2270+ if (!append_before_attn_) {
2271+ // The first part of attention, which only involves the q and the newly appended k/v.
2272+ is_first_kernel = false ;
22832273 if (is_chain_) {
22842274 // If the batch does not form a tree, use raggedness prefill kernel.
22852275 f_attention_prefill_ragged_ (q_data, cur_append_length_indptr_view_, k_data, v_data,
@@ -2301,32 +2291,43 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
23012291 merged_attn_scores_view_, /* rotary_mode=*/ rope_mode_ == RoPEMode::kInline ,
23022292 rotary_scale_, rotary_theta_, attn_score_scaling_factor, cur_batch_size_);
23032293 }
2294+ }
23042295
2305- for (int d = 0 ; d < num_depths_; ++d) {
2306- if (page_indices_on_depths_view_[d]->shape [0 ] == 0 ) {
2307- continue ;
2308- }
2309- if (use_decode_kernel_[d]) {
2310- // Use decode kernel for depth d
2311- f_decode (/* depth=*/ d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d],
2312- page_indices_on_depths_view_[d], length_info_on_depths_view_[d],
2313- k_rope_pos_offset_view_[d], q_rope_position_map_view_, temp_attn_output_view_,
2314- temp_attn_scores_view_,
2315- /* rotary_mode=*/ rope_mode_ == RoPEMode::kInline , rotary_scale_, rotary_theta_,
2316- attn_score_scaling_factor);
2317- } else {
2318- // Use prefill kernel for depth d
2319- f_prefill (
2320- /* depth=*/ d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id],
2321- page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d],
2322- length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], q_rope_position_map_view_,
2323- temp_attn_output_view_, temp_attn_scores_view_,
2324- /* causal=*/ 0 ,
2325- /* rotary_mode=*/ rope_mode_ == RoPEMode::kInline , rotary_scale_, rotary_theta_,
2326- attn_score_scaling_factor);
2327- }
2296+ for (int d = 0 ; d < num_depths_; ++d) {
2297+ if (page_indices_on_depths_view_[d]->shape [0 ] == 0 ) {
2298+ continue ;
2299+ }
2300+ NDArray attn_output;
2301+ NDArray attn_scores;
2302+ if (is_first_kernel) {
2303+ attn_output = output;
2304+ attn_scores = merged_attn_scores_view_;
2305+ } else {
2306+ attn_output = temp_attn_output_view_;
2307+ attn_scores = temp_attn_scores_view_;
2308+ }
2309+ if (use_decode_kernel_[d]) {
2310+ // Use decode kernel for depth d
2311+ f_decode (/* depth=*/ d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d],
2312+ page_indices_on_depths_view_[d], length_info_on_depths_view_[d],
2313+ k_rope_pos_offset_view_[d], q_rope_position_map_view_, attn_output, attn_scores,
2314+ /* rotary_mode=*/ rope_mode_ == RoPEMode::kInline , rotary_scale_, rotary_theta_,
2315+ attn_score_scaling_factor);
2316+ } else {
2317+ // Use prefill kernel for depth d
2318+ f_prefill (/* depth=*/ d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id],
2319+ page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d],
2320+ length_info_on_depths_view_[d], k_rope_pos_offset_view_[d],
2321+ q_rope_position_map_view_, attn_output, attn_scores, /* causal=*/ 0 ,
2322+ /* rotary_mode=*/ rope_mode_ == RoPEMode::kInline , rotary_scale_, rotary_theta_,
2323+ attn_score_scaling_factor);
2324+ }
2325+
2326+ if (!is_first_kernel) {
23282327 f_merge_inplace_ (output, merged_attn_scores_view_, temp_attn_output_view_,
23292328 temp_attn_scores_view_);
2329+ } else {
2330+ is_first_kernel = false ;
23302331 }
23312332 }
23322333 }
0 commit comments