diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index b74efc8d8b9..cc443f55b1a 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -470,7 +470,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase): if self.fused_qkv_a_proj is None or not isinstance( quant_method, AscendW8A8LinearMethod): reasons.append( - "Currently mlapo only supports W8A8 quantization in MLA scenario." + "Currently mlapo only supports W8A8 quantization in SFA scenario." "Some layers in your model are not quantized with W8A8," "thus mlapo is disabled for these layers.") if self.enable_sfa_cp: @@ -597,8 +597,6 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): q_a_proj_wt = self.fused_qkv_a_proj.weight.data[ ..., :self.q_lora_rank].contiguous() - self.fused_qkv_a_proj.weight = None - kv_a_proj_wt = kv_a_proj_wt.t().contiguous() kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) kv_a_proj_wt = kv_a_proj_wt.t().contiguous() @@ -673,9 +671,12 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device) self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device) - if self.vllm_config.kv_transfer_config is not None: + if self.vllm_config.kv_transfer_config is not None and \ + self.vllm_config.kv_transfer_config.is_kv_consumer: + self.fused_qkv_a_proj.weight = None self.fused_qkv_a_proj.deq_scale = None self.fused_qkv_a_proj.quant_bias = None + self.q_proj.weight = None self.q_proj.deq_scale = None self.q_proj.quant_bias = None torch.npu.empty_cache()