diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 4888ae51d1d3..c7f925817a6a 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -31,15 +31,14 @@ if current_platform.is_rocm(): import aiter - from aiter.ops.triton.utils.device_info import get_num_sms from vllm.triton_utils import tl, triton def block_size(x, head_dim): return min(65536 // x.element_size(), triton.next_power_of_2(head_dim)) - def num_programs(head_dim): - return min(head_dim, get_num_sms()) + def num_programs(total_tokens): + return min(total_tokens, current_platform.get_cu_count()) @triton.jit def cp_mha_gather_cache_kernel( @@ -58,11 +57,11 @@ def cp_mha_gather_cache_kernel( x, max_block_num, num_tokens, + num_programs, DEQUANT: tl.constexpr, PAGE_SIZE: tl.constexpr, CACHE_FORMAT: tl.constexpr, BLOCK_SIZE: tl.constexpr, - NUM_PRGMS: tl.constexpr, ): bid = tl.program_id(0) col_offsets = tl.arange(0, BLOCK_SIZE) @@ -70,7 +69,7 @@ def cp_mha_gather_cache_kernel( k_scale = tl.load(k_scale_ptr) v_scale = tl.load(v_scale_ptr) - for token_id in tl.range(bid, num_tokens, NUM_PRGMS): + for token_id in tl.range(bid, num_tokens, num_programs): key_ptr_offset = key_ptr + token_id * head_size * num_heads value_ptr_offset = value_ptr + token_id * head_size * num_heads batch_idx = tl.load(token_to_batch_ptr + token_id) @@ -162,11 +161,11 @@ def cp_mha_gather_cache( x, block_tables.size(1), total_tokens, + NUM_PRGMS, DEQUANT=dequant, PAGE_SIZE=page_size, CACHE_FORMAT=kv_cache_layout, BLOCK_SIZE=BLOCK_SIZE, - NUM_PRGMS=NUM_PRGMS, )