Prefix Caching- fix t4 triton error#2517
Conversation
| import triton | ||
| import triton.language as tl | ||
|
|
||
| TESLA = 'Tesla' in torch.cuda.get_device_name(0) |
There was a problem hiding this comment.
would it be possible to check for compute capability instead? also, we should do this inside context_attention_fwd, as calling CUDA APIs before we set CUDA_VISIBLE_DEVICES will lead to errors.
There was a problem hiding this comment.
Maybe we can set prefix_block_size as a parameter in CacheConfig and allow user configure in LLM?
There was a problem hiding this comment.
this sort of a thing should be ideally derived automatically.
There was a problem hiding this comment.
@esmeetu The block size is mainly dependent on the shared mem size for different GPU architectures. It will affect the prefix-prefill kernel speed a little bit but has nothing to do with the GPU memory utilization.
|
Amazing! @caoshiyi Thanks for your help! This is good for me now and speed is indeed a x2-x3 speedup. But when doing further testing, i encountering engine stuck issue when GPU KV Cache is full (i change prefix 5~6 times). And the request is always at the pending state. After one more change for prefix(will take up >10% KV cache), the engine will stuck. |
|
#2511 looks solving my second issue. |
zhuohan123
left a comment
There was a problem hiding this comment.
LGTM! Left a small comment.
| import triton | ||
| import triton.language as tl | ||
|
|
||
| TESLA = 'Tesla' in torch.cuda.get_device_name(0) |
There was a problem hiding this comment.
Can we set this variable in a function instead of a global variable? Setting it in global variable may lead to issues in distributed setting.
|
@caoshiyi What is the blocker to this PR? Could you address @Yard1 and @zhuohan123's comments? |
| alibi_slopes=None): | ||
| BLOCK = 128 | ||
|
|
||
| cap = torch.cuda.get_device_capability() |
There was a problem hiding this comment.
Does prefix caching adapt other hardware? like AMD? This only considers cuda arch. Might it better that we define a global utility to get block size which handles different hardwares.
There was a problem hiding this comment.
I believe this kernel only works for NVIDIA right now. Let me merge this fix first and we can systematically test for AMD later.
Fix #2513, need a smaller block size for Turing GPUs