Skip to content

[Core] Pipeline Parallel support for Model Runner V2#33960

Merged
WoosukKwon merged 2 commits intovllm-project:mainfrom
ZhanqiuHu:feature/model-runner-v2-pp
Feb 17, 2026
Merged

[Core] Pipeline Parallel support for Model Runner V2#33960
WoosukKwon merged 2 commits intovllm-project:mainfrom
ZhanqiuHu:feature/model-runner-v2-pp

Conversation

@ZhanqiuHu
Copy link
Copy Markdown
Contributor

@ZhanqiuHu ZhanqiuHu commented Feb 6, 2026

Summary

Co-authored with @yewentao256

Add Pipeline Parallel (PP) support to Model Runner V2 (vllm/v1/worker/gpu/model_runner.py). This introduces a modular PPHandler class that encapsulates all PP logic, keeping the model runner code clean. Verified correct output and competitive throughput against the V1 baseline.

Related: #32455 (Q1 2026 Roadmap) — PP is listed as a missing feature for Model Runner V2.

Changes

In Model Runner (vllm/v1/worker/gpu/model_runner.py)

  • execute_model:

    • First rank: Run model forward with raw token inputs, send IntermediateTensors to next stage.
    • Middle ranks: Accept IntermediateTensors from previous stage, run model forward, send IntermediateTensors to next stage.
    • Last rank: Accept IntermediateTensors from previous stage, run model forward, store hidden states for sampling.
    • Delegates input/output preparation to PPHandler.
  • sample_tokens:

    • Last rank: Sample tokens, broadcast sampled_token_ids/num_sampled/num_rejected to all other ranks, then return ModelRunnerOutput.
    • Non-last ranks: Receive broadcast tensors, call postprocess to update local state, return None.
  • _dummy_run: Create dummy intermediate tensors for non-first ranks; skip sampler for non-last ranks.

  • capture_model: Skip CUDA graph capture when PP is enabled (eager-only for now).

  • Deterministic request sorting: Use req_id as tie-breaker to ensure consistent batch ordering across PP ranks.

  • Multimodal guard: Only prepare MM embeddings on the first PP rank.

New Class: PPHandler (vllm/v1/worker/gpu/pp_handler.py)

New module with a PPHandler class. All public methods no-op when PP is disabled or called on an inapplicable rank, so callers don't need guard conditions.

Method Description
maybe_broadcast_sampled_tokens Last rank broadcasts sampled_token_ids, num_sampled, num_rejected to all PP ranks. No-ops on non-last ranks.
maybe_receive_sampled_tokens Non-last ranks receive the broadcast tensors. Returns None on last rank. Supports variable max_sample_len for future speculative decoding + PP.
prepare_model_inputs Builds the model.forward() kwargs — raw inputs for first rank, intermediate tensors for others.
prepare_output Extracts hidden states (last rank) or wraps as IntermediateTensors (non-last ranks).

Helper classes: PPConfig (dataclass with rank role, size, index) and PPRole enum (NO_PP, FIRST, MIDDLE, LAST).

Future Work

  • CUDA graph support for PP (currently eager-only)
  • Async (non-blocking) next-sampled-token transfer
  • Async intermediate tensor send/recv
  • Speculative decoding + PP integration

Test Plan

Accuracy — lm_eval gsm8k (5-shot):

export MODEL="Qwen/Qwen3-30B-A3B-Thinking-2507-FP8"

# V1 baseline
vllm serve $MODEL -pp 2 --port 9256 --enable-expert-parallel --max-num-seqs 128 --enforce-eager

# V2
VLLM_USE_V2_MODEL_RUNNER=1 vllm serve $MODEL -pp 2 --port 9256 --enable-expert-parallel --max-num-seqs 128 --enforce-eager

# Evaluate
lm_eval --model local-completions \
  --model_args "base_url=http://127.0.0.1:9256/v1/completions,model=$MODEL,num_concurrent=1024" \
  --tasks gsm8k

Throughput — decode-heavy random workload:

vllm bench serve --model $MODEL --dataset-name random --host 127.0.0.1 \
  --random-input-len 2 --random-output-len 512 --num-prompts 128 \
  --port 9256 --num-warmups 16

Test Results

Accuracy (gsm8k, PP=2, Qwen3-30B-A3B MoE FP8)

Metric V1 V2
flexible-extract (exact_match) 0.6641 ± 0.0130 0.6641 ± 0.0130
strict-match (exact_match) 0.7801 ± 0.0114 0.7794 ± 0.0114

V2 matches V1 accuracy. The 0.0007 difference in strict-match is within noise (±0.0114 stderr).

Throughput (PP=2, Qwen3-30B-A3B MoE FP8, input=2, output=512, 128 prompts)

Metric V1 V2 Diff
Output token throughput (tok/s) 7315.42 6956.99 -4.9%
Total token throughput (tok/s) 7344.00 6984.17 -4.9%
Mean TTFT (ms) 170.33 194.20 +14.0%
Mean TPOT (ms) 17.12 17.99 +5.1%
Mean ITL (ms) 17.12 17.99 +5.1%

Note: Claude (Anthropic) was used as a coding assistant during development of this PR.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the v1 label Feb 6, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces pipeline parallelism (PP) support for the Model Runner V2 by adding a new PPHandler class. This is a great approach as it encapsulates the PP-related logic, keeping the model runner code cleaner and more maintainable. The changes in GPUModelRunner correctly integrate this handler for different stages of model execution, such as dummy runs, model execution, and token sampling. The distinction between pipeline stages (first, middle, last) is handled well throughout the changes. The current implementation uses blocking communication, which is a reasonable first step, with non-blocking communication planned for future work.

I have one suggestion to improve the robustness of the PPHandler by adding a more explicit check for expected tensor keys, which will improve the developer experience for those implementing PP support in new models.

if self.produces_final_output:
# Last rank: extract hidden states for sampling
if isinstance(hidden_states, IntermediateTensors):
return hidden_states["hidden_states"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Accessing hidden_states["hidden_states"] directly can lead to an uninformative KeyError if a model's forward method on the last pipeline stage returns an IntermediateTensors object that doesn't contain the "hidden_states" key. To improve robustness and provide a clearer error message for model developers, it's better to check for the key's existence and raise a ValueError with an explanatory message if it's missing.

Suggested change
return hidden_states["hidden_states"]
if "hidden_states" not in hidden_states.tensors:
raise ValueError(
"IntermediateTensors from model on the last PP rank must "
"contain 'hidden_states' tensor.")
return hidden_states["hidden_states"]

@mengjian0502
Copy link
Copy Markdown

🚀🚀🚀

@ZhanqiuHu ZhanqiuHu force-pushed the feature/model-runner-v2-pp branch from b8c0f49 to c863e21 Compare February 9, 2026 17:26
@ZhanqiuHu ZhanqiuHu marked this pull request as ready for review February 9, 2026 17:27
@ZhanqiuHu ZhanqiuHu requested a review from WoosukKwon as a code owner February 9, 2026 17:28
@ZhanqiuHu ZhanqiuHu changed the title [Core][WIP] Pipeline Parallel support for Model Runner V2 [Core] Pipeline Parallel support for Model Runner V2 Feb 9, 2026
@ZhanqiuHu ZhanqiuHu force-pushed the feature/model-runner-v2-pp branch from e169ff1 to 5afb5ae Compare February 9, 2026 22:55
Copy link
Copy Markdown
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZhanqiuHu Thanks for the PR! I like the design of using PP handler to encapsulate the PP-related logics.

That said, I think we could improve the model runner code and make the change even smaller. Please check out my comments.

Comment on lines +939 to +945
# PP input preparation: handler centralizes all PP input logic.
model_inputs = self.pp_handler.prepare_model_inputs(
input_batch.input_ids,
positions,
input_batch.inputs_embeds,
intermediate_tensors,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep the model inputs explicit? I’m not a fan of encapsulating them into a separate object. For example, when supporting CUDA graphs, it’s critical to reason about all inputs and ensure they use consistent memory addresses. This abstraction makes that harder to see.

self.execute_model_state = hidden_states, input_batch, kv_connector_output
output = self.pp_handler.prepare_output(hidden_states, kv_connector_output)

if isinstance(output, IntermediateTensors):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just check the PP rank directly?

Comment on lines +998 to +1001
# Broadcast to non-last ranks (handles spec decode multi-token)
self.pp_handler.maybe_broadcast_sampled_tokens(
sampler_output, num_sampled, num_rejected
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be nice if we can do something like

Suggested change
# Broadcast to non-last ranks (handles spec decode multi-token)
self.pp_handler.maybe_broadcast_sampled_tokens(
sampler_output, num_sampled, num_rejected
)
if self.use_pp:
# Broadcast to non-last ranks (handles spec decode multi-token)
self.pp_handler.maybe_broadcast_sampled_tokens(
sampler_output, num_sampled, num_rejected
)

…code readiness

Co-authored with @yewentao256

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
@ZhanqiuHu ZhanqiuHu force-pushed the feature/model-runner-v2-pp branch from 5afb5ae to ef1f640 Compare February 10, 2026 21:27
@mergify
Copy link
Copy Markdown

mergify bot commented Feb 10, 2026

Hi @ZhanqiuHu, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@ZhanqiuHu
Copy link
Copy Markdown
Contributor Author

Hi @WoosukKwon,

Thanks for the review! Just pushed the updates:

Changes:

  • Exposed use_pp checking and PP rank checking (get_pp_group().is_first_rank / is_last_rank) explicitly in model runner.
  • Only create pp_handler when PP is enabled.
  • Removed unused code in pp_handler.py, it now only handles token broadcast/receive.
  • Made model inputs and outputs processing explicit in model runner (instead of prepare_model_inputs / prepare_output).

Note:

  • Right now PPHandler doesn't hold any state, just methods, technically no class is needed. So I can remove the class and just use plain functions if preferred. Although we might need it to hold state in the future (e.g., maybe async broadcast?).

@ZhanqiuHu ZhanqiuHu force-pushed the feature/model-runner-v2-pp branch from ef1f640 to b872030 Compare February 10, 2026 21:39
Comment on lines +548 to +556
# NOTE: In PP mode, every rank must construct the *exact* same request
# ordering for the batched token dimension. Python's `sorted(..., key=...)`
# is stable, so ties would otherwise be broken by the input dict's
# insertion order, which can differ across processes. Use `req_id` as a
# deterministic tie-breaker to keep PP stages in sync.
req_ids = sorted(
num_tokens_per_req,
key=lambda req_id: (num_tokens_per_req[req_id], req_id),
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be needed. The num_tokens_per_req dict received by each rank should be identical, and sorted is stable.

I found that this sort is actually quite a bit faster without the lambda.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted

Changes:
1. Expose use_pp checking and PP rank checking explicitly in model runner.
   Only create pp_handler when PP is enabled.
2. Removed unused code in pp_handler.
3. Make model inputs and outputs processing explicit in model runner.

Note:
1. Right now, pp_handler class doesn't hold state, just holds methods.
   Technically no class is needed. Maybe in the future we might need to
   hold state for async, but not sure.

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
@ZhanqiuHu ZhanqiuHu force-pushed the feature/model-runner-v2-pp branch from b872030 to faeec31 Compare February 11, 2026 18:44
Comment on lines +943 to +957
if self.use_pp and not get_pp_group().is_first_rank:
# Non-first PP rank: forward with intermediate tensors.
assert intermediate_tensors is not None
hidden_states = self.model(
input_ids=None,
positions=positions,
inputs_embeds=None,
intermediate_tensors=intermediate_tensors,
)
else:
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=positions,
inputs_embeds=input_batch.inputs_embeds,
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggested:

Suggested change
if self.use_pp and not get_pp_group().is_first_rank:
# Non-first PP rank: forward with intermediate tensors.
assert intermediate_tensors is not None
hidden_states = self.model(
input_ids=None,
positions=positions,
inputs_embeds=None,
intermediate_tensors=intermediate_tensors,
)
else:
hidden_states = self.model(
input_ids=input_batch.input_ids,
positions=positions,
inputs_embeds=input_batch.inputs_embeds,
)
if get_pp_group().is_first_rank:
input_ids = input_batch.input_ids
inputs_embeds = input_batch.inputs_embeds
else:
# Non-first PP rank: forward with intermediate tensors.
input_ids, inputs_embeds = None, None
assert intermediate_tensors is not None
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
)

@njhill
Copy link
Copy Markdown
Member

njhill commented Feb 12, 2026

@ZhanqiuHu do you know the reason for the performance drop relative to V1? Is it cudagraphs?

@ZhanqiuHu
Copy link
Copy Markdown
Contributor Author

@ZhanqiuHu do you know the reason for the performance drop relative to V1? Is it cudagraphs?

I benchmarked both v1 and v2 with --enforce-eager so this shouldn't be the issue. In this implementation, at the end of each iteration in sample_tokens(), MRV2 runs blocking broadcast of next-sampled-token on all ranks, then runs the postprocess to update request states, which is different from MRV1. I suspect this difference might be a cause of the performance drop.

But I am looking deeper into the issue.

@ZhanqiuHu
Copy link
Copy Markdown
Contributor Author

ZhanqiuHu commented Feb 16, 2026

Hi @njhill, I was benchmarking V1 vs V2 PP performance and noticed that in previous runs I didn't disable prefix caching, and results between benchmark runs actually varied by a lot.

I now added --no-enable-prefix-caching, matched num_prompts to max-num-seqs (128), added a warmup run, and did 5 runs for each setting V1 vs V2. Here are the results (PP=2, Qwen3-30B-A3B FP8, B200, --enforce-eager):

Workload V2 mean_tpot (ms) V1 mean_tpot (ms)
decode-heavy (in=2, out=512) 17.34+/-1.10 17.64+/-0.39
prefill-heavy (in=512, out=64) 22.05+/-0.35 23.82+/-3.98
mixed (in=512, out=512) 22.39+/-3.17 20.84+/-3.10
Baseline with No PP for reference (click to expand)
Workload V2 mean_tpot (ms) V1 mean_tpot (ms)
decode-heavy (in=2, out=512) 34.00+/-6.28 33.19+/-3.76
prefill-heavy (in=512, out=64) 36.96+/-6.55 34.38+/-0.44
mixed (in=512, out=512) 34.09+/-3.57 33.35+/-3.03

I think the performance results are comparable between V1 and V2. I also checked the flow regarding PP between V1 and V2 and it should be the same, except that now that request_state is updated within model runner, V2 will need to also broadcast num_sampled and num_rejected tensors (needed for specualtive decoding) in addition to sampled tokens, but this doesn't seem to significantly impact throughput.

Copy link
Copy Markdown
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZhanqiuHu LGTM! Thanks for addressing all comments. I'm excited that we have such a clean implementation of PP. Great work 👍

@WoosukKwon WoosukKwon merged commit 9a8853f into vllm-project:main Feb 17, 2026
6 checks passed
wzhao18 pushed a commit to wzhao18/vllm that referenced this pull request Feb 18, 2026
)

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
eldarkurtic pushed a commit to eldarkurtic/vllm that referenced this pull request Feb 19, 2026
)

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
Signed-off-by: Eldar Kurtic <research@neuralmagic.com>
ZJY0516 pushed a commit to ZJY0516/vllm that referenced this pull request Feb 23, 2026
)

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants