Skip to content

Commit cd09ab6

Browse files
authored
[Runtime] Reorganize PagedKVCache attn kernel invocation (#17237)
This PR reorganizes the attention kernel invocation logic in the PagedKVCache, so that in cases of sequence fork, we can effectively merge one ragged-prefill kernel and a decode kernel into a single decode kernel.
1 parent 21c12fb commit cd09ab6

File tree

2 files changed

+65
-64
lines changed

2 files changed

+65
-64
lines changed

src/relax/transform/fuse_ops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ class FunctionCreator : public ExprMutator {
646646
return tvm::tir::UndefinedVars(prim_value->value).empty();
647647
} else if (const auto* shape_expr = expr.as<ShapeExprNode>()) {
648648
return std::all_of(shape_expr->values.begin(), shape_expr->values.end(),
649-
[this](const PrimExpr& e) { return tvm::tir::UndefinedVars(e).empty(); });
649+
[](const PrimExpr& e) { return tvm::tir::UndefinedVars(e).empty(); });
650650
}
651651
return false;
652652
}

src/runtime/relax_vm/paged_kv_cache.cc

Lines changed: 64 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,7 +1535,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
15351535
CHECK_EQ(chunked_block_ids_arr[num_depths_ - 1].size(), cur_batch_size_);
15361536
}
15371537

1538-
append_before_attn_ = !support_sliding_window_ && num_depths_ == 1 && use_decode_kernel_[0];
1538+
append_before_attn_ = !support_sliding_window_ && use_decode_kernel_.back();
15391539
if (NeedKernelBeginForward() && num_qo_heads_ / num_kv_heads_ >= 4) {
15401540
// When GQA group size is at least 4 and FlashInfer is enabled,
15411541
// we always use prefill kernel for better performance.
@@ -2220,39 +2220,33 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
22202220
return;
22212221
}
22222222

