diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer index 0dd801d2027a..1e379898a589 160000 --- a/3rdparty/flashinfer +++ b/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit 0dd801d2027af89f3603cbbf68a76e9503bb2f57 +Subproject commit 1e379898a589cdd4ff18a4621fcbe18d63501545 diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 8809a1b0729e..78a7ed1dd1f8 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -57,8 +57,10 @@ namespace relax_vm { constexpr const int kPagedKVCacheMaxBlockDepth = 2; /*! \brief The maximum tree size of a single sequence in tree attention. */ constexpr const int kTreeAttnMaxTreeSize = 256; -/*! \brief The 8MB workspace size for attention auxiliary data. */ -constexpr const int kAttnWorkspaceByte = 128 * 1024 * 1024; +/*! \brief The 1MB workspace size for integer attention auxiliary data. */ +constexpr const int kIntAttnWorkspaceByte = 1 * 1024 * 1024; +/*! \brief The 128MB workspace size for floating-point attention auxiliary data. */ +constexpr const int kFloatAttnWorkspaceByte = 768 * 1024 * 1024; /*! \brief The id of the temporary logical page, which is useful for sliding window. */ constexpr const int kPagedKVCacheTempPageId = -1; @@ -915,7 +917,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { NDArray temp_attn_output_device_; NDArray temp_attn_scores_device_; NDArray merged_attn_scores_device_; - std::vector temp_attn_workspace_; + std::vector temp_int_attn_workspace_; + NDArray temp_float_attn_workspace_; //------------------------------------------- // Below are the auxiliary data structure on CPU. @@ -1089,8 +1092,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) { if (NeedKernelBeginForward()) { - temp_attn_workspace_.push_back( - NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device)); + temp_int_attn_workspace_.push_back( + NDArray::Empty({kIntAttnWorkspaceByte / 4}, DataType::Float(32), device)); } qo_indptr_on_depths_view_.push_back(NDArray()); page_indptr_on_depths_view_.push_back(NDArray()); @@ -1103,8 +1106,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } // Additional workspace for the "prefill with ragged kv" kernel. if (NeedKernelBeginForward()) { - temp_attn_workspace_.push_back( - NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device)); + temp_int_attn_workspace_.push_back( + NDArray::Empty({kIntAttnWorkspaceByte / 4}, DataType::Float(32), device)); + temp_float_attn_workspace_ = + NDArray::Empty({kFloatAttnWorkspaceByte / 4}, DataType::Float(32), device); } temp_attn_q_device_ = @@ -2324,7 +2329,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (!append_before_attn_) { if (is_chain_on_depths_[0]) { f_attention_prefill_ragged_begin_forward_.value()( - temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(), + temp_float_attn_workspace_, temp_int_attn_workspace_[0], + cur_append_lengths_indptr_host_.as_ndarray(), cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_, num_kv_heads_, head_dim_, copy_stream_); } @@ -2336,14 +2342,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { CHECK(!support_sliding_window_) << "Kernel BeginForward doesn't support sliding window."; if (use_decode_kernel_[d]) { f_attention_decode_begin_forward_.value()( - d, temp_attn_workspace_[d + 1], page_indptr_on_depths_host_[d].as_ndarray(), + d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1], + page_indptr_on_depths_host_[d].as_ndarray(), last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, num_kv_heads_, head_dim_, page_size_, /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_); } else { f_attention_prefill_begin_forward_.value()( - /*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(), - page_indptr_on_depths_host_[d].as_ndarray(), + /*depth=*/d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1], + qo_indptr_on_depths_host_[d].as_ndarray(), page_indptr_on_depths_host_[d].as_ndarray(), static_cast(page_indptr_on_depths_host_[d].size()) - 1, num_qo_heads_, num_kv_heads_, head_dim_, page_size_, copy_stream_); } diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index 2252cb8d9c09..4c25383178ac 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -324,7 +324,7 @@ def set_global_func(): ) fattention_merge_state = tvm.get_global_func("flashinfer.merge_state_in_place") - target = tvm.target.Target("nvidia/geforce-rtx-3090-ti") + target = tvm.target.Target.from_device(device) builts = [] for tir_func in [ kv_cache_transpose_append, diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 5ab96caa9bc0..82f85f4b17fa 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -111,7 +111,7 @@ def set_global_func(head_dim, dtype): fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty") fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv") - target = tvm.target.Target("cuda") + target = tvm.target.Target.from_device(device) builts = [] for tir_func in [ _kv_cache_transpose_append(num_kv_heads, head_dim, dtype),