diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 393ddb25eda..2f04cdfd141 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -417,6 +417,8 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tp_group().rank_in_group self.q_b_proj = kwargs["q_b_proj"] + self.skip_topk = kwargs.get("skip_topk", False) + self.topk_indices_buffer = kwargs.get("topk_indices_buffer") ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp @@ -429,6 +431,11 @@ def __init__( self.local_num_heads = self.num_heads self.vllm_config = get_current_vllm_config() + self.use_index_cache = self.skip_topk or getattr( + self.vllm_config.model_config.hf_config, + "use_index_cache", + False, + ) self.is_kv_producer = ( self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer ) @@ -1027,6 +1034,26 @@ def indexer_select_post_process( ) return topk_indices + def _get_indexcache_topk_indices(self, num_tokens: int) -> torch.Tensor: + if self.topk_indices_buffer is None: + raise RuntimeError("IndexCache requires topk_indices_buffer when skip_topk is enabled.") + topk_indices = self.topk_indices_buffer[:num_tokens] + if topk_indices.dim() == 2: + topk_indices = topk_indices.unsqueeze(1) + return topk_indices + + def _update_indexcache_topk_indices(self, topk_indices: torch.Tensor) -> None: + if self.topk_indices_buffer is None: + return + num_tokens = topk_indices.shape[0] + topk_tokens = topk_indices.shape[-1] + topk_indices_to_cache = topk_indices + topk_indices_buffer = self.topk_indices_buffer[:num_tokens, :topk_tokens] + if topk_indices_to_cache.dim() == 3 and topk_indices_buffer.dim() == 2: + assert topk_indices_to_cache.shape[1] == 1 + topk_indices_to_cache = topk_indices_to_cache.squeeze(1) + topk_indices_buffer.copy_(topk_indices_to_cache) + def _execute_sparse_flash_attention_process( self, ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key ): @@ -1229,16 +1256,22 @@ def forward( if self.is_kv_producer: attn_metadata.reshape_cache_event.record() - topk_indices = self.indexer_select_post_process( - x=hidden_states, - q_c=q_c, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - cos=cos, - sin=sin, - actual_seq_lengths_query=actual_seq_lengths_query, - actual_seq_lengths_key=actual_seq_lengths_key, - ) + topk_num_tokens = num_input_tokens or hidden_states.shape[0] + if self.skip_topk: + topk_indices = self._get_indexcache_topk_indices(topk_num_tokens) + else: + topk_indices = self.indexer_select_post_process( + x=hidden_states, + q_c=q_c, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + cos=cos, + sin=sin, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + ) + if self.use_index_cache: + self._update_indexcache_topk_indices(topk_indices) attn_output = self._execute_sparse_flash_attention_process( ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key diff --git a/vllm_ascend/ops/mla.py b/vllm_ascend/ops/mla.py index 689ed0cd672..8462f76c24a 100644 --- a/vllm_ascend/ops/mla.py +++ b/vllm_ascend/ops/mla.py @@ -80,6 +80,7 @@ def __init__( cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", + skip_topk: bool = False, ) -> None: nn.Module.__init__(self) self.hidden_size = hidden_size @@ -90,6 +91,7 @@ def __init__( self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim self.v_head_dim = v_head_dim self.prefix = prefix + self.skip_topk = skip_topk hf_config = get_current_vllm_config().model_config.hf_text_config self.enable_shared_expert_dp = get_ascend_config().enable_shared_expert_dp self.tp_size = get_tensor_model_parallel_world_size() @@ -112,6 +114,8 @@ def __init__( prefix=f"{prefix}.attn", use_sparse=mla_modules.is_sparse, indexer=ascend_indexer, + skip_topk=skip_topk, + topk_indices_buffer=getattr(mla_modules, "topk_indices_buffer", None), # extra args rotary_emb=mla_modules.rotary_emb, fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,