Skip to content

Commit d695af4

Browse files
committed
[3rdparty] Bump FlashInfer
This PR bumps FlashInfer and updates PagedKVCache accordingly for performance improvement. Some notes on this bump: * When the Grouped-Query Attention group size is at least 4 and FlashInfer is enabled, we use the prefill attn kernel for better performance. * We enlarge the temporary workspace for FlashInfer use accordingly, as FlashInfer in the current version may consume much larger workspace. We turn off the workspace when FlashInfer is not enabled. * We reduce the max block depth to be 1, in observation of the limited help of cascade inference when batch size is not large and the prompt reuse is low.
1 parent 3a02309 commit d695af4

File tree

4 files changed

+58
-18
lines changed

4 files changed

+58
-18
lines changed

3rdparty/flashinfer

Submodule flashinfer updated 138 files

src/runtime/relax_vm/paged_kv_cache.cc

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ namespace relax_vm {
5454
* \brief The maximum allowed block depth (a.k.a. number of common
5555
* prefixes) in paged KV cache.
5656
*/
57-
constexpr const int kPagedKVCacheMaxBlockDepth = 5;
57+
constexpr const int kPagedKVCacheMaxBlockDepth = 1;
5858
/*! \brief The maximum tree size of a single sequence in tree attention. */
5959
constexpr const int kTreeAttnMaxTreeSize = 256;
6060
/*! \brief The 8MB workspace size for attention auxiliary data. */
61-
constexpr const int kAttnWorkspaceByte = 8 * 1024 * 1024;
61+
constexpr const int kAttnWorkspaceByte = 128 * 1024 * 1024;
6262
/*! \brief The id of the temporary logical page, which is useful for sliding window. */
6363
constexpr const int kPagedKVCacheTempPageId = -1;
6464

@@ -119,6 +119,9 @@ struct Block {
119119
void Reset() {
120120
page_ids.clear();
121121
seq_length = 0;
122+
start_pos = 0;
123+
sink_length = 0;
124+
sliding_window_offset = 0;
122125
parent_idx = -1;
123126
external_ref_cnt = 0;
124127
}
@@ -169,11 +172,9 @@ struct Sequence {
169172
this->last_block_idx = last_block_idx;
170173
int32_t block_ptr = last_block_idx;
171174
// Go through each block in the sequence, sum up the length.
172-
int depth = 0;
173175
while (true) {
174176
const Block& block = global_block_pool->at(block_ptr);
175177
this->seq_length += block.seq_length;
176-
++depth;
177178
if (block.parent_idx == -1) {
178179
break;
179180
}
@@ -1078,17 +1079,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
10781079
dtype_aux_, preferred_host_device);
10791080

10801081
for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
1081-
temp_attn_workspace_.push_back(
1082-
NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
1082+
if (NeedKernelBeginForward()) {
1083+
temp_attn_workspace_.push_back(
1084+
NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
1085+
}
10831086
qo_indptr_on_depths_view_.push_back(NDArray());
10841087
page_indptr_on_depths_view_.push_back(NDArray());
10851088
page_indices_on_depths_view_.push_back(NDArray());
10861089
length_info_on_depths_view_.push_back(NDArray());
10871090
k_rope_pos_offset_view_.push_back(NDArray());
10881091
}
10891092
// Additional workspace for the "prefill with ragged kv" kernel.
1090-
temp_attn_workspace_.push_back(
1091-
NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
1093+
if (NeedKernelBeginForward()) {
1094+
temp_attn_workspace_.push_back(
1095+
NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
1096+
}
10921097

10931098
temp_attn_q_device_ =
10941099
NDArray::Empty({prefill_chunk_size_, num_qo_heads, head_dim}, dtype, device);
@@ -1531,6 +1536,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
15311536
}
15321537

15331538
append_before_attn_ = !support_sliding_window_ && num_depths_ == 1 && use_decode_kernel_[0];
1539+
if (NeedKernelBeginForward() && num_qo_heads_ / num_kv_heads_ >= 4) {
1540+
// When GQA group size is at least 4 and FlashInfer is enabled,
1541+
// we always use prefill kernel for better performance.
1542+
std::fill(use_decode_kernel_.begin(), use_decode_kernel_.end(), /*value=*/false);
1543+
}
1544+
15341545
if (append_before_attn_) {
15351546
// Right now we use different kernels when depth is 1 or not 1.
15361547
// For the case where maximum depth is 1, we create the auxiliary
@@ -2196,11 +2207,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
21962207
use_decode_kernel};
21972208
}
21982209

2210+
/*! \brief Check whether BeginForward for kernels is needed. */
2211+
bool NeedKernelBeginForward() {
2212+
return f_attention_prefill_begin_forward_.defined() &&
2213+
f_attention_decode_begin_forward_.defined() &&
2214+
f_attention_prefill_ragged_begin_forward_.defined();
2215+
}
2216+
21992217
/*! \brief Invoke the "begin forward" functions of underlying kernels. */
22002218
void KernelBeginForward() {
2201-
if (!f_attention_prefill_begin_forward_.defined() ||
2202-
!f_attention_decode_begin_forward_.defined() ||
2203-
!f_attention_prefill_ragged_begin_forward_.defined()) {
2219+
if (!NeedKernelBeginForward()) {
22042220
return;
22052221
}
22062222

@@ -2214,8 +2230,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
22142230
}
22152231
} else {
22162232
f_attention_prefill_ragged_begin_forward_.value()(
2217-
temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_,
2218-
num_qo_heads_, num_kv_heads_, head_dim_, copy_stream_);
2233+
temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(),
2234+
cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_,
2235+
num_kv_heads_, head_dim_, copy_stream_);
22192236
if (support_sliding_window_) {
22202237
return;
22212238
}
@@ -2232,8 +2249,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
22322249
} else {
22332250
f_attention_prefill_begin_forward_.value()(
22342251
/*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(),
2235-
length_info_on_depths_view_[d]->shape[0], num_qo_heads_, num_kv_heads_, head_dim_,
2236-
copy_stream_);
2252+
page_indptr_on_depths_host_[d].as_ndarray(),
2253+
static_cast<int>(page_indptr_on_depths_host_[d].size()) - 1, num_qo_heads_,
2254+
num_kv_heads_, head_dim_, page_size_, copy_stream_);
22372255
}
22382256
}
22392257
}

tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from tvm.script import tir as T
3030

3131
reserved_nseq = 32
32-
maximum_total_seq_length = 1024
32+
maximum_total_seq_length = 2048
3333
prefill_chunk_size = 512
3434
page_size = 16
3535
num_layers = 4
@@ -249,6 +249,7 @@ def copy_single_page(
249249
):
250250
for t in T.thread_binding(tx, thread="threadIdx.x"):
251251
with T.block("copy"):
252+
T.where(b * tx + t < copy_length * num_heads * head_dim)
252253
vh = T.axis.spatial(
253254
num_heads,
254255
T.Cast("int32", (b * tx + t) // (copy_length * head_dim)),
@@ -662,6 +663,16 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode):
662663
cached_v.pop(i)
663664
verify_cached_kv(kv_cache, seq_ids=list(range(i)), expected_k=cached_k, expected_v=cached_v)
664665

666+
# Test fork after page recycle
667+
apply_attention(kv_cache, rope_mode, [(0, 7), (1, 24)], cached_k, cached_v)
668+
apply_attention(kv_cache, rope_mode, [((2, 1, -1), 10)], cached_k, cached_v)
669+
apply_attention(kv_cache, rope_mode, [((3, 0, -1), 20)], cached_k, cached_v)
670+
apply_attention(kv_cache, rope_mode, [(2, 1), (3, 1)], cached_k, cached_v)
671+
672+
apply_attention(kv_cache, rope_mode, [(10, 7), (11, 24)], cached_k, cached_v)
673+
apply_attention(kv_cache, rope_mode, [((12, 11, -1), 200)], cached_k, cached_v)
674+
apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k, cached_v)
675+
665676

666677
@pytest.mark.skip(reason="Require FlashInfer enabled")
667678
def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode):

tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from tvm.target import Target
3434

3535
reserved_nseq = 32
36-
maximum_total_seq_length = 1024
36+
maximum_total_seq_length = 2048
3737
prefill_chunk_size = 512
3838
page_size = 16
3939
num_layers = 4
@@ -615,6 +615,16 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config):
615615

616616
assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences"
617617

618+
# Test fork after page recycle
619+
apply_attention(kv_cache, rope_mode, [(0, 7), (1, 24)], cached_k, cached_v)
620+
apply_attention(kv_cache, rope_mode, [((2, 1, -1), 10)], cached_k, cached_v)
621+
apply_attention(kv_cache, rope_mode, [((3, 0, -1), 20)], cached_k, cached_v)
622+
apply_attention(kv_cache, rope_mode, [(2, 1), (3, 1)], cached_k, cached_v)
623+
624+
apply_attention(kv_cache, rope_mode, [(10, 7), (11, 24)], cached_k, cached_v)
625+
apply_attention(kv_cache, rope_mode, [((12, 11, -1), 200)], cached_k, cached_v)
626+
apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k, cached_v)
627+
618628

619629
@tvm.testing.requires_gpu
620630
@tvm.testing.requires_cuda
@@ -2547,6 +2557,7 @@ def copy_single_page(
25472557
):
25482558
for t in T.thread_binding(tx, thread="threadIdx.x"):
25492559
with T.block("copy"):
2560+
T.where(b * tx + t < copy_length * num_heads * head_dim)
25502561
vh = T.axis.spatial(
25512562
num_heads,
25522563
T.Cast("int32", (b * tx + t) // (copy_length * head_dim)),

0 commit comments

Comments
 (0)