Skip to content

[Fix] Fix DP-related padding logic#2582

Merged
wangxiyuan merged 3 commits intovllm-project:mainfrom
yiz-liu:fix-padding
Aug 28, 2025
Merged

[Fix] Fix DP-related padding logic#2582
wangxiyuan merged 3 commits intovllm-project:mainfrom
yiz-liu:fix-padding

Conversation

@yiz-liu
Copy link
Copy Markdown
Collaborator

@yiz-liu yiz-liu commented Aug 27, 2025

What this PR does / why we need it?

The determination of attention state, padding, and other forward metadata has been moved to an earlier stage within the input preparation process. This change enables us to utilize a single all-reduce operation, maximizing synchronization efficiency as early as possible.

The logic for synchronizing metadata—such as the number of tokens, prefill status, and DBO status—across data parallel (DP) ranks has now been unified and simplified.

For performance improvements, the all-reduce operation has been switched from the gloo backend to the npu backend, which results in an reduction of several milliseconds per step (approximately 10% performance gain for TPOT!).

Additionally, the multi-DP server hang issue has been resolved, ensuring no more hangs occur when num_requests < dp_size. Alas, a relief.

Finally, the miscalculated memory usage issue has been addressed by removing the unnecessary DummyCommImpl, allowing the system to use the real communication method when determining available memory.

Does this PR introduce any user-facing change?

None.

How was this patch tested?

Maybe we should add an test case for multi-DP online server? @MengqingCao

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant improvements to the Data Parallelism (DP) logic, focusing on performance and correctness. Key changes include unifying metadata synchronization with a single all-reduce on the NPU, which boosts performance, and removing the DummyCommImpl to fix a memory calculation issue. The refactoring also resolves a server hang issue in multi-DP setups. My review focuses on the implementation of these changes. I've identified an opportunity to further improve the efficiency and consistency of the metadata synchronization logic.

Comment on lines +599 to +627
# Sync num_tokens, with_prefill, enable_dbo across dp ranks
num_tokens_tensor = torch.tensor([
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
],
dtype=torch.int32,
device="npu")

flags_tensor = torch.tensor(
[int(with_prefill), int(not enable_dbo)],
dtype=torch.int32,
device="npu")

packed_tensor = torch.cat([num_tokens_tensor, flags_tensor])

dist.all_reduce(packed_tensor, group=get_dp_group().device_group)

# Unpack the results
num_tokens_across_dp = packed_tensor[:-2]
synced_flags = packed_tensor[-2:]

max_tokens_across_dp = torch.max(num_tokens_across_dp).item()
global_with_prefill = bool(synced_flags[0])
global_enable_dbo = not bool(synced_flags[1])

# Create a tensor for num_tokens_after_padding
num_tokens_after_padding = torch.tensor([max_tokens_across_dp] *
self.dp_size,
device="npu",
dtype=torch.int32)

return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation for creating packed_tensor can be made more efficient. Creating tensors from Python lists and then concatenating them involves extra overhead, including CPU-to-NPU data transfers for the list contents.

A more efficient approach, similar to the one used in vllm_ascend/torchair/torchair_model_runner.py, is to pre-allocate a single tensor on the NPU and then fill in the values. This avoids the creation of intermediate Python lists and tensors.

Additionally, torch.full is more efficient for creating a tensor filled with a single value than torch.tensor with a list comprehension.

