diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 393ddb25eda..24ffcdd4b69 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -152,6 +152,8 @@ class AscendSFAMetadata: dsa_cp_context: DSACPContext | None = None reshape_cache_event: torch.npu.Event = None sfa_cp_metadata: AscendPCPMetadata | None = None + # Shared top-k indices reused by IndexCache layers in the same forward pass. + shared_topk_indices: torch.Tensor | None = None num_decodes: int = 0 num_decode_tokens: int = 0 num_prefills: int = 0 @@ -447,6 +449,7 @@ def __init__( if self.vllm_config.model_config.hf_config.model_type in ["glm_moe_dsa"]: self.is_rope_neox_style = False self.use_torch_npu_lightning_indexer = True + self.skip_topk = kwargs.get("skip_topk", False) # dsa c8 self.use_sparse_c8_indexer = ascend_config.is_sparse_c8_layer(self.layer_name) @@ -1229,16 +1232,20 @@ def forward( if self.is_kv_producer: attn_metadata.reshape_cache_event.record() - topk_indices = self.indexer_select_post_process( - x=hidden_states, - q_c=q_c, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - cos=cos, - sin=sin, - actual_seq_lengths_query=actual_seq_lengths_query, - actual_seq_lengths_key=actual_seq_lengths_key, - ) + if self.skip_topk and attn_metadata.shared_topk_indices is not None: + topk_indices = attn_metadata.shared_topk_indices + else: + topk_indices = self.indexer_select_post_process( + x=hidden_states, + q_c=q_c, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + cos=cos, + sin=sin, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + ) + attn_metadata.shared_topk_indices = topk_indices attn_output = self._execute_sparse_flash_attention_process( ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key diff --git a/vllm_ascend/ops/mla.py b/vllm_ascend/ops/mla.py index 689ed0cd672..2b9672d6a12 100644 --- a/vllm_ascend/ops/mla.py +++ b/vllm_ascend/ops/mla.py @@ -80,6 +80,7 @@ def __init__( cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", + skip_topk: bool = False, ) -> None: nn.Module.__init__(self) self.hidden_size = hidden_size @@ -122,6 +123,7 @@ def __init__( kv_a_layernorm=mla_modules.kv_a_layernorm, o_proj=mla_modules.o_proj, layer_name=f"{prefix}.attn", + skip_topk=skip_topk, ) original_process_weights = self.mla_attn.process_weights_after_loading