Skip to content

Commit efbb7dd

Browse files
MasterJH5574jinhongyiispectrometerHBH
committed
[Unity] Support TIR kernel for PagedKVCache
This PR supports PagedKVCache with leveraging TIR kernels. Right now we do not have sufficient TIR kernels for multi-level sequences in PagedKVCache, therefore `Fork` in PagedKVCache is disabled when such a function does not exist. This PR adds a "reduced" creator of PagedKVCache, where some auxiliary functions such as the begin/end forward function of prefill/decode default to None. CUDA tests are added to ensure correctness. Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Bohan Hou <[email protected]>
1 parent b8230f6 commit efbb7dd

File tree

3 files changed

+1149
-40
lines changed

3 files changed

+1149
-40
lines changed

src/runtime/relax_vm/paged_kv_cache.cc

Lines changed: 77 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -277,15 +277,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
277277
PackedFunc f_transpose_append_;
278278
PackedFunc f_attention_prefill_;
279279
PackedFunc f_attention_decode_;
280-
PackedFunc f_attention_prefill_ragged_;
281-
PackedFunc f_attention_prefill_ragged_begin_forward_;
282-
PackedFunc f_attention_prefill_ragged_end_forward_;
283-
PackedFunc f_attention_prefill_begin_forward_;
284-
PackedFunc f_attention_prefill_end_forward_;
285-
PackedFunc f_attention_decode_begin_forward_;
286-
PackedFunc f_attention_decode_end_forward_;
280+
Optional<PackedFunc> f_attention_prefill_ragged_;
281+
Optional<PackedFunc> f_attention_prefill_ragged_begin_forward_;
282+
Optional<PackedFunc> f_attention_prefill_ragged_end_forward_;
283+
Optional<PackedFunc> f_attention_prefill_begin_forward_;
284+
Optional<PackedFunc> f_attention_prefill_end_forward_;
285+
Optional<PackedFunc> f_attention_decode_begin_forward_;
286+
Optional<PackedFunc> f_attention_decode_end_forward_;
287287
PackedFunc f_rotary_;
288-
PackedFunc f_merge_inplace_;
288+
Optional<PackedFunc> f_merge_inplace_;
289289
Optional<PackedFunc> f_debug_get_kv_;
290290

291291
/*! \brief Number of fork depth in the current round of forward. */
@@ -297,19 +297,23 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
297297

298298
public:
299299
/*! \brief Constructor. Take the cache configuration and initialize the NDArrays. */
300-
explicit PagedAttentionKVCacheObj(
301-
int64_t page_size, //
302-
int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads,
303-
int64_t head_dim, //
304-
int64_t reserved_num_seqs, int64_t num_total_pages, //
305-
double rotary_scale, double rotary_theta, //
306-
DLDataType dtype, DLDevice device, PackedFunc f_transpose_append,
307-
PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
308-
PackedFunc f_attention_prefill_ragged, PackedFunc f_attention_prefill_ragged_begin_forward,
309-
PackedFunc f_attention_prefill_ragged_end_forward,
310-
PackedFunc f_attention_prefill_begin_forward, PackedFunc f_attention_prefill_end_forward,
311-
PackedFunc f_attention_decode_begin_forward, PackedFunc f_attention_decode_end_forward,
312-
PackedFunc f_rotary, PackedFunc f_merge_inplace, Optional<PackedFunc> f_debug_get_kv)
300+
explicit PagedAttentionKVCacheObj(int64_t page_size, //
301+
int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads,
302+
int64_t head_dim, //
303+
int64_t reserved_num_seqs, int64_t num_total_pages, //
304+
double rotary_scale, double rotary_theta, //
305+
DLDataType dtype, DLDevice device,
306+
PackedFunc f_transpose_append, PackedFunc f_attention_prefill,
307+
PackedFunc f_attention_decode,
308+
Optional<PackedFunc> f_attention_prefill_ragged,
309+
Optional<PackedFunc> f_attention_prefill_ragged_begin_forward,
310+
Optional<PackedFunc> f_attention_prefill_ragged_end_forward,
311+
Optional<PackedFunc> f_attention_prefill_begin_forward,
312+
Optional<PackedFunc> f_attention_prefill_end_forward,
313+
Optional<PackedFunc> f_attention_decode_begin_forward,
314+
Optional<PackedFunc> f_attention_decode_end_forward,
315+
PackedFunc f_rotary, Optional<PackedFunc> f_merge_inplace,
316+
Optional<PackedFunc> f_debug_get_kv)
313317
: page_size_(page_size),
314318
num_layers_(num_layers),
315319
num_qo_heads_(num_qo_heads),
@@ -418,6 +422,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
418422
<< "The parent sequence \"" << parent_seq_id << "\" cannot be found in KV cache.";
419423
CHECK(seq_map_.find(child_seq_id) == seq_map_.end())
420424
<< "The child sequence \"" << child_seq_id << "\" is already in the KV cache.";
425+
CHECK(f_merge_inplace_.defined() && f_attention_prefill_ragged_.defined())
426+
<< "Attention merge-score function not available. ForkSequence is thereby not supported.";
421427