2223-
if (append_before_attn_) {
2224-
if (!support_sliding_window_) {
2223+
if (!append_before_attn_) {
2224+
if (is_chain_) {
2225+
f_attention_prefill_ragged_begin_forward_.value()(
2226+
temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(),
2227+
cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_,
2228+
num_kv_heads_, head_dim_, copy_stream_);
2229+
} else {
2230+
LOG(FATAL) << "Kernel BeginForward doesn't support tree attn.";
2231+
}
2232+
}
2233+
for (int d = 0; d < num_depths_; ++d) {
2234+
if (page_indices_on_depths_view_[d]->shape[0] == 0) {
2235+
continue;
2236+
}
2237+
CHECK(!support_sliding_window_) << "Kernel BeginForward doesn't support sliding window.";
2238+
if (use_decode_kernel_[d]) {
22252239
f_attention_decode_begin_forward_.value()(
2226-
/*depth=*/0, temp_attn_workspace_[1], page_indptr_on_depths_host_[0].as_ndarray(),
2227-
last_page_len_on_depths_host_[0].as_ndarray(), num_qo_heads_, num_kv_heads_, head_dim_,
2240+
d, temp_attn_workspace_[d + 1], page_indptr_on_depths_host_[d].as_ndarray(),
2241+
last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, num_kv_heads_, head_dim_,
22282242
page_size_,
22292243
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
2230-
}
2231-
} else {
2232-
f_attention_prefill_ragged_begin_forward_.value()(
2233-
temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(),
2234-
cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_,
2235-
num_kv_heads_, head_dim_, copy_stream_);
2236-
if (support_sliding_window_) {
2237-
return;
2238-
}
2239-
for (int d = 0; d < num_depths_; ++d) {
2240-
if (page_indices_on_depths_view_[d]->shape[0] == 0) {
2241-
continue;
2242-
}
2243-
if (use_decode_kernel_[d]) {
2244-
f_attention_decode_begin_forward_.value()(
2245-
d, temp_attn_workspace_[d + 1], page_indptr_on_depths_host_[d].as_ndarray(),
2246-
last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, num_kv_heads_,
2247-
head_dim_, page_size_,
2248-
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
2249-
} else {
2250-
f_attention_prefill_begin_forward_.value()(
2251-
/*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(),
2252-
page_indptr_on_depths_host_[d].as_ndarray(),
2253-
static_cast<int>(page_indptr_on_depths_host_[d].size()) - 1, num_qo_heads_,
2254-
num_kv_heads_, head_dim_, page_size_, copy_stream_);
2255-
}
2244+
} else {
2245+
f_attention_prefill_begin_forward_.value()(
2246+
/*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(),
2247+
page_indptr_on_depths_host_[d].as_ndarray(),
2248+
static_cast<int>(page_indptr_on_depths_host_[d].size()) - 1, num_qo_heads_,
2249+
num_kv_heads_, head_dim_, page_size_, copy_stream_);
22562250
}
22572251
}
22582252
}
@@ -2271,15 +2265,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
22712265
PackedFunc f_decode =
22722266
!support_sliding_window_ ? f_attention_decode_ : f_attention_decode_sliding_window_;
22732267
CHECK_GE(num_depths_, 1) << "The number of effective depths must be greater or equal to 1.";
2274-
if (append_before_attn_) {
2275-
f_decode(
2276-
/*depth=*/0, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[0],
2277-
page_indices_on_depths_view_[0], length_info_on_depths_view_[0],
2278-
k_rope_pos_offset_view_[0], q_rope_position_map_view_, output, merged_attn_scores_view_,
2279-
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_,
2280-
attn_score_scaling_factor);
2281-
} else {
2282-
// Compute appended text self-attention
2268+
2269+
bool is_first_kernel = true;
2270+
if (!append_before_attn_) {
2271+
// The first part of attention, which only involves the q and the newly appended k/v.
2272+
is_first_kernel = false;
22832273
if (is_chain_) {
22842274
// If the batch does not form a tree, use raggedness prefill kernel.
22852275
f_attention_prefill_ragged_(q_data, cur_append_length_indptr_view_, k_data, v_data,
@@ -2301,32 +2291,43 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
23012291
merged_attn_scores_view_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline,
23022292
rotary_scale_, rotary_theta_, attn_score_scaling_factor, cur_batch_size_);
23032293
}
2294+
}
23042295

2305-
for (int d = 0; d < num_depths_; ++d) {
2306-
if (page_indices_on_depths_view_[d]->shape[0] == 0) {
2307-
continue;
2308-
}
2309-
if (use_decode_kernel_[d]) {
2310-
// Use decode kernel for depth d
2311-
f_decode(/*depth=*/d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d],
2312-
page_indices_on_depths_view_[d], length_info_on_depths_view_[d],
2313-
k_rope_pos_offset_view_[d], q_rope_position_map_view_, temp_attn_output_view_,
2314-
temp_attn_scores_view_,
2315-
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_,
2316-
attn_score_scaling_factor);
2317-
} else {
2318-
// Use prefill kernel for depth d
2319-
f_prefill(
2320-
/*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id],
2321-
page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d],
2322-
length_info_on_depths_view_[d], k_rope_pos_offset_view_[d], q_rope_position_map_view_,
2323-
temp_attn_output_view_, temp_attn_scores_view_,
2324-
/*causal=*/0,
2325-
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_,
2326-
attn_score_scaling_factor);
2327-
}
2296+
for (int d = 0; d < num_depths_; ++d) {
2297+
if (page_indices_on_depths_view_[d]->shape[0] == 0) {
2298+
continue;
2299+
}
2300+
NDArray attn_output;
2301+
NDArray attn_scores;
2302+
if (is_first_kernel) {
2303+
attn_output = output;
2304+
attn_scores = merged_attn_scores_view_;
2305+
} else {
2306+
attn_output = temp_attn_output_view_;
2307+
attn_scores = temp_attn_scores_view_;
2308+
}
2309+
if (use_decode_kernel_[d]) {
2310+
// Use decode kernel for depth d
2311+
f_decode(/*depth=*/d, q_data, pages_[local_layer_id], page_indptr_on_depths_view_[d],
2312+
page_indices_on_depths_view_[d], length_info_on_depths_view_[d],
2313+
k_rope_pos_offset_view_[d], q_rope_position_map_view_, attn_output, attn_scores,
2314+
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_,
2315+
attn_score_scaling_factor);
2316+
} else {
2317+
// Use prefill kernel for depth d
2318+
f_prefill(/*depth=*/d, q_data, qo_indptr_on_depths_view_[d], pages_[local_layer_id],
2319+
page_indptr_on_depths_view_[d], page_indices_on_depths_view_[d],
2320+
length_info_on_depths_view_[d], k_rope_pos_offset_view_[d],
2321+
q_rope_position_map_view_, attn_output, attn_scores, /*causal=*/0,
2322+
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, rotary_scale_, rotary_theta_,
2323+
attn_score_scaling_factor);
2324+
}
2325+
2326+
if (!is_first_kernel) {
23282327
f_merge_inplace_(output, merged_attn_scores_view_, temp_attn_output_view_,
23292328
temp_attn_scores_view_);
2329+
} else {
2330+
is_first_kernel = false;
23302331
}
23312332
}
23322333
}

0 commit comments

Comments
 (0)