Skip to content

[BUGFIX] Make HF input preprocessing async to prevent frontend event loop blocking#33337

Closed
yzhu802 wants to merge 3 commits intovllm-project:mainfrom
yzhu802:bugfix/async_input_process
Closed

[BUGFIX] Make HF input preprocessing async to prevent frontend event loop blocking#33337
yzhu802 wants to merge 3 commits intovllm-project:mainfrom
yzhu802:bugfix/async_input_process

Conversation

@yzhu802
Copy link
Copy Markdown

@yzhu802 yzhu802 commented Jan 29, 2026

Purpose

Improve the issue where the frontend process event loop becomes blocked under medium to high concurrency with multimodal requests, due to preprocessing in the Hugging Face (HF) processor. This improvement helps enhance the performance of the vLLM framework in multimodal scenarios.

In addition, this change is critical for concurrent testing in the downstream framework vllm-omni. In multi-stage orchestration, the frontend process is responsible for forwarding the computation results of each stage to the next one. If the event loop is blocked, forwarding is delayed, leading to severe pipeline stalls.

Test Plan

The service was deployed on 2× A800 GPUs (tp=2) using Kubernetes, and end-to-end stress tests were conducted under production-like conditions. To eliminate network latency jitter, the client and server communicated via ClusterIP within the cluster.

The test data distribution consisted of 40% text-only, 30% video, and 30% image requests. Samples were drawn with replacement from a large dataset, making cache hits unlikely.

Tests were conducted at QPS = 5, 10, 20, and 50, with the client sending requests at intervals following a Poisson distribution. Each test lasted 10 seconds. The server pod was restarted between tests to eliminate caching effects.

(Test files are not included in this PR. If needed, I can add them to the PR or share them through other means.)

Test Result

Under moderate concurrency (QPS = 20), this simple fix resulted in a 25% reduction in average TTFT and a 37% reduction in p99 TTFT.
Under high concurrency (QPS = 50), it eliminated ServerTimeoutError issues caused by frontend process blocking. With a timeout threshold of 10 seconds, 11.83% of requests triggered this network error in the original implementation.

QPS : 20.0 | before bugfix | after bugfix

Throughput (tok/s) | 139.98 | 152.79
Avg TTFT (ms) | 3455.25 | 2611.99
P99 TTFT (ms) | 7226.91 | 4604.19
Error Rate | 0.00% | 0.00%

QPS : 50.0 | before bugfix | after bugfix

Throughput (tok/s) | 156.83 | 150.32
Avg TTFT (ms) | 10621.45 | 10669.77
P99 TTFT (ms) | 18253.06 | 18711.43
Error Rate | 11.83% | 0.00%

This is a minimal-change solution and is not performance-optimal. To ensure thread safety under multi-threaded execution and to keep all changes contained within a single file for easier review, I intentionally chose a relatively coarse lock granularity.

If you are interested in further improving performance in this scenario, I would be happy to propose an RFC to address it.

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Yufeng Zhu <yzhu802@gatech.edu>
@mergify mergify Bot added v1 bug Something isn't working labels Jan 29, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses the issue of the frontend event loop being blocked by moving the synchronous process_inputs call to a ThreadPoolExecutor. The implementation is clear, and the use of a lock to maintain thread safety for the InputProcessor is a practical solution for this bugfix. My review includes one suggestion to enhance the shutdown process for the newly introduced thread pool, ensuring a more graceful termination.

