@@ -147,13 +147,14 @@ struct Sequence {
147147 */
148148 int last_block_attn_sink_size = 0 ;
149149
150- explicit Sequence (const std::vector<Block>& global_block_pool, int32_t last_block_idx) {
150+ explicit Sequence (std::vector<Block>* global_block_pool, int32_t last_block_idx) {
151+ ++global_block_pool->at (last_block_idx).external_ref_cnt ;
151152 this ->last_block_idx = last_block_idx;
152153 int32_t block_ptr = last_block_idx;
153154 // Go through each block in the sequence, sum up the length.
154155 int depth = 0 ;
155156 while (true ) {
156- const Block& block = global_block_pool[ block_ptr] ;
157+ const Block& block = global_block_pool-> at ( block_ptr) ;
157158 this ->seq_length += block.seq_length ;
158159 ++depth;
159160 if (block.parent_idx == -1 ) {
@@ -965,17 +966,17 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
965966 CHECK (seq_map_.find (seq_id) == seq_map_.end ())
966967 << " The sequence \" " << seq_id << " \" is already in the KV cache." ;
967968 int32_t block_idx = GetFreeBlock ();
968- seq_map_.insert ({seq_id, Sequence (global_block_pool_, block_idx)});
969+ seq_map_.insert ({seq_id, Sequence (& global_block_pool_, block_idx)});
969970 dirty_aux_data_device_ = true ;
970971 }
971972
972973 void RemoveSequence (int64_t seq_id) final {
973974 auto it = seq_map_.find (seq_id);
974975 CHECK (it != seq_map_.end ()) << " The sequence \" " << seq_id << " \" cannot be found in KV cache." ;
975976 int32_t block_idx = it->second .last_block_idx ;
976- CHECK_EQ (global_block_pool_[block_idx]. external_ref_cnt , 0 )
977- << " The sequence is currently referenced by other sequence and thus cannot be removed. " ;
978- while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0 ) {
977+ // The block should have at least one reference, which comes from the sequence.
978+ ICHECK_GE (global_block_pool_[block_idx]. external_ref_cnt , 1 ) ;
979+ while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 1 ) {
979980 // - Free pages in the last block.
980981 for (int32_t page_id : global_block_pool_[block_idx].page_ids ) {
981982 free_page_ids_.push_back (page_id);
@@ -985,7 +986,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
985986 }
986987 // - Decrease the external reference of the parent block.
987988 if (block_idx != -1 ) {
988- ICHECK_GT (global_block_pool_[block_idx].external_ref_cnt , 0 );
989+ ICHECK_GT (global_block_pool_[block_idx].external_ref_cnt , 1 );
989990 --global_block_pool_[block_idx].external_ref_cnt ;
990991 }
991992 seq_map_.erase (it);
@@ -1018,11 +1019,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
10181019 // Update child block start position and parent index
10191020 global_block_pool_[child_block_idx].start_pos = parent_it->second .seq_length ;
10201021 global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
1021- if (global_block_pool_[parent_block_idx].seq_length ) {
1022- // If parent is not empty, append a new block
1022+ if (parent_block_idx == parent_it->second .last_block_idx &&
1023+ global_block_pool_[parent_block_idx].seq_length ) {
1024+ // To enable the parent sequence to continue decode after the fork,
1025+ // we add a new empty block at the end of the parent sequence.
1026+ // So the new decoded KV data will go into the new block.
10231027 int32_t new_parent_block_idx = GetFreeBlock ();
10241028 global_block_pool_[new_parent_block_idx].start_pos = parent_it->second .seq_length ;
10251029 global_block_pool_[new_parent_block_idx].parent_idx = parent_block_idx;
1030+ global_block_pool_[new_parent_block_idx].external_ref_cnt = 1 ;
10261031 parent_it->second .last_block_idx = new_parent_block_idx;
10271032 }
10281033 } else {
@@ -1055,7 +1060,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
10551060 global_block_pool_[forked_block_idx].parent_idx ;
10561061 global_block_pool_[forked_block_idx].parent_idx = parent_block_idx;
10571062 global_block_pool_[child_block_idx].parent_idx = parent_block_idx;
1058- global_block_pool_[parent_block_idx].external_ref_cnt = 1 ;
1063+ global_block_pool_[parent_block_idx].external_ref_cnt = 2 ;
10591064
10601065 // Move common leading pages to new parent block
10611066 auto first_page = global_block_pool_[forked_block_idx].page_ids .begin ();
@@ -1085,7 +1090,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
10851090 }
10861091 }
10871092 // Create the child sequence with the child block.
1088- seq_map_.insert ({child_seq_id, Sequence (global_block_pool_, child_block_idx)});
1093+ seq_map_.insert ({child_seq_id, Sequence (& global_block_pool_, child_block_idx)});
10891094 dirty_aux_data_device_ = true ;
10901095 }
10911096
@@ -1119,7 +1124,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
11191124 << " A sequence cannot be enabled twice for sliding window." ;
11201125
11211126 // Compute the total length of the prefix blocks of this sequence.
1122- Block& last_block = global_block_pool_[it->second .last_block_idx ];
1127+ const Block& last_block = global_block_pool_[it->second .last_block_idx ];
11231128 int32_t prefix_length = it->second .seq_length - last_block.seq_length ;
11241129 ICHECK_GE (prefix_length, 0 );
11251130 // Since the prefix blocks cannot sliding, they are natural
@@ -1139,7 +1144,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
11391144 << " The sequence only has length " << it->second .seq_length
11401145 << " , while the length of pop is " << n << " which exceeds the whole sequence length." ;
11411146 int32_t block_idx = it->second .last_block_idx ;
1142- while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0 ) {
1147+ // The block should have at least one reference, which comes from the sequence.
1148+ ICHECK_GE (global_block_pool_[block_idx].external_ref_cnt , 1 );
1149+ while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 1 ) {
11431150 if (n > global_block_pool_[block_idx].seq_length ) {
11441151 n -= global_block_pool_[block_idx].seq_length ;
11451152 it->second .seq_length -= global_block_pool_[block_idx].seq_length ;
@@ -1168,14 +1175,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
11681175 }
11691176
11701177 if (n) {
1171- int32_t temp_seq_id = -1 - seq_id;
1178+ // We use a temporary sequence id for fork.
1179+ // This temporary seq id will immediately end its effect outside this function.
1180+ int64_t temp_seq_id = -1 - seq_id;
11721181 CHECK (seq_map_.find (temp_seq_id) == seq_map_.end ());
11731182 ForkSequence (seq_id, temp_seq_id, it->second .seq_length - n);
11741183 CHECK (seq_map_.find (temp_seq_id) != seq_map_.end ());
11751184 RemoveSequence (seq_id);
11761185 CHECK (seq_map_.find (seq_id) == seq_map_.end ());
11771186 auto it = seq_map_.find (temp_seq_id);
1178- seq_map_.insert ({seq_id, Sequence (global_block_pool_, it->second . last_block_idx ) });
1187+ seq_map_.insert ({seq_id, it->second });
11791188 seq_map_.erase (temp_seq_id);
11801189 }
11811190
@@ -1184,6 +1193,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
11841193
11851194 /* ************* Raw Info Query **************/
11861195
1196+ bool Empty () const final {
1197+ return seq_map_.empty () && //
1198+ free_block_idx_.size () == global_block_pool_.size () && //
1199+ free_page_ids_.size () == static_cast <size_t >(num_total_pages_);
1200+ }
1201+
11871202 int32_t GetNumAvailablePages () const final { return free_page_ids_.size (); }
11881203
11891204 int32_t GetTotalSequenceLength () const final {
@@ -1565,8 +1580,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
15651580 int32_t block_idx = seq->last_block_idx ;
15661581 Block& block = global_block_pool_[block_idx];
15671582 CHECK_GT (append_length, 0 ) << " Append with length 0 is not allowed." ;
1568- CHECK_EQ (block.external_ref_cnt , 0 )
1569- << " The block is " << block.external_ref_cnt
1583+ CHECK_EQ (block.external_ref_cnt , 1 )
1584+ << " The block is " << block.external_ref_cnt - 1
15701585 << " -time referenced by other blocks, thus cannot accept new KV values." ;
15711586
15721587 // ==================== Reserve ====================
0 commit comments