Skip to content

Commit 1eac178

Browse files
authored
[Runtime] Fix PagedKVCache for PopN and enhance tests (#17045)
This PR fixes a bug in the PagedKVCache which may happen when the sequence removal order is not consistent with the reverse order of sequence add/fork order. With this fix, the PagedKVCache now supports removing sequences in any order without breaking. This PR also adds an `empty` function to PagedKVCache to check if the KV cache is empty. Right now this function is only used for test purpose, where we check if everything in the KV cache is freed after removing all sequences.
1 parent 820f1b6 commit 1eac178

File tree

4 files changed

+62
-21
lines changed

4 files changed

+62
-21
lines changed

src/runtime/relax_vm/kv_state.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward")
4747
// Attention KV Cache methods
4848
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq")
4949
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::EnableSlidingWindowForSeq);
50+
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_empty")
51+
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::Empty);
5052
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages")
5153
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetNumAvailablePages);
5254
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_total_sequence_length")

src/runtime/relax_vm/kv_state.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ class AttentionKVCacheObj : public KVStateObj {
117117
public:
118118
/************** Raw Info Query **************/
119119

120+
/*! \brief Check if the KV cache is empty. */
121+
virtual bool Empty() const = 0;
120122
/*!
121123
* \brief Get the number of available pages in the KV cache.
122124
* When the underlying KV cache implementation is not

src/runtime/relax_vm/paged_kv_cache.cc

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,14 @@ struct Sequence {
147147
*/
148148
int last_block_attn_sink_size = 0;
149149

150-
explicit Sequence(const std::vector<Block>& global_block_pool, int32_t last_block_idx) {
150+
explicit Sequence(std::vector<Block>* global_block_pool, int32_t last_block_idx) {
151+
++global_block_pool->at(last_block_idx).external_ref_cnt;
151152
this->last_block_idx = last_block_idx;
152153
int32_t block_ptr = last_block_idx;
153154
// Go through each block in the sequence, sum up the length.
154155
int depth = 0;
155156
while (true) {
156-
const Block& block = global_block_pool[block_ptr];
157+
const Block& block = global_block_pool->at(block_ptr);
157158
this->seq_length += block.seq_length;
158159
++depth;
159160
if (block.parent_idx == -1) {
@@ -965,17 +966,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
965966
CHECK(seq_map_.find(seq_id) == seq_map_.end())
966967
<< "The sequence \"" << seq_id << "\" is already in the KV cache.";
967968
int32_t block_idx = GetFreeBlock();
968-
seq_map_.insert({seq_id, Sequence(global_block_pool_, block_idx)});
969+
seq_map_.insert({seq_id, Sequence(&global_block_pool_, block_idx)});
969970
dirty_aux_data_device_ = true;
970971
}
971972

972973
void RemoveSequence(int64_t seq_id) final {
973974
auto it = seq_map_.find(seq_id);
974975
CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache.";
975976
int32_t block_idx = it->second.last_block_idx;
976-
CHECK_EQ(global_block_pool_[block_idx].external_ref_cnt, 0)
977-
<< "The sequence is currently referenced by other sequence and thus cannot be removed.";
978-
while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) {
977+
// The block should have at least one reference, which comes from the sequence.
978+
ICHECK_GE(global_block_pool_[block_idx].external_ref_cnt, 1);
979+
while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 1) {
979980
// - Free pages in the last block.
980981
for (int32_t page_id : global_block_pool_[block_idx].page_ids) {
981982
free_page_ids_.push_back(page_id);
@@ -985,7 +986,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
985986
}
986987
// - Decrease the external reference of the parent block.
987988
if (block_idx != -1) {
988-
ICHECK_GT(global_block_pool_[block_idx].external_ref_cnt, 0);
989+
ICHECK_GT(global_block_pool_[block_idx].external_ref_cnt, 1);
989990
--global_block_pool_[block_idx].external_ref_cnt;
990991
}
991992
seq_map_.erase(it);
@@ -1018,11 +1019,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
10181019
// Update child block start position and parent index
10191020
global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length;
10201021
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
1021-
if (global_block_pool_[parent_block_idx].seq_length) {
1022-
// If parent is not empty, append a new block
1022+
if (parent_block_idx == parent_it->second.last_block_idx &&
1023+
global_block_pool_[parent_block_idx].seq_length) {
1024+
// To enable the parent sequence to continue decode after the fork,
1025+
// we add a new empty block at the end of the parent sequence.
1026+
// So the new decoded KV data will go into the new block.
10231027
int32_t new_parent_block_idx = GetFreeBlock();
10241028
global_block_pool_[new_parent_block_idx].start_pos = parent_it->second.seq_length;
10251029
global_block_pool_[new_parent_block_idx].parent_idx = parent_block_idx;
1030+
global_block_pool_[new_parent_block_idx].external_ref_cnt = 1;
10261031
parent_it->second.last_block_idx = new_parent_block_idx;
10271032
}
10281033
} else {
@@ -1055,7 +1060,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
10551060
global_block_pool_[forked_block_idx].parent_idx;
10561061
global_block_pool_[forked_block_idx].parent_idx = parent_block_idx;
10571062
global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
1058-
global_block_pool_[parent_block_idx].external_ref_cnt = 1;
1063+
global_block_pool_[parent_block_idx].external_ref_cnt = 2;
10591064

10601065
// Move common leading pages to new parent block
10611066
auto first_page = global_block_pool_[forked_block_idx].page_ids.begin();
@@ -1085,7 +1090,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
10851090
}
10861091
}
10871092
// Create the child sequence with the child block.
1088-
seq_map_.insert({child_seq_id, Sequence(global_block_pool_, child_block_idx)});
1093+
seq_map_.insert({child_seq_id, Sequence(&global_block_pool_, child_block_idx)});
10891094
dirty_aux_data_device_ = true;
10901095
}
10911096

@@ -1119,7 +1124,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
11191124
<< "A sequence cannot be enabled twice for sliding window.";
11201125

11211126
// Compute the total length of the prefix blocks of this sequence.
1122-
Block& last_block = global_block_pool_[it->second.last_block_idx];
1127+
const Block& last_block = global_block_pool_[it->second.last_block_idx];
11231128
int32_t prefix_length = it->second.seq_length - last_block.seq_length;
11241129
ICHECK_GE(prefix_length, 0);
11251130
// Since the prefix blocks cannot sliding, they are natural
@@ -1139,7 +1144,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
11391144
<< "The sequence only has length " << it->second.seq_length
11401145
<< ", while the length of pop is " << n << " which exceeds the whole sequence length.";
11411146
int32_t block_idx = it->second.last_block_idx;
1142-
while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) {
1147+
// The block should have at least one reference, which comes from the sequence.
1148+
ICHECK_GE(global_block_pool_[block_idx].external_ref_cnt, 1);
1149+
while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 1) {
11431150
if (n > global_block_pool_[block_idx].seq_length) {
11441151
n -= global_block_pool_[block_idx].seq_length;
11451152
it->second.seq_length -= global_block_pool_[block_idx].seq_length;
@@ -1168,14 +1175,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
11681175
}
11691176

11701177
if (n) {
1171-
int32_t temp_seq_id = -1 - seq_id;
1178+
// We use a temporary sequence id for fork.
1179+
// This temporary seq id will immediately end its effect outside this function.
1180+
int64_t temp_seq_id = -1 - seq_id;
11721181
CHECK(seq_map_.find(temp_seq_id) == seq_map_.end());
11731182
ForkSequence(seq_id, temp_seq_id, it->second.seq_length - n);
11741183
CHECK(seq_map_.find(temp_seq_id) != seq_map_.end());
11751184
RemoveSequence(seq_id);
11761185
CHECK(seq_map_.find(seq_id) == seq_map_.end());
11771186
auto it = seq_map_.find(temp_seq_id);
1178-
seq_map_.insert({seq_id, Sequence(global_block_pool_, it->second.last_block_idx)});
1187+
seq_map_.insert({seq_id, it->second});
11791188
seq_map_.erase(temp_seq_id);
11801189
}
11811190

@@ -1184,6 +1193,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
11841193

11851194
/************** Raw Info Query **************/
11861195

1196+
bool Empty() const final {
1197+
return seq_map_.empty() && //
1198+
free_block_idx_.size() == global_block_pool_.size() && //
1199+
free_page_ids_.size() == static_cast<size_t>(num_total_pages_);
1200+
}
1201+
11871202
int32_t GetNumAvailablePages() const final { return free_page_ids_.size(); }
11881203

11891204
int32_t GetTotalSequenceLength() const final {
@@ -1565,8 +1580,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
15651580
int32_t block_idx = seq->last_block_idx;
15661581
Block& block = global_block_pool_[block_idx];
15671582
CHECK_GT(append_length, 0) << "Append with length 0 is not allowed.";
1568-
CHECK_EQ(block.external_ref_cnt, 0)
1569-
<< "The block is " << block.external_ref_cnt
1583+
CHECK_EQ(block.external_ref_cnt, 1)
1584+
<< "The block is " << block.external_ref_cnt - 1
15701585
<< "-time referenced by other blocks, thus cannot accept new KV values.";
15711586

15721587
// ==================== Reserve ====================

tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
fbegin_forward = None
5555
fend_forward = None
5656
fattention_with_fuse_qkv = None
57+
fis_empty = None
5758
fdebug_get_kv = None
5859

5960
ftranspose_append = None
@@ -71,7 +72,7 @@
7172

7273
def set_global_func(head_dim, dtype):
7374
global fclear, fadd_sequence, fremove_sequence, ffork_sequence, fenable_sliding_window_for_seq
74-
global fpopn, fbegin_forward, fend_forward, fattention_with_fuse_qkv, fdebug_get_kv
75+
global fpopn, fbegin_forward, fend_forward, fattention_with_fuse_qkv, fis_empty, fdebug_get_kv
7576
global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode, fattn_prefill_ragged
7677
global fattn_prefill_sliding_window, fattn_decode_sliding_window
7778
global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page
@@ -89,6 +90,7 @@ def set_global_func(head_dim, dtype):
8990
fattention_with_fuse_qkv = tvm.get_global_func(
9091
"vm.builtin.attention_kv_cache_attention_with_fused_qkv"
9192
)
93+
fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty")
9294
fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv")
9395

9496
target = tvm.target.Target("cuda")
@@ -489,11 +491,19 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
489491
for batch in operation_seq:
490492
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
491493

492-
for i in range(19, -1, -1):
494+
num_sequence = 20
495+
for i in range(num_sequence):
493496
fremove_sequence(kv_cache, i)
494497
cached_k.pop(i)
495498
cached_v.pop(i)
496-
verify_cached_kv(kv_cache, seq_ids=list(range(i)), expected_k=cached_k, expected_v=cached_v)
499+
verify_cached_kv(
500+
kv_cache,
501+
seq_ids=list(range(i + 1, num_sequence)),
502+
expected_k=cached_k,
503+
expected_v=cached_v,
504+
)
505+
506+
assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences"
497507

498508

499509
@tvm.testing.requires_gpu
@@ -510,14 +520,26 @@ def test_paged_attention_kv_cache_popn(kv_cache_and_config):
510520
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
511521
apply_attention(kv_cache, rope_mode, [((4, 3, -1), 35)], cached_k, cached_v)
512522

513-
popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0)]
523+
popn_operations = [(0, 17), (1, 57), (2, 16), (3, 0), (4, 37)]
514524
for seq_id, pop_length in popn_operations:
515525
fpopn(kv_cache, seq_id, pop_length)
516526
if pop_length != 0:
517527
cached_k[seq_id] = cached_k[seq_id][:, :-pop_length, ...]
518528
cached_v[seq_id] = cached_v[seq_id][:, :-pop_length, ...]
519529
verify_cached_kv(kv_cache, seq_ids=list(range(4)), expected_k=cached_k, expected_v=cached_v)
520530

531+
num_sequence = 5
532+
for seq_id in range(num_sequence):
533+
fremove_sequence(kv_cache, seq_id)
534+
verify_cached_kv(
535+
kv_cache,
536+
seq_ids=list(range(seq_id + 1, num_sequence)),
537+
expected_k=cached_k,
538+
expected_v=cached_v,
539+
)
540+
541+
assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences"
542+
521543

522544
@tvm.testing.requires_gpu
523545
@tvm.testing.requires_cuda

0 commit comments

Comments
 (0)