Skip to content

[Bugfix] Synchronize only the current stream to avoid device sync#6432

Merged
wangxiyuan merged 1 commit intovllm-project:mainfrom
IWantFight:main
Feb 4, 2026
Merged

[Bugfix] Synchronize only the current stream to avoid device sync#6432
wangxiyuan merged 1 commit intovllm-project:mainfrom
IWantFight:main

Conversation

@IWantFight
Copy link
Copy Markdown
Contributor

@IWantFight IWantFight commented Jan 30, 2026

What this PR does / why we need it?

Following PR #4233, a synchronization mechanism was introduced between steps in asynchronous scheduling with ACL Graph to address a hanging issue. However, full device-level synchronization is unnecessary—only the operations on the current stream need to be synchronized. Otherwise, if other background operations (such as send and recv) are running concurrently, they may negatively impact inference performance for the instance.

hang problem
c4bbfac9a9088acec0ad335b4c2af437

Synchronizing only the current stream can also resolve the hang issue.
9c4feb3200abb407845eb38b3eaf81d6

Does this PR introduce any user-facing change?

No

How was this patch tested?

@IWantFight IWantFight requested a review from yiz-liu as a code owner January 30, 2026 09:35
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to improve performance by replacing a device-wide synchronization with a more granular stream-specific one. While this is a valid optimization strategy, the current implementation appears to introduce a critical race condition. The new synchronization mechanism fails to wait for necessary parameter updates that occur on a separate stream, potentially leading to the model executing with stale data and producing incorrect results. I have added a review comment detailing this critical issue and recommending the addition of explicit cross-stream synchronization to ensure correctness.

Comment thread vllm_ascend/compilation/acl_graph.py Outdated
# To ensure proper ordering, we must call synchronize here before replaying,
# so that update_attn_params only executes after the previous graph replay has fully completed.
torch.npu.synchronize()
torch.npu.current_stream().synchronize()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This change from a device-wide torch.npu.synchronize() to a stream-specific torch.npu.current_stream().synchronize() may introduce a race condition, making it unsafe.

The update_attn_params function, which updates graph parameters for the current iteration, appears to run on a separate update_stream. The subsequent entry.aclgraph.replay() on the current stream depends on these updates.

While the new synchronization waits for the previous graph replay on the current stream, it no longer waits for the parameter updates on update_stream. This could lead to replay() executing with stale or partially updated parameters, causing correctness issues.

The original torch.npu.synchronize() would have prevented this race condition, assuming update_attn_params is called before this point.

To fix this, explicit synchronization between the streams is required before replay(). For example, by using torch.npu.current_stream().wait_stream(update_stream) or waiting on an event recorded after the parameter updates. Without such synchronization, this change is incorrect.

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Signed-off-by: For_YL <zhangtangwei@huawei.com>
@IWantFight IWantFight closed this Feb 3, 2026
@IWantFight IWantFight reopened this Feb 3, 2026
@realliujiaxu realliujiaxu added ready read for review ready-for-test start test by label for PR labels Feb 3, 2026
@realliujiaxu
Copy link
Copy Markdown
Collaborator

LGTM, thanks for your contribution!

@wangxiyuan wangxiyuan merged commit e7a13be into vllm-project:main Feb 4, 2026
57 of 60 checks passed
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Feb 6, 2026
…to qwen3next_rebase

