We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 39e36dc commit e7cca24Copy full SHA for e7cca24
flashinfer/decode.py
@@ -2432,8 +2432,8 @@ def xqa_batch_decode_with_kv_cache(
2432
head_dim = k_cache.shape[3]
2433
2434
workspace_u8 = workspace_buffer.view(torch.uint8)
2435
- semaphore = workspace_u8[: round_up(4 * sm_count, 16)]
2436
- scratch = workspace_u8[round_up(4 * sm_count, 16) :]
+ semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore
+ scratch = workspace_u8[8 * 1024 * 1024 :]
2437
kv_scale_value = bmm2_scale
2438
q_scale_value = bmm1_scale / kv_scale_value * (head_dim**0.5)
2439
0 commit comments