Skip to content

[Feature][Attention][PCP] Support PCP (Prefill Context Parallel) with MLA#28988

Open
FENP wants to merge 11 commits intovllm-project:mainfrom
FENP:prefill-context-parallel-mla
Open

[Feature][Attention][PCP] Support PCP (Prefill Context Parallel) with MLA#28988
FENP wants to merge 11 commits intovllm-project:mainfrom
FENP:prefill-context-parallel-mla

Conversation

@FENP
Copy link
Copy Markdown
Contributor

@FENP FENP commented Nov 19, 2025

Purpose

Ref to issue #25749. Enable PCP for MLA models.
This PR mainly includes the following changes:

  • Add PCP Manager to vllm/v1/worker/cp_utils.py for model runner.
  • Modified vllm/v1/worker/gpu_model_runner.py for PCP splitting logic for tokens
  • Modified vllm/v1/attention/backends/mla/common.py to adapt the MLA backend to PCP
  • Add utility functions required by PCP to vllm/v1/attention/backends/utils.py and vllm/attention/ops/common.py
  • Renamed variables and functions shared by both PCP and DCP

Test Plan

vllm serve deepseek-ai/DeepSeek-V2-Lite-Chat --gpu-memory-utilization 0.9 --tensor-parallel-size 1 --prefill-context-parallel-size 2

Test Result

  • PCP1TP8 (Baseline)
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6277|±  |0.0094|
|     |       |strict-match    |     5|exact_match|↑  |0.6179|±  |0.0095|
  • PCP2TP4
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6259|±  |0.0094|
|     |       |strict-match    |     5|exact_match|↑  |0.6164|±  |0.0095|
  • PCP4TP2
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6308|±  |0.0094|
|     |       |strict-match    |     5|exact_match|↑  |0.6236|±  |0.0094|
  • PCP8TP1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6240|±  |0.0094|
|     |       |strict-match    |     5|exact_match|↑  |0.6168|±  |0.0095|

Benchmark

In addition to reducing GPU memory redundancy and increasing KV cache capacity, PCP can also reduce the all-reduce communication overhead of o_proj and lower TTFT. We evaluated the performance of PCP and TP on DeepSeek-R1 and Kimi-K2 using 4K-length inputs on the H20-3e.

DeepSeek-R1

vllm bench serve --backend vllm  --model deepseek-ai/DeepSeek-R1/  --endpoint /v1/completions   --dataset-name random   --random-input 4096   --random-output 1   --max-concurrency 1  --num-prompt 10 --ignore-eos --metric-percentiles "50,90,99"
Parallel Config P50 TTFT (ms) P90 TTFT (ms) P99 TTFT (ms) Performance gain
PCP1TP8 414 417 418 ---
PCP2TP4 400 403 404 3.2%
PCP4TP2 392 394 395 5.3%

Kimi-K2

vllm bench serve --backend vllm  --model moonshotai/Kimi-K2-Instruct/  --endpoint /v1/completions   --dataset-name random   --random-input 4096   --random-output 1   --max-concurrency 1  --num-prompt 10 --ignore-eos --metric-percentiles "50,90,99"  --trust-remote-code
Parallel Config P50 TTFT (ms) P90 TTFT (ms) P99 TTFT (ms) Performance gain
PCP1TP8 394 396 396 ---
PCP2TP4 371 373 373 5.9%
PCP4TP2 351 353 354 10.9%

Of course, PCP additionally introduces communication overhead from KV all-gather and index select kernel launch overhead for restoring KV. Further tuning is still needed to improve performance.

  • Support piecewise graph
  • Support chunk prefill and prefix caching
  • Add ci test
  • Accuracy test
  • Benchmark

Limitations

Although the current PCP logic is fully compatible with decoding, the lack of splitting of decode tokens means that every PCP rank holds the full set of decode tokens, leading to significant redundant communication and computation (include attention and MoE). We therefore recommend enabling PCP on P instances in a P/D-disaggregation case.

Future work

These items will be tackled in follow-up PRs; community contributions are warmly welcomed.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify
Copy link
Copy Markdown

mergify bot commented Nov 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @FENP.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 19, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Prefill Context Parallelism (PCP) to the Multi-Level Attention (MLA) backend. The changes are extensive, involving refactoring of context parallelism logic to be more generic, adding new metadata and utility functions for PCP, and implementing the PCP attention logic based on the Dual-Chunk-Swap strategy.

My review has identified a critical issue in the attention correction logic where DCP and PCP corrections are applied in the wrong order, which will lead to incorrect results. I have also pointed out a significant performance issue related to nested communication calls that should be optimized. Overall, the PR is a good step towards enabling PCP, but these critical issues need to be addressed.

