diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 00315137420..50d933a1e65 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -785,7 +785,7 @@ def exec_kv_prefill( # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv_no_split = kv_no_split.view( B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA" + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, self.kv_a_layernorm.weight, diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 80ada4d04a0..1527820e5dc 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -961,7 +961,7 @@ def exec_kv_prefill( kv = self.kv_a_proj_with_mqa(hidden_states)[0] # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA_BLK_NZ" if self.enable_kv_nz else "PA" + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( kv, self.kv_a_layernorm.weight,