diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 141d57d908eb..2ea3c346f5a8 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -1114,7 +1114,50 @@ def forward( ) return - if rocm_aiter_ops.is_shuffle_kv_cache_enabled(): + # The ll4mi kernel in paged_attention_v1 requires + # HEAD_SIZE >= 16 * NWARPS (= 64 on ROCm with NWARPS=4). + # For smaller head sizes or sliding window attention, + # fall back to the unified_attention triton kernel which + # handles both correctly. + _MIN_HEAD_SIZE_FOR_LL4MI = 64 + use_unified_attention = self.head_size < _MIN_HEAD_SIZE_FOR_LL4MI + + if use_unified_attention: + assert not rocm_aiter_ops.is_shuffle_kv_cache_enabled(), ( + "unified_attention fallback with shuffle layout " + "is not supported yet." + ) + from aiter.ops.triton.unified_attention import ( + unified_attention, + ) + + decode_cu_seqlens_q = attn_metadata.query_start_loc[ + : num_decodes + 1 + ] + descale_shape = ( + num_decodes, + key_cache.shape[2], + ) + unified_attention( + q=query[:num_decode_tokens], + k=key_cache, + v=value_cache, + out=output[:num_decode_tokens], + cu_seqlens_q=decode_cu_seqlens_q, + max_seqlen_q=1, + seqused_k=attn_metadata.seq_lens[:num_decodes], + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=attn_metadata.block_table[:num_decodes], + softcap=self.logits_soft_cap, + q_descale=None, + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + elif rocm_aiter_ops.is_shuffle_kv_cache_enabled(): num_blocks, block_size, num_kv_heads, head_size = key_cache.shape x = 16 // key_cache.element_size() k_cache_template = torch.empty(