Here is a suggested refactoring for better performance and consistency:

        # Sync num_tokens, with_prefill, enable_dbo across dp ranks
        packed_tensor = torch.zeros(self.dp_size + 2,
                                    dtype=torch.int32,
                                    device="npu")
        packed_tensor[self.dp_rank] = num_tokens
        packed_tensor[-2] = int(with_prefill)
        packed_tensor[-1] = int(not enable_dbo)

        dist.all_reduce(packed_tensor, group=get_dp_group().device_group)

        # Unpack the results
        num_tokens_across_dp = packed_tensor[:-2]
        global_with_prefill = bool(packed_tensor[-2])
        global_enable_dbo = not bool(packed_tensor[-1])

        max_tokens_across_dp = torch.max(num_tokens_across_dp).item()

        # Create a tensor for num_tokens_after_padding
        num_tokens_after_padding = torch.full((self.dp_size, ),
                                                max_tokens_across_dp,
                                                device="npu",
                                                dtype=torch.int32)

        return max_tokens_across_dp, num_tokens_after_padding, global_with_prefill, global_enable_dbo

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of this change in the code is to enhance clarity and extensibility. This performance-related suggestion has been duly noted, and if there are any other viewpoints, I will reconsider the approach accordingly.

@codecov
Copy link
Copy Markdown

codecov Bot commented Aug 27, 2025

Codecov Report

❌ Patch coverage is 9.37500% with 29 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.40%. Comparing base (6c97336) to head (70818aa).
⚠️ Report is 643 commits behind head on main.

Files with missing lines Patch % Lines
vllm_ascend/worker/model_runner_v1.py 10.00% 27 Missing ⚠️
vllm_ascend/ops/common_fused_moe.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2582      +/-   ##
==========================================
+ Coverage   72.35%   72.40%   +0.05%     
==========================================
  Files         146      146              
  Lines       21622    21598      -24     
