Skip to content

[Bugfix] Offload blocking tokenizer ops to shared thread pool to unblock event loop#34789

Merged
vllm-bot merged 7 commits intovllm-project:mainfrom
scyyh11:fix/event-loop-blocking-multimodal-main
Mar 27, 2026
Merged

[Bugfix] Offload blocking tokenizer ops to shared thread pool to unblock event loop#34789
vllm-bot merged 7 commits intovllm-project:mainfrom
scyyh11:fix/event-loop-blocking-multimodal-main

Conversation

@scyyh11
Copy link
Copy Markdown
Contributor

@scyyh11 scyyh11 commented Feb 18, 2026

Purpose

Fix event loop blocking caused by multimodal request preprocessing (base64 decoding, image transforms, HF processor operations) and chat template rendering. Under high concurrency, these synchronous CPU-bound operations block the asyncio event loop, causing /health, /v1/models, and /metrics endpoints to become unresponsive (P95 latency >200ms, with spikes over 1s).

Changes:

  • Add a shared ThreadPoolExecutor on BaseRenderer (size controlled by --preprocessing-thread-pool-workers, default 1)
  • Always offload multimodal preprocessing to the shared thread pool to keep the event loop responsive
  • Wrap chat template rendering in HfRenderer, MistralRenderer, DeepseekV32Renderer, and Grok2Renderer with the shared executor via make_async
  • Consolidate MistralRenderer's separate ThreadPoolExecutor into the shared one
  • Serialize clear_mm_cache through the shared executor to avoid races with concurrent process_inputs on the mm_processor_cache

Test Plan

Benchmarked on 1x NVIDIA A100-SXM4-80GB using vllm bench serve with --request-rate 20 --num-prompts 200 and a custom high-concurrency benchmark with PaddleOCR-VL-1.5.

Tests performed:

  1. vllm bench serve with Llama-3.1-8B-Instruct (text-only, --request-rate 20 --num-prompts 200)
  2. vllm bench serve with Qwen2.5-VL-7B-Instruct (multimodal, --request-rate 20 --num-prompts 200)
  3. Custom high-concurrency benchmark with PaddleOCR-VL-1.5 (500 real OmniDocBench images, 300 concurrency)
  4. --preprocessing-thread-pool-workers comparison (1 vs 2 vs 4) with PaddleOCR-VL-1.5

Test Results

1. Text-Only (meta-llama/Llama-3.1-8B-Instruct)

vllm bench serve --request-rate 20 --num-prompts 200

Metric This PR Main Diff
Throughput (req/s) 14.77 14.80 -0.2%
Output tok/s 1,889.97 1,893.27 -0.2%
Mean TTFT (ms) 278.02 273.71 +1.6%
P99 TTFT (ms) 518.97 525.17 -1.2%
Mean TPOT (ms) 54.14 53.79 +0.7%

No regression. All metrics within noise.

2. Multimodal (Qwen/Qwen2.5-VL-7B-Instruct)

vllm bench serve --request-rate 20 --num-prompts 200 --backend openai-chat --dataset-name random-mm

Metric This PR Main Diff
Throughput (req/s) 6.74 6.73 +0.1%
Output tok/s 862.46 861.79 +0.1%
Mean TTFT (ms) 7,489.43 7,483.03 +0.1%
P99 TTFT (ms) 17,005.28 16,974.36 +0.2%
Mean TPOT (ms) 125.91 126.07 -0.1%

No regression. All metrics within noise.

3. PaddleOCR-VL-1.5 High Concurrency (500 prompts, 300 concurrency)

Custom benchmark with real OmniDocBench document images.
--max-num-batched-tokens 131072 --no-enable-prefix-caching --mm-processor-cache-gb 0 --gpu-memory-utilization 0.5

Metric This PR Main Diff
Throughput (req/s) 5.88 5.67 +3.7%
Token throughput (tok/s) 4,712.97 4,504.46 +4.6%
TTFT mean (ms) 31,117.71 33,058.79 -5.9%
TTFT P99 (ms) 47,273.22 51,613.86 -8.4%
/health median (ms) 0.70 222.44 318x better
/health P99 (ms) 18.88 1,641.53 87x better

