From e6987dff37fcb98683b8d8b8b8fccf68961ed771 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 8 Jul 2024 18:55:14 -0400 Subject: [PATCH] [3rdparty] Bump FlashInfer This PR bumps FlashInfer and updates PagedKVCache accordingly for performance improvement. Some notes on this bump: * When the Grouped-Query Attention group size is at least 4 and FlashInfer is enabled, we use the prefill attn kernel for better performance. * We enlarge the temporary workspace for FlashInfer use accordingly, as FlashInfer in the current version may consume much larger workspace. We turn off the workspace when FlashInfer is not enabled. * We reduce the max block depth to be 2, in observation of the limited help of cascade inference when batch size is not large and the prompt reuse is low. --- 3rdparty/flashinfer | 2 +- src/runtime/relax_vm/paged_kv_cache.cc | 48 +++++++++++++------ ...tin_paged_attention_kv_cache_flashinfer.py | 13 ++++- ...me_builtin_paged_attention_kv_cache_tir.py | 13 ++++- 4 files changed, 58 insertions(+), 18 deletions(-) diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer index 7e9cc7ff42ca..0dd801d2027a 160000 --- a/3rdparty/flashinfer +++ b/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit 7e9cc7ff42ca283c317061a877305d09a395fad2 +Subproject commit 0dd801d2027af89f3603cbbf68a76e9503bb2f57 diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 2fb8a72f4279..5aa1411ec154 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -54,11 +54,11 @@ namespace relax_vm { * \brief The maximum allowed block depth (a.k.a. number of common * prefixes) in paged KV cache. */ -constexpr const int kPagedKVCacheMaxBlockDepth = 5; +constexpr const int kPagedKVCacheMaxBlockDepth = 2; /*! \brief The maximum tree size of a single sequence in tree attention. */ constexpr const int kTreeAttnMaxTreeSize = 256; /*! \brief The 8MB workspace size for attention auxiliary data. */ -constexpr const int kAttnWorkspaceByte = 8 * 1024 * 1024; +constexpr const int kAttnWorkspaceByte = 128 * 1024 * 1024; /*! \brief The id of the temporary logical page, which is useful for sliding window. */ constexpr const int kPagedKVCacheTempPageId = -1; @@ -119,6 +119,9 @@ struct Block { void Reset() { page_ids.clear(); seq_length = 0; + start_pos = 0; + sink_length = 0; + sliding_window_offset = 0; parent_idx = -1; external_ref_cnt = 0; } @@ -169,11 +172,9 @@ struct Sequence { this->last_block_idx = last_block_idx; int32_t block_ptr = last_block_idx; // Go through each block in the sequence, sum up the length. - int depth = 0; while (true) { const Block& block = global_block_pool->at(block_ptr); this->seq_length += block.seq_length; - ++depth; if (block.parent_idx == -1) { break; } @@ -1078,8 +1079,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { dtype_aux_, preferred_host_device); for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { - temp_attn_workspace_.push_back( - NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device)); + if (NeedKernelBeginForward()) { + temp_attn_workspace_.push_back( + NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device)); + } qo_indptr_on_depths_view_.push_back(NDArray()); page_indptr_on_depths_view_.push_back(NDArray()); page_indices_on_depths_view_.push_back(NDArray()); @@ -1087,8 +1090,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { k_rope_pos_offset_view_.push_back(NDArray()); } // Additional workspace for the "prefill with ragged kv" kernel. - temp_attn_workspace_.push_back( - NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device)); + if (NeedKernelBeginForward()) { + temp_attn_workspace_.push_back( + NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device)); + } temp_attn_q_device_ = NDArray::Empty({prefill_chunk_size_, num_qo_heads, head_dim}, dtype, device); @@ -1531,6 +1536,12 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } append_before_attn_ = !support_sliding_window_ && num_depths_ == 1 && use_decode_kernel_[0]; + if (NeedKernelBeginForward() && num_qo_heads_ / num_kv_heads_ >= 4) { + // When GQA group size is at least 4 and FlashInfer is enabled, + // we always use prefill kernel for better performance. + std::fill(use_decode_kernel_.begin(), use_decode_kernel_.end(), /*value=*/false); + } + if (append_before_attn_) { // Right now we use different kernels when depth is 1 or not 1. // For the case where maximum depth is 1, we create the auxiliary @@ -2196,11 +2207,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { use_decode_kernel}; } + /*! \brief Check whether BeginForward for kernels is needed. */ + bool NeedKernelBeginForward() { + return f_attention_prefill_begin_forward_.defined() && + f_attention_decode_begin_forward_.defined() && + f_attention_prefill_ragged_begin_forward_.defined(); + } + /*! \brief Invoke the "begin forward" functions of underlying kernels. */ void KernelBeginForward() { - if (!f_attention_prefill_begin_forward_.defined() || - !f_attention_decode_begin_forward_.defined() || - !f_attention_prefill_ragged_begin_forward_.defined()) { + if (!NeedKernelBeginForward()) { return; } @@ -2214,8 +2230,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } else { f_attention_prefill_ragged_begin_forward_.value()( - temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, - num_qo_heads_, num_kv_heads_, head_dim_, copy_stream_); + temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(), + cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_, + num_kv_heads_, head_dim_, copy_stream_); if (support_sliding_window_) { return; } @@ -2232,8 +2249,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } else { f_attention_prefill_begin_forward_.value()( /*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(), - length_info_on_depths_view_[d]->shape[0], num_qo_heads_, num_kv_heads_, head_dim_, - copy_stream_); + page_indptr_on_depths_host_[d].as_ndarray(), + static_cast(page_indptr_on_depths_host_[d].size()) - 1, num_qo_heads_, + num_kv_heads_, head_dim_, page_size_, copy_stream_); } } } diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index bade04a7d753..cab10f84cddf 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -29,7 +29,7 @@ from tvm.script import tir as T reserved_nseq = 32 -maximum_total_seq_length = 1024 +maximum_total_seq_length = 2048 prefill_chunk_size = 512 page_size = 16 num_layers = 4 @@ -249,6 +249,7 @@ def copy_single_page( ): for t in T.thread_binding(tx, thread="threadIdx.x"): with T.block("copy"): + T.where(b * tx + t < copy_length * num_heads * head_dim) vh = T.axis.spatial( num_heads, T.Cast("int32", (b * tx + t) // (copy_length * head_dim)), @@ -662,6 +663,16 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode): cached_v.pop(i) verify_cached_kv(kv_cache, seq_ids=list(range(i)), expected_k=cached_k, expected_v=cached_v) + # Test fork after page recycle + apply_attention(kv_cache, rope_mode, [(0, 7), (1, 24)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((2, 1, -1), 10)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((3, 0, -1), 20)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [(2, 1), (3, 1)], cached_k, cached_v) + + apply_attention(kv_cache, rope_mode, [(10, 7), (11, 24)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((12, 11, -1), 200)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k, cached_v) + @pytest.mark.skip(reason="Require FlashInfer enabled") def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode): diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 9192bb901ff0..3c85a13e4cfc 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -33,7 +33,7 @@ from tvm.target import Target reserved_nseq = 32 -maximum_total_seq_length = 1024 +maximum_total_seq_length = 2048 prefill_chunk_size = 512 page_size = 16 num_layers = 4 @@ -615,6 +615,16 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config): assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences" + # Test fork after page recycle + apply_attention(kv_cache, rope_mode, [(0, 7), (1, 24)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((2, 1, -1), 10)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((3, 0, -1), 20)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [(2, 1), (3, 1)], cached_k, cached_v) + + apply_attention(kv_cache, rope_mode, [(10, 7), (11, 24)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [((12, 11, -1), 200)], cached_k, cached_v) + apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k, cached_v) + @tvm.testing.requires_gpu @tvm.testing.requires_cuda @@ -2547,6 +2557,7 @@ def copy_single_page( ): for t in T.thread_binding(tx, thread="threadIdx.x"): with T.block("copy"): + T.where(b * tx + t < copy_length * num_heads * head_dim) vh = T.axis.spatial( num_heads, T.Cast("int32", (b * tx + t) // (copy_length * head_dim)),