diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 6f8c7222fdce..f6978dae00b0 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -9,7 +9,7 @@ ARG PYTORCH_AUDIO_BRANCH="v2.9.0" ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git" ARG FA_BRANCH="0e60e394" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="6af8b687" +ARG AITER_BRANCH="1f5a392" ARG AITER_REPO="https://github.com/ROCm/aiter.git" ARG MORI_BRANCH="2d02c6a9" ARG MORI_REPO="https://github.com/ROCm/mori.git" diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 3febbe57a66f..c6d0301e975c 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -1143,8 +1143,25 @@ def forward( return assert attn_metadata.decode_metadata is not None + _, num_heads, head_size = query.shape + num_seqs = attn_metadata.seq_lens.shape[0] + if rocm_aiter_ops.is_shuffle_kv_cache_enabled(): - num_blocks, block_size, num_kv_heads, head_size = key_cache.shape + max_num_partitions = ( + attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1 + ) // _PARTITION_SIZE_ROCM + tmp_out = torch.empty( + (num_seqs, num_heads, max_num_partitions, head_size), + dtype=query.dtype, + device=query.device, + ) + exp_sums = torch.empty( + (num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=query.device, + ) + max_logits = torch.empty_like(exp_sums) + num_blocks, block_size, num_kv_heads, _ = key_cache.shape x = 16 // key_cache.element_size() k_cache_template = torch.empty( [num_blocks, num_kv_heads, head_size // x, block_size, x], @@ -1158,18 +1175,36 @@ def forward( ) new_key_cache = key_cache.view_as(k_cache_template) new_value_cache = value_cache.view_as(v_cache_template) - aiter.pa_fwd_asm( + k_qscale_asm = ( + layer._k_scale + if attn_metadata.k_scale is None + else attn_metadata.k_scale + ) + v_qscale_asm = ( + layer._v_scale + if attn_metadata.v_scale is None + else attn_metadata.v_scale + ) + aiter.paged_attention_common( Q=query[:num_decode_tokens], K=new_key_cache, V=new_value_cache, + tmp_out=tmp_out, + max_logits=max_logits, + exp_sums=exp_sums, + max_seq_len=attn_metadata.max_seq_len, block_tables=attn_metadata.block_table[:num_decodes], context_lens=attn_metadata.seq_lens[:num_decodes], block_tables_stride0=attn_metadata.block_table[ :num_decodes ].stride(0), - K_QScale=attn_metadata.k_scale, - V_QScale=attn_metadata.v_scale, + scale=self.scale, + K_QScale_hip=layer._k_scale, + V_QScale_hip=layer._v_scale, + K_QScale_asm=k_qscale_asm, + V_QScale_asm=v_qscale_asm, out_=output[:num_decode_tokens], + kv_cache_dtype=self.kv_cache_dtype, ) else: _, num_heads, head_size = query.shape