Event loop stays fully responsive under high multimodal concurrency. The /health endpoint drops from 222ms to <1ms median.

4. --preprocessing-thread-pool-workers Comparison (PaddleOCR-VL-1.5, 500 prompts, 300 concurrency)

Metric workers=1 workers=2 workers=4
Throughput (req/s) 5.86 5.92 5.87
Token throughput (tok/s) 4,624.78 4,559.03 4,554.05
TTFT mean (ms) 31,319 30,974 31,294
TTFT P99 (ms) 47,536 47,084 47,830
/health median (ms) 0.67 0.62 0.67
/health P99 (ms) 17.20 17.47 20.13

All worker counts perform identically. This is consistent with #34789 (comment).
The key improvement comes from offloading preprocessing off the event loop (so /health stays responsive), not from parallelizing it. Default of workers=1 is sufficient.

Summary

What Result
Event loop liveness (/health) 318x improvement (222ms → 0.7ms median)
Request throughput (high concurrency) +3.7% (5.67 → 5.88 req/s)
TTFT (high concurrency) -5.9% (33.1s → 31.1s mean)
Text-only regression None (-0.2% throughput, within noise)
Multimodal regression None (+0.1% throughput, within noise)
"Already borrowed" errors Zero across all tests

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify Bot added the v1 label Feb 18, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Feb 18, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @scyyh11.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added needs-rebase bug Something isn't working labels Feb 18, 2026
@scyyh11 scyyh11 force-pushed the fix/event-loop-blocking-multimodal-main branch 2 times, most recently from 88dc973 to 787320b Compare February 18, 2026 08:43
@mergify mergify Bot removed the needs-rebase label Feb 18, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses event loop blocking caused by synchronous, CPU-bound operations like tokenization and multimodal preprocessing. The core change, introducing a shared single-threaded ThreadPoolExecutor in BaseRenderer, is a well-justified and clean solution to serialize access to non-thread-safe HuggingFace tokenizers while unblocking the event loop. The refactoring is consistently applied across various renderers and engine components, centralizing the execution of these blocking tasks. The provided test results clearly demonstrate a significant improvement in API endpoint latency under load. The code is of high quality, and I have no issues to report.

@DarkLight1337
Copy link
Copy Markdown
Member

DarkLight1337 commented Feb 18, 2026

This looks like an adaptation of #33337 to the new Renderer code. Can you run the same benchmark as #33337 to see if there are any regressions (for both MM and text-only models)? If so, we should gate this behind a flag. cc @yzhu802

Comment thread vllm/v1/engine/input_processor.py Outdated
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Feb 19, 2026

Hi @scyyh11, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@scyyh11 scyyh11 force-pushed the fix/event-loop-blocking-multimodal-main branch from 87804ae to c66d990 Compare February 19, 2026 01:27
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Feb 19, 2026

Hi @scyyh11, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@scyyh11
Copy link
Copy Markdown
Contributor Author

scyyh11 commented Feb 19, 2026

Benchmark Results

Benchmarked using vllm bench serve with --request-rate 20 --num-prompts 200 on a single A100.

Text-only (meta-llama/Llama-3.1-8B-Instruct)

Metric main this PR diff
Throughput (req/s) 13.72 14.27 +4.0%
Output tok/s 1756.47 1827.01 +4.0%
Mean TTFT (ms) 432.02 342.51 -20.7%
P99 TTFT (ms) 972.08 669.58 -31.1%
Mean TPOT (ms) 64.57 59.82 -7.4%

Multimodal (Qwen/Qwen2.5-VL-7B-Instruct)

Metric main this PR diff
Throughput (req/s) 4.67 4.47 -4.3%
Output tok/s 597.32 572.39 -4.2%
Mean TTFT (ms) 18355.36 18376.02 +0.1%
P99 TTFT (ms) 29880.78 31762.12 +6.3%
Mean TPOT (ms) 143.60 158.13 +10.1%

I added conditional offloading in _process_for_engine_async — only multimodal requests are offloaded to the thread pool, while text-only requests are processed directly on the event loop since they only do lightweight dict creation. This avoids serializing all requests through a single thread. No regression for text-only models. Multimodal shows a small throughput decrease (~4%) which is likely within run-to-run variance given the high absolute TTFT values (~18s). TTFT for multimodal is essentially unchanged (+0.1%).

@DarkLight1337

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Feb 19, 2026

Hi @scyyh11, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

1 similar comment
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Feb 19, 2026

Hi @scyyh11, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@scyyh11 scyyh11 force-pushed the fix/event-loop-blocking-multimodal-main branch from 6d12527 to 876e55a Compare February 19, 2026 01:55
@DarkLight1337
Copy link
Copy Markdown
Member

4% is actually pretty significant. I suggest rerunning the benchmark a few more times to reduce the variance and be more sure.

@scyyh11 scyyh11 force-pushed the fix/event-loop-blocking-multimodal-main branch 2 times, most recently from 8a7c128 to d3abd89 Compare February 19, 2026 03:42
@scyyh11
Copy link
Copy Markdown
Contributor Author

scyyh11 commented Feb 19, 2026

4% is actually pretty significant. I suggest rerunning the benchmark a few more times to reduce the variance and be more sure.

Multimodal (Qwen/Qwen2.5-VL-7B-Instruct) — 3-run average

Metric main (avg ± std) this PR (avg ± std) diff
Throughput (req/s) 4.77 ± 0.04 4.56 ± 0.16 -4.4%
Output tok/s 611.01 ± 6.35 583.28 ± 20.09 -4.5%
Mean TTFT (ms) 17264.45 ± 476 18565.44 ± 1622 +7.5%
P99 TTFT (ms) 28976.39 ± 385 30999.60 ± 1459 +7.0%
Mean TPOT (ms) 144.32 ± 1.00 150.37 ± 2.17 +4.2%

I agree that there is gap and it probably comes from thread pool context switching overhead. What's your implementation sugguestions? Should I make this configurable?

@DarkLight1337
Copy link
Copy Markdown
Member

DarkLight1337 commented Feb 19, 2026

Yes, like the other PR we should make this opt-in.

MockModelConfig was missing the renderer_num_workers attribute
introduced by the thread pool changes, causing 23 test failures
in CI when BaseRenderer.__init__ tried to read it.

Signed-off-by: Bvicii <yizhanhuang2002@gmail.com>
auto-merge was automatically disabled March 26, 2026 16:26

Head branch was pushed to by a user without write access

@scyyh11 scyyh11 requested a review from DarkLight1337 March 26, 2026 16:27
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) March 26, 2026 16:27
@scyyh11
Copy link
Copy Markdown
Contributor Author

scyyh11 commented Mar 26, 2026

Both CI failures are known flaky tests on upstream main, tracked in open issues. Neither is related to our changes.

  1. test_run_eagle_dp[FLASH_ATTN]Test Failure: test_run_eagle_dp[FLASH_ATTN] produces non-deterministic outputs with EAGLE speculative decoding #38234

  2. test_abort_during_final_step[False]Flaky test: test_abort_during_final_step[False] fails intermittently #38221

Our PR does not touch speculative decoding, EAGLE, data parallelism, engine abort logic, KV connectors or scheduler output handling.

@scyyh11
Copy link
Copy Markdown
Contributor Author

scyyh11 commented Mar 27, 2026

Hi @DarkLight1337, the CI failed but not due to our changes, could you manually merge?

@noooop
Copy link
Copy Markdown
Collaborator

noooop commented Mar 27, 2026

unblock Language Models Test (Extended Pooling) & Multi-Modal Models (Extended Pooling)

@vllm-bot vllm-bot merged commit 999dfc1 into vllm-project:main Mar 27, 2026
63 of 65 checks passed
nithinvc pushed a commit to nithinvc/vllm that referenced this pull request Mar 27, 2026
…ock event loop (vllm-project#34789)

Signed-off-by: Bvicii <yizhanhuang2002@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>

Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
@scyyh11 scyyh11 deleted the fix/event-loop-blocking-multimodal-main branch March 27, 2026 23:26
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
…ock event loop (vllm-project#34789)

Signed-off-by: Bvicii <yizhanhuang2002@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants