Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions vllm_ascend/attention/sfa_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
if self.fused_qkv_a_proj is None or not isinstance(
quant_method, AscendW8A8LinearMethod):
reasons.append(
"Currently mlapo only supports W8A8 quantization in MLA scenario."
"Currently mlapo only supports W8A8 quantization in SFA scenario."
"Some layers in your model are not quantized with W8A8,"
"thus mlapo is disabled for these layers.")
if self.enable_sfa_cp:
Expand Down Expand Up @@ -597,8 +597,6 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[
..., :self.q_lora_rank].contiguous()

self.fused_qkv_a_proj.weight = None

kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
Expand Down Expand Up @@ -673,9 +671,12 @@ def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)

if self.vllm_config.kv_transfer_config is not None:
if self.vllm_config.kv_transfer_config is not None and \
self.vllm_config.kv_transfer_config.is_kv_consumer:
self.fused_qkv_a_proj.weight = None
self.fused_qkv_a_proj.deq_scale = None
self.fused_qkv_a_proj.quant_bias = None
self.q_proj.weight = None
self.q_proj.deq_scale = None
self.q_proj.quant_bias = None
torch.npu.empty_cache()
Comment on lines +674 to 682
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current logic for releasing MLAPO weights only covers the case where the node is a KV consumer in a Prefill-Decode (PD) mixed scenario. This correctly fixes the original bug but introduces a memory leak in non-PD (standalone) scenarios.

Previously, self.fused_qkv_a_proj.weight was released unconditionally after its data was processed. With this change, it is no longer released in non-PD scenarios, as self.vllm_config.kv_transfer_config would be None.

The logic should be updated to release the processed weights in both non-PD scenarios and on the consumer side of PD scenarios. This ensures memory is freed correctly in all configurations.

Suggested change
if self.vllm_config.kv_transfer_config is not None and \
self.vllm_config.kv_transfer_config.is_kv_consumer:
self.fused_qkv_a_proj.weight = None
self.fused_qkv_a_proj.deq_scale = None
self.fused_qkv_a_proj.quant_bias = None
self.q_proj.weight = None
self.q_proj.deq_scale = None
self.q_proj.quant_bias = None
torch.npu.empty_cache()
if self.vllm_config.kv_transfer_config is None or \
self.vllm_config.kv_transfer_config.is_kv_consumer:
self.fused_qkv_a_proj.weight = None
self.fused_qkv_a_proj.deq_scale = None
self.fused_qkv_a_proj.quant_bias = None
self.q_proj.weight = None
self.q_proj.deq_scale = None
self.q_proj.quant_bias = None
torch.npu.empty_cache()

Expand Down
Loading