Comment on lines 1828 to 1833
cur_allgather_kvcache.copy_(
get_dcp_group().all_gather(local_gathered_kvcache, dim=0)
get_pcp_group().all_gather(
get_dcp_group().all_gather(local_gathered_kvcache, dim=0),
dim=0,
)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The nested all_gather calls, first over the DCP group and then over the PCP group, are inefficient as they introduce extra communication overhead and synchronization points. This should be optimized into a single all_gather operation.

To achieve this, a new communication group that combines the ranks from both DCP and PCP should be created during initialization. Then, a single all_gather can be performed over this combined "context parallel" (CP) group. This will be more performant. The TODO comment already acknowledges this, and this comment serves to emphasize its importance for performance.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

https://github.com/vllm-project/vllm/blob/9c4884f9884071f7d36b26df87b69eeb6a08ae26/v1/attention/backends/mla/common.py#L212-L215
P1 Badge Importing undefined get_pcp_group

Lines 212‑215 import get_pcp_group from vllm.distributed.parallel_state, but that module still only exposes get_dcp_group (the commit merely introduced a _CP variable without any getter). Importing common.py will therefore immediately raise ImportError: cannot import name 'get_pcp_group', so none of the new PCP code paths can even be instantiated.


https://github.com/vllm-project/vllm/blob/9c4884f9884071f7d36b26df87b69eeb6a08ae26/v1/attention/backends/mla/flashattn_mla.py#L79-L86
P1 Badge FlashAttn builder now passes nonexistent kwarg

The call to super().__init__(…, supports_cp_with_varlen=True) in FlashAttnMLAMetadataBuilder.__init__ (lines 79‑86) will raise TypeError: __init__() got an unexpected keyword argument 'supports_cp_with_varlen' because MLACommonMetadataBuilder.__init__ still only accepts supports_dcp_with_varlen. This prevents the FlashAttn MLA backend from constructing at all.


https://github.com/vllm-project/vllm/blob/9c4884f9884071f7d36b26df87b69eeb6a08ae26/v1/attention/backends/mla/common.py#L572-L574
P1 Badge Referencing cp_kv_cache_interleave_size attribute that does not exist

Lines 572‑574 now read self.cp_local_block_size = parallel_config.cp_kv_cache_interleave_size, but ParallelConfig (vllm/config/parallel.py) defines only dcp_kv_cache_interleave_size. As soon as MLACommonMetadataBuilder is constructed this access raises AttributeError: 'ParallelConfig' object has no attribute 'cp_kv_cache_interleave_size', so the MLA backend cannot even initialize.


https://github.com/vllm-project/vllm/blob/9c4884f9884071f7d36b26df87b69eeb6a08ae26/v1/attention/backends/utils.py#L1118-L1124
P1 Badge New utils annotation causes NameError on import

The new helper pcp_kv_allgather_and_restore (lines 1118‑1124) annotates pcp_group: GroupCoordinator, but GroupCoordinator is only imported inside the TYPE_CHECKING block and there is no from __future__ import annotations. When Python evaluates these annotations at import time it looks up GroupCoordinator, fails to find the name, and raises NameError, breaking vllm.v1.attention.backends.utils for every runtime import.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@mergify mergify bot removed the needs-rebase label Nov 19, 2025
@FENP FENP force-pushed the prefill-context-parallel-mla branch 3 times, most recently from 2d9034a to 8ac9843 Compare November 20, 2025 07:10
@FENP FENP requested review from mgoin and tjtanaa as code owners November 20, 2025 07:10
@mergify mergify bot added nvidia rocm Related to AMD ROCm labels Nov 20, 2025
@FENP FENP force-pushed the prefill-context-parallel-mla branch from 8ac9843 to 5e79da7 Compare November 20, 2025 07:38
@mergify
Copy link
Copy Markdown

mergify bot commented Nov 21, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @FENP.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

FENP added 6 commits December 28, 2025 14:35
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
@FENP FENP force-pushed the prefill-context-parallel-mla branch from 4b40860 to a0a9181 Compare December 28, 2025 09:57
@mergify mergify bot removed the needs-rebase label Dec 28, 2025
@FENP
Copy link
Copy Markdown
Contributor Author

FENP commented Dec 28, 2025

hi @FENP Can you resolve the conflict? I want to test it locally.

Hi @chaunceyjiang , I've resolved the conflicts—feel free to try it out and share your feedback!

@mergify
Copy link
Copy Markdown

mergify bot commented Dec 29, 2025

Hi @FENP, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

…vice tensor

Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
@FENP FENP force-pushed the prefill-context-parallel-mla branch from a2b046a to cf9c2eb Compare January 4, 2026 03:16
@mergify
Copy link
Copy Markdown

mergify bot commented Jan 4, 2026

Hi @FENP, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

zzzzwwjj pushed a commit to vllm-project/vllm-ascend that referenced this pull request Jan 4, 2026
### What this PR does / why we need it?
Since the [PR](vllm-project/vllm#28988) for PCP
modifications to `GPUModelRunner` has not yet been merged into vLLM,
this PR temporarily requires adjustments to certain buffer sizes. These
changes can be reverted once the original
[PR](vllm-project/vllm#28988) is merged.

### Does this PR introduce _any_ user-facing change?
No

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@5326c89

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
@LucasWilkinson
Copy link
Copy Markdown
Collaborator

LucasWilkinson commented Jan 7, 2026

@FENP is this ready for another round of review? happy to start reviewing whenever it is ready

@FENP
Copy link
Copy Markdown
Contributor Author

FENP commented Jan 7, 2026

@FENP is this ready for another round of review? happy to start reviewing whenever it is ready

@LucasWilkinson It’s ready for your review. Thanks in advance!
BTW, the current pre-commit error is a bit strange — I don't understand why it's reporting "Undefined name num_scheduled_tokens_np" at vllm/v1/worker/gpu_model_runner.py:3252:52.

Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Jan 7, 2026

Hi @FENP, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@LucasWilkinson LucasWilkinson self-assigned this Jan 7, 2026
@mergify
Copy link
Copy Markdown

mergify bot commented Jan 7, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @FENP.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 7, 2026
Comment on lines +3205 to +3207
if self.pcp_world_size > 1:
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it quite messy that _prepare_inputs modifies num_scheduled_tokens_np and scheduler_output internally. Its not very clear to the reader here thats whats happening and thats why this recomputation/re-assignment is required.

I think we should try harder to keep PCP more isolated for now, im working on an idea here (vibe coded and not tested yet): FENP#4 to leave _prepare_inputs untouched. It just shuffles the inputs after preparation to select for the tokens this PCP rank cares about. It means we do some duplicated/wasted work, but I think better for the initial implementations and we can do broader refactors on the model runner later to make it support this more naturally with less duplicated/wasted work, potentially by breaking up prepare inputs. Thoughts? (sorry not a complete review yet, will continue tmrw)

Rozwel-dx pushed a commit to Rozwel-dx/vllm-ascend that referenced this pull request Jan 8, 2026
### What this PR does / why we need it?
Since the [PR](vllm-project/vllm#28988) for PCP
modifications to `GPUModelRunner` has not yet been merged into vLLM,
this PR temporarily requires adjustments to certain buffer sizes. These
changes can be reverted once the original
[PR](vllm-project/vllm#28988) is merged.

### Does this PR introduce _any_ user-facing change?
No

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@5326c89

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
### What this PR does / why we need it?
Since the [PR](vllm-project/vllm#28988) for PCP
modifications to `GPUModelRunner` has not yet been merged into vLLM,
this PR temporarily requires adjustments to certain buffer sizes. These
changes can be reverted once the original
[PR](vllm-project/vllm#28988) is merged.

### Does this PR introduce _any_ user-facing change?
No

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@5326c89

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
### What this PR does / why we need it?
Since the [PR](vllm-project/vllm#28988) for PCP
modifications to `GPUModelRunner` has not yet been merged into vLLM,
this PR temporarily requires adjustments to certain buffer sizes. These
changes can be reverted once the original
[PR](vllm-project/vllm#28988) is merged.

### Does this PR introduce _any_ user-facing change?
No

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@5326c89

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
### What this PR does / why we need it?
Since the [PR](vllm-project/vllm#28988) for PCP
modifications to `GPUModelRunner` has not yet been merged into vLLM,
this PR temporarily requires adjustments to certain buffer sizes. These
changes can be reverted once the original
[PR](vllm-project/vllm#28988) is merged.

### Does this PR introduce _any_ user-facing change?
No

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@5326c89

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
### What this PR does / why we need it?
Since the [PR](vllm-project/vllm#28988) for PCP
modifications to `GPUModelRunner` has not yet been merged into vLLM,
this PR temporarily requires adjustments to certain buffer sizes. These
changes can be reverted once the original
[PR](vllm-project/vllm#28988) is merged.

### Does this PR introduce _any_ user-facing change?
No

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@5326c89

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

5 participants