Skip to content

Commit e7cca24

Browse files
committed
upd
1 parent 39e36dc commit e7cca24

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

flashinfer/decode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2432,8 +2432,8 @@ def xqa_batch_decode_with_kv_cache(
24322432
head_dim = k_cache.shape[3]
24332433

24342434
workspace_u8 = workspace_buffer.view(torch.uint8)
2435-
semaphore = workspace_u8[: round_up(4 * sm_count, 16)]
2436-
scratch = workspace_u8[round_up(4 * sm_count, 16) :]
2435+
semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore
2436+
scratch = workspace_u8[8 * 1024 * 1024 :]
24372437
kv_scale_value = bmm2_scale
24382438
q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5)
24392439

0 commit comments

Comments
 (0)