Skip to content

Commit a51c9c6

Browse files
committed
[Runtime] Support PagedKVCache with tree attention
This PR introduces the tree attention to PagedKVCache. With this feature, now the KV cache is ready for tree attention cases such as speculative decoding trees. This PR adds tree attention tests to test the correctness. The changes in this PR to KVState interface are backward compatible.
1 parent 1eac178 commit a51c9c6

File tree

5 files changed

+1083
-134
lines changed

5 files changed

+1083
-134
lines changed

src/runtime/relax_vm/kv_state.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,26 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_fork_sequence")
4040
.set_body_method<KVState>(&KVStateObj::ForkSequence);
4141
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_popn").set_body_method<KVState>(&KVStateObj::PopN);
4242
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_begin_forward")
43-
.set_body_method<KVState>(&KVStateObj::BeginForward);
43+
.set_body([](TVMArgs args, TVMRetValue* rv) {
44+
CHECK(args.size() == 3 || args.size() == 4)
45+
<< "KVState BeginForward only accepts 3 or 4 arguments";
46+
KVState kv_state = args[0];
47+
IntTuple seq_ids = args[1];
48+
IntTuple append_lengths = args[2];
49+
Optional<IntTuple> token_tree_parents{nullptr};
50+
if (args.size() == 4) {
51+
token_tree_parents = Optional<IntTuple>(args[3]);
52+
}
53+
kv_state->BeginForward(seq_ids, append_lengths, token_tree_parents);
54+
});
4455
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward")
4556
.set_body_method<KVState>(&KVStateObj::EndForward);
4657

4758
// Attention KV Cache methods
4859
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq")
4960
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::EnableSlidingWindowForSeq);
61+
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes")
62+
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::CommitAcceptedTokenTreeNodes);
5063
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_empty")
5164
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::Empty);
5265
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages")

src/runtime/relax_vm/kv_state.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,12 @@ class KVStateObj : public Object {
8989
* in the model forward function.
9090
* \param seq_ids The ids of the sequence to run in the incoming model forward.
9191
* \param append_lengths The sequence lengths to run forward for for each sequence.
92+
* \param token_tree_parents The parent idx array of the token trees. Its length
93+
* is the sum of "append_lengths". Nullptr means the token tree of each sequence
94+
* is a chain.
9295
*/
93-
virtual void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths) = 0;
96+
virtual void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths,
97+
const Optional<IntTuple>& token_tree_parents = NullOpt) = 0;
9498

9599
/*!
96100
* \brief Mark the start of the forward function.
@@ -142,6 +146,15 @@ class AttentionKVCacheObj : public KVStateObj {
142146
virtual void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size,
143147
int32_t attn_sink_size) = 0;
144148

149+
/*!
150+
* \brief Committed the accepted token tree nodes to KV cache.
151+
* The commit will update the KV cache, by compacting the KV data and discard
152+
* the KV data of rejected tokens.
153+
* This is a mandatory step when the BeginForward is given with a token tree.
154+
* \param leaf_indices The leaf token tree node index of each sequence.
155+
*/
156+
virtual void CommitAcceptedTokenTreeNodes(const IntTuple& leaf_indices) = 0;
157+
145158
/************** Attention **************/
146159

147160
/*!

0 commit comments

Comments
 (0)