Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class MockModelConfig:
skip_tokenizer_init = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
renderer_num_workers: int = 1

def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@ class MockModelConfig:
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
renderer_num_workers: int = 1

def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class MockModelConfig:
skip_tokenizer_init = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
renderer_num_workers: int = 1

def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class MockModelConfig:
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
renderer_num_workers: int = 1

def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
Expand Down
1 change: 1 addition & 0 deletions tests/renderers/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class MockModelConfig:
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
renderer_num_workers: int = 1


@dataclass
Expand Down
1 change: 1 addition & 0 deletions tests/renderers/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class MockModelConfig:
skip_tokenizer_init: bool = False
is_encoder_decoder: bool = False
is_multimodal_model: bool = False
renderer_num_workers: int = 1


@dataclass
Expand Down
4 changes: 4 additions & 0 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,10 @@ class ModelConfig:
definitions"""
io_processor_plugin: str | None = None
"""IOProcessor plugin name to load at model startup"""
renderer_num_workers: int = 1
"""Number of worker threads in the renderer thread pool. This pool
handles async tokenization, chat template rendering, and multimodal
preprocessing."""

# Pooler config
pooler_config: PoolerConfig | None = None
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ class EngineArgs:
MultiModalConfig.mm_encoder_attn_backend
)
io_processor_plugin: str | None = None
renderer_num_workers: int = 1
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
video_pruning_rate: float | None = MultiModalConfig.video_pruning_rate
mm_tensor_ipc: MMTensorIPC = MultiModalConfig.mm_tensor_ipc
Expand Down Expand Up @@ -767,6 +768,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
model_group.add_argument(
"--io-processor-plugin", **model_kwargs["io_processor_plugin"]
)
model_group.add_argument(
"--renderer-num-workers",
**model_kwargs["renderer_num_workers"],
)

# Model loading arguments
load_kwargs = get_kwargs(LoadConfig)
Expand Down Expand Up @@ -1438,6 +1443,7 @@ def create_model_config(self) -> ModelConfig:
video_pruning_rate=self.video_pruning_rate,
mm_tensor_ipc=self.mm_tensor_ipc,
io_processor_plugin=self.io_processor_plugin,
renderer_num_workers=self.renderer_num_workers,
)

def validate_tensorizer_args(self):
Expand Down
124 changes: 118 additions & 6 deletions vllm/renderers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from concurrent.futures import Executor, ThreadPoolExecutor
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, overload

Expand Down Expand Up @@ -38,7 +39,10 @@
from vllm.multimodal.processing import ProcessorInputs as MMProcessorInputs
from vllm.multimodal.registry import MultiModalTimingRegistry
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
from vllm.utils.async_utils import (
AsyncMicrobatchTokenizer,
make_async,
)
from vllm.utils.counter import AtomicCounter
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.metrics.stats import MultiModalCacheStats
Expand Down Expand Up @@ -78,11 +82,28 @@ def __init__(self, config: "VllmConfig", tokenizer: _T | None) -> None:

self.tokenizer = tokenizer

# Shared thread pool executor for blocking tokenizer and
# multimodal preprocessing operations. The multimodal processor
# receives a deep-copied tokenizer (see #36557) so it is safe to
# run tokenization and MM preprocessing concurrently.
pool_workers = config.model_config.renderer_num_workers
self._executor = ThreadPoolExecutor(max_workers=pool_workers)

# Multimodal preprocessing is always offloaded to the thread pool
# to keep the asyncio event loop responsive under concurrent load.
self._mm_executor: Executor = self._executor

Comment on lines +85 to +95
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think this PoolExecutor should be placed in the entrypoint rather than the renderer.
Of course, I don't want this to block the PR.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I think we should explore using ProcessPoolExecutor.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regardless, the key improvement comes from offloading preprocessing off the event loop, which is very significant.

# Lazy initialization since offline LLM doesn't use async
self._async_tokenizer: AsyncMicrobatchTokenizer | None = None

self.mm_processor: BaseMultiModalProcessor | None = None
self._mm_cache_stats: MultiModalCacheStats | None = None
self._clear_mm_cache_async = make_async(
self.clear_mm_cache, executor=self._executor
)
self._process_multimodal_async = make_async(
self._process_multimodal, executor=self._mm_executor
)
if config.model_config.is_multimodal_model:
mm_processor_cache = mm_registry.processor_cache_from_config(config)

Expand Down Expand Up @@ -119,7 +140,9 @@ def get_tokenizer(self) -> _T:

def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
if self._async_tokenizer is None:
self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer())
self._async_tokenizer = AsyncMicrobatchTokenizer(
self.get_tokenizer(), executor=self._executor
)

return self._async_tokenizer

Expand Down Expand Up @@ -211,11 +234,24 @@ def warmup(self, chat_params: ChatParams) -> None:
finally:
self.clear_mm_cache()

async def clear_mm_cache_async(self) -> None:
"""Serialize clear_mm_cache through the shared executor to avoid
races with concurrent process_inputs on the mm_processor_cache."""
await self._clear_mm_cache_async()

def shutdown(self) -> None:
mm_processor_cache = self.mm_processor_cache
if mm_processor_cache is not None:
mm_processor_cache.close()

if executor := getattr(self, "_executor", None):
Comment thread
scyyh11 marked this conversation as resolved.
executor.shutdown(wait=False)

if (
mm_executor := getattr(self, "_mm_executor", None)
) is not None and mm_executor is not executor:
mm_executor.shutdown(wait=False)

def get_bos_token_id(self) -> int | None:
if self.tokenizer is None:
logger.warning_once(
Expand Down Expand Up @@ -621,6 +657,9 @@ def _process_tokens(
self,
prompt: TokensPrompt,
) -> TokensInput | MultiModalInput:
"""Process token inputs, with multimodal preprocessing offloaded
to the shared thread pool in the async variant.
"""
prompt_token_ids = prompt["prompt_token_ids"]

engine_input: TokensInput | MultiModalInput
Expand Down Expand Up @@ -670,12 +709,46 @@ def _process_embeds(self, prompt: EmbedsPrompt) -> EmbedsInput:
cache_salt=prompt.get("cache_salt"),
)

async def _process_tokens_async(
self,
prompt: TokensPrompt,
) -> TokensInput | MultiModalInput:
prompt_token_ids = prompt["prompt_token_ids"]

engine_input: TokensInput | MultiModalInput
if multi_modal_data := prompt.get("multi_modal_data"):
engine_input = await self._process_multimodal_async(
prompt_token_ids,
multi_modal_data,
mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
tokenization_kwargs=None,
mm_uuids=prompt.get("multi_modal_uuids"),
)
else:
engine_input = tokens_input(prompt_token_ids)

if prompt_text := prompt.get("prompt"):
engine_input["prompt"] = prompt_text
if cache_salt := prompt.get("cache_salt"):
engine_input["cache_salt"] = cache_salt

return engine_input

def _process_singleton(self, prompt: SingletonTokPrompt) -> SingletonInput:
if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type]

return self._process_tokens(prompt) # type: ignore[arg-type]

async def _process_singleton_async(
self,
prompt: SingletonTokPrompt,
) -> SingletonInput:
if "prompt_embeds" in prompt:
return self._process_embeds(prompt) # type: ignore[arg-type]

return await self._process_tokens_async(prompt) # type: ignore[arg-type]

def _process_enc_dec(
self,
prompt: EncoderDecoderTokPrompt,
Expand All @@ -699,6 +772,28 @@ def _process_enc_dec(
skip_decoder_start_token=skip_decoder_start_token,
)

async def _process_enc_dec_async(
self,
prompt: EncoderDecoderTokPrompt,
) -> EncoderDecoderInput:
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]

encoder_input, decoder_input = await asyncio.gather(
self._process_singleton_async(enc_prompt),
(
asyncio.sleep(0)
if dec_prompt is None
else self._process_singleton_async(dec_prompt)
),
)

return build_enc_dec_input(
encoder_input=encoder_input,
decoder_input=decoder_input,
decoder_start_token_id=self.get_dec_start_token_id(),
)

def process_for_engine(self, prompt: TokPrompt, arrival_time: float) -> EngineInput:
engine_input: EngineInput
if "encoder_prompt" in prompt:
Expand All @@ -710,6 +805,21 @@ def process_for_engine(self, prompt: TokPrompt, arrival_time: float) -> EngineIn

return engine_input

async def process_for_engine_async(
self, prompt: TokPrompt, arrival_time: float
) -> EngineInput:
engine_input: EngineInput
if "encoder_prompt" in prompt:
engine_input = await self._process_enc_dec_async(
prompt # type: ignore[arg-type]
)
else:
engine_input = await self._process_singleton_async(prompt)

engine_input["arrival_time"] = arrival_time

return engine_input

# Top-level methods
def render_cmpl(
self,
Expand Down Expand Up @@ -747,7 +857,9 @@ async def render_cmpl_async(

self._apply_prompt_extras(tok_prompts, prompt_extras)

return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
return await asyncio.gather(
*(self.process_for_engine_async(p, arrival_time) for p in tok_prompts)
)

def render_chat(
self,
Expand Down Expand Up @@ -811,8 +923,8 @@ async def render_chat_async(

self._apply_prompt_extras(tok_prompts, prompt_extras)

eng_prompts = [
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
]
eng_prompts = await asyncio.gather(
*(self.process_for_engine_async(p, arrival_time) for p in tok_prompts)
)

return out_conversations, eng_prompts
22 changes: 18 additions & 4 deletions vllm/renderers/deepseek_v32.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm.config import VllmConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
Expand All @@ -9,6 +10,7 @@
)
from vllm.logger import init_logger
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
from vllm.utils.async_utils import make_async

from .base import BaseRenderer
from .inputs import DictPrompt
Expand All @@ -19,12 +21,25 @@


class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
def __init__(
self,
config: VllmConfig,
tokenizer: DeepseekV32Tokenizer | None,
) -> None:
super().__init__(config, tokenizer)

self._apply_chat_template_async = make_async(
self._apply_chat_template, executor=self._executor
)

def _apply_chat_template(self, *args, **kwargs):
return self.get_tokenizer().apply_chat_template(*args, **kwargs)

def render_messages(
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
self.model_config,
Expand All @@ -33,7 +48,7 @@ def render_messages(
mm_processor_kwargs=params.mm_processor_kwargs,
)

prompt_raw = tokenizer.apply_chat_template(
prompt_raw = self._apply_chat_template(
conversation=conversation,
messages=messages,
**params.get_apply_chat_template_kwargs(),
Expand All @@ -52,7 +67,6 @@ async def render_messages_async(
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
self.model_config,
Expand All @@ -61,7 +75,7 @@ async def render_messages_async(
mm_processor_kwargs=params.mm_processor_kwargs,
)

prompt_raw = tokenizer.apply_chat_template(
prompt_raw = await self._apply_chat_template_async(
conversation=conversation,
messages=messages,
**params.get_apply_chat_template_kwargs(),
Expand Down
Loading
Loading