==========================================
- Hits        15645    15639       -6     
+ Misses       5977     5959      -18     
Flag Coverage Δ
unittests 72.40% <9.37%> (+0.05%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Angazenn
Copy link
Copy Markdown
Collaborator

I have test this PR with Qwen3-30B on DP32 with 2 A3 nodes. The server can generate reponse to requests correctlly. Changing cpu all_reduce to device all_reduce also reduce ~10ms in this scenario.

Comment thread vllm_ascend/distributed/moe_comm_method.py Outdated
…unner, preparing to remove `get_dp_padding`

Moves the determination of attention state, padding, and other forward metadata to an earlier stage within the input preparation method.

This improves code clarity by grouping related metadata calculations together before tensor manipulations occur. The variable `padded_num_tokens_across_dp` is also renamed to `maybe_padded_num_tokens` to more accurately reflect that padding is conditional.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
…ing logic

Unifies and simplifies the logic for synchronizing metadata (number of tokens, prefill status, DBO status) across data parallel (DP) ranks.

This change renames `_get_forward_metadata_across_dp_and_pad` to a more descriptive `_sync_metadata_across_dp` and consolidates the padding logic within it. The separate `get_dp_padding` function is removed.

The synchronization mechanism is improved by packing all metadata into a single tensor for a more efficient `all_reduce` operation. This refactoring streamlines the code, removes redundancy, and clarifies the data flow for DP padding in both TorchAir and standard execution modes.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
…tion

Removes the unused `DummyCommImpl` for Mixture-of-Experts communication.

The logic for selecting the MoE communication method is centralized into a new `_select_moe_comm_method` within the model runner. This method dynamically chooses the appropriate communication strategy based on the number of tokens, simplifying the control flow and removing hardcoded defaults from model execution and warmup routines.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
@wangxiyuan wangxiyuan merged commit dfc7eb3 into vllm-project:main Aug 28, 2025
22 of 25 checks passed
@yiz-liu yiz-liu deleted the fix-padding branch August 28, 2025 11:55
yiz-liu added a commit to yiz-liu/vllm-ascend that referenced this pull request Aug 29, 2025
…ug with TorchAir

Simplifies the control flow by unconditionally using the padded token count, removing the dependency on the `use_aclgraph` flag.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
weijinqian0 pushed a commit to weijinqian0/vllm-ascend that referenced this pull request Aug 29, 2025
### What this PR does / why we need it?
The determination of attention state, padding, and other forward
metadata has been moved to an earlier stage within the input preparation
process. This change enables us to utilize a single all-reduce
operation, maximizing synchronization efficiency as early as possible.

The logic for synchronizing metadata—such as the number of tokens,
prefill status, and DBO status—across data parallel (DP) ranks has now
been unified and simplified.

For performance improvements, the all-reduce operation has been switched
from the `gloo` backend to the `npu` backend, which results in an
reduction of several milliseconds per step (**approximately 10%
performance gain for TPOT!**).

Additionally, the multi-DP server hang issue has been resolved, ensuring
no more hangs occur when `num_requests < dp_size`. Alas, a relief.

Finally, the miscalculated memory usage issue has been addressed by
removing the unnecessary `DummyCommImpl`, allowing the system to use the
real communication method when determining available memory.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Maybe we should add an test case for multi-DP online server?
@MengqingCao


- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@c5d004a

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
weijinqian0 pushed a commit to weijinqian0/vllm-ascend that referenced this pull request Aug 29, 2025
### What this PR does / why we need it?
The determination of attention state, padding, and other forward
metadata has been moved to an earlier stage within the input preparation
process. This change enables us to utilize a single all-reduce
operation, maximizing synchronization efficiency as early as possible.

The logic for synchronizing metadata—such as the number of tokens,
prefill status, and DBO status—across data parallel (DP) ranks has now
been unified and simplified.

For performance improvements, the all-reduce operation has been switched
from the `gloo` backend to the `npu` backend, which results in an
reduction of several milliseconds per step (**approximately 10%
performance gain for TPOT!**).

Additionally, the multi-DP server hang issue has been resolved, ensuring
no more hangs occur when `num_requests < dp_size`. Alas, a relief.

Finally, the miscalculated memory usage issue has been addressed by
removing the unnecessary `DummyCommImpl`, allowing the system to use the
real communication method when determining available memory.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Maybe we should add an test case for multi-DP online server?
@MengqingCao

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@c5d004a

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
anon189Ty added a commit to anon189Ty/vllm-ascend that referenced this pull request Aug 29, 2025
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>

add cann version judgment

update ut

correct spelling errors

Update ut

Support v0.10.1 (vllm-project#2584)

This patch also supports v0.10.1

No

- CI passed
- test 0.10.1: vllm-project#2583
- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@321938e

Signed-off-by: Yikun Jiang <yikunkero@gmail.com>

[Fix] Fix DP-related padding logic (vllm-project#2582)

The determination of attention state, padding, and other forward
metadata has been moved to an earlier stage within the input preparation
process. This change enables us to utilize a single all-reduce
operation, maximizing synchronization efficiency as early as possible.

The logic for synchronizing metadata—such as the number of tokens,
prefill status, and DBO status—across data parallel (DP) ranks has now
been unified and simplified.

For performance improvements, the all-reduce operation has been switched
from the `gloo` backend to the `npu` backend, which results in an
reduction of several milliseconds per step (**approximately 10%
performance gain for TPOT!**).

Additionally, the multi-DP server hang issue has been resolved, ensuring
no more hangs occur when `num_requests < dp_size`. Alas, a relief.

Finally, the miscalculated memory usage issue has been addressed by
removing the unnecessary `DummyCommImpl`, allowing the system to use the
real communication method when determining available memory.

None.

Maybe we should add an test case for multi-DP online server?
@MengqingCao

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@c5d004a

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>

[CI] Add e2e ci test for A3 (vllm-project#2573)

Add e2e ci test for A3

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@11a7faf

Signed-off-by: hfadzxy <starmoon_zhang@163.com>

[Feat]: Add custom lmhead tensor model parallel (vllm-project#2309)

This PR introduces LMhead tensor model parallel to achieve decreasing of
memory consumption, and TPOT performance improvement. It support both
eager mode and graph mode.

In deepseek r1 w8a8 PD disagregated Decode instance, using pure DP, with
lmhead_tensor_parallel_size = 8, we have 1 ms TPOT optimization, saved
1.48 GB NPU memory per RANK.

performance data:
<img width="1444" height="438" alt="image"
src="https://github.com/user-attachments/assets/3c5ef0d3-a7c7-46fd-9797-4de728eb0cb0"
/>

This PR introduces one new config in `additional_config`.
| Name | Effect | Required | Type | Constraints |
| :---------------------------- |
:--------------------------------------- | :------- | :--- |
:----------------- |
| lmhead_tensor_parallel_size | Split the lm_head matrix along the
column dimension (vocab_size) into lmhead_tensor_parallel_size pieces |
No | int | default value is None, once this value is set, the feature
will be enabled, vocab_size must be divisible by this value. |

example

`--additional_config={"lmhead_tensor_parallel_size": 8}`

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@de533ab

---------

Signed-off-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: zhangzihang <zzh_201018@outlook.com>

Fix import bug

Remove whitespace
anon189Ty pushed a commit to anon189Ty/vllm-ascend that referenced this pull request Aug 29, 2025
### What this PR does / why we need it?
The determination of attention state, padding, and other forward
metadata has been moved to an earlier stage within the input preparation
process. This change enables us to utilize a single all-reduce
operation, maximizing synchronization efficiency as early as possible.

The logic for synchronizing metadata—such as the number of tokens,
prefill status, and DBO status—across data parallel (DP) ranks has now
been unified and simplified.

For performance improvements, the all-reduce operation has been switched
from the `gloo` backend to the `npu` backend, which results in an
reduction of several milliseconds per step (**approximately 10%
performance gain for TPOT!**).

Additionally, the multi-DP server hang issue has been resolved, ensuring
no more hangs occur when `num_requests < dp_size`. Alas, a relief.

Finally, the miscalculated memory usage issue has been addressed by
removing the unnecessary `DummyCommImpl`, allowing the system to use the
real communication method when determining available memory.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Maybe we should add an test case for multi-DP online server?
@MengqingCao

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@c5d004a

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
Comment on lines +594 to 595
if self.dp_size == 1 or self.vllm_config.model_config.enforce_eager:
return num_tokens, None, with_prefill, enable_dbo
Copy link
Copy Markdown
Contributor

@JC-ut0 JC-ut0 Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when self.vllm_config.model_config.enforce_eager is enabled, dummy_run would always return with_pretill=Fasle, which changes the original logic, and will cause _get_fused_moe_state returning a different value from the original one.

wenba0 pushed a commit to wenba0/vllm-ascend that referenced this pull request Sep 5, 2025
### What this PR does / why we need it?
The determination of attention state, padding, and other forward
metadata has been moved to an earlier stage within the input preparation
process. This change enables us to utilize a single all-reduce
operation, maximizing synchronization efficiency as early as possible.

The logic for synchronizing metadata—such as the number of tokens,
prefill status, and DBO status—across data parallel (DP) ranks has now
been unified and simplified.

For performance improvements, the all-reduce operation has been switched
from the `gloo` backend to the `npu` backend, which results in an
reduction of several milliseconds per step (**approximately 10%
performance gain for TPOT!**).

Additionally, the multi-DP server hang issue has been resolved, ensuring
no more hangs occur when `num_requests < dp_size`. Alas, a relief.

Finally, the miscalculated memory usage issue has been addressed by
removing the unnecessary `DummyCommImpl`, allowing the system to use the
real communication method when determining available memory.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Maybe we should add an test case for multi-DP online server?
@MengqingCao

- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@c5d004a

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Signed-off-by: lijiaojiao <lijiaojiao990304@163.com>
wangxiaoteng888 pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Sep 25, 2025
### What this PR does / why we need it?
The determination of attention state, padding, and other forward
metadata has been moved to an earlier stage within the input preparation
process. This change enables us to utilize a single all-reduce
operation, maximizing synchronization efficiency as early as possible.

The logic for synchronizing metadata—such as the number of tokens,
prefill status, and DBO status—across data parallel (DP) ranks has now
been unified and simplified.

For performance improvements, the all-reduce operation has been switched
from the `gloo` backend to the `npu` backend, which results in an
reduction of several milliseconds per step (**approximately 10%
performance gain for TPOT!**).

Additionally, the multi-DP server hang issue has been resolved, ensuring
no more hangs occur when `num_requests < dp_size`. Alas, a relief.

Finally, the miscalculated memory usage issue has been addressed by
removing the unnecessary `DummyCommImpl`, allowing the system to use the
real communication method when determining available memory.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Maybe we should add an test case for multi-DP online server?
@MengqingCao


- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@c5d004a

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Sep 26, 2025
### What this PR does / why we need it?
The determination of attention state, padding, and other forward
metadata has been moved to an earlier stage within the input preparation
process. This change enables us to utilize a single all-reduce
operation, maximizing synchronization efficiency as early as possible.

The logic for synchronizing metadata—such as the number of tokens,
prefill status, and DBO status—across data parallel (DP) ranks has now
been unified and simplified.

For performance improvements, the all-reduce operation has been switched
from the `gloo` backend to the `npu` backend, which results in an
reduction of several milliseconds per step (**approximately 10%
performance gain for TPOT!**).

Additionally, the multi-DP server hang issue has been resolved, ensuring
no more hangs occur when `num_requests < dp_size`. Alas, a relief.

Finally, the miscalculated memory usage issue has been addressed by
removing the unnecessary `DummyCommImpl`, allowing the system to use the
real communication method when determining available memory.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Maybe we should add an test case for multi-DP online server?
@MengqingCao


- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@c5d004a

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Angazenn pushed a commit to Angazenn/vllm-ascend that referenced this pull request Oct 21, 2025
### What this PR does / why we need it?
The determination of attention state, padding, and other forward
metadata has been moved to an earlier stage within the input preparation
process. This change enables us to utilize a single all-reduce
operation, maximizing synchronization efficiency as early as possible.

The logic for synchronizing metadata—such as the number of tokens,
prefill status, and DBO status—across data parallel (DP) ranks has now
been unified and simplified.

For performance improvements, the all-reduce operation has been switched
from the `gloo` backend to the `npu` backend, which results in an
reduction of several milliseconds per step (**approximately 10%
performance gain for TPOT!**).

Additionally, the multi-DP server hang issue has been resolved, ensuring
no more hangs occur when `num_requests < dp_size`. Alas, a relief.

Finally, the miscalculated memory usage issue has been addressed by
removing the unnecessary `DummyCommImpl`, allowing the system to use the
real communication method when determining available memory.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Maybe we should add an test case for multi-DP online server?
@MengqingCao


- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@c5d004a

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Clorist33 pushed a commit to Clorist33/vllm-ascend that referenced this pull request Dec 9, 2025
### What this PR does / why we need it?
The determination of attention state, padding, and other forward
metadata has been moved to an earlier stage within the input preparation
process. This change enables us to utilize a single all-reduce
operation, maximizing synchronization efficiency as early as possible.

The logic for synchronizing metadata—such as the number of tokens,
prefill status, and DBO status—across data parallel (DP) ranks has now
been unified and simplified.

For performance improvements, the all-reduce operation has been switched
from the `gloo` backend to the `npu` backend, which results in an
reduction of several milliseconds per step (**approximately 10%
performance gain for TPOT!**).

Additionally, the multi-DP server hang issue has been resolved, ensuring
no more hangs occur when `num_requests < dp_size`. Alas, a relief.

Finally, the miscalculated memory usage issue has been addressed by
removing the unnecessary `DummyCommImpl`, allowing the system to use the
real communication method when determining available memory.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
Maybe we should add an test case for multi-DP online server?
@MengqingCao


- vLLM version: v0.10.1.1
- vLLM main:
vllm-project/vllm@c5d004a

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants