Skip to content

Commit 4a37f64

Browse files
authored
[KVCache] Increase coalesce threshold (#17280)
This PR changes the threshold of coalesce in kvcache for better performance.
1 parent 132daf6 commit 4a37f64

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/runtime/relax_vm/paged_kv_cache.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1727,7 +1727,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
17271727
qkv_data->dtype);
17281728
// Part 2. Split fused qkv and apply rotary embedding to q/k data.
17291729
f_split_rotary_(qkv_data, q_rope_position_map_view_, q_data, k_data, v_data,
1730-
rope_mode_ == RoPEMode::kNormal);
1730+
static_cast<int>(rope_mode_ == RoPEMode::kNormal));
17311731

17321732
// Part 3. Append k/v data to kv-cache if flag "append_before_attn" is set.
17331733
if (append_before_attn_) {
@@ -2202,7 +2202,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
22022202
}
22032203
double coalesce_ratio = 1.0 * page_counter_uncoalesced / page_counter_coalesced;
22042204
// Do not coalesce and use batch decode kernel when coalesce ratio is small.
2205-
bool use_decode_kernel = is_decode_request_ && coalesce_ratio < 1.1;
2205+
bool use_decode_kernel = is_decode_request_ && coalesce_ratio < 32;
22062206
return {use_decode_kernel || !enable_coalesce ? uncoalesced_block_ids : coalesced_block_ids,
22072207
use_decode_kernel};
22082208
}

0 commit comments

Comments
 (0)