Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
import warnings
from collections.abc import AsyncGenerator, Iterable, Mapping
from concurrent.futures import ThreadPoolExecutor
from copy import copy
from dataclasses import dataclass
from typing import Any
Expand Down Expand Up @@ -140,6 +141,11 @@ def __init__(
self.model_config.io_processor_plugin,
)

# Single-thread for running process_inputs() in background
# to avoid blocking the asyncio event loop during preprocessing.
# max_workers=1 ensures sequential execution, no need for locks.
self._input_processor_executor = ThreadPoolExecutor(max_workers=1)

# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(
self.tokenizer,
Expand Down Expand Up @@ -278,13 +284,52 @@ def shutdown(self):
if input_processor := getattr(self, "input_processor", None):
input_processor.close()

if executor := getattr(self, "_input_processor_executor", None):
executor.shutdown(wait=False)
Comment on lines +287 to +288
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling shutdown(wait=False) will not cancel pending tasks in the executor. This can delay the process exit if there are long-running preprocessing tasks in the queue.

In Python 3.9+, ThreadPoolExecutor.shutdown() accepts a cancel_futures=True argument, which will cancel any pending tasks that have not started. This allows for a more graceful and faster shutdown.

This suggestion adds a version check to use this feature where available, making the shutdown process more robust.

Suggested change
if executor := getattr(self, "_input_processor_executor", None):
executor.shutdown(wait=False)
if executor := getattr(self, "_input_processor_executor", None):
import sys
if sys.version_info >= (3, 9):
executor.shutdown(wait=False, cancel_futures=True)
else:
executor.shutdown(wait=False)


handler = getattr(self, "output_handler", None)
if handler is not None:
cancel_task_threadsafe(handler)

async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return await self.engine_core.get_supported_tasks_async()

async def _process_inputs_async(
self,
request_id: str,
prompt: PromptType,
params: SamplingParams | PoolingParams,
arrival_time: float | None = None,
lora_request: LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
data_parallel_rank: int | None = None,
resumable: bool = False,
) -> EngineCoreRequest:
"""Run process_inputs() in a background thread to avoid blocking event loop.

Uses a single-threaded executor to ensure sequential execution,
which naturally provides thread safety without explicit locks.
"""

def _run():
return self.input_processor.process_inputs(
request_id,
prompt,
params,
arrival_time,
lora_request,
tokenization_kwargs,
trace_headers,
priority,
data_parallel_rank,
resumable,
)

loop = asyncio.get_event_loop()
return await loop.run_in_executor(self._input_processor_executor, _run)

async def add_request(
self,
request_id: str,
Expand Down Expand Up @@ -352,7 +397,7 @@ async def add_request(
raise ValueError(
"should only provide prompt_text with EngineCoreRequest"
)
request = self.input_processor.process_inputs(
request = await self._process_inputs_async(
request_id,
prompt,
params,
Expand Down Expand Up @@ -447,7 +492,7 @@ async def _add_streaming_input_request(

# Create request for validation, also used as the finished signal
# once the input stream is closed.
final_req = self.input_processor.process_inputs(
final_req = await self._process_inputs_async(
request_id=request_id,
prompt=TokensPrompt(prompt_token_ids=[0]),
params=sampling_params,
Expand All @@ -467,7 +512,7 @@ async def handle_inputs():
self._validate_streaming_input_sampling_params(sp)
else:
sp = sampling_params
req = self.input_processor.process_inputs(
req = await self._process_inputs_async(
request_id=internal_req_id,
prompt=input_chunk.prompt,
params=sp,
Expand Down