Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,14 @@

if current_platform.is_rocm():
import aiter
from aiter.ops.triton.utils.device_info import get_num_sms

from vllm.triton_utils import tl, triton

def block_size(x, head_dim):
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))

def num_programs(head_dim):
return min(head_dim, get_num_sms())
def num_programs(total_tokens):
return min(total_tokens, current_platform.get_cu_count())

@triton.jit
def cp_mha_gather_cache_kernel(
Expand All @@ -58,19 +57,19 @@ def cp_mha_gather_cache_kernel(
x,
max_block_num,
num_tokens,
num_programs,
DEQUANT: tl.constexpr,
PAGE_SIZE: tl.constexpr,
CACHE_FORMAT: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
NUM_PRGMS: tl.constexpr,
):
bid = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
if DEQUANT:
k_scale = tl.load(k_scale_ptr)
v_scale = tl.load(v_scale_ptr)

for token_id in tl.range(bid, num_tokens, NUM_PRGMS):
for token_id in tl.range(bid, num_tokens, num_programs):
key_ptr_offset = key_ptr + token_id * head_size * num_heads
value_ptr_offset = value_ptr + token_id * head_size * num_heads
batch_idx = tl.load(token_to_batch_ptr + token_id)
Expand Down Expand Up @@ -162,11 +161,11 @@ def cp_mha_gather_cache(
x,
block_tables.size(1),
total_tokens,
NUM_PRGMS,
DEQUANT=dequant,
PAGE_SIZE=page_size,
CACHE_FORMAT=kv_cache_layout,
BLOCK_SIZE=BLOCK_SIZE,
NUM_PRGMS=NUM_PRGMS,
)


Expand Down