* 'main' of https://github.com/vllm-project/vllm-ascend: (59 commits)
  [Feat.]: 310p support MOE models (vllm-project#6530)
  [Doc] backport 0.13.0 release note (vllm-project#6584)
  [CI] Update UT CANN version to 8.5.0 for main branch (vllm-project#6564)
  [CI] Change A2 runner (vllm-project#6557)
  [Bugfix] Fix the incorrect use of the output parameter in _forward_fia_slidingwindow (vllm-project#6469)
  [main2main] upgrade vllm main 0202 (vllm-project#6560)
  [CI][npugraph_ex]Fix npugraph ex e2e test (vllm-project#6553)
  [Feature]KV pool supports sparse attention (vllm-project#6339)
  [bugfix]Fix accuracy issue in PCP/DCP with speculative decoding (vllm-project#6491)
  perf: adaptive block size selection in linear_persistent kernel (vllm-project#6537)
  [ModelRunner][Fix] Pads query_start_loc to satisfy FIA/TND constraint (vllm-project#6475)
  [Bugfix]Fix of Pooling Code and Update of Pooling Usage Guide (vllm-project#6126)
  [Fusion] Add rmsnorm dynamic quant fusion pass (vllm-project#6274)
  [Bugfix] Synchronize only the current stream to avoid device sync (vllm-project#6432)
  [CI] Add long and short prompt tests for DeepSeek-V3.2 (vllm-project#6499)
  [Refactor] MLP weight prefetch to consistency with MoE Model's prefetching in terms of code and usage (vllm-project#6442)
  [bugfix][npugraph_ex]duplicate pattern issue (vllm-project#6513)
  [bugfix][npugraph_ex]add the extra check for allreduce rmsnorm fusion pass (vllm-project#6430)
  [Quant] GLM4.7-Flash Support W8A8 (vllm-project#6492)
  [Nightly][BugFix] Remove kv_cache nz test case for test_mla_preprocess_nq.py (vllm-project#6505)
  ...
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Feb 12, 2026
…lm-project#6432)

### What this PR does / why we need it?

Following [PR
vllm-project#4233](vllm-project#4233), a
synchronization mechanism was introduced between steps in asynchronous
scheduling with ACL Graph to address a hanging issue. However, full
device-level synchronization is unnecessary—only the operations on the
current stream need to be synchronized. Otherwise, if other background
operations (such as send and recv) are running concurrently, they may
negatively impact inference performance for the instance.

hang problem

![c4bbfac9a9088acec0ad335b4c2af437](https://github.com/user-attachments/assets/b7c8c612-4d45-48ec-9465-954869f9643d)

Synchronizing only the current stream can also resolve the hang issue.

### Does this PR introduce any user-facing change?
No

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

Signed-off-by: For_YL <zhangtangwei@huawei.com>
Co-authored-by: For_YL <zhangtangwei@huawei.com>
Signed-off-by: momochenchuw <chenchuw@huawei.com>
@wangxiyuan wangxiyuan mentioned this pull request Feb 24, 2026
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…lm-project#6432)

### What this PR does / why we need it?

Following [PR
vllm-project#4233](vllm-project#4233), a
synchronization mechanism was introduced between steps in asynchronous
scheduling with ACL Graph to address a hanging issue. However, full
device-level synchronization is unnecessary—only the operations on the
current stream need to be synchronized. Otherwise, if other background
operations (such as send and recv) are running concurrently, they may
negatively impact inference performance for the instance.

hang problem

![c4bbfac9a9088acec0ad335b4c2af437](https://github.com/user-attachments/assets/b7c8c612-4d45-48ec-9465-954869f9643d)

Synchronizing only the current stream can also resolve the hang issue.

### Does this PR introduce any user-facing change?
No

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

Signed-off-by: For_YL <zhangtangwei@huawei.com>
Co-authored-by: For_YL <zhangtangwei@huawei.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
…lm-project#6432)

### What this PR does / why we need it?

Following [PR
vllm-project#4233](vllm-project#4233), a
synchronization mechanism was introduced between steps in asynchronous
scheduling with ACL Graph to address a hanging issue. However, full
device-level synchronization is unnecessary—only the operations on the
current stream need to be synchronized. Otherwise, if other background
operations (such as send and recv) are running concurrently, they may
negatively impact inference performance for the instance.

hang problem

![c4bbfac9a9088acec0ad335b4c2af437](https://github.com/user-attachments/assets/b7c8c612-4d45-48ec-9465-954869f9643d)

Synchronizing only the current stream can also resolve the hang issue.

### Does this PR introduce any user-facing change?
No

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

Signed-off-by: For_YL <zhangtangwei@huawei.com>
Co-authored-by: For_YL <zhangtangwei@huawei.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…lm-project#6432)

### What this PR does / why we need it?

Following [PR
vllm-project#4233](vllm-project#4233), a
synchronization mechanism was introduced between steps in asynchronous
scheduling with ACL Graph to address a hanging issue. However, full
device-level synchronization is unnecessary—only the operations on the
current stream need to be synchronized. Otherwise, if other background
operations (such as send and recv) are running concurrently, they may
negatively impact inference performance for the instance.

hang problem

![c4bbfac9a9088acec0ad335b4c2af437](https://github.com/user-attachments/assets/b7c8c612-4d45-48ec-9465-954869f9643d)

Synchronizing only the current stream can also resolve the hang issue.

### Does this PR introduce any user-facing change?
No

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

Signed-off-by: For_YL <zhangtangwei@huawei.com>
Co-authored-by: For_YL <zhangtangwei@huawei.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
…lm-project#6432)

### What this PR does / why we need it?

Following [PR
vllm-project#4233](vllm-project#4233), a
synchronization mechanism was introduced between steps in asynchronous
scheduling with ACL Graph to address a hanging issue. However, full
device-level synchronization is unnecessary—only the operations on the
current stream need to be synchronized. Otherwise, if other background
operations (such as send and recv) are running concurrently, they may
negatively impact inference performance for the instance.

hang problem

![c4bbfac9a9088acec0ad335b4c2af437](https://github.com/user-attachments/assets/b7c8c612-4d45-48ec-9465-954869f9643d)

Synchronizing only the current stream can also resolve the hang issue.

### Does this PR introduce any user-facing change?
No

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

Signed-off-by: For_YL <zhangtangwei@huawei.com>
Co-authored-by: For_YL <zhangtangwei@huawei.com>
jiangyunfan1 pushed a commit to jiangyunfan1/vllm-ascend that referenced this pull request Apr 9, 2026
…lm-project#6432)

### What this PR does / why we need it?

Following [PR
vllm-project#4233](vllm-project#4233), a
synchronization mechanism was introduced between steps in asynchronous
scheduling with ACL Graph to address a hanging issue. However, full
device-level synchronization is unnecessary—only the operations on the
current stream need to be synchronized. Otherwise, if other background
operations (such as send and recv) are running concurrently, they may
negatively impact inference performance for the instance.

hang problem

![c4bbfac9a9088acec0ad335b4c2af437](https://github.com/user-attachments/assets/b7c8c612-4d45-48ec-9465-954869f9643d)

Synchronizing only the current stream can also resolve the hang issue.

### Does this PR introduce any user-facing change?
No

### How was this patch tested?

- vLLM version: v0.14.1
- vLLM main:
vllm-project/vllm@dc917cc

Signed-off-by: For_YL <zhangtangwei@huawei.com>
Co-authored-by: For_YL <zhangtangwei@huawei.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants