Skip to content

fix bugs when token_classify & classify run concurrently#36614

Merged
vllm-bot merged 1 commit intovllm-project:mainfrom
staugust:update_all_pool
Mar 11, 2026
Merged

fix bugs when token_classify & classify run concurrently#36614
vllm-bot merged 1 commit intovllm-project:mainfrom
staugust:update_all_pool

Conversation

@staugust
Copy link
Copy Markdown
Contributor

@staugust staugust commented Mar 10, 2026

Purpose

Fix bugs for *For*Classification, *ClassificationModel models that runs token_classify and classify concurrently.

vllm version 0.17.0

steps to reproduce

pip install vllm==0.17.0
python \
  -m vllm.entrypoints.openai.api_server \
  --model Qwen/Qwen3-Reranker-8B \
  --max-num-batched-tokens 65536 \
  --max-num-seqs 64 \
  --port 31001 \
  --tensor-parallel-size 1 \
  --served-model-name auto \
  --runner pooling \
  --enable-prompt-tokens-details \
  --no-enable-prefix-caching \
  --log-error-stack \
  --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' \
  --trust-remote-code


# then do loop curl 

for i in {1..200}; do
curl -X POST "http://127.0.0.1:31001/pooling" \
-H "Content-Type: application/json" \
-d '{
  "model": "auto",
  "task": "token_classify",
  "input": [ "why", "ok"]
}' &

curl -X POST "http://127.0.0.1:31001/classify" \
-H "Content-Type: application/json" \
-d '{
  "model": "auto",
  "user": "abc",
  "input": ["why", "ok"]
}' &

done
wait

Error log shows that hidden_states and pooling_cursor.num_scheduled_tokens_cpu mismatched. From DispatchPooler.forward passes whole hidden_states of a batch to AllPool, while num_scheduled_tokens_cpu in pooling_metadat set in

num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices],
and
pooling_metadata[offset : offset + num_items],
is a subset for current pooling_task

(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]   File "/opt/conda/envs/os/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]     return self._call_impl(*args, **kwargs)
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]   File "/opt/conda/envs/os/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]     return forward_call(*args, **kwargs)
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]   File "/opt/conda/envs/os/lib/python3.10/site-packages/vllm/model_executor/layers/pooler/special.py", line 101, in forward
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]     group_output: PoolerOutput = pooler(
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]   File "/opt/conda/envs/os/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]     return self._call_impl(*args, **kwargs)
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]   File "/opt/conda/envs/os/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]     return forward_call(*args, **kwargs)
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]   File "/opt/conda/envs/os/lib/python3.10/site-packages/vllm/model_executor/layers/pooler/tokwise/poolers.py", line 91, in forward
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]     pooled_data = self.pooling(hidden_states, pooling_metadata)
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]   File "/opt/conda/envs/os/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]     return self._call_impl(*args, **kwargs)
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]   File "/opt/conda/envs/os/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]     return forward_call(*args, **kwargs)
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]   File "/opt/conda/envs/os/lib/python3.10/site-packages/vllm/model_executor/layers/pooler/tokwise/methods.py", line 50, in forward
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]     hidden_states_all = hidden_states.split(
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]   File "/opt/conda/envs/os/lib/python3.10/site-packages/torch/_tensor.py", line 1066, in split
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102]     return torch._VF.split_with_sizes(
(EngineCore_DP0 pid=56494) ERROR 03-10 16:43:31 [v1/engine/core.py:1102] RuntimeError: split_with_sizes expects split_sizes to sum exactly to 34 (input tensor's size at dimension 0), but got split_sizes=[1, 1]

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@staugust staugust requested a review from noooop as a code owner March 10, 2026 08:49
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to resolve a critical concurrency bug that caused RuntimeError during simultaneous token_classify and classify operations due to tensor dimension mismatches. While the updated code correctly utilizes first_token_indices_gpu and last_token_indices_gpu for slicing hidden_states to address the crash, it introduces a critical security vulnerability: cross-user data leakage. The current slicing logic uses relative indices, which can lead to tasks incorrectly pooling tokens from the beginning of the batch, potentially exposing one user's data to another. This requires correction by using absolute indices or proper input tensor slicing.

Comment thread vllm/model_executor/layers/pooler/tokwise/methods.py
Comment thread vllm/model_executor/layers/pooler/tokwise/methods.py
Copy link
Copy Markdown
Collaborator

@noooop noooop left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for your fix!

@noooop noooop enabled auto-merge (squash) March 10, 2026 14:32
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 10, 2026
Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
auto-merge was automatically disabled March 10, 2026 14:33

Head branch was pushed to by a user without write access

@noooop noooop enabled auto-merge (squash) March 10, 2026 14:34
@vllm-bot vllm-bot merged commit b386bb3 into vllm-project:main Mar 11, 2026
59 of 63 checks passed
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
…t#36614)

Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
…t#36614)

Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
…t#36614)

Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
…t#36614)

Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
…t#36614)

Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
liuchenbing2026 pushed a commit to liuchenbing2026/vllm that referenced this pull request Apr 4, 2026
…t#36614)

Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
big-yellow-duck pushed a commit to EmbeddedLLM/vllm that referenced this pull request Apr 8, 2026
…t#36614)

Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
…t#36614)

Signed-off-by: augusto.yjh <augusto.yjh@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants