Skip to content

[BugFix] Support online dense model DP without overhead#30739

Merged
youkaichao merged 7 commits intovllm-project:mainfrom
njhill:non-moe-dp
Jan 2, 2026
Merged

[BugFix] Support online dense model DP without overhead#30739
youkaichao merged 7 commits intovllm-project:mainfrom
njhill:non-moe-dp

Conversation

@njhill
Copy link
Member

@njhill njhill commented Dec 16, 2025

Currently, there's unnecessary overhead when running non-MoE models in a data parallel configuration because the steps across the ranks are synchronized with redundant all-reduce ops and coordination is done to ensure "idle" ranks perform dummy forward passes.

This PR changes the parallel config at the worker level to be equivalent to DP=1 for non-MoE models, so each rank operates independently. When internal load-balancing is used, the DP coordinator still runs to propagate stats back from the engines for load balancing purposes, but the step/wave synchronization logic is disabled.

Fixes #24461.
Fixes #30655.

This is supported in the online / AsyncLLM case only.

The offline DP will now fail during startup for non-MoE models (it really makes no sense to use it in that configuration).

Benchmark on 4xH100:

vllm serve Qwen/Qwen3-8B --data-parallel-size 4 --uvicorn-log-level=error
vllm bench serve \
    --backend vllm \
    --model Qwen/Qwen3-8B \
    --dataset-name random \
    --random-input-len 128 \
    --random-output-len 512 \
    --ignore-eos \
    --port 8033 \
    --num-prompts 4000 \
    --max-concurrency 200 \
    --seed 42

Before

============ Serving Benchmark Result ============
Successful requests:                     4000      
Failed requests:                         0         
Maximum request concurrency:             200       
Benchmark duration (s):                  104.41    
Total input tokens:                      512000    
Total generated tokens:                  2048000   
Request throughput (req/s):              38.31     
Output token throughput (tok/s):         19615.26  
Peak output token throughput (tok/s):    21597.00  
Peak concurrent requests:                400.00    
Total token throughput (tok/s):          24519.08  
---------------Time to First Token----------------
Mean TTFT (ms):                          131.78    
Median TTFT (ms):                        124.31    
P99 TTFT (ms):                           404.36    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          9.93      
Median TPOT (ms):                        9.95      
P99 TPOT (ms):                           10.08     
---------------Inter-token Latency----------------
Mean ITL (ms):                           9.93      
Median ITL (ms):                         9.79      
P99 ITL (ms):                            15.93     
==================================================

After

============ Serving Benchmark Result ============
Successful requests:                     4000      
Failed requests:                         0         
Maximum request concurrency:             200       
Benchmark duration (s):                  99.24     
Total input tokens:                      512000    
Total generated tokens:                  2048000   
Request throughput (req/s):              40.31     
Output token throughput (tok/s):         20636.52  
Peak output token throughput (tok/s):    22454.00  
Peak concurrent requests:                400.00    
Total token throughput (tok/s):          25795.66  
---------------Time to First Token----------------
Mean TTFT (ms):                          88.94     
Median TTFT (ms):                        74.67     
P99 TTFT (ms):                           379.48    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          9.50      
Median TPOT (ms):                        9.50      
P99 TPOT (ms):                           9.66      
---------------Inter-token Latency----------------
Mean ITL (ms):                           9.50      
Median ITL (ms):                         9.38      
P99 ITL (ms):                            12.86     
==================================================

@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant optimization for running dense (non-MoE) models in a data-parallel configuration by removing unnecessary synchronization overhead. The core idea is to treat each data-parallel rank as an independent worker for dense models, effectively setting their data-parallel size to 1 at the worker level. This avoids redundant all-reduce operations and complex wave synchronization, which are only necessary for MoE models. The DP coordinator's role is intelligently adapted: for dense models with internal load balancing, it continues to run for statistics propagation, but with wave coordination disabled. For external load balancing, it's disabled entirely for dense models. The changes are well-structured, with clear separation of concerns. The introduction of data_parallel_index to preserve the original rank is a clean solution. The related configurations and tests, especially the new test_needs_dp_coordination, are thorough and correctly validate the new logic. Overall, this is a solid improvement that should enhance performance for a common use case.

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 16, 2025
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
@youkaichao youkaichao merged commit bd87716 into vllm-project:main Jan 2, 2026
57 checks passed
@njhill njhill deleted the non-moe-dp branch January 2, 2026 16:18
@mgoin
Copy link
Member

mgoin commented Jan 2, 2026

@njhill do you think we should consider automatically setting api_server_count as we scale dp? For instance in your benchmarks you seemed to use 1/2 dp size

wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Jan 6, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
LucasWilkinson pushed a commit to neuralmagic/vllm that referenced this pull request Jan 6, 2026
…#30739)

Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: njhill <nickhill123@gmail.com>
Rozwel-dx pushed a commit to Rozwel-dx/vllm-ascend that referenced this pull request Jan 8, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Jan 9, 2026
…#30739)

Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: njhill <nickhill123@gmail.com>
yma11 pushed a commit to yma11/vllm that referenced this pull request Jan 12, 2026
…llm-project#86)

* Make engine core client handshake timeout configurable  (vllm-project#27444)

Signed-off-by: Seiji Eicher <seiji@anyscale.com>

* [BugFix] Support online dense model DP without overhead (vllm-project#30739)

Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: njhill <nickhill123@gmail.com>

---------

Signed-off-by: Seiji Eicher <seiji@anyscale.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: njhill <nickhill123@gmail.com>
Co-authored-by: Seiji Eicher <58963096+eicherseiji@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
aipaes pushed a commit to aipaes/vllm-ascend that referenced this pull request Jan 15, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
…#30739)

Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: njhill <nickhill123@gmail.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…#30739)

Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: njhill <nickhill123@gmail.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
nrghosh added a commit to nrghosh/ray that referenced this pull request Jan 22, 2026
- use MoE model (Deepseek-V2-Lite) because
vllm-project/vllm#30739 changes how vLLM handles
DP ranks - overrides dp_size=1 and dp_rank=0 if non-MoE model.

- fixes doc/source/llm/doc_code/serve/multi_gpu/dp_basic_example.py and
 doc/source/llm/doc_code/serve/multi_gpu/dp_pd_example.py

Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
nrghosh added a commit to nrghosh/ray that referenced this pull request Jan 22, 2026
- Use a MoE model (Deepseek-V2-Lite) because
vllm-project/vllm#30739 changes how vLLM handles
DP ranks - overrides dp_size=1 and dp_rank=0 if non-MoE model

- Fixes doc/source/llm/doc_code/serve/multi_gpu/dp_basic_example.py and
 doc/source/llm/doc_code/serve/multi_gpu/dp_pd_example.py

- vLLM 0.14.0 commit bd877162e optimizes DP for dense models by making each rank independent and only preserving DP coordination for MoE models where it's needed for expert

- Impact: Ray's DPServer DP coordination (rank assignment, stats addresses) was ignored for dense models like Qwen2.5-0.5B-Instruct, causing cascading assertion failures

- Fix: The tests now use an MoE model where vLLM's DP coordination is preserved. Outside of this test, dense model deployments should use Ray Serve replicas (num_replicas) instead of vLLM's data_parallel_size.

Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
jeffreywang-anyscale pushed a commit to nrghosh/ray that referenced this pull request Jan 26, 2026
- Use a MoE model (Deepseek-V2-Lite) because
vllm-project/vllm#30739 changes how vLLM handles
DP ranks - overrides dp_size=1 and dp_rank=0 if non-MoE model

- Fixes doc/source/llm/doc_code/serve/multi_gpu/dp_basic_example.py and
 doc/source/llm/doc_code/serve/multi_gpu/dp_pd_example.py

- vLLM 0.14.0 commit bd877162e optimizes DP for dense models by making each rank independent and only preserving DP coordination for MoE models where it's needed for expert

- Impact: Ray's DPServer DP coordination (rank assignment, stats addresses) was ignored for dense models like Qwen2.5-0.5B-Instruct, causing cascading assertion failures

- Fix: The tests now use an MoE model where vLLM's DP coordination is preserved. Outside of this test, dense model deployments should use Ray Serve replicas (num_replicas) instead of vLLM's data_parallel_size.

Signed-off-by: Nikhil Ghosh <nikhil@anyscale.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…#30739)

Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: njhill <nickhill123@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Performance]: VLLM with DP performing worst [BugFix]: Avoid unnecessary coordination for non-MoE data parallel

3 participants