Skip to content
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,15 +1181,15 @@ def forward(
)

descale_shape = (
attn_metadata.query_start_loc[:num_decodes].shape[0] - 1,
num_decodes,
key_cache.shape[2],
)
unified_attention(
q=query[:num_decode_tokens],
k=key_cache,
v=value_cache,
out=output[:num_decode_tokens],
cu_seqlens_q=attn_metadata.query_start_loc[:num_decodes],
cu_seqlens_q=attn_metadata.query_start_loc[: num_decodes + 1],
max_seqlen_q=decode_max_query_len,
seqused_k=attn_metadata.seq_lens[:num_decodes],
max_seqlen_k=attn_metadata.max_seq_len,
Expand Down
Loading