Comment on lines +288 to +289
if executor := getattr(self, "_input_processor_executor", None):
executor.shutdown(wait=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling shutdown(wait=False) will not cancel pending tasks in the executor. This can delay the process exit if there are long-running preprocessing tasks in the queue.

In Python 3.9+, ThreadPoolExecutor.shutdown() accepts a cancel_futures=True argument, which will cancel any pending tasks that have not started. This allows for a more graceful and faster shutdown.

This suggestion adds a version check to use this feature where available, making the shutdown process more robust.

Suggested change
if executor := getattr(self, "_input_processor_executor", None):
executor.shutdown(wait=False)
if executor := getattr(self, "_input_processor_executor", None):
import sys
if sys.version_info >= (3, 9):
executor.shutdown(wait=False, cancel_futures=True)
else:
executor.shutdown(wait=False)

@yzhu802
Copy link
Copy Markdown
Author

yzhu802 commented Jan 29, 2026

@DarkLight1337
Copy link
Copy Markdown
Member

I tried something similar to this a while ago and found a modest improvement in some cases: #17831

But eventually we decided that API server scale-out is easier to do and more effective.

@yzhu802
Copy link
Copy Markdown
Author

yzhu802 commented Jan 29, 2026

@DarkLight1337 , Thanks for pointing out #17831! I fully agree that api-server scale-out is the most effective way to improve overall throughput and handle CPU-intensive tasks. My PR is actually intended to complement that approach rather than replace it.

While multi-processing increases the total processing capacity, ensuring the event loop remains non-blocking is still critical for system stability, especially under bursty traffic.

Here is why I believe this patch is necessary:

  1. Preventing Connection Timeouts under Load: Even with multiple API servers, if the arrival rate temporarily exceeds the service rate, a synchronous HF processor will block the asyncio event loop. This doesn't just delay the current request—it prevents the server from accepting new connections or maintaining heartbeats, leading to immediate client-side timeouts before the request can even be queued.

  2. Evidence from Stress Tests: In my benchmark (QPS=50), the current implementation resulted in an 11.83% ServerTimeoutError rate. These were not model inference delays, but connection timeouts caused by the frozen event loop. Applying this async fix reduced that error rate to 0%.

In short, while scale-out raises the ceiling, this async fix raises the floor for stability. Given that the changes are contained and use a coarse-grained lock for safety, I think it offers a valuable robustness improvement for edge cases.

@ywang96
Copy link
Copy Markdown
Member

ywang96 commented Feb 1, 2026

Previously, @DarkLight1337 @njhill and I have had some attempts comparsing running the HF processor in MT vs MP - if we were to add MT for async processor we also need to consider the fact that it will not work with the multimodal feature cache at the moment IIUC. Have you tried to compare this PR to main branch with --api-server-count set?

@njhill
Copy link
Copy Markdown
Member

njhill commented Feb 1, 2026

Thanks @yzhu802! I agree it's probably good to do something like this.

I don't think the lock is needed though and I don't see why > 1 thread is needed. We should be able to just use max_workers=1 in the thread pool. Apart from that, the code can hopefully be simplified a bit, and we should double check for any thread-safety implications.

Under moderate concurrency (QPS = 20), this simple fix resulted in a 25% reduction in average TTFT and a 37% reduction in p99 TTFT.
Under high concurrency (QPS = 50), it eliminated ServerTimeoutError issues caused by frontend process blocking. With a timeout threshold of 10 seconds, 11.83% of requests triggered this network error in the original implementation.

Just to have a clear understanding, were these comparisons relative to a baseline with --api-server-count set, and if so to what value?

Signed-off-by: Yufeng Zhu <yzhu802@gatech.edu>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Feb 2, 2026

Hi @yzhu802, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@yzhu802
Copy link
Copy Markdown
Author

yzhu802 commented Feb 2, 2026

Thanks for the valuable feedback!

I have updated the PR to address the complexity concerns and conducted additional benchmarks with "--api-server-count 2" as requested.

1. Code Simplification

@njhill, You are absolutely right. Since the goal is simply to offload the CPU-intensive preprocessing from the asyncio event loop—rather than achieving parallel execution—I have:

Removed the lock.

Set max_workers=1 for the thread pool.

Corrected the legacy threads_num configuration.

Note on Thread Safety & Cache: Using max_workers=1 effectively serializes access to the HF processor/cache, mimicking the original serial behavior but without blocking the main event loop. This should mitigate the concerns regarding the multimodal feature cache thread-safety mentioned by @ywang96.

2. Benchmark with api-server-count=2

I ran the stress test again comparing main vs. this PR with --api-server-count 2.
This time the original implementation did not encounter TimeoutErrors until the QPS reached 100.

While there is a minor overhead in throughput and TTFT due to thread context switching (as expected with the introduction of the thread pool), the critical improvement is the elimination of request failures due to freezing event loop(0.00% Error Rate vs 3.73% Error Rate).


     QPS : 50.0              | before bugfix | after bugfix  
   ------------------------------------------------------------
   Throughput (tok/s)        | 164.78        | 161.91  
   Avg TTFT (ms)             | 9975.97       | 10067.30  
   P99 TTFT (ms)             | 19232.83      | 19666.31
   Error Rate                | 0.00%         | 0.00%  
   ------------------------------------------------------------

     QPS : 100.0             | before bugfix | after bugfix  
   ------------------------------------------------------------
   Throughput (tok/s)        | 175.69        | 170.14  
   Avg TTFT (ms)             | 22650.08      | 23163.35 
   P99 TTFT (ms)             | 41113.30      | 43315.94 
   Error Rate                | 3.73%         | 0.00%  
   ------------------------------------------------------------

3. Context on Downstream Orchestration (vllm-omni)

There is an architectural reason why keeping the event loop non-blocking is critical for us. In vllm-omni, the frontend process acts as a generic orchestrator. It must forward intermediate results from Stage 0 to Stage 1 continuously. If the event loop blocks during HF preprocessing, forwarding halts, causing pipeline stalls across the entire multi-stage setup. This fix ensures the frontend remains responsive to coordinate the pipeline, which is a requirement for advanced multimodal orchestration.
We are currently refactoring the entrypoint for vLLM-Omni; this is the current related RFC:
vllm-project/vllm-omni#967 (comment)

I have uploaded the reproduction scripts to https://github.com/yzhu802/vllm-stress-test for reference.

Signed-off-by: Yufeng Zhu <yzhu802@gatech.edu>
@DarkLight1337
Copy link
Copy Markdown
Member

DarkLight1337 commented Feb 2, 2026

Given the overhead, I think we should make this configurable and have it set to off by default in the main repo.

@njhill
Copy link
Copy Markdown
Member

njhill commented Feb 2, 2026

Thanks @yzhu802, yes based on those results this change appears to actually degrade overall performance in all cases?

To clarify - was --api-server-count=2 set in both the before and after runs this time?

Do you think you could repeat the same test with --api-server-count=4?

Also in these runs, what was the attained QPS in each case? If much lower than the load driver QPS then we're kind of testing an overloaded server situation where we will want to do some load-shedding / rate-limiting anyhow (there's other in-flight work for that...).

