Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
006e99e
change pa_fwd_asm --> paged_attention_common for shuffle kv-cache
samutamm Jan 20, 2026
7d24855
fix K/V q scale asm args
samutamm Jan 21, 2026
078fcf3
pin to latest aiter
samutamm Jan 22, 2026
6e72543
ruff check
samutamm Jan 22, 2026
ac04101
simplify arguments
samutamm Jan 23, 2026
3d36878
paged_attention_common to _aiter_ops
samutamm Feb 10, 2026
c8f41ae
pre-commit
samutamm Feb 11, 2026
4bad7d3
paged_attention_common to rocm_aiter_ops
samutamm Feb 12, 2026
cf0e506
Merge branch 'main' into pa_common_shuffle_kv_cache
samutamm Feb 12, 2026
5e87b9b
Merge branch 'main' into pa_common_shuffle_kv_cache
samutamm Feb 16, 2026
09e73a9
revert ruff
samutamm Feb 16, 2026
2506e60
Merge branch 'main' into pa_common_shuffle_kv_cache
samutamm Feb 24, 2026
aee35cc
fix typo
samutamm Feb 24, 2026
624c68f
fix typo
samutamm Feb 24, 2026
4d9941f
Merge branch 'main' into pa_common_shuffle_kv_cache
samutamm Feb 25, 2026
6c18d08
fix merge conflict
samutamm Feb 25, 2026
b4d17f0
Merge branch 'main' into pa_common_shuffle_kv_cache
samutamm Feb 27, 2026
14b52aa
Merge branch 'main' into pa_common_shuffle_kv_cache
tjtanaa Mar 12, 2026
6ff94bb
Merge branch 'main' into pa_common_shuffle_kv_cache
tuukkjs Mar 25, 2026
8a4944f
fix: use consistent KV scales for HIP path in paged_attention_common
tuukkjs Mar 25, 2026
356d850
Use same variables as for asm
tuukkjs Mar 25, 2026
60480e7
Merge branch 'main' into pa_common_shuffle_kv_cache
tuukkjs Mar 30, 2026
6c13eae
Merge branch 'main' into pa_common_shuffle_kv_cache
tjtanaa Mar 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2070,5 +2070,56 @@ def pa_fwd_asm(
out_=out_,
)

@staticmethod
def paged_attention_common(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
tmp_out: torch.Tensor,
max_logits: torch.Tensor,
exp_sums: torch.Tensor,
max_seq_len: int,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_tables_stride0: int,
scale: float,
K_QScale_hip: torch.Tensor,
V_QScale_hip: torch.Tensor,
K_QScale_asm: torch.Tensor,
V_QScale_asm: torch.Tensor,
out_: torch.Tensor,
kv_cache_dtype: str,
):
"""
Paged attention common function.

This function is NOT wrapped with @is_aiter_supported decorator
to allow explicit backend selection via attention_config to work
even when VLLM_ROCM_USE_AITER=0.

Note: This performs lazy import of aiter.paged_attention_common
"""
from aiter import paged_attention_common

return paged_attention_common(
Q=Q,
K=K,
V=V,
tmp_out=tmp_out,
max_logits=max_logits,
exp_sums=exp_sums,
max_seq_len=max_seq_len,
block_tables=block_tables,
context_lens=context_lens,
block_tables_stride0=block_tables_stride0,
scale=scale,
K_QScale_hip=K_QScale_hip,
V_QScale_hip=V_QScale_hip,
K_QScale_asm=K_QScale_asm,
V_QScale_asm=V_QScale_asm,
out_=out_,
kv_cache_dtype=kv_cache_dtype,
)


rocm_aiter_ops.register_ops_once()
42 changes: 38 additions & 4 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,7 +1247,23 @@ def forward(
v_descale=layer._v_scale.expand(descale_shape),
)
elif rocm_aiter_ops.is_shuffle_kv_cache_enabled():
num_blocks, block_size, num_kv_heads, head_size = key_cache.shape
_, num_heads, head_size = query.shape
num_seqs = attn_metadata.seq_lens.shape[0]
max_num_partitions = (
attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1
) // _PARTITION_SIZE_ROCM
tmp_out = torch.empty(
(num_seqs, num_heads, max_num_partitions, head_size),
dtype=query.dtype,
device=query.device,
)
exp_sums = torch.empty(
(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=query.device,
)
max_logits = torch.empty_like(exp_sums)
num_blocks, block_size, num_kv_heads, _ = key_cache.shape
x = 16 // key_cache.element_size()
k_cache_template = torch.empty(
[num_blocks, num_kv_heads, head_size // x, block_size, x],
Expand All @@ -1261,18 +1277,36 @@ def forward(
)
new_key_cache = key_cache.view_as(k_cache_template)
new_value_cache = value_cache.view_as(v_cache_template)
rocm_aiter_ops.pa_fwd_asm(
k_qscale = (
layer._k_scale
if attn_metadata.k_scale is None
else attn_metadata.k_scale
)
v_qscale = (
layer._v_scale
if attn_metadata.v_scale is None
else attn_metadata.v_scale
)
rocm_aiter_ops.paged_attention_common(
Q=query[:num_decode_tokens],
K=new_key_cache,
V=new_value_cache,
tmp_out=tmp_out,
max_logits=max_logits,
exp_sums=exp_sums,
max_seq_len=attn_metadata.max_seq_len,
block_tables=attn_metadata.block_table[:num_decodes],
context_lens=attn_metadata.seq_lens[:num_decodes],
block_tables_stride0=attn_metadata.block_table[
:num_decodes
].stride(0),
K_QScale=attn_metadata.k_scale,
V_QScale=attn_metadata.v_scale,
scale=self.scale,
K_QScale_hip=k_qscale,
V_QScale_hip=v_qscale,
K_QScale_asm=k_qscale,
V_QScale_asm=v_qscale,
out_=output[:num_decode_tokens],
kv_cache_dtype=self.kv_cache_dtype,
)
else:
_, num_heads, head_size = query.shape
Expand Down
Loading