3131
3232if current_platform .is_rocm ():
3333 import aiter
34- from aiter .ops .triton .utils .device_info import get_num_sms
3534
3635 from vllm .triton_utils import tl , triton
3736
3837 def block_size (x , head_dim ):
3938 return min (65536 // x .element_size (), triton .next_power_of_2 (head_dim ))
4039
41- def num_programs (head_dim ):
42- return min (head_dim , get_num_sms ())
40+ def num_programs (total_tokens ):
41+ return min (total_tokens , current_platform . get_cu_count ())
4342
4443 @triton .jit
4544 def cp_mha_gather_cache_kernel (
@@ -58,19 +57,19 @@ def cp_mha_gather_cache_kernel(
5857 x ,
5958 max_block_num ,
6059 num_tokens ,
60+ num_programs ,
6161 DEQUANT : tl .constexpr ,
6262 PAGE_SIZE : tl .constexpr ,
6363 CACHE_FORMAT : tl .constexpr ,
6464 BLOCK_SIZE : tl .constexpr ,
65- NUM_PRGMS : tl .constexpr ,
6665 ):
6766 bid = tl .program_id (0 )
6867 col_offsets = tl .arange (0 , BLOCK_SIZE )
6968 if DEQUANT :
7069 k_scale = tl .load (k_scale_ptr )
7170 v_scale = tl .load (v_scale_ptr )
7271
73- for token_id in tl .range (bid , num_tokens , NUM_PRGMS ):
72+ for token_id in tl .range (bid , num_tokens , num_programs ):
7473 key_ptr_offset = key_ptr + token_id * head_size * num_heads
7574 value_ptr_offset = value_ptr + token_id * head_size * num_heads
7675 batch_idx = tl .load (token_to_batch_ptr + token_id )
@@ -162,11 +161,11 @@ def cp_mha_gather_cache(
162161 x ,
163162 block_tables .size (1 ),
164163 total_tokens ,
164+ NUM_PRGMS ,
165165 DEQUANT = dequant ,
166166 PAGE_SIZE = page_size ,
167167 CACHE_FORMAT = kv_cache_layout ,
168168 BLOCK_SIZE = BLOCK_SIZE ,
169- NUM_PRGMS = NUM_PRGMS ,
170169 )
171170
172171
0 commit comments