Other observations:

  • Isn't process_inputs now called in the API server separately from AsyncLLM? (i.e. here). Were you using vllm bench serve for the benchmarks
  • An idea of something else to try: have an async version of process_inputs and only do the thread offloading lower down the stack if/when appropriate. However this may be undesirable if it introduces a bunch of code duplication.

@yzhu802
Copy link
Copy Markdown
Author

yzhu802 commented Feb 3, 2026

Thanks @njhill and @DarkLight1337 — these are all very fair points! 🙂

Experiment Clarification

Yes, the comparison is fair: both runs use the same end-to-end stress test, with the only difference being git checkout <branch_name>.

For the second experiment, the key parameters are:

--api-server-count 2
--distributed-executor-backend mp
--mm-processor-cache-gb 0

All other settings are kept identical across runs. Full scripts are available here: https://github.com/yzhu802/vllm-stress-test

On Performance vs. Saturation

You're absolutely right that these experiments operate in a saturated / overloaded regime. In a typical inference deployment, this is where load-shedding or rate-limiting should eventually kick in, and I agree this PR is not meant to address that.

Similarly, comparing a single background thread (with context-switch overhead) against the main event loop is not intended to demonstrate a throughput win. Achieving real speedups would require finer-grained locking or true parallelism, which is out of scope here.

(Note In the degenerate case of api-server-count = 1, this change can also incidentally improve TTFT & Throughput when a considerate percent of requests are videos, but that is not the primary motivation.)

Intent of this PR: liveness, not throughput

The core goal of this PR is liveness.

In the current main branch, synchronous HF preprocessing can block the asyncio event loop entirely under load. When that happens, the frontend process becomes unresponsive: it cannot accept new connections, maintain heartbeats, or forward intermediate results. This leads to failure modes that are independent of model inference speed.

With this change, the event loop remains alive even when the system is saturated. While there is a small and expected performance overhead, the frontend avoids pathological failures such as connection timeouts caused by a frozen event loop.

This property is particularly important for orchestration-heavy setups like vllm-omni (an official downstream project under the vLLM org), where the frontend event loop acts as a coordinator between multiple stages and must continue forwarding results even when preprocessing large multimodal inputs.

Configurability

Given the architectural nature of this requirement and the fact that it introduces a small overhead for standard users, I fully agree this should be configurable and off by default.

My Proposal:

Introduce a flag (e.g. --enable-async-preprocessor, default = False)
Default behavior remains exactly the same as today (zero regression)
Users who care about frontend liveness under overload can explicitly opt in
The flag only affects where preprocessing runs (event loop vs. background thread), not what preprocessing does.

Future Work

Longer-term, I’m happy to propose an RFC exploring finer-grained locking or making the multimodal cache thread-safe to enable true parallel preprocessing. I intentionally kept this PR minimal and localized to reduce review and correctness risk.

Happy to update the PR accordingly if this direction sounds reasonable.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Feb 5, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @yzhu802.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 23, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @yzhu802.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 23, 2026
@DarkLight1337
Copy link
Copy Markdown
Member

Closing as superseded by #34789

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working needs-rebase v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants