Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/flashinfer
Submodule flashinfer updated 138 files
48 changes: 33 additions & 15 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ namespace relax_vm {
* \brief The maximum allowed block depth (a.k.a. number of common
* prefixes) in paged KV cache.
*/
constexpr const int kPagedKVCacheMaxBlockDepth = 5;
constexpr const int kPagedKVCacheMaxBlockDepth = 2;
/*! \brief The maximum tree size of a single sequence in tree attention. */
constexpr const int kTreeAttnMaxTreeSize = 256;
/*! \brief The 8MB workspace size for attention auxiliary data. */
constexpr const int kAttnWorkspaceByte = 8 * 1024 * 1024;
constexpr const int kAttnWorkspaceByte = 128 * 1024 * 1024;
/*! \brief The id of the temporary logical page, which is useful for sliding window. */
constexpr const int kPagedKVCacheTempPageId = -1;

Expand Down Expand Up @@ -119,6 +119,9 @@ struct Block {
void Reset() {
page_ids.clear();
seq_length = 0;
start_pos = 0;
sink_length = 0;
sliding_window_offset = 0;
parent_idx = -1;
external_ref_cnt = 0;
}
Expand Down Expand Up @@ -169,11 +172,9 @@ struct Sequence {
this->last_block_idx = last_block_idx;
int32_t block_ptr = last_block_idx;
// Go through each block in the sequence, sum up the length.
int depth = 0;
while (true) {
const Block& block = global_block_pool->at(block_ptr);
this->seq_length += block.seq_length;
++depth;
if (block.parent_idx == -1) {
break;
}
Expand Down Expand Up @@ -1078,17 +1079,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
dtype_aux_, preferred_host_device);

for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
temp_attn_workspace_.push_back(
NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
if (NeedKernelBeginForward()) {
temp_attn_workspace_.push_back(
NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
}
qo_indptr_on_depths_view_.push_back(NDArray());
page_indptr_on_depths_view_.push_back(NDArray());
page_indices_on_depths_view_.push_back(NDArray());
length_info_on_depths_view_.push_back(NDArray());
k_rope_pos_offset_view_.push_back(NDArray());
}
// Additional workspace for the "prefill with ragged kv" kernel.
temp_attn_workspace_.push_back(
NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
if (NeedKernelBeginForward()) {
temp_attn_workspace_.push_back(
NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
}

temp_attn_q_device_ =
NDArray::Empty({prefill_chunk_size_, num_qo_heads, head_dim}, dtype, device);
Expand Down Expand Up @@ -1531,6 +1536,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}

append_before_attn_ = !support_sliding_window_ && num_depths_ == 1 && use_decode_kernel_[0];
if (NeedKernelBeginForward() && num_qo_heads_ / num_kv_heads_ >= 4) {
// When GQA group size is at least 4 and FlashInfer is enabled,
// we always use prefill kernel for better performance.
std::fill(use_decode_kernel_.begin(), use_decode_kernel_.end(), /*value=*/false);
}

if (append_before_attn_) {
// Right now we use different kernels when depth is 1 or not 1.
// For the case where maximum depth is 1, we create the auxiliary
Expand Down Expand Up @@ -2196,11 +2207,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
use_decode_kernel};
}

/*! \brief Check whether BeginForward for kernels is needed. */
bool NeedKernelBeginForward() {
return f_attention_prefill_begin_forward_.defined() &&
f_attention_decode_begin_forward_.defined() &&
f_attention_prefill_ragged_begin_forward_.defined();
}

/*! \brief Invoke the "begin forward" functions of underlying kernels. */
void KernelBeginForward() {
if (!f_attention_prefill_begin_forward_.defined() ||
!f_attention_decode_begin_forward_.defined() ||
!f_attention_prefill_ragged_begin_forward_.defined()) {
if (!NeedKernelBeginForward()) {
return;
}

Expand All @@ -2214,8 +2230,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
} else {
f_attention_prefill_ragged_begin_forward_.value()(
temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_,
num_qo_heads_, num_kv_heads_, head_dim_, copy_stream_);
temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(),
cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_,
num_kv_heads_, head_dim_, copy_stream_);
if (support_sliding_window_) {
return;
}
Expand All @@ -2232,8 +2249,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
} else {
f_attention_prefill_begin_forward_.value()(
/*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(),
length_info_on_depths_view_[d]->shape[0], num_qo_heads_, num_kv_heads_, head_dim_,
copy_stream_);
page_indptr_on_depths_host_[d].as_ndarray(),
static_cast<int>(page_indptr_on_depths_host_[d].size()) - 1, num_qo_heads_,
num_kv_heads_, head_dim_, page_size_, copy_stream_);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from tvm.script import tir as T

reserved_nseq = 32
maximum_total_seq_length = 1024
maximum_total_seq_length = 2048
prefill_chunk_size = 512
page_size = 16
num_layers = 4
Expand Down Expand Up @@ -249,6 +249,7 @@ def copy_single_page(
):
for t in T.thread_binding(tx, thread="threadIdx.x"):
with T.block("copy"):
T.where(b * tx + t < copy_length * num_heads * head_dim)
vh = T.axis.spatial(
num_heads,
T.Cast("int32", (b * tx + t) // (copy_length * head_dim)),
Expand Down Expand Up @@ -662,6 +663,16 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode):
cached_v.pop(i)
verify_cached_kv(kv_cache, seq_ids=list(range(i)), expected_k=cached_k, expected_v=cached_v)

# Test fork after page recycle
apply_attention(kv_cache, rope_mode, [(0, 7), (1, 24)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((2, 1, -1), 10)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((3, 0, -1), 20)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [(2, 1), (3, 1)], cached_k, cached_v)

apply_attention(kv_cache, rope_mode, [(10, 7), (11, 24)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((12, 11, -1), 200)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k, cached_v)


@pytest.mark.skip(reason="Require FlashInfer enabled")
def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from tvm.target import Target

reserved_nseq = 32
maximum_total_seq_length = 1024
maximum_total_seq_length = 2048
prefill_chunk_size = 512
page_size = 16
num_layers = 4
Expand Down Expand Up @@ -615,6 +615,16 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):

assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences"

# Test fork after page recycle
apply_attention(kv_cache, rope_mode, [(0, 7), (1, 24)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((2, 1, -1), 10)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((3, 0, -1), 20)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [(2, 1), (3, 1)], cached_k, cached_v)

apply_attention(kv_cache, rope_mode, [(10, 7), (11, 24)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [((12, 11, -1), 200)], cached_k, cached_v)
apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k, cached_v)


@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
Expand Down Expand Up @@ -2547,6 +2557,7 @@ def copy_single_page(
):
for t in T.thread_binding(tx, thread="threadIdx.x"):
with T.block("copy"):
T.where(b * tx + t < copy_length * num_heads * head_dim)
vh = T.axis.spatial(
num_heads,
T.Cast("int32", (b * tx + t) // (copy_length * head_dim)),
Expand Down