Skip to content

Commit 062bfd9

Browse files
committed
[3rdparty] Bump FlashInfer for tmp workspace reduction
This PR bumps FlashInfer to reduce the size of required temporary workspace.
1 parent 931efc7 commit 062bfd9

File tree

4 files changed

+27
-20
lines changed

4 files changed

+27
-20
lines changed

3rdparty/flashinfer

Submodule flashinfer updated 72 files

src/runtime/relax_vm/paged_kv_cache.cc

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,10 @@ namespace relax_vm {
5757
constexpr const int kPagedKVCacheMaxBlockDepth = 2;
5858
/*! \brief The maximum tree size of a single sequence in tree attention. */
5959
constexpr const int kTreeAttnMaxTreeSize = 256;
60-
/*! \brief The 8MB workspace size for attention auxiliary data. */
61-
constexpr const int kAttnWorkspaceByte = 128 * 1024 * 1024;
60+
/*! \brief The 1MB workspace size for integer attention auxiliary data. */
61+
constexpr const int kIntAttnWorkspaceByte = 1 * 1024 * 1024;
62+
/*! \brief The 128MB workspace size for floating-point attention auxiliary data. */
63+
constexpr const int kFloatAttnWorkspaceByte = 128 * 1024 * 1024;
6264
/*! \brief The id of the temporary logical page, which is useful for sliding window. */
6365
constexpr const int kPagedKVCacheTempPageId = -1;
6466

@@ -915,7 +917,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
915917
NDArray temp_attn_output_device_;
916918
NDArray temp_attn_scores_device_;
917919
NDArray merged_attn_scores_device_;
918-
std::vector<NDArray> temp_attn_workspace_;
920+
std::vector<NDArray> temp_int_attn_workspace_;
921+
NDArray temp_float_attn_workspace_;
919922

920923
//-------------------------------------------
921924
// Below are the auxiliary data structure on CPU.
@@ -1089,8 +1092,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
10891092

10901093
for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
10911094
if (NeedKernelBeginForward()) {
1092-
temp_attn_workspace_.push_back(
1093-
NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
1095+
temp_int_attn_workspace_.push_back(
1096+
NDArray::Empty({kIntAttnWorkspaceByte / 4}, DataType::Float(32), device));
10941097
}
10951098
qo_indptr_on_depths_view_.push_back(NDArray());
10961099
page_indptr_on_depths_view_.push_back(NDArray());
@@ -1103,8 +1106,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
11031106
}
11041107
// Additional workspace for the "prefill with ragged kv" kernel.
11051108
if (NeedKernelBeginForward()) {
1106-
temp_attn_workspace_.push_back(
1107-
NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
1109+
temp_int_attn_workspace_.push_back(
1110+
NDArray::Empty({kIntAttnWorkspaceByte / 4}, DataType::Float(32), device));
1111+
temp_float_attn_workspace_ =
1112+
NDArray::Empty({kFloatAttnWorkspaceByte / 4}, DataType::Float(32), device);
11081113
}
11091114

11101115
temp_attn_q_device_ =
@@ -2324,7 +2329,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
23242329
if (!append_before_attn_) {
23252330
if (is_chain_on_depths_[0]) {
23262331
f_attention_prefill_ragged_begin_forward_.value()(
2327-
temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(),
2332+
temp_float_attn_workspace_, temp_int_attn_workspace_[0],
2333+
cur_append_lengths_indptr_host_.as_ndarray(),
23282334
cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_,
23292335
num_kv_heads_, head_dim_, copy_stream_);
23302336
}
@@ -2336,14 +2342,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
23362342
CHECK(!support_sliding_window_) << "Kernel BeginForward doesn't support sliding window.";
23372343
if (use_decode_kernel_[d]) {
23382344
f_attention_decode_begin_forward_.value()(
2339-
d, temp_attn_workspace_[d + 1], page_indptr_on_depths_host_[d].as_ndarray(),
2345+
d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1],
2346+
page_indptr_on_depths_host_[d].as_ndarray(),
23402347
last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, num_kv_heads_, head_dim_,
23412348
page_size_,
23422349
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
23432350
} else {
23442351
f_attention_prefill_begin_forward_.value()(
2345-
/*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(),
2346-
page_indptr_on_depths_host_[d].as_ndarray(),
2352+
/*depth=*/d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1],
2353+
qo_indptr_on_depths_host_[d].as_ndarray(), page_indptr_on_depths_host_[d].as_ndarray(),
23472354
static_cast<int>(page_indptr_on_depths_host_[d].size()) - 1, num_qo_heads_,
23482355
num_kv_heads_, head_dim_, page_size_, copy_stream_);
23492356
}

tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,19 @@ def kv_cache_transpose_append(
9696
pages[position_map[vgpos] // page_size, 0, vh, position_map[vgpos] % page_size, vf]
9797
)
9898
position: T.int64 = T.Cast("int64", position_map[vgpos])
99-
pages[
100-
T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vf
101-
] = k_data[vgpos, vh, vf]
99+
pages[T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vf] = (
100+
k_data[vgpos, vh, vf]
101+
)
102102
with T.block("v_transpose_append"):
103103
vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
104104
T.reads(position_map[vgpos], k_data[vgpos, vh, vf])
105105
T.writes(
106106
pages[position_map[vgpos] // page_size, 1, vh, position_map[vgpos] % page_size, vf]
107107
)
108108
position: T.int64 = T.Cast("int64", position_map[vgpos])
109-
pages[
110-
T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vf
111-
] = v_data[vgpos, vh, vf]
109+
pages[T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vf] = (
110+
v_data[vgpos, vh, vf]
111+
)
112112

113113

114114
def llama_rope_with_position_map( # pylint: disable=too-many-arguments
@@ -324,7 +324,7 @@ def set_global_func():
324324
)
325325
fattention_merge_state = tvm.get_global_func("flashinfer.merge_state_in_place")
326326

327-
target = tvm.target.Target("nvidia/geforce-rtx-3090-ti")
327+
target = tvm.target.Target.from_device(device)
328328
builts = []
329329
for tir_func in [
330330
kv_cache_transpose_append,

tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def set_global_func(head_dim, dtype):
111111
fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty")
112112
fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv")
113113

114-
target = tvm.target.Target("cuda")
114+
target = tvm.target.Target.from_device(device)
115115
builts = []
116116
for tir_func in [
117117
_kv_cache_transpose_append(num_kv_heads, head_dim, dtype),

0 commit comments

Comments
 (0)