diff --git a/vllm/entrypoints/pooling/base/io_processor.py b/vllm/entrypoints/pooling/base/io_processor.py index fd4c076cdda0..3a9c454681e7 100644 --- a/vllm/entrypoints/pooling/base/io_processor.py +++ b/vllm/entrypoints/pooling/base/io_processor.py @@ -82,7 +82,31 @@ def pre_process_online(self, ctx: PoolingServeContext): ctx.engine_inputs = engine_inputs async def pre_process_online_async(self, ctx: PoolingServeContext): - self.pre_process_online(ctx) + request = ctx.request + + if isinstance(ctx.request, PoolingChatLikeRequest): + self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + _, engine_inputs = await self._preprocess_chat_online_async( + request, + request.messages, + default_template=self.chat_template, + default_template_content_format=self.chat_template_content_format, + default_template_kwargs=None, + ) + elif isinstance(request, PoolingCompletionLikeRequest): + engine_inputs = await self._preprocess_completion_online_async( + request, + prompt_input=request.input, + prompt_embeds=None, + ) + else: + raise ValueError(f"Invalid {self.name} request type") + + ctx.engine_inputs = engine_inputs def post_process_online( self, @@ -109,7 +133,13 @@ def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput ) async def pre_process_offline_async(self, ctx: OfflineInputsContext): - return self.pre_process_offline(ctx) + assert not isinstance(ctx.prompts, ScoringData) + tok_params = self.renderer.default_cmpl_tok_params.with_kwargs( + **(ctx.tokenization_kwargs or {}) + ) + return await self._preprocess_completion_offline_async( + prompts=ctx.prompts, tok_params=tok_params + ) def post_process_offline( self, @@ -126,6 +156,104 @@ async def post_process_offline_async( ####################################### # helpers + async def _preprocess_completion_online_async( + self, + request: RendererRequest, + prompt_input: str | list[str] | list[int] | list[list[int]] | None, + prompt_embeds: bytes | list[bytes] | None, + ) -> list[EngineInput]: + renderer = self.renderer + model_config = self.model_config + + prompts = list[SingletonPrompt | bytes]() + if prompt_embeds is not None: # embeds take higher priority + prompts.extend(prompt_to_seq(prompt_embeds)) + if prompt_input is not None: + prompts.extend(prompt_to_seq(prompt_input)) + + parsed_prompts = [ + ( + prompt + if isinstance(prompt, bytes) + else parse_model_prompt(model_config, prompt) + ) + for prompt in prompts + ] + tok_params = request.build_tok_params(model_config) + + return await renderer.render_cmpl_async( + parsed_prompts, + tok_params, + prompt_extras={ + k: v + for k in ("mm_processor_kwargs", "cache_salt") + if (v := getattr(request, k, None)) is not None + }, + ) + + async def _preprocess_chat_online_async( + self, + request: RendererChatRequest, + messages: list[ChatCompletionMessageParam], + default_template: str | None, + default_template_content_format: ChatTemplateContentFormatOption, + default_template_kwargs: dict[str, Any] | None, + tool_dicts: list[dict[str, Any]] | None = None, + tool_parser: type[ToolParser] | None = None, + ) -> tuple[list[ConversationMessage], list[EngineInput]]: + renderer = self.renderer + + default_template_kwargs = merge_kwargs( + default_template_kwargs, + dict( + tools=tool_dicts, + tokenize=is_mistral_tokenizer(renderer.tokenizer), + ), + ) + + mm_config = self.model_config.multimodal_config + + tok_params = request.build_tok_params(self.model_config) + chat_params = request.build_chat_params( + default_template, default_template_content_format + ).with_defaults( + default_template_kwargs, + default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None), + ) + + (conversation,), (engine_input,) = await renderer.render_chat_async( + [messages], + chat_params, + tok_params, + prompt_extras={ + k: v + for k in ("mm_processor_kwargs", "cache_salt") + if (v := getattr(request, k, None)) is not None + }, + ) + + return conversation, [engine_input] + + async def _preprocess_completion_offline_async( + self, + prompts: "PromptType | Sequence[PromptType]", + tok_params: TokenizeParams, + prompt_extras: dict[str, Any] | None = None, + ) -> Sequence[EngineInput]: + prompts = prompt_to_seq(prompts) + parsed_prompts = [ + ( + prompt + if isinstance(prompt, bytes) + else parse_model_prompt(self.model_config, prompt) + ) + for prompt in prompts + ] + + return await self.renderer.render_cmpl_async( + parsed_prompts, tok_params, prompt_extras=prompt_extras + ) + def _preprocess_completion_online( self, request: RendererRequest, diff --git a/vllm/entrypoints/pooling/embed/io_processor.py b/vllm/entrypoints/pooling/embed/io_processor.py index 614f8e0d9d02..cad89c0e3419 100644 --- a/vllm/entrypoints/pooling/embed/io_processor.py +++ b/vllm/entrypoints/pooling/embed/io_processor.py @@ -65,6 +65,15 @@ def pre_process_online(self, ctx: PoolingServeContext): if self.enable_chunked_processing: self._pre_process_chunked(ctx) + async def pre_process_online_async(self, ctx: PoolingServeContext): + if isinstance(ctx.request, CohereEmbedRequest): + await self._pre_process_cohere_online_async(ctx) + else: + await super().pre_process_online_async(ctx) + + if self.enable_chunked_processing: + self._pre_process_chunked(ctx) + def post_process_online( self, ctx: PoolingServeContext, @@ -438,6 +447,82 @@ def _pre_process_cohere_online(self, ctx: PoolingServeContext) -> None: request, all_messages, truncate_prompt_tokens, truncation_side ) + async def _pre_process_cohere_online_async(self, ctx: PoolingServeContext) -> None: + """Async version of ``_pre_process_cohere_online``. + + Uses the async renderer path so that CPU-bound multimodal + preprocessing runs in a thread pool instead of blocking the + asyncio event loop. + """ + request = ctx.request + assert isinstance(request, CohereEmbedRequest) + + if request.texts is None and request.images is None and request.inputs is None: + raise ValueError("One of texts, images, or inputs must be provided") + + truncate_prompt_tokens, truncation_side = self._resolve_cohere_truncation( + request + ) + input_type = request.input_type + self._validate_input_type(input_type) + + if request.images is not None: + input: list[CohereEmbedInput] = [ + CohereEmbedInput( + content=[ + CohereEmbedContent(type="image_url", image_url={"url": uri}) + ] + ) + for uri in request.images + ] + elif request.inputs is not None: + input = request.inputs + else: + texts = request.texts or [] + task_prefix = self._get_task_instruction_prefix(input_type) + + if task_prefix is None: + ctx.engine_inputs = await self._preprocess_cohere_text_completion_async( + request, + texts, + truncate_prompt_tokens, + truncation_side, + ) + return + + all_messages = [ + self._mixed_input_to_messages( + CohereEmbedInput( + content=[CohereEmbedContent(type="text", text=text)] + ), + task_prefix=task_prefix, + ) + for text in texts + ] + if self._has_chat_template(): + ctx.engine_inputs = await self._batch_render_chat_async( + request, + all_messages, + truncate_prompt_tokens, + truncation_side, + ) + else: + ctx.engine_inputs = await self._preprocess_cohere_text_completion_async( + request, + self._apply_task_instruction(texts, input_type), + truncate_prompt_tokens, + truncation_side, + ) + return + + task_prefix = self._get_task_instruction_prefix(input_type) + all_messages = [ + self._mixed_input_to_messages(inp, task_prefix=task_prefix) for inp in input + ] + ctx.engine_inputs = await self._batch_render_chat_async( + request, all_messages, truncate_prompt_tokens, truncation_side + ) + def _has_chat_template(self) -> bool: return ( resolve_chat_template( @@ -509,6 +594,69 @@ def _batch_render_chat( _, engine_inputs = renderer.render_chat(all_messages, chat_params, tok_params) return engine_inputs + async def _batch_render_chat_async( + self, + request: CohereEmbedRequest, + all_messages: Sequence[list[ChatCompletionMessageParam]], + truncate_prompt_tokens: int | None, + truncation_side: Literal["left", "right"] | None, + ) -> list[EngineInput]: + """Async version of ``_batch_render_chat``.""" + if not all_messages: + return [] + + proxy = EmbeddingChatRequest( + model=request.model, + messages=list(all_messages[0]), + dimensions=request.output_dimension, + encoding_format="float", + truncate_prompt_tokens=truncate_prompt_tokens, + truncation_side=truncation_side, + ) + + renderer = self.renderer + mm_config = self.model_config.multimodal_config + + tok_params = proxy.build_tok_params(self.model_config) + chat_params = proxy.build_chat_params( + self.chat_template, + self.chat_template_content_format, + ).with_defaults( + merge_kwargs( + None, + dict( + tools=None, + tokenize=is_mistral_tokenizer(renderer.tokenizer), + ), + ), + default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None), + ) + + _, engine_inputs = await renderer.render_chat_async( + all_messages, chat_params, tok_params + ) + return engine_inputs + + async def _preprocess_cohere_text_completion_async( + self, + request: CohereEmbedRequest, + texts: list[str], + truncate_prompt_tokens: int | None, + truncation_side: Literal["left", "right"] | None, + ) -> list[EngineInput]: + """Async version of ``_preprocess_cohere_text_completion``.""" + proxy = EmbeddingCompletionRequest( + model=request.model, + input=texts, + dimensions=request.output_dimension, + encoding_format="float", + truncate_prompt_tokens=truncate_prompt_tokens, + truncation_side=truncation_side, + ) + return await self._preprocess_completion_online_async( + proxy, prompt_input=proxy.input, prompt_embeds=None + ) + def _validate_input_type(self, input_type: str | None) -> None: """Raise if *input_type* is not supported by this model.""" if input_type is None: diff --git a/vllm/entrypoints/pooling/scoring/io_processor.py b/vllm/entrypoints/pooling/scoring/io_processor.py index 70fe1b221412..468adab64506 100644 --- a/vllm/entrypoints/pooling/scoring/io_processor.py +++ b/vllm/entrypoints/pooling/scoring/io_processor.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio import time from collections.abc import Sequence from typing import Any, TypeAlias, cast @@ -96,6 +97,33 @@ def pre_process_online(self, ctx: ScoringServeContext): ctx.engine_inputs = engine_inputs ctx.intermediates = len(scoring_data.data_1) + async def pre_process_online_async(self, ctx: ScoringServeContext): + request = ctx.request + + if isinstance(request, ScoreRequest): + data_1 = request.data_1 + data_2 = request.data_2 + elif isinstance(request, RerankRequest): + data_1 = request.query + data_2 = request.documents + else: + raise ValueError(f"Invalid {self.name} request type") + + scoring_data = self.valid_inputs(data_1, data_2) + tok_params = request.build_tok_params(self.model_config) + engine_inputs = await self._pre_process_async( + scoring_data, + tok_params, + prompt_extras={ + k: v + for k in ("mm_processor_kwargs", "cache_salt") + if (v := getattr(request, k, None)) is not None + }, + ) + + ctx.engine_inputs = engine_inputs + ctx.intermediates = len(scoring_data.data_1) + def post_process_online( self, ctx: ScoringServeContext, @@ -145,6 +173,21 @@ def _pre_process( prompts=data_1 + data_2, tok_params=tok_params, prompt_extras=prompt_extras ) + async def _pre_process_async( + self, + scoring_data: ScoringData, + tok_params: TokenizeParams, + prompt_extras: dict[str, Any] | None = None, + ) -> Sequence[EngineInput]: + data_1 = score_data_to_prompts(scoring_data.data_1, "query", self.model_config) + data_2 = score_data_to_prompts( + scoring_data.data_2, "document", self.model_config + ) + + return await self._preprocess_completion_offline_async( + prompts=data_1 + data_2, tok_params=tok_params, prompt_extras=prompt_extras + ) + def _post_process(self, outputs: list[PoolingRequestOutput], offset: int): emb_data_1 = outputs[:offset] emb_data_2 = outputs[offset:] @@ -269,6 +312,37 @@ def pre_process_online(self, ctx: ScoringServeContext): ctx.engine_inputs = engine_inputs ctx.pooling_params = pooling_params_list + async def pre_process_online_async(self, ctx: ScoringServeContext): + request = ctx.request + + if isinstance(request, ScoreRequest): + data_1 = request.data_1 + data_2 = request.data_2 + elif isinstance(request, RerankRequest): + data_1 = request.query + data_2 = request.documents + else: + raise ValueError(f"Invalid {self.name} request type") + + scoring_data = self.valid_inputs(data_1, data_2) + tok_params = request.build_tok_params(self.model_config) + pooling_params = self.create_pooling_params(request) + + engine_inputs, pooling_params_list = await self._pre_process_async( + scoring_data, + tok_params, + pooling_params, + chat_template=self.chat_template, + prompt_extras={ + k: v + for k in ("mm_processor_kwargs", "cache_salt") + if (v := getattr(request, k, None)) is not None + }, + ) + + ctx.engine_inputs = engine_inputs + ctx.pooling_params = pooling_params_list + ####################################### # offline APIs @@ -332,6 +406,55 @@ def _pre_process( ) return engine_inputs, pooling_params_list + async def _pre_process_async( + self, + scoring_data: ScoringData, + tok_params: TokenizeParams, + pooling_params: PoolingParams | None, + chat_template: str | None = None, + prompt_extras: dict[str, Any] | None = None, + ) -> tuple[Sequence[EngineInput], list[PoolingParams]]: + # todo: support prompt_extras + arrival_time = time.time() + + data_1 = scoring_data.data_1 + data_2 = scoring_data.data_2 + + if len(data_1) == 1: + data_1 = data_1 * len(data_2) + + if pooling_params is None: + pooling_params = PoolingParams(task="classify") + + pooling_params_list = list[PoolingParams]() + engine_prompts = list[TokensPrompt]() + for q, d in zip(data_1, data_2): + _, engine_prompt = self.get_score_prompt( + data_1=q, + data_2=d, + encode_kwargs=tok_params.get_encode_kwargs(), + chat_template=chat_template, + ) + + if token_type_ids := engine_prompt.pop("token_type_ids", None): + params = pooling_params.clone() + compressed = compress_token_type_ids(token_type_ids) + params.extra_kwargs = {"compressed_token_type_ids": compressed} + pooling_params_list.append(params) + else: + pooling_params_list.append(pooling_params) + + tok_params.apply_post_tokenization(self.tokenizer, engine_prompt) + engine_prompts.append(engine_prompt) + + engine_inputs = await asyncio.gather( + *( + self.renderer.process_for_engine_async(prompt, arrival_time) + for prompt in engine_prompts + ) + ) + return list(engine_inputs), pooling_params_list + def get_score_prompt( self, data_1: ScoreData,