422428
int32_t parent_block_idx = parent_it->second.last_block_idx;
423429
// Create a child block with the parent block pointer.
@@ -558,13 +564,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
558564
}
559565

560566
void EndForward() final {
567+
if (!f_attention_prefill_end_forward_.defined() || !f_attention_decode_end_forward_.defined() ||
568+
!f_attention_prefill_ragged_end_forward_.defined()) {
569+
return;
570+
}
561571
// Mark the dirty flag as true, so that BeginForward is required
562572
// to be invoked before the next round of model forward.
563573
dirty_aux_data_device_ = true;
564-
f_attention_prefill_ragged_end_forward_();
574+
f_attention_prefill_ragged_end_forward_.value()();
565575
for (int d = 0; d < num_depths_; ++d) {
566-
f_attention_prefill_end_forward_(d);
567-
f_attention_decode_end_forward_(d);
576+
f_attention_prefill_end_forward_.value()(d);
577+
f_attention_decode_end_forward_.value()(d);
568578
}
569579
}
570580

@@ -845,30 +855,36 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
845855

846856
/*! \brief Invoke the "begin forward" functions of underlying kernels. */
847857
void KernelBeginForward() {
858+
if (!f_attention_prefill_begin_forward_.defined() ||
859+
!f_attention_decode_begin_forward_.defined() ||
860+
!f_attention_prefill_ragged_begin_forward_.defined()) {
861+
return;
862+
}
863+
848864
if (num_depths_ == 1) {
849865
if (use_decode_kernel_[0]) {
850-
f_attention_decode_begin_forward_(
866+
f_attention_decode_begin_forward_.value()(
851867
/*depth=*/0, page_indptr_on_depths_view_[0], last_page_len_on_depths_view_[0],
852868
num_qo_heads_, num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/true);
853869
} else {
854-
f_attention_prefill_begin_forward_(/*depth=*/0, qo_indptr_on_depths_view_[0],
855-
cur_batch_size_, num_qo_heads_, num_kv_heads_);
870+
f_attention_prefill_begin_forward_.value()(/*depth=*/0, qo_indptr_on_depths_view_[0],
871+
cur_batch_size_, num_qo_heads_, num_kv_heads_);
856872
}
857873
} else {
858-
f_attention_prefill_ragged_begin_forward_(cur_append_length_indptr_view_, cur_batch_size_,
859-
num_qo_heads_, num_kv_heads_);
874+
f_attention_prefill_ragged_begin_forward_.value()(
875+
cur_append_length_indptr_view_, cur_batch_size_, num_qo_heads_, num_kv_heads_);
860876
for (int d = 0; d < num_depths_; ++d) {
861877
if (page_indices_on_depths_view_[d]->shape[0] == 0) {
862878
continue;
863879
}
864880
if (use_decode_kernel_[d]) {
865-
f_attention_decode_begin_forward_(
881+
f_attention_decode_begin_forward_.value()(
866882
d, page_indptr_on_depths_view_[d], last_page_len_on_depths_view_[d], num_qo_heads_,
867883
num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/false);
868884
} else {
869-
f_attention_prefill_begin_forward_(/*depth=*/d, qo_indptr_on_depths_view_[d],
870-
last_page_len_on_depths_view_[d]->shape[0],
871-
num_qo_heads_, num_kv_heads_);
885+
f_attention_prefill_begin_forward_.value()(/*depth=*/d, qo_indptr_on_depths_view_[d],
886+
last_page_len_on_depths_view_[d]->shape[0],
887+
num_qo_heads_, num_kv_heads_);
872888
}
873889
}
874890
}
@@ -896,10 +912,11 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
896912
}
897913
} else {
898914
// Compute appended text self-attention
899-
f_attention_prefill_ragged_(q_data, cur_append_length_indptr_view_, k_data, v_data,
900-
cur_append_length_indptr_view_, output, merged_attn_scores_view_,
901-
/*causal=*/1,
902-
/*rotary_mode=*/0, rotary_scale_, rotary_theta_);
915+
f_attention_prefill_ragged_.value()(q_data, cur_append_length_indptr_view_, k_data, v_data,
916+
cur_append_length_indptr_view_, output,
917+
merged_attn_scores_view_,
918+
/*causal=*/1,
919+
/*rotary_mode=*/0, rotary_scale_, rotary_theta_);
903920

904921
for (int d = 0; d < num_depths_; ++d) {
905922
if (page_indices_on_depths_view_[d]->shape[0] == 0) {
@@ -920,8 +937,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
920937
/*causal=*/0,
921938
/*rotary_mode=*/0, rotary_scale_, rotary_theta_);
922939
}
923-
f_merge_inplace_(output, merged_attn_scores_view_, temp_attn_output_view_,
924-
temp_attn_scores_view_);
940+
f_merge_inplace_.value()(output, merged_attn_scores_view_, temp_attn_output_view_,
941+
temp_attn_scores_view_);
925942
}
926943
}
927944
}
@@ -1068,6 +1085,27 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
10681085
return PagedAttentionKVCache(std::move(n));
10691086
});
10701087

1088+
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
1089+
.set_body_typed([](ShapeTuple cache_config, int64_t num_layers, int64_t num_qo_heads,
1090+
int64_t num_kv_heads, int64_t head_dim, double rotary_scale,
1091+
double rotary_theta, NDArray init, PackedFunc f_transpose_append,
1092+
PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
1093+
PackedFunc f_rotary, Optional<PackedFunc> f_debug_get_kv) {
1094+
CHECK_EQ(cache_config.size(), 3);
1095+
int64_t reserved_num_seqs = cache_config[0];
1096+
int64_t total_token_capacity = cache_config[1];
1097+
int64_t page_size = cache_config[2];
1098+
int64_t num_total_pages = (total_token_capacity + page_size - 1) / page_size;
1099+
ObjectPtr<PagedAttentionKVCacheObj> n = make_object<PagedAttentionKVCacheObj>(
1100+
page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, reserved_num_seqs,
1101+
num_total_pages, rotary_scale, rotary_theta, init->dtype, init->device,
1102+
std::move(f_transpose_append), std::move(f_attention_prefill),
1103+
std::move(f_attention_decode), //
1104+
NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, //
1105+
std::move(f_rotary), NullOpt, std::move(f_debug_get_kv));
1106+
return PagedAttentionKVCache(std::move(n));
1107+
});
1108+
10711109
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_clear")
10721110
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::Clear);
10731111
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_add_sequence")

tests/python/relax/test_runtime_builtin_paged_attention_kv_cache.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,20 @@ def kv_cache_transpose_append(
8181
for global_pos, h, f in T.grid(ntoken, num_kv_heads, head_dim):
8282
with T.block("k_transpose_append"):
8383
vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
84+
T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
85+
T.writes(
86+
pages[position_map[vgpos] // page_size, 0, vh, position_map[vgpos] % page_size, vf]
87+
)
8488
position: T.int64 = T.Cast("int64", position_map[vgpos])
8589
pages[
8690
T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vf
8791
] = k_data[vgpos, vh, vf]
8892
with T.block("v_transpose_append"):
8993
vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
94+
T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
95+
T.writes(
96+
pages[position_map[vgpos] // page_size, 1, vh, position_map[vgpos] % page_size, vf]
97+
)
9098
position: T.int64 = T.Cast("int64", position_map[vgpos])
9199
pages[
92100
T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vf
@@ -115,6 +123,11 @@ def copy_cache(
115123
for p, h, d in T.grid(seqlen, num_kv_heads, head_dim):
116124
with T.block("copy0"):
117125
vp, vh, vd = T.axis.remap("SSS", [p, h, d])
126+
T.reads(
127+
position_map[vp],
128+
pages[position_map[vp] // page_size, 0:2, vh, position_map[vp] % page_size, vd],
129+
)
130+
T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd])
118131
position: T.int64 = T.Cast("int64", position_map[vp])
119132
k_data[layer_id, vp, vh, vd] = pages[
120133
T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vd
@@ -457,7 +470,7 @@ def test_paged_attention_kv_cache_popn(kv_cache):
457470
if pop_length != 0:
458471
cached_k[seq_id] = cached_k[seq_id][:, :-pop_length, ...]
459472
cached_v[seq_id] = cached_v[seq_id][:, :-pop_length, ...]
460-
verify_cached_kv(kv_cache, seq_ids=list(range(4)), expected_k=cached_k, expected_v=cached_v)
473+
verify_cached_kv(kv_cache, seq_ids=list(range(5)), expected_k=cached_k, expected_v=cached_v)
461474

462475

463476
if __name__ == "__main__":

0 commit comments

Comments
 (0)