Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,11 @@ class AttentionKVCacheObj : public KVStateObj {
* The commit will update the KV cache, by compacting the KV data and discard
* the KV data of rejected tokens.
* This is a mandatory step when the BeginForward is given with a token tree.
* \param seq_ids The ids of the sequences to commit.
* \param leaf_indices The leaf token tree node index of each sequence.
*/
virtual void CommitAcceptedTokenTreeNodes(const IntTuple& leaf_indices) = 0;
virtual void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids,
const IntTuple& leaf_indices) = 0;

/************** Attention **************/

Expand Down
177 changes: 110 additions & 67 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,18 @@ struct Sequence {
*/
int last_block_attn_sink_size = 0;

/*! \brief Whether the current appended tokens form a chain (not a tree). */
bool is_chain = true;
/*! \brief The token tree parent pointer array of the current appended tokens. */
std::vector<int32_t> token_tree_parent_ptr;
/*! \brief The depth of each node in the token tree. */
std::vector<int32_t> token_tree_node_depths;
/*!
* \brief A boolean denoting whether the accepted token tree indices of
* this sequence are committed
*/
bool accepted_indices_committed = true;

explicit Sequence(std::vector<Block>* global_block_pool, int32_t last_block_idx) {
++global_block_pool->at(last_block_idx).external_ref_cnt;
this->last_block_idx = last_block_idx;
Expand Down Expand Up @@ -879,10 +891,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
IntTuple cur_seq_ids_;
/*! \brief The append lengths of the sequences in the current round of forwarding. */
IntTuple cur_append_lengths_;
/*! \brief The token tree parent array of the sequences in the current round of forwarding. */
IntTuple cur_token_tree_parent_ptr_{nullptr};
/*! \brief The depth of each node in the token tree, for the sequences in the current batch. */
std::vector<std::vector<int32_t>> cur_token_tree_node_depths_;
/*! \brief Whether the current batch of sequences are token chains (not token trees). */
bool is_chain_;
/*! \brief Number of fork depth in the current round of forward. */
Expand Down Expand Up @@ -1187,6 +1195,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
<< "The forked position should be non-negative, or -1 for last position as default.";
CHECK_LE(fork_pos, parent_it->second.seq_length)
<< "The forked position should not exceed the total length of parent sequence.";
CHECK(parent_it->second.accepted_indices_committed)
<< "The parent sequence's token tree computed in the last round of forward has not been "
"committed with accepted nodes.";

int32_t child_block_idx = GetFreeBlock();
if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) {
Expand Down Expand Up @@ -1434,25 +1445,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {

void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths,
const Optional<IntTuple>& opt_token_tree_parent_ptr) final {
CHECK(!cur_token_tree_parent_ptr_.defined())
<< "The last round of forward which involves token tree has not been committed. Please "
"call \"CommitAcceptedTreeNodes\" to commit the accepted tokens.";

CHECK_EQ(seq_ids.size(), append_lengths.size())
<< "The seq_ids size (" << seq_ids.size() << ") and append_lengths size ("
<< append_lengths.size() << ") mismatch.";
cur_batch_size_ = seq_ids.size();
cur_seq_ids_ = seq_ids;
cur_append_lengths_ = append_lengths;

// - Check token tree validity and process the token tree.
is_chain_ = true;
tree_attn_mask_host_.clear();
tree_attn_mn_indptr_host_.clear();
if (opt_token_tree_parent_ptr.defined()) {
is_chain_ = ConstructTokenTreeMask(opt_token_tree_parent_ptr.value());
}

// - Collect sequence/block/page information for attention.
std::vector<Sequence*> sequences;
std::vector<int32_t> last_block_length_before_append;
Expand All @@ -1474,6 +1473,29 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
}

// - Check token tree validity and process the token tree.
is_chain_ = true;
tree_attn_mask_host_.clear();
tree_attn_mn_indptr_host_.clear();
if (opt_token_tree_parent_ptr.defined()) {
is_chain_ = ConstructTokenTreeMask(sequences, opt_token_tree_parent_ptr.value());
} else {
// The input batch does not form trees. So each sequence in the batch
// is required to have all past accepted tokens committed.
for (int i = 0; i < cur_batch_size_; ++i) {
Sequence* sequence = sequences[i];
CHECK(sequence->accepted_indices_committed)
<< "The input batch does not form a tree, in which case the sequences in the input "
"batch are expected to have their accepted tokens token tree nodes committed. "
"Please invoke CommitAcceptedTokenTreeNodes for sequence "
<< seq_ids[i];
sequence->is_chain = true;
sequence->token_tree_parent_ptr.clear();
sequence->token_tree_node_depths.clear();
}
is_chain_ = true;
}

std::vector<std::vector<int32_t>> block_ids_on_depths = GetBlockIdsOnDepth(sequences);
num_depths_ = block_ids_on_depths.size();
ICHECK_LE(num_depths_, kPagedKVCacheMaxBlockDepth);
Expand Down Expand Up @@ -1559,7 +1581,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
for (int64_t pos = 0; pos < append_length; ++pos) {
q_rope_position_map_host_.push_back(
k_ragged_rope_pos_offset_host_[i] +
(is_chain_ ? pos : cur_token_tree_node_depths_[i][pos]));
(is_chain_ ? pos : sequences[i]->token_tree_node_depths[pos]));

int32_t pos_in_block = block.seq_length - append_length + pos;
if (last_block_length_before_append[i] + pos < block.sink_length) {
Expand Down Expand Up @@ -1649,19 +1671,26 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
}

void CommitAcceptedTokenTreeNodes(const IntTuple& leaf_indices) final {
CHECK_NE(cur_batch_size_, -1)
<< "Cannot commit accepted token tree nodes since BeginForward is not invoked.";
CHECK_EQ(leaf_indices.size(), cur_batch_size_)
<< "The number of input leaf indices does not equal to the current batch size.";
void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids, const IntTuple& leaf_indices) final {
CHECK_EQ(seq_ids.size(), leaf_indices.size())
<< "The given seq_ids and leaf_indices have different size.";
int num_seq_to_commit = seq_ids.size();

for (int i = 0; i < cur_batch_size_; ++i) {
CHECK_GE(leaf_indices[i], 0)
<< "Invalid tree index " << leaf_indices[i] << " which is negative";
CHECK_LT(leaf_indices[i], cur_append_lengths_[i])
std::vector<Sequence*> sequences;
sequences.reserve(num_seq_to_commit);
for (int i = 0; i < num_seq_to_commit; ++i) {
auto it = seq_map_.find(seq_ids[i]);
CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i]
<< "\" cannot be found in KV cache.";
sequences.push_back(&it->second);
CHECK(!it->second.accepted_indices_committed)
<< "The accepted nodes of sequence " << seq_ids[i] << " are already committed.";
CHECK_GE(leaf_indices[i], -1)
<< "Invalid tree index " << leaf_indices[i] << " which is less than -1";
CHECK_LT(leaf_indices[i], static_cast<int64_t>(it->second.token_tree_parent_ptr.size()))
<< "Invalid tree index " << leaf_indices[i]
<< " which is larger than or equals to the append length " << cur_append_lengths_[i]
<< " of the sequence";
<< " which is larger than or equals to the append length "
<< it->second.token_tree_parent_ptr.size() << " of the sequence";
}

if (!is_chain_) {
Expand All @@ -1670,16 +1699,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
commit_copy_dst_pos_in_page_table_host_.clear();
commit_copy_length_indptr_host_.push_back(0);

for (int i = 0; i < cur_batch_size_; ++i) {
for (int i = 0; i < num_seq_to_commit; ++i) {
if (leaf_indices[i] == -1) {
// No node is accepted. All nodes in the token tree need to be popped.
continue;
}

// Get the accepted node path on the token tree.
std::vector<int32_t> path_on_tree;
path_on_tree.reserve(cur_token_tree_node_depths_[i][leaf_indices[i]] + 1);
path_on_tree.reserve(sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1);
int node = leaf_indices[i];
while (node != -1) {
path_on_tree.push_back(node);
node = cur_token_tree_parent_ptr_[cur_append_lengths_indptr_host_[i] + node];
node = sequences[i]->token_tree_parent_ptr[node];
}
ICHECK_EQ(path_on_tree.size(), cur_token_tree_node_depths_[i][leaf_indices[i]] + 1);
ICHECK_EQ(path_on_tree.size(), sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1);
// Get the destination array (range [0, path_length - 1)) of KV cache copy.
std::vector<int32_t> copy_dst_pos_in_seq;
copy_dst_pos_in_seq.resize(path_on_tree.size());
Expand Down Expand Up @@ -1714,14 +1748,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
// Note: Function "PopN" only changes the page table structure and does not
// change the KV cache data. Therefore, we can directly use it, since
// we have already launched all copies.
for (int i = 0; i < cur_batch_size_; ++i) {
for (int i = 0; i < num_seq_to_commit; ++i) {
int64_t length_to_pop =
cur_append_lengths_[i] - cur_token_tree_node_depths_[i][leaf_indices[i]] - 1;
cur_append_lengths_[i] -
(leaf_indices[i] != -1 ? (sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1) : 0);
PopN(cur_seq_ids_[i], length_to_pop);
// Reset the sequence states.
sequences[i]->accepted_indices_committed = true;
sequences[i]->token_tree_parent_ptr.clear();
sequences[i]->token_tree_node_depths.clear();
}

// Reset the token tree.
cur_token_tree_parent_ptr_ = IntTuple{nullptr};
}

NDArray GetQueryPositions() final {
Expand Down Expand Up @@ -1814,70 +1850,77 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
return block_idx;
}

bool ConstructTokenTreeMask(const IntTuple& token_tree_parent_ptr) {
bool ConstructTokenTreeMask(const std::vector<Sequence*>& sequences,
const IntTuple& token_tree_parent_ptr) {
// We check if the token tree deteriorates to a chain,
// because chain cases can have simplified attention work flow.
bool is_chain = true;
cur_token_tree_parent_ptr_ = token_tree_parent_ptr;
cur_token_tree_node_depths_.clear();
cur_token_tree_node_depths_.reserve(cur_batch_size_);

int64_t sum_append_length = 0;
int64_t sum_new_append_length = 0;
// - Construct the mn indptr array, which is the indptr of the mask size of each sequence.
tree_attn_mn_indptr_host_.push_back(0);
for (int64_t append_length : cur_append_lengths_) {
sum_append_length += append_length;
ICHECK_EQ(sequences.size(), cur_batch_size_);
ICHECK_EQ(cur_append_lengths_.size(), cur_batch_size_);
for (int i = 0; i < cur_batch_size_; ++i) {
int64_t append_length = cur_append_lengths_[i];
// Update the token tree parent pointers.
sequences[i]->token_tree_parent_ptr = {
token_tree_parent_ptr->data + sum_new_append_length,
token_tree_parent_ptr->data + sum_new_append_length + cur_append_lengths_[i]};
sum_new_append_length += cur_append_lengths_[i];

CHECK_LE(append_length, kTreeAttnMaxTreeSize)
<< "The tree size is " << append_length << " which exceeds the maximum tree size limit "
<< kTreeAttnMaxTreeSize;
tree_attn_mn_indptr_host_.push_back(tree_attn_mn_indptr_host_.back() +
static_cast<int32_t>(append_length * append_length));
append_length * append_length);
}
CHECK_EQ(token_tree_parent_ptr.size(), sum_append_length)
<< "Invalid token tree size. The sum of \"append_lengths\" is " << sum_append_length
CHECK_EQ(token_tree_parent_ptr.size(), sum_new_append_length)
<< "Invalid token tree size. The sum of \"append_lengths\" is " << sum_new_append_length
<< " while there are " << token_tree_parent_ptr.size()
<< " elements in \"token_tree_parent_ptr\".";

// - Construct the mask of each sequence.
int processed_pos = 0;
for (int i = 0; i < cur_batch_size_; ++i) {
int64_t append_length = cur_append_lengths_[i];
int64_t tree_size = sequences[i]->token_tree_parent_ptr.size();
std::vector<std::vector<int32_t>> mask;
std::vector<int32_t> depth;
mask.reserve(append_length);
depth.reserve(append_length);
for (int64_t n = 0; n < append_length; ++n) {
CHECK_LT(token_tree_parent_ptr[processed_pos], n)
mask.reserve(tree_size);
depth.reserve(tree_size);
sequences[i]->is_chain = true;
sequences[i]->accepted_indices_committed = false;
for (int64_t n = 0; n < tree_size; ++n) {
CHECK_LT(sequences[i]->token_tree_parent_ptr[n], n)
<< "Invalid token tree. The parent of node " << n << " in tree " << i << " is "
<< token_tree_parent_ptr[processed_pos] << ", which is not smaller than " << n;
CHECK_GE(token_tree_parent_ptr[processed_pos], -1)
<< sequences[i]->token_tree_parent_ptr[n] << ", which is not smaller than " << n;
CHECK_GE(sequences[i]->token_tree_parent_ptr[n], -1)
<< "Invalid token tree. The parent of node " << n << " in tree " << i << " is "
<< token_tree_parent_ptr[processed_pos];
if (token_tree_parent_ptr[processed_pos] != n - 1) {
<< sequences[i]->token_tree_parent_ptr[n];
if (sequences[i]->token_tree_parent_ptr[n] != n - 1) {
// The parent of the current node is not the last node.
// Therefore the tree is not a chain.
sequences[i]->is_chain = false;
is_chain = false;
}

std::vector<int32_t> single_pos_mask;
if (token_tree_parent_ptr[processed_pos] != -1) {
if (sequences[i]->token_tree_parent_ptr[n] != -1) {
// The current node has a parent in the token tree.
single_pos_mask = {mask[token_tree_parent_ptr[processed_pos]].begin(),
mask[token_tree_parent_ptr[processed_pos]].end()};
depth.push_back(depth[token_tree_parent_ptr[processed_pos]] + 1);
single_pos_mask = {mask[sequences[i]->token_tree_parent_ptr[n]].begin(),
mask[sequences[i]->token_tree_parent_ptr[n]].end()};
depth.push_back(depth[sequences[i]->token_tree_parent_ptr[n]] + 1);
} else {
// The current node is root in the token tree.
single_pos_mask.resize(append_length, /*value=*/0);
single_pos_mask.resize(tree_size, /*value=*/0);
depth.push_back(0);
}
single_pos_mask[n] = 1;
mask.push_back(single_pos_mask);
for (int32_t mask_val : single_pos_mask) {
tree_attn_mask_host_.push_back(mask_val);
}

++processed_pos;
}
cur_token_tree_node_depths_.push_back(std::move(depth));
sequences[i]->token_tree_node_depths = std::move(depth);
}

return is_chain;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,10 @@ def apply_attention(
fend_forward(kv_cache)

if accepted_leaf_indices is not None:
fcommit_accepted_token_tree_nodes(kv_cache, ShapeTuple(accepted_leaf_indices))
seq_ids = [seq_id for seq_id, _ in batch]
fcommit_accepted_token_tree_nodes(
kv_cache, ShapeTuple(seq_ids), ShapeTuple(accepted_leaf_indices)
)
for i, (accepted_leaf_idx, (seq_id, append_length)) in enumerate(
zip(accepted_leaf_indices, batch)
):
Expand All @@ -449,7 +452,7 @@ def apply_attention(
node = token_tree_parent_ptr_list[i][node]
offset = cached_k[seq_id].shape[1] - append_length
length_to_pop = append_length - len(tree_path)
assert 0 <= length_to_pop < append_length
assert 0 <= length_to_pop <= append_length
for dst_pos, src_pos in enumerate(reversed(tree_path)):
if dst_pos == src_pos:
continue
Expand Down Expand Up @@ -773,7 +776,7 @@ def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config):
[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8], # chain of length 10
[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], # chain of length 14
],
accepted_leaf_indices=[2, 6, 6, 4],
accepted_leaf_indices=[2, 6, -1, 4],
)
# Do 5 rounds of decode.
for _ in range(5):
Expand Down