diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c8366ecce543..d2f5d0ff3947 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -382,6 +382,12 @@ def _rocm_aiter_mla_decode_fwd_impl( logit_cap: float = 0.0, q_scale: torch.Tensor | None = None, kv_scale: torch.Tensor | None = None, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, ) -> None: from aiter.mla import mla_decode_fwd @@ -395,6 +401,29 @@ def _rocm_aiter_mla_decode_fwd_impl( kwargs["q_scale"] = q_scale kwargs["kv_scale"] = kv_scale + if work_meta_data is not None: + assert work_indptr is not None, ( + "work_indptr must be provided with work_meta_data" + ) + assert work_info_set is not None, ( + "work_info_set must be provided with work_meta_data" + ) + assert reduce_indptr is not None, ( + "reduce_indptr must be provided with work_meta_data" + ) + assert reduce_final_map is not None, ( + "reduce_final_map must be provided with work_meta_data" + ) + assert reduce_partial_map is not None, ( + "reduce_partial_map must be provided with work_meta_data" + ) + kwargs["work_meta_data"] = work_meta_data + kwargs["work_indptr"] = work_indptr + kwargs["work_info_set"] = work_info_set + kwargs["reduce_indptr"] = reduce_indptr + kwargs["reduce_final_map"] = reduce_final_map + kwargs["reduce_partial_map"] = reduce_partial_map + mla_decode_fwd( q, kv_buffer.view(-1, 1, 1, q.shape[-1]), @@ -421,6 +450,12 @@ def _rocm_aiter_mla_decode_fwd_fake( logit_cap: float = 0.0, q_scale: torch.Tensor | None = None, kv_scale: torch.Tensor | None = None, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, ) -> None: pass @@ -1614,6 +1649,12 @@ def mla_decode_fwd( logit_cap: float = 0.0, q_scale: torch.Tensor | None = None, kv_scale: torch.Tensor | None = None, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, ): torch.ops.vllm.rocm_aiter_mla_decode_fwd( q, @@ -1628,6 +1669,12 @@ def mla_decode_fwd( logit_cap=logit_cap, q_scale=q_scale, kv_scale=kv_scale, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, ) @staticmethod diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 9ded911620d5..80ede377824f 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -18,7 +18,12 @@ QueryLenSupport, ) from vllm.triton_utils import tl, triton -from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf +from vllm.v1.attention.backend import ( + AttentionCGSupport, + AttentionLayer, + CommonAttentionMetadata, + MultipleOf, +) from vllm.v1.kv_cache_interface import AttentionSpec @@ -70,8 +75,14 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): max_qo_len: int | None = None +@dataclass class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): - pass + work_meta_data: torch.Tensor | None = None + work_indptr: torch.Tensor | None = None + work_info_set: torch.Tensor | None = None + reduce_indptr: torch.Tensor | None = None + reduce_final_map: torch.Tensor | None = None + reduce_partial_map: torch.Tensor | None = None class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): @@ -115,6 +126,55 @@ def __init__( max_num_pages, dtype=torch.int32, device=device ) + from aiter import dtypes, get_mla_metadata_info_v1 + + self._num_attention_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config + ) + q_dtype = self.decode_attn_out_dtype + kv_cache_dtype_str = getattr(vllm_config.cache_config, "cache_dtype", "auto") + if kv_cache_dtype_str in ("fp8", "fp8_e4m3", "fp8_e5m2"): + kv_cache_dtype_str = "fp8" + else: + kv_cache_dtype_str = "bf16" + kv_dtype = dtypes.d_dtypes.get(kv_cache_dtype_str, dtypes.bf16) + ( + (work_meta_data_size, work_meta_data_type), + (work_indptr_size, work_indptr_type), + (work_info_set_size, work_info_set_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_mla_metadata_info_v1( + max_num_reqs, + 1, + self._num_attention_heads, + q_dtype, + kv_dtype, + is_sparse=False, + fast_mode=True, + ) + self._mla_work_meta_data = torch.empty( + work_meta_data_size, dtype=work_meta_data_type, device=device + ) + self._mla_work_indptr = torch.empty( + work_indptr_size, dtype=work_indptr_type, device=device + ) + self._mla_work_info_set = torch.empty( + work_info_set_size, dtype=work_info_set_type, device=device + ) + self._mla_reduce_indptr = torch.empty( + reduce_indptr_size, dtype=reduce_indptr_type, device=device + ) + self._mla_reduce_final_map = torch.empty( + reduce_final_map_size, dtype=reduce_final_map_type, device=device + ) + self._mla_reduce_partial_map = torch.empty( + reduce_partial_map_size, + dtype=reduce_partial_map_type, + device=device, + ) + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.paged_kv_indptr = torch.zeros( max_num_reqs + 1, dtype=torch.int32, device=device @@ -183,6 +243,28 @@ def _build_decode( 0, num_reqs + 1, step=1, dtype=torch.int32, device=device ) + from aiter import get_mla_metadata_v1 + + get_mla_metadata_v1( + qo_indptr, + paged_kv_indptr, + paged_kv_last_page_len, + self._num_attention_heads, + 1, + True, + self._mla_work_meta_data, + self._mla_work_info_set, + self._mla_work_indptr, + self._mla_reduce_indptr, + self._mla_reduce_final_map, + self._mla_reduce_partial_map, + page_size=1, + kv_granularity=16, + max_seqlen_qo=max_qo_len, + uni_seqlen_qo=max_qo_len, + fast_mode=True, + ) + attn_metadata = AiterMLADecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, @@ -197,6 +279,23 @@ def _build_decode( return attn_metadata + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AiterMLAMetadata: + attn_metadata = super().build( + common_prefix_len, common_attn_metadata, fast_build + ) + attn_metadata.work_meta_data = self._mla_work_meta_data + attn_metadata.work_indptr = self._mla_work_indptr + attn_metadata.work_info_set = self._mla_work_info_set + attn_metadata.reduce_indptr = self._mla_reduce_indptr + attn_metadata.reduce_final_map = self._mla_reduce_final_map + attn_metadata.reduce_partial_map = self._mla_reduce_partial_map + return attn_metadata + @triton.jit def _copy_page_indices_kernel( @@ -337,6 +436,12 @@ def forward_mqa( attn_metadata.decode.paged_kv_last_page_len, q_scale=layer._q_scale, kv_scale=layer._k_scale, + work_meta_data=attn_metadata.work_meta_data, + work_indptr=attn_metadata.work_indptr, + work_info_set=attn_metadata.work_info_set, + reduce_indptr=attn_metadata.reduce_indptr, + reduce_final_map=attn_metadata.reduce_final_map, + reduce_partial_map=attn_metadata.reduce_partial_map, ) if self._needs_head_repeat: