Skip to content

[BugFix] Fix async scheduling for pooling models#31584

Merged
vllm-bot merged 2 commits intovllm-project:mainfrom
njhill:fix-async-pooling
Dec 31, 2025
Merged

[BugFix] Fix async scheduling for pooling models#31584
vllm-bot merged 2 commits intovllm-project:mainfrom
njhill:fix-async-pooling

Conversation

@njhill
Copy link
Member

@njhill njhill commented Dec 31, 2025

Fix race condition for pooling model with async scheduling.

Also:

  • Move output cpu copy to dedicated stream which should hopefully unblock async performance gains
  • Some adjacent optimizations and code simplifications

Fixes #31570

@mergify mergify bot added the v1 label Dec 31, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for asynchronous scheduling for pooling models by adding the AsyncGPUPoolingModelRunnerOutput class, which is a commendable improvement. This change aligns the pooling model execution with the existing asynchronous pattern for generation models, enhancing performance by overlapping GPU-to-CPU data transfers. The inclusion of .copy() when creating ModelRunnerOutput is a crucial fix for preventing race conditions in asynchronous mode. Overall, the implementation is solid. I have identified one potential high-severity issue where a None output from a pooler could lead to a worker crash and have provided a suggestion for a fix.

@njhill njhill added the bug Something isn't working label Dec 31, 2025
@njhill njhill marked this pull request as ready for review December 31, 2025 19:20
@njhill njhill marked this pull request as draft December 31, 2025 19:36
Signed-off-by: njhill <nickhill123@gmail.com>
Signed-off-by: njhill <nickhill123@gmail.com>
@vllm-project vllm-project deleted a comment from mergify bot Dec 31, 2025
@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 31, 2025
@njhill njhill marked this pull request as ready for review December 31, 2025 20:09
@njhill njhill requested a review from WoosukKwon as a code owner December 31, 2025 20:09
@vllm-bot vllm-bot merged commit 6c2cfb6 into vllm-project:main Dec 31, 2025
58 of 63 checks passed
@njhill njhill deleted the fix-async-pooling branch December 31, 2025 22:49
wjunLu added a commit to wjunLu/vllm-ascend that referenced this pull request Jan 4, 2026
Signed-off-by: wjunLu <wjunlu217@gmail.com>
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Jan 6, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Rozwel-dx pushed a commit to Rozwel-dx/vllm-ascend that referenced this pull request Jan 8, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Jan 9, 2026
aipaes pushed a commit to aipaes/vllm-ascend that referenced this pull request Jan 15, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
Signed-off-by: njhill <nickhill123@gmail.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
### What this PR does / why we need it?

Upgrade vllm commit to 0105 (8be6432bdaf6275664d857b1e5e9bf8ed1ce299e)

1. Remove `maybe_padded_num_tokens` arg in `model_runner_v1.py` since
vllm-project/vllm#31517 deleted unused arg

2. Remove dense `Qwen/Qwen3-0.6B` in
`tests/e2e/multicard/test_aclgraph_capture_replay.py` and
`tests/e2e/multicard/test_data_parallel.py` due to
vllm-project/vllm#30739
where offline data parallel mode will not be supported/useful for dense
models

3. Adapt `vllm_ascend/worker/worker.py` due to
vllm-project/vllm#31584

4. Adapt `self.block_size` calling due to
vllm-project/vllm#31540

5. Modify `test_mla_v1.py` due to
vllm-project/vllm#28454 , which refactorred
`get_head_size()`

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@7157596

Signed-off-by: wjunLu <wjunlu217@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[CI Failure]: Pooling models (Classification model) + tp failure in Full CI run

3 participants