From 188ac882ee094498a97cf852ae173d2fba0e3dee Mon Sep 17 00:00:00 2001 From: whx-sjtu <2952154980@qq.com> Date: Sat, 28 Jun 2025 11:27:26 +0800 Subject: [PATCH] fix accuray bug of prefix-caching Signed-off-by: whx-sjtu <2952154980@qq.com> --- vllm_ascend/attention/mla_v1.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index a50c5fdf6ad..9f266fe41a2 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -751,7 +751,8 @@ def _forward_prefill( if attn_metadata.attn_state in [ AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit ] and not ascend_config.chunked_prefill_for_mla: attn_output_torch = torch.empty(num_tokens, self.num_heads * self.v_head_dim, @@ -776,7 +777,8 @@ def _forward_prefill( causal=True) elif attn_metadata.attn_state in [ AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit ]: attn_lse = torch.empty(self.num_heads, num_tokens, @@ -830,7 +832,8 @@ def _forward_prefill( [num_tokens, self.num_heads * self.v_head_dim]) if attn_metadata.attn_state in [ AscendAttentionState.ChunkedPrefill, - AscendAttentionState.SpecDecoding + AscendAttentionState.SpecDecoding, + AscendAttentionState.PrefillCacheHit ] and not ascend_config.chunked_prefill_for_mla: attn_output = attn_output_torch