diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index b6f90c717d4..13ea34ab31b 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -909,6 +909,21 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device) self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device) + # On KV consumers (decode-only) MLAPO uses the transformed weights built above; + # the original fused_qkv_a_proj/q_proj weights and quant params are no longer + # referenced, so drop them to save memory. + ascend_config = get_ascend_config() + if self.vllm_config.kv_transfer_config is not None and \ + self.vllm_config.kv_transfer_config.is_kv_consumer and \ + ascend_config.recompute_scheduler_enable: + self.fused_qkv_a_proj.weight = None + self.fused_qkv_a_proj.deq_scale = None + self.fused_qkv_a_proj.quant_bias = None + self.q_proj.weight = None + self.q_proj.deq_scale = None + self.q_proj.quant_bias = None + torch.npu.empty_cache() + def get_context_seq_len_npu(self, index: int, attn_metadata: AscendMLAMetadata): prefill_metadata = attn_metadata.prefill diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 6588686eb57..12ac00bc09c 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -371,7 +371,7 @@ def __init__( if self.enable_sfa_cp: self.local_num_heads = self.num_heads * self.tp_size - #TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97 + # TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97 self._replace_linear_class_for_sfa_cp() from vllm_ascend.distributed.parallel_state import \ get_shared_weight_group @@ -537,7 +537,7 @@ def exec_kv( cache_mode=cache_mode, is_output_kv=True, ) - #TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97 + # TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97 k_pe = get_tp_group().all_gather(k_pe, 0) k_nope = get_tp_group().all_gather(k_nope, 0) @@ -659,8 +659,13 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device) self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device) + # On KV consumers (decode-only) MLAPO uses the transformed weights built above; + # the original fused_qkv_a_proj/q_proj weights and quant params are no longer + # referenced, so drop them to save memory. + ascend_config = get_ascend_config() if self.vllm_config.kv_transfer_config is not None and \ - self.vllm_config.kv_transfer_config.is_kv_consumer: + self.vllm_config.kv_transfer_config.is_kv_consumer and \ + ascend_config.recompute_scheduler_enable: self.fused_qkv_a_proj.weight = None self.fused_qkv_a_proj.deq_scale = None self.fused_qkv_a_proj.quant_bias = None