Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile.rocm_base
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ ARG PYTORCH_AUDIO_BRANCH="v2.9.0"
ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git"
ARG FA_BRANCH="0e60e394"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="6af8b687"
ARG AITER_BRANCH="1f5a392"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG MORI_BRANCH="2d02c6a9"
ARG MORI_REPO="https://github.com/ROCm/mori.git"
Expand Down
43 changes: 39 additions & 4 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,8 +1143,25 @@ def forward(
return
assert attn_metadata.decode_metadata is not None

_, num_heads, head_size = query.shape
num_seqs = attn_metadata.seq_lens.shape[0]

if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
num_blocks, block_size, num_kv_heads, head_size = key_cache.shape
max_num_partitions = (
attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1
) // _PARTITION_SIZE_ROCM
tmp_out = torch.empty(
(num_seqs, num_heads, max_num_partitions, head_size),
dtype=query.dtype,
device=query.device,
)
exp_sums = torch.empty(
(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=query.device,
)
max_logits = torch.empty_like(exp_sums)
num_blocks, block_size, num_kv_heads, _ = key_cache.shape
x = 16 // key_cache.element_size()
k_cache_template = torch.empty(
[num_blocks, num_kv_heads, head_size // x, block_size, x],
Expand All @@ -1158,18 +1175,36 @@ def forward(
)
new_key_cache = key_cache.view_as(k_cache_template)
new_value_cache = value_cache.view_as(v_cache_template)
aiter.pa_fwd_asm(
k_qscale_asm = (
layer._k_scale
if attn_metadata.k_scale is None
else attn_metadata.k_scale
)
v_qscale_asm = (
layer._v_scale
if attn_metadata.v_scale is None
else attn_metadata.v_scale
)
aiter.paged_attention_common(
Q=query[:num_decode_tokens],
K=new_key_cache,
V=new_value_cache,
tmp_out=tmp_out,
max_logits=max_logits,
exp_sums=exp_sums,
max_seq_len=attn_metadata.max_seq_len,
block_tables=attn_metadata.block_table[:num_decodes],
context_lens=attn_metadata.seq_lens[:num_decodes],
block_tables_stride0=attn_metadata.block_table[
:num_decodes
].stride(0),
K_QScale=attn_metadata.k_scale,
V_QScale=attn_metadata.v_scale,
scale=self.scale,
K_QScale_hip=layer._k_scale,
V_QScale_hip=layer._v_scale,
K_QScale_asm=k_qscale_asm,
V_QScale_asm=v_qscale_asm,
out_=output[:num_decode_tokens],
kv_cache_dtype=self.kv_cache_dtype,
)
else:
_, num_heads, head_size = query.shape
Expand Down