Skip to content

[perf] Fix MLAPO weight disposal for KV-consumer MLA in PD-mix deploy...#5192

Merged
zzzzwwjj merged 1 commit intovllm-project:mainfrom
kiscad:fix-mlapo-mla
Jan 5, 2026
Merged

[perf] Fix MLAPO weight disposal for KV-consumer MLA in PD-mix deploy...#5192
zzzzwwjj merged 1 commit intovllm-project:mainfrom
kiscad:fix-mlapo-mla

Conversation

@kiscad
Copy link
Contributor

@kiscad kiscad commented Dec 19, 2025

What this PR does / why we need it?

  • Problem: In MLA+MLAPO, KV-consumer deployments keep fused_qkv_a_proj/q_proj weights and quant params even though MLAPO uses the prepacked buffers, increasing memory footprint on decode nodes.
  • Fix: Conditionally drop those tensors only when kv_transfer_config.is_kv_consumer to reclaim memory (consistent with the SFA behavior Fix incorrect MLAPO weight release in PD mixex scenarios. #4774 ).

Does this PR introduce any user-facing change?

How was this patch tested?

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix an issue with MLAPO weight disposal in PD-mix deployments where MLA acts as a KV consumer. The change introduces logic to conditionally clear certain weights and quantization parameters to free up memory on KV consumer nodes after they have been processed for the MLAPO kernel. While the logic for freeing memory seems correct for the intended scenario, it introduces a potential critical issue by making the _process_weights_for_fused_mlapo method non-idempotent. A second invocation on a KV consumer node would lead to a crash. I've added a comment with a suggestion to add a guard to prevent this.

@kiscad kiscad changed the title [perf] Fix MLAPO weight disposal for KV-consumer MLA in PD-mix deploy… [perf] Fix MLAPO weight disposal for KV-consumer MLA in PD-mix deploy... Dec 19, 2025
@github-actions
Copy link
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@jianzs
Copy link
Collaborator

jianzs commented Dec 22, 2025

If this step is taken, it means the decoding node must skip the prefill stage. How is this currently guaranteed? What happens if the request is preempted and needs to be computed?

@kiscad
Copy link
Contributor Author

kiscad commented Dec 23, 2025

If this step is taken, it means the decoding node must skip the prefill stage. How is this currently guaranteed? What happens if the request is preempted and needs to be computed?

Decoding node may never enter the prefilling stage.

@ZYang6263
Copy link
Collaborator

Similar to SFA, this PR releases unused memory.

@jianzs
Copy link
Collaborator

jianzs commented Dec 23, 2025

If this step is taken, it means the decoding node must skip the prefill stage. How is this currently guaranteed? What happens if the request is preempted and needs to be computed?

Decoding node may never enter the prefilling stage.

How can we ensure this assumption holds? What if some requests get preempted due to insufficient KV cache?

Comment on lines -942 to -944
"ep": {
"hccl_buffer_size": calculate_ep_buffer_size()
},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this code have to do with the current PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually not. I will put it in another PR.

@kiscad kiscad force-pushed the fix-mlapo-mla branch 2 times, most recently from 784af55 to 2c6c2e0 Compare December 23, 2025 07:18
@kiscad
Copy link
Contributor Author

kiscad commented Dec 23, 2025

If this step is taken, it means the decoding node must skip the prefill stage. How is this currently guaranteed? What happens if the request is preempted and needs to be computed?

Decoding node may never enter the prefilling stage.

How can we ensure this assumption holds? What if some requests get preempted due to insufficient KV cache?

For KV-cache insufficiency/preemption: on consumer nodes we rely on recompute_scheduler_enable. RecomputeScheduler handles allocation failures by dropping the preempted request back to the PD proxy (stop_reason="recomputed") instead of recomputing locally, so the prompt/prefill gets recomputed on a producer and KV is reloaded before decode resumes.

@jianzs
Copy link
Collaborator

jianzs commented Dec 23, 2025

If this step is taken, it means the decoding node must skip the prefill stage. How is this currently guaranteed? What happens if the request is preempted and needs to be computed?

Decoding node may never enter the prefilling stage.

How can we ensure this assumption holds? What if some requests get preempted due to insufficient KV cache?

For KV-cache insufficiency/preemption: on consumer nodes we rely on recompute_scheduler_enable. RecomputeScheduler handles allocation failures by dropping the preempted request back to the PD proxy (stop_reason="recomputed") instead of recomputing locally, so the prompt/prefill gets recomputed on a producer and KV is reloaded before decode resumes.

The recomputed feature is currently only available in the vllm-ascend proxy's examples folder and is not yet a stable solution, so it's not ready for production use.

@weijinqian0 weijinqian0 added ready read for review ready-for-test start test by label for PR labels Dec 23, 2025
@jianzs
Copy link
Collaborator

jianzs commented Dec 24, 2025

It would be better to adjust the condition for #4774 as well.

@kiscad
Copy link
Contributor Author

kiscad commented Dec 24, 2025

It would be better to adjust the condition for #4774 as well.

Yea, good suggestion.

@github-actions
Copy link
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

…ments

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
@zzzzwwjj zzzzwwjj merged commit a2daacb into vllm-project:main Jan 5, 2026
19 checks passed
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Jan 6, 2026
…to FIA_rebase

* 'main' of https://github.com/vllm-project/vllm-ascend: (58 commits)
  [Main2Main] Upgrade vllm commit to 0106 (vllm-project#5617)
  [CI]update bisheng version (vllm-project#5621)
  [UT][PCP&DCP] UT for block_table.py (vllm-project#5032)
  [Main2Main] Upgrade vllm commit to 0105 (vllm-project#5595)
  [CI] mv ops to correct path (vllm-project#5615)
  [BugFix] Fix Smoke Testing Bug for DSR1 longseq (vllm-project#5613)
  Revert "[Feat] enable hierarchical mc2 ops on A2 by default (vllm-project#5545)" (vllm-project#5611)
  [TRITON][TEST]Add nightly test for triton split_qkv_rmsnorm_rope (vllm-project#5267)
  [perf] Fix MLAPO weight disposal for KV-consumer MLA in PD-mix deploy... (vllm-project#5192)
  [docs] Correct image about prefill phase of PCP (vllm-project#5598)
  [CI] update triton-ascend version (vllm-project#5584)
  [P/D]Remove mooncake kvpool unused parameter `local_hostname` (vllm-project#5574)
  [Bugfix] record cos and sin cache in AscendRotaryEmbedding (vllm-project#5516)
  [bugfix] fix test_camem failed with triton-ascend (vllm-project#5492)
  [UT]add triton ops ut :  test_fused_qkvzba_split_reshape_cat (vllm-project#5474)
  [CI] Download models from ms (vllm-project#5405)
  Docs: Add A3 Docker image guidance for Atlas A3 machines (vllm-project#5256)
  [Doc] Add NNAL installation guide and requirements (vllm-project#5235)
  Add the requirement of arctic-inference which  speculative decoding with suffix_decode  (vllm-project#5045)
  [BugFix][Fusion] Fix graph fusion failure problem (vllm-project#5253)
  ...
Rozwel-dx pushed a commit to Rozwel-dx/vllm-ascend that referenced this pull request Jan 8, 2026
…... (vllm-project#5192)

### What this PR does / why we need it?

- Problem: In MLA+MLAPO, KV-consumer deployments keep
fused_qkv_a_proj/q_proj weights and quant params even though MLAPO uses
the prepacked buffers, increasing memory footprint on decode nodes.
- Fix: Conditionally drop those tensors only when
`kv_transfer_config.is_kv_consumer` to reclaim memory (consistent with
the SFA behavior vllm-project#4774 ).

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
aipaes pushed a commit to aipaes/vllm-ascend that referenced this pull request Jan 15, 2026
…... (vllm-project#5192)

### What this PR does / why we need it?

- Problem: In MLA+MLAPO, KV-consumer deployments keep
fused_qkv_a_proj/q_proj weights and quant params even though MLAPO uses
the prepacked buffers, increasing memory footprint on decode nodes.
- Fix: Conditionally drop those tensors only when
`kv_transfer_config.is_kv_consumer` to reclaim memory (consistent with
the SFA behavior vllm-project#4774 ).

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…... (vllm-project#5192)

### What this PR does / why we need it?

- Problem: In MLA+MLAPO, KV-consumer deployments keep
fused_qkv_a_proj/q_proj weights and quant params even though MLAPO uses
the prepacked buffers, increasing memory footprint on decode nodes.
- Fix: Conditionally drop those tensors only when
`kv_transfer_config.is_kv_consumer` to reclaim memory (consistent with
the SFA behavior vllm-project#4774 ).

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
…... (vllm-project#5192)

### What this PR does / why we need it?

- Problem: In MLA+MLAPO, KV-consumer deployments keep
fused_qkv_a_proj/q_proj weights and quant params even though MLAPO uses
the prepacked buffers, increasing memory footprint on decode nodes.
- Fix: Conditionally drop those tensors only when
`kv_transfer_config.is_kv_consumer` to reclaim memory (consistent with
the SFA behavior vllm-project#4774 ).

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…... (vllm-project#5192)

### What this PR does / why we need it?

- Problem: In MLA+MLAPO, KV-consumer deployments keep
fused_qkv_a_proj/q_proj weights and quant params even though MLAPO uses
the prepacked buffers, increasing memory footprint on decode nodes.
- Fix: Conditionally drop those tensors only when
`kv_transfer_config.is_kv_consumer` to reclaim memory (consistent with
the SFA behavior vllm-project#4774 ).

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
…... (vllm-project#5192)

### What this PR does / why we need it?

- Problem: In MLA+MLAPO, KV-consumer deployments keep
fused_qkv_a_proj/q_proj weights and quant params even though MLAPO uses
the prepacked buffers, increasing memory footprint on decode nodes.
- Fix: Conditionally drop those tensors only when
`kv_transfer_config.is_kv_consumer` to reclaim memory (consistent with
the SFA behavior vllm-project#4774 ).

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants