Skip to content
Merged
45 changes: 44 additions & 1 deletion vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,50 @@ def forward(
)
return

if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
# The ll4mi kernel in paged_attention_v1 requires
# HEAD_SIZE >= 16 * NWARPS (= 64 on ROCm with NWARPS=4).
# For smaller head sizes or sliding window attention,
# fall back to the unified_attention triton kernel which
# handles both correctly.
_MIN_HEAD_SIZE_FOR_LL4MI = 64
use_unified_attention = self.head_size < _MIN_HEAD_SIZE_FOR_LL4MI

if use_unified_attention:
assert not rocm_aiter_ops.is_shuffle_kv_cache_enabled(), (
"unified_attention fallback with shuffle layout "
"is not supported yet."
)
from aiter.ops.triton.unified_attention import (
unified_attention,
)

decode_cu_seqlens_q = attn_metadata.query_start_loc[
: num_decodes + 1
]
descale_shape = (
num_decodes,
key_cache.shape[2],
)
unified_attention(
q=query[:num_decode_tokens],
k=key_cache,
v=value_cache,
out=output[:num_decode_tokens],
cu_seqlens_q=decode_cu_seqlens_q,
max_seqlen_q=1,
seqused_k=attn_metadata.seq_lens[:num_decodes],
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=attn_metadata.block_table[:num_decodes],
softcap=self.logits_soft_cap,
q_descale=None,
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
elif rocm_aiter_ops.is_shuffle_kv_cache_enabled():
num_blocks, block_size, num_kv_heads, head_size = key_cache.shape
x = 16 // key_cache.element_size()
k_cache_template = torch.empty(
Expand Down