From 9cadcc01f86ed00dd11542973259a6b17457a271 Mon Sep 17 00:00:00 2001 From: Sathish Sanjeevi Date: Mon, 9 Mar 2026 21:07:16 -0500 Subject: [PATCH 1/4] initial commit for persistent MLA Signed-off-by: Sathish Sanjeevi --- vllm/_aiter_ops.py | 37 ++++++ vllm/envs.py | 5 + .../attention/backends/mla/rocm_aiter_mla.py | 108 ++++++++++++++++++ 3 files changed, 150 insertions(+) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c8366ecce543..930166cdc97b 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -382,6 +382,12 @@ def _rocm_aiter_mla_decode_fwd_impl( logit_cap: float = 0.0, q_scale: torch.Tensor | None = None, kv_scale: torch.Tensor | None = None, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, ) -> None: from aiter.mla import mla_decode_fwd @@ -395,6 +401,19 @@ def _rocm_aiter_mla_decode_fwd_impl( kwargs["q_scale"] = q_scale kwargs["kv_scale"] = kv_scale + if work_meta_data is not None: + assert work_indptr is not None, "work_indptr must be provided with work_meta_data" + assert work_info_set is not None, "work_info_set must be provided with work_meta_data" + assert reduce_indptr is not None, "reduce_indptr must be provided with work_meta_data" + assert reduce_final_map is not None, "reduce_final_map must be provided with work_meta_data" + assert reduce_partial_map is not None, "reduce_partial_map must be provided with work_meta_data" + kwargs["work_meta_data"] = work_meta_data + kwargs["work_indptr"] = work_indptr + kwargs["work_info_set"] = work_info_set + kwargs["reduce_indptr"] = reduce_indptr + kwargs["reduce_final_map"] = reduce_final_map + kwargs["reduce_partial_map"] = reduce_partial_map + mla_decode_fwd( q, kv_buffer.view(-1, 1, 1, q.shape[-1]), @@ -421,6 +440,12 @@ def _rocm_aiter_mla_decode_fwd_fake( logit_cap: float = 0.0, q_scale: torch.Tensor | None = None, kv_scale: torch.Tensor | None = None, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, ) -> None: pass @@ -1614,6 +1639,12 @@ def mla_decode_fwd( logit_cap: float = 0.0, q_scale: torch.Tensor | None = None, kv_scale: torch.Tensor | None = None, + work_meta_data: torch.Tensor | None = None, + work_indptr: torch.Tensor | None = None, + work_info_set: torch.Tensor | None = None, + reduce_indptr: torch.Tensor | None = None, + reduce_final_map: torch.Tensor | None = None, + reduce_partial_map: torch.Tensor | None = None, ): torch.ops.vllm.rocm_aiter_mla_decode_fwd( q, @@ -1628,6 +1659,12 @@ def mla_decode_fwd( logit_cap=logit_cap, q_scale=q_scale, kv_scale=kv_scale, + work_meta_data=work_meta_data, + work_indptr=work_indptr, + work_info_set=work_info_set, + reduce_indptr=reduce_indptr, + reduce_final_map=reduce_final_map, + reduce_partial_map=reduce_partial_map, ) @staticmethod diff --git a/vllm/envs.py b/vllm/envs.py index 3b7312a4f378..4fa9cd9656e6 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -111,6 +111,7 @@ VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True + VLLM_ROCM_USE_AITER_MLA_PERSISTENT: bool = False VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -931,6 +932,10 @@ def _get_or_set_default() -> str: "VLLM_ROCM_USE_AITER_MLA": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in ("true", "1") ), + "VLLM_ROCM_USE_AITER_MLA_PERSISTENT": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_MLA_PERSISTENT", "False").lower() + in ("true", "1") + ), # Whether to use aiter mha ops. # By default is enabled. "VLLM_ROCM_USE_AITER_MHA": lambda: ( diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 9ded911620d5..8fde21e13d43 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -9,6 +9,7 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig from vllm.config.cache import CacheDType +from vllm import envs from vllm.model_executor.layers.attention.mla_attention import ( MLACommonBackend, MLACommonDecodeMetadata, @@ -69,6 +70,13 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): # The max query output length: int max_qo_len: int | None = None + work_meta_data: torch.Tensor | None = None + work_indptr: torch.Tensor | None = None + work_info_set: torch.Tensor | None = None + reduce_indptr: torch.Tensor | None = None + reduce_final_map: torch.Tensor | None = None + reduce_partial_map: torch.Tensor | None = None + class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): pass @@ -115,6 +123,59 @@ def __init__( max_num_pages, dtype=torch.int32, device=device ) + if envs.VLLM_ROCM_USE_AITER_MLA_PERSISTENT: + from aiter import dtypes, get_mla_metadata_info_v1 + + num_attention_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config + ) + q_dtype = self.decode_attn_out_dtype + kv_cache_dtype_str = getattr( + vllm_config.cache_config, "cache_dtype", "auto" + ) + if kv_cache_dtype_str in ("fp8", "fp8_e4m3", "fp8_e5m2"): + kv_cache_dtype_str = "fp8" + else: + kv_cache_dtype_str = "bf16" + kv_dtype = dtypes.d_dtypes.get(kv_cache_dtype_str, dtypes.bf16) + ( + (work_meta_data_size, work_meta_data_type), + (work_indptr_size, work_indptr_type), + (work_info_set_size, work_info_set_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_mla_metadata_info_v1( + max_num_reqs, + 1, + num_attention_heads, + q_dtype, + kv_dtype, + is_sparse=False, + fast_mode=True, + ) + self._mla_work_meta_data = torch.empty( + work_meta_data_size, dtype=work_meta_data_type, device=device + ) + self._mla_work_indptr = torch.empty( + work_indptr_size, dtype=work_indptr_type, device=device + ) + self._mla_work_info_set = torch.empty( + work_info_set_size, dtype=work_info_set_type, device=device + ) + self._mla_reduce_indptr = torch.empty( + reduce_indptr_size, dtype=reduce_indptr_type, device=device + ) + self._mla_reduce_final_map = torch.empty( + reduce_final_map_size, dtype=reduce_final_map_type, device=device + ) + self._mla_reduce_partial_map = torch.empty( + reduce_partial_map_size, + dtype=reduce_partial_map_type, + device=device, + ) + self._num_attention_heads = num_attention_heads + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.paged_kv_indptr = torch.zeros( max_num_reqs + 1, dtype=torch.int32, device=device @@ -183,6 +244,41 @@ def _build_decode( 0, num_reqs + 1, step=1, dtype=torch.int32, device=device ) + decode_work_meta_data = None + decode_work_indptr = None + decode_work_info_set = None + decode_reduce_indptr = None + decode_reduce_final_map = None + decode_reduce_partial_map = None + if getattr(self, "_mla_work_meta_data", None) is not None: + from aiter import get_mla_metadata_v1 + + get_mla_metadata_v1( + qo_indptr, + paged_kv_indptr, + paged_kv_last_page_len, + self._num_attention_heads, + 1, + True, + self._mla_work_meta_data, + self._mla_work_info_set, + self._mla_work_indptr, + self._mla_reduce_indptr, + self._mla_reduce_final_map, + self._mla_reduce_partial_map, + page_size=1, + kv_granularity=16, + max_seqlen_qo=max_qo_len, + uni_seqlen_qo=max_qo_len, + fast_mode=True, + ) + decode_work_meta_data = self._mla_work_meta_data + decode_work_indptr = self._mla_work_indptr + decode_work_info_set = self._mla_work_info_set + decode_reduce_indptr = self._mla_reduce_indptr + decode_reduce_final_map = self._mla_reduce_final_map + decode_reduce_partial_map = self._mla_reduce_partial_map + attn_metadata = AiterMLADecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, @@ -193,6 +289,12 @@ def _build_decode( dcp_tot_seq_lens=dcp_tot_seq_lens_device, max_qo_len=max_qo_len, attn_out_dtype=self.decode_attn_out_dtype, + work_meta_data=decode_work_meta_data, + work_indptr=decode_work_indptr, + work_info_set=decode_work_info_set, + reduce_indptr=decode_reduce_indptr, + reduce_final_map=decode_reduce_final_map, + reduce_partial_map=decode_reduce_partial_map, ) return attn_metadata @@ -337,6 +439,12 @@ def forward_mqa( attn_metadata.decode.paged_kv_last_page_len, q_scale=layer._q_scale, kv_scale=layer._k_scale, + work_meta_data=attn_metadata.decode.work_meta_data, + work_indptr=attn_metadata.decode.work_indptr, + work_info_set=attn_metadata.decode.work_info_set, + reduce_indptr=attn_metadata.decode.reduce_indptr, + reduce_final_map=attn_metadata.decode.reduce_final_map, + reduce_partial_map=attn_metadata.decode.reduce_partial_map, ) if self._needs_head_repeat: From 92e1b6b43fc7a64b2303e7bcad18f55c4d8c1d9e Mon Sep 17 00:00:00 2001 From: Sathish Sanjeevi Date: Wed, 18 Mar 2026 13:47:20 -0500 Subject: [PATCH 2/4] move persistent MLA buffers from decode metadata to AiterMLAMetadata - per review Signed-off-by: Sathish Sanjeevi --- .../attention/backends/mla/rocm_aiter_mla.py | 62 ++++++++++--------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 8fde21e13d43..cc23ff29073e 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -19,7 +19,12 @@ QueryLenSupport, ) from vllm.triton_utils import tl, triton -from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf +from vllm.v1.attention.backend import ( + AttentionCGSupport, + AttentionLayer, + CommonAttentionMetadata, + MultipleOf, +) from vllm.v1.kv_cache_interface import AttentionSpec @@ -70,6 +75,9 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): # The max query output length: int max_qo_len: int | None = None + +@dataclass +class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): work_meta_data: torch.Tensor | None = None work_indptr: torch.Tensor | None = None work_info_set: torch.Tensor | None = None @@ -78,10 +86,6 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata): reduce_partial_map: torch.Tensor | None = None -class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): - pass - - class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): # TODO(luka, lucas): audit this as part of: # https://github.com/vllm-project/vllm/issues/22945 @@ -244,12 +248,6 @@ def _build_decode( 0, num_reqs + 1, step=1, dtype=torch.int32, device=device ) - decode_work_meta_data = None - decode_work_indptr = None - decode_work_info_set = None - decode_reduce_indptr = None - decode_reduce_final_map = None - decode_reduce_partial_map = None if getattr(self, "_mla_work_meta_data", None) is not None: from aiter import get_mla_metadata_v1 @@ -272,12 +270,6 @@ def _build_decode( uni_seqlen_qo=max_qo_len, fast_mode=True, ) - decode_work_meta_data = self._mla_work_meta_data - decode_work_indptr = self._mla_work_indptr - decode_work_info_set = self._mla_work_info_set - decode_reduce_indptr = self._mla_reduce_indptr - decode_reduce_final_map = self._mla_reduce_final_map - decode_reduce_partial_map = self._mla_reduce_partial_map attn_metadata = AiterMLADecodeMetadata( block_table=block_table_tensor, @@ -289,16 +281,28 @@ def _build_decode( dcp_tot_seq_lens=dcp_tot_seq_lens_device, max_qo_len=max_qo_len, attn_out_dtype=self.decode_attn_out_dtype, - work_meta_data=decode_work_meta_data, - work_indptr=decode_work_indptr, - work_info_set=decode_work_info_set, - reduce_indptr=decode_reduce_indptr, - reduce_final_map=decode_reduce_final_map, - reduce_partial_map=decode_reduce_partial_map, ) return attn_metadata + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AiterMLAMetadata: + attn_metadata = super().build( + common_prefix_len, common_attn_metadata, fast_build + ) + if getattr(self, "_mla_work_meta_data", None) is not None: + attn_metadata.work_meta_data = self._mla_work_meta_data + attn_metadata.work_indptr = self._mla_work_indptr + attn_metadata.work_info_set = self._mla_work_info_set + attn_metadata.reduce_indptr = self._mla_reduce_indptr + attn_metadata.reduce_final_map = self._mla_reduce_final_map + attn_metadata.reduce_partial_map = self._mla_reduce_partial_map + return attn_metadata + @triton.jit def _copy_page_indices_kernel( @@ -439,12 +443,12 @@ def forward_mqa( attn_metadata.decode.paged_kv_last_page_len, q_scale=layer._q_scale, kv_scale=layer._k_scale, - work_meta_data=attn_metadata.decode.work_meta_data, - work_indptr=attn_metadata.decode.work_indptr, - work_info_set=attn_metadata.decode.work_info_set, - reduce_indptr=attn_metadata.decode.reduce_indptr, - reduce_final_map=attn_metadata.decode.reduce_final_map, - reduce_partial_map=attn_metadata.decode.reduce_partial_map, + work_meta_data=attn_metadata.work_meta_data, + work_indptr=attn_metadata.work_indptr, + work_info_set=attn_metadata.work_info_set, + reduce_indptr=attn_metadata.reduce_indptr, + reduce_final_map=attn_metadata.reduce_final_map, + reduce_partial_map=attn_metadata.reduce_partial_map, ) if self._needs_head_repeat: From c43dc6068a426adabd2c23540e443b21e72c1a82 Mon Sep 17 00:00:00 2001 From: Sathish Sanjeevi Date: Wed, 18 Mar 2026 23:01:36 -0500 Subject: [PATCH 3/4] satisfy pre-commit checks/criteria Signed-off-by: Sathish Sanjeevi --- vllm/_aiter_ops.py | 20 ++++++++++++++----- .../attention/backends/mla/rocm_aiter_mla.py | 2 +- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 930166cdc97b..d2f5d0ff3947 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -402,11 +402,21 @@ def _rocm_aiter_mla_decode_fwd_impl( kwargs["kv_scale"] = kv_scale if work_meta_data is not None: - assert work_indptr is not None, "work_indptr must be provided with work_meta_data" - assert work_info_set is not None, "work_info_set must be provided with work_meta_data" - assert reduce_indptr is not None, "reduce_indptr must be provided with work_meta_data" - assert reduce_final_map is not None, "reduce_final_map must be provided with work_meta_data" - assert reduce_partial_map is not None, "reduce_partial_map must be provided with work_meta_data" + assert work_indptr is not None, ( + "work_indptr must be provided with work_meta_data" + ) + assert work_info_set is not None, ( + "work_info_set must be provided with work_meta_data" + ) + assert reduce_indptr is not None, ( + "reduce_indptr must be provided with work_meta_data" + ) + assert reduce_final_map is not None, ( + "reduce_final_map must be provided with work_meta_data" + ) + assert reduce_partial_map is not None, ( + "reduce_partial_map must be provided with work_meta_data" + ) kwargs["work_meta_data"] = work_meta_data kwargs["work_indptr"] = work_indptr kwargs["work_info_set"] = work_info_set diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index cc23ff29073e..4c2645971e56 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -6,10 +6,10 @@ import torch +from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig from vllm.config.cache import CacheDType -from vllm import envs from vllm.model_executor.layers.attention.mla_attention import ( MLACommonBackend, MLACommonDecodeMetadata, From ab8159d742559b029b731fe183ebfdb41d60b374 Mon Sep 17 00:00:00 2001 From: Sathish Sanjeevi Date: Wed, 18 Mar 2026 23:18:25 -0500 Subject: [PATCH 4/4] remove VLLM_ROCM_USE_AITER_MLA_PERSISTENT env flag Signed-off-by: Sathish Sanjeevi --- vllm/envs.py | 5 - .../attention/backends/mla/rocm_aiter_mla.py | 155 +++++++++--------- 2 files changed, 74 insertions(+), 86 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 4fa9cd9656e6..3b7312a4f378 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -111,7 +111,6 @@ VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS: bool = False VLLM_ROCM_USE_AITER_TRITON_GEMM: bool = True - VLLM_ROCM_USE_AITER_MLA_PERSISTENT: bool = False VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -932,10 +931,6 @@ def _get_or_set_default() -> str: "VLLM_ROCM_USE_AITER_MLA": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in ("true", "1") ), - "VLLM_ROCM_USE_AITER_MLA_PERSISTENT": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER_MLA_PERSISTENT", "False").lower() - in ("true", "1") - ), # Whether to use aiter mha ops. # By default is enabled. "VLLM_ROCM_USE_AITER_MHA": lambda: ( diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 4c2645971e56..80ede377824f 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -6,7 +6,6 @@ import torch -from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.config import VllmConfig from vllm.config.cache import CacheDType @@ -127,58 +126,54 @@ def __init__( max_num_pages, dtype=torch.int32, device=device ) - if envs.VLLM_ROCM_USE_AITER_MLA_PERSISTENT: - from aiter import dtypes, get_mla_metadata_info_v1 + from aiter import dtypes, get_mla_metadata_info_v1 - num_attention_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config - ) - q_dtype = self.decode_attn_out_dtype - kv_cache_dtype_str = getattr( - vllm_config.cache_config, "cache_dtype", "auto" - ) - if kv_cache_dtype_str in ("fp8", "fp8_e4m3", "fp8_e5m2"): - kv_cache_dtype_str = "fp8" - else: - kv_cache_dtype_str = "bf16" - kv_dtype = dtypes.d_dtypes.get(kv_cache_dtype_str, dtypes.bf16) - ( - (work_meta_data_size, work_meta_data_type), - (work_indptr_size, work_indptr_type), - (work_info_set_size, work_info_set_type), - (reduce_indptr_size, reduce_indptr_type), - (reduce_final_map_size, reduce_final_map_type), - (reduce_partial_map_size, reduce_partial_map_type), - ) = get_mla_metadata_info_v1( - max_num_reqs, - 1, - num_attention_heads, - q_dtype, - kv_dtype, - is_sparse=False, - fast_mode=True, - ) - self._mla_work_meta_data = torch.empty( - work_meta_data_size, dtype=work_meta_data_type, device=device - ) - self._mla_work_indptr = torch.empty( - work_indptr_size, dtype=work_indptr_type, device=device - ) - self._mla_work_info_set = torch.empty( - work_info_set_size, dtype=work_info_set_type, device=device - ) - self._mla_reduce_indptr = torch.empty( - reduce_indptr_size, dtype=reduce_indptr_type, device=device - ) - self._mla_reduce_final_map = torch.empty( - reduce_final_map_size, dtype=reduce_final_map_type, device=device - ) - self._mla_reduce_partial_map = torch.empty( - reduce_partial_map_size, - dtype=reduce_partial_map_type, - device=device, - ) - self._num_attention_heads = num_attention_heads + self._num_attention_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config + ) + q_dtype = self.decode_attn_out_dtype + kv_cache_dtype_str = getattr(vllm_config.cache_config, "cache_dtype", "auto") + if kv_cache_dtype_str in ("fp8", "fp8_e4m3", "fp8_e5m2"): + kv_cache_dtype_str = "fp8" + else: + kv_cache_dtype_str = "bf16" + kv_dtype = dtypes.d_dtypes.get(kv_cache_dtype_str, dtypes.bf16) + ( + (work_meta_data_size, work_meta_data_type), + (work_indptr_size, work_indptr_type), + (work_info_set_size, work_info_set_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_mla_metadata_info_v1( + max_num_reqs, + 1, + self._num_attention_heads, + q_dtype, + kv_dtype, + is_sparse=False, + fast_mode=True, + ) + self._mla_work_meta_data = torch.empty( + work_meta_data_size, dtype=work_meta_data_type, device=device + ) + self._mla_work_indptr = torch.empty( + work_indptr_size, dtype=work_indptr_type, device=device + ) + self._mla_work_info_set = torch.empty( + work_info_set_size, dtype=work_info_set_type, device=device + ) + self._mla_reduce_indptr = torch.empty( + reduce_indptr_size, dtype=reduce_indptr_type, device=device + ) + self._mla_reduce_final_map = torch.empty( + reduce_final_map_size, dtype=reduce_final_map_type, device=device + ) + self._mla_reduce_partial_map = torch.empty( + reduce_partial_map_size, + dtype=reduce_partial_map_type, + device=device, + ) if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.paged_kv_indptr = torch.zeros( @@ -248,28 +243,27 @@ def _build_decode( 0, num_reqs + 1, step=1, dtype=torch.int32, device=device ) - if getattr(self, "_mla_work_meta_data", None) is not None: - from aiter import get_mla_metadata_v1 - - get_mla_metadata_v1( - qo_indptr, - paged_kv_indptr, - paged_kv_last_page_len, - self._num_attention_heads, - 1, - True, - self._mla_work_meta_data, - self._mla_work_info_set, - self._mla_work_indptr, - self._mla_reduce_indptr, - self._mla_reduce_final_map, - self._mla_reduce_partial_map, - page_size=1, - kv_granularity=16, - max_seqlen_qo=max_qo_len, - uni_seqlen_qo=max_qo_len, - fast_mode=True, - ) + from aiter import get_mla_metadata_v1 + + get_mla_metadata_v1( + qo_indptr, + paged_kv_indptr, + paged_kv_last_page_len, + self._num_attention_heads, + 1, + True, + self._mla_work_meta_data, + self._mla_work_info_set, + self._mla_work_indptr, + self._mla_reduce_indptr, + self._mla_reduce_final_map, + self._mla_reduce_partial_map, + page_size=1, + kv_granularity=16, + max_seqlen_qo=max_qo_len, + uni_seqlen_qo=max_qo_len, + fast_mode=True, + ) attn_metadata = AiterMLADecodeMetadata( block_table=block_table_tensor, @@ -294,13 +288,12 @@ def build( attn_metadata = super().build( common_prefix_len, common_attn_metadata, fast_build ) - if getattr(self, "_mla_work_meta_data", None) is not None: - attn_metadata.work_meta_data = self._mla_work_meta_data - attn_metadata.work_indptr = self._mla_work_indptr - attn_metadata.work_info_set = self._mla_work_info_set - attn_metadata.reduce_indptr = self._mla_reduce_indptr - attn_metadata.reduce_final_map = self._mla_reduce_final_map - attn_metadata.reduce_partial_map = self._mla_reduce_partial_map + attn_metadata.work_meta_data = self._mla_work_meta_data + attn_metadata.work_indptr = self._mla_work_indptr + attn_metadata.work_info_set = self._mla_work_info_set + attn_metadata.reduce_indptr = self._mla_reduce_indptr + attn_metadata.reduce_final_map = self._mla_reduce_final_map + attn_metadata.reduce_partial_map = self._mla_reduce_partial_map return attn_metadata