Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/runtime/relax_vm/kv_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,26 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_fork_sequence")
.set_body_method<KVState>(&KVStateObj::ForkSequence);
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_popn").set_body_method<KVState>(&KVStateObj::PopN);
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_begin_forward")
.set_body_method<KVState>(&KVStateObj::BeginForward);
.set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK(args.size() == 3 || args.size() == 4)
<< "KVState BeginForward only accepts 3 or 4 arguments";
KVState kv_state = args[0];
IntTuple seq_ids = args[1];
IntTuple append_lengths = args[2];
Optional<IntTuple> token_tree_parent_ptr{nullptr};
if (args.size() == 4) {
token_tree_parent_ptr = Optional<IntTuple>(args[3]);
}
kv_state->BeginForward(seq_ids, append_lengths, token_tree_parent_ptr);
});
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward")
.set_body_method<KVState>(&KVStateObj::EndForward);

// Attention KV Cache methods
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::EnableSlidingWindowForSeq);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::CommitAcceptedTokenTreeNodes);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_empty")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::Empty);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages")
Expand Down
15 changes: 14 additions & 1 deletion src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,12 @@ class KVStateObj : public Object {
* in the model forward function.
* \param seq_ids The ids of the sequence to run in the incoming model forward.
* \param append_lengths The sequence lengths to run forward for for each sequence.
* \param token_tree_parent_ptr The parent idx array of the token trees. Its length
* is the sum of "append_lengths". Nullptr means the token tree of each sequence
* is a chain.
*/
virtual void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths) = 0;
virtual void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths,
const Optional<IntTuple>& token_tree_parent_ptr = NullOpt) = 0;

/*!
* \brief Mark the start of the forward function.
Expand Down Expand Up @@ -142,6 +146,15 @@ class AttentionKVCacheObj : public KVStateObj {
virtual void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size,
int32_t attn_sink_size) = 0;

/*!
* \brief Committed the accepted token tree nodes to KV cache.
* The commit will update the KV cache, by compacting the KV data and discard
* the KV data of rejected tokens.
* This is a mandatory step when the BeginForward is given with a token tree.
* \param leaf_indices The leaf token tree node index of each sequence.
*/
virtual void CommitAcceptedTokenTreeNodes(const IntTuple& leaf_indices) = 0;

/************** Attention **************/

/*!
Expand Down
Loading