diff --git a/tests/entrypoints/pooling/embed/test_io_processor.py b/tests/entrypoints/pooling/embed/test_io_processor.py index f25911b661f5..341ccbd5f0c5 100644 --- a/tests/entrypoints/pooling/embed/test_io_processor.py +++ b/tests/entrypoints/pooling/embed/test_io_processor.py @@ -4,6 +4,7 @@ import pytest +from vllm import PoolingParams from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor from vllm.entrypoints.pooling.embed.protocol import ( CohereEmbedContent, @@ -218,6 +219,7 @@ class TestPreProcessCohereOnline: def _make_context(**request_kwargs) -> PoolingServeContext[CohereEmbedRequest]: return PoolingServeContext( request=CohereEmbedRequest(model="test", **request_kwargs), + pooling_params=PoolingParams(), model_name="test", request_id="embd-test", ) @@ -233,13 +235,13 @@ def test_text_only_without_task_prefix_uses_completion_path(self): ctx = self._make_context(texts=["hello"]) calls: list[tuple[str, object]] = [] - def preprocess_completion(request, prompt_input, prompt_embeds): + def preprocess_cmpl_online(request, prompt_input, prompt_embeds): calls.append(("completion", prompt_input)) return ["completion"] handler._get_task_instruction_prefix = lambda _input_type: None handler._has_chat_template = lambda: False - handler._preprocess_completion_online = preprocess_completion + handler._preprocess_cmpl_online = preprocess_cmpl_online handler._batch_render_chat = lambda *_args, **_kwargs: ( pytest.fail("text-only request should not require chat rendering") ) @@ -254,7 +256,7 @@ def test_text_only_falls_back_to_prefixed_completion_without_template(self): ctx = self._make_context(texts=["hello"], input_type="query") calls: list[tuple[str, object]] = [] - def preprocess_completion(request, prompt_input, prompt_embeds): + def preprocess_cmpl(request, prompt_input, prompt_embeds): calls.append(("completion", prompt_input)) return ["fallback"] @@ -263,7 +265,7 @@ def preprocess_completion(request, prompt_input, prompt_embeds): handler._batch_render_chat = lambda *_args, **_kwargs: ( pytest.fail("chat rendering should be skipped without a template") ) - handler._preprocess_completion_online = preprocess_completion + handler._preprocess_cmpl_online = preprocess_cmpl handler._pre_process_cohere_online(ctx) @@ -297,7 +299,7 @@ def batch_render_chat( handler._get_task_instruction_prefix = lambda _input_type: "query: " handler._has_chat_template = lambda: True handler._batch_render_chat = batch_render_chat - handler._preprocess_completion_online = lambda *_args, **_kwargs: ( + handler._preprocess_cmpl_online = lambda *_args, **_kwargs: ( pytest.fail("completion path should be skipped when a template exists") ) diff --git a/vllm/entrypoints/pooling/base/io_processor.py b/vllm/entrypoints/pooling/base/io_processor.py index 79f350382ddd..83e82664ef1b 100644 --- a/vllm/entrypoints/pooling/base/io_processor.py +++ b/vllm/entrypoints/pooling/base/io_processor.py @@ -72,7 +72,7 @@ def pre_process_online(self, ctx: PoolingServeContext): default_template_kwargs=None, ) elif isinstance(request, PoolingCompletionLikeRequest): - engine_inputs = self._preprocess_completion_online( + engine_inputs = self._preprocess_cmpl_online( request, prompt_input=request.input, prompt_embeds=None, @@ -82,21 +82,12 @@ def pre_process_online(self, ctx: PoolingServeContext): ctx.engine_inputs = engine_inputs - async def pre_process_online_async(self, ctx: PoolingServeContext): - self.pre_process_online(ctx) - def post_process_online( self, ctx: PoolingServeContext, ): pass - async def post_process_online_async( - self, - ctx: PoolingServeContext, - ): - self.post_process_online(ctx) - ####################################### # offline APIs @@ -109,12 +100,7 @@ def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput tok_params = self.renderer.default_cmpl_tok_params.with_kwargs( **(ctx.tokenization_kwargs or {}) ) - return self._preprocess_completion_offline( - prompts=prompts_seq, tok_params=tok_params - ) - - async def pre_process_offline_async(self, ctx: OfflineInputsContext): - return self.pre_process_offline(ctx) + return self._preprocess_cmpl_offline(prompts=prompts_seq, tok_params=tok_params) def post_process_offline( self, @@ -122,16 +108,10 @@ def post_process_offline( ) -> list[PoolingRequestOutput]: return ctx.outputs - async def post_process_offline_async( - self, - ctx: OfflineOutputsContext, - ) -> list[PoolingRequestOutput]: - return self.post_process_offline(ctx) - ####################################### # helpers - def _preprocess_completion_online( + def _preprocess_cmpl_online( self, request: RendererRequest, prompt_input: str | list[str] | list[int] | list[list[int]] | None, @@ -209,7 +189,7 @@ def _preprocess_chat_online( return conversation, [engine_input] - def _preprocess_completion_offline( + def _preprocess_cmpl_offline( self, prompts: PromptType | Sequence[PromptType], tok_params: TokenizeParams, diff --git a/vllm/entrypoints/pooling/base/serving.py b/vllm/entrypoints/pooling/base/serving.py index 0b83bf4c381b..c65bedc70f17 100644 --- a/vllm/entrypoints/pooling/base/serving.py +++ b/vllm/entrypoints/pooling/base/serving.py @@ -3,9 +3,11 @@ from abc import ABC, abstractmethod from collections.abc import AsyncGenerator, Mapping +from concurrent.futures import Executor from http import HTTPStatus from typing import ClassVar +import torch from fastapi import Request from fastapi.responses import Response from starlette.datastructures import Headers @@ -32,7 +34,7 @@ log_tracing_disabled_warning, ) from vllm.utils import random_uuid -from vllm.utils.async_utils import merge_async_iterators +from vllm.utils.async_utils import make_async, merge_async_iterators from .io_processor import PoolingIOProcessor @@ -67,16 +69,47 @@ def __init__( trust_request_chat_template=trust_request_chat_template, ) - @abstractmethod + # Shared thread pool executor for preprocessing and postprocessing. + self._executor: Executor = models.renderer._executor + self._preprocessing_async = make_async( + self._preprocessing, executor=self._executor + ) + self._postprocessing_async = make_async( + self._postprocessing, executor=self._executor + ) + async def __call__( self, request: AnyPoolingRequest, raw_request: Request | None = None, ) -> Response: + io_processor = self.get_io_processor(request) + ctx = await self._init_ctx(io_processor, request, raw_request) + await self._preprocessing_async(io_processor, ctx) + await self._prepare_generators(ctx) + await self._collect_batch(ctx) + return await self._postprocessing_async(io_processor, ctx) + + @abstractmethod + def get_io_processor(self, request: AnyPoolingRequest) -> PoolingIOProcessor: raise NotImplementedError + @torch.inference_mode() + def _preprocessing( + self, io_processor: PoolingIOProcessor, ctx: PoolingServeContext + ): + return io_processor.pre_process_online(ctx) + + @torch.inference_mode() + def _postprocessing( + self, io_processor: PoolingIOProcessor, ctx: PoolingServeContext + ): + io_processor.post_process_online(ctx) + return self._build_response(ctx) + async def _init_ctx( self, + io_processor: PoolingIOProcessor, request: AnyPoolingRequest, raw_request: Request | None = None, ): @@ -84,10 +117,12 @@ async def _init_ctx( request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}" await self._check_model(request) + pooling_params = io_processor.create_pooling_params(request) ctx = PoolingServeContext( request=request, raw_request=raw_request, model_name=model_name, + pooling_params=pooling_params, request_id=request_id, ) @@ -175,7 +210,7 @@ async def _collect_batch( ctx.final_res_batch = [res for res in final_res_batch if res is not None] @abstractmethod - async def _build_response( + def _build_response( self, ctx: PoolingServeContext, ) -> Response: @@ -362,18 +397,5 @@ def init_io_processor( ) -> PoolingIOProcessor: raise NotImplementedError - async def __call__( - self, - request: AnyPoolingRequest, - raw_request: Request | None = None, - ) -> Response: - ctx = await self._init_ctx(request, raw_request) - await self.io_processor.pre_process_online_async(ctx) - - if ctx.pooling_params is None: - ctx.pooling_params = self.io_processor.create_pooling_params(request) - - await self._prepare_generators(ctx) - await self._collect_batch(ctx) - await self.io_processor.post_process_online_async(ctx) - return await self._build_response(ctx) + def get_io_processor(self, request: AnyPoolingRequest) -> PoolingIOProcessor: + return self.io_processor diff --git a/vllm/entrypoints/pooling/classify/serving.py b/vllm/entrypoints/pooling/classify/serving.py index 0a729075bce0..a48ec819b7fd 100644 --- a/vllm/entrypoints/pooling/classify/serving.py +++ b/vllm/entrypoints/pooling/classify/serving.py @@ -31,7 +31,7 @@ class ServingClassification(PoolingServing): def init_io_processor(self, *args, **kwargs) -> ClassifyIOProcessor: return ClassifyIOProcessor(*args, **kwargs) - async def _build_response( + def _build_response( self, ctx: ClassificationServeContext, ) -> JSONResponse: diff --git a/vllm/entrypoints/pooling/embed/io_processor.py b/vllm/entrypoints/pooling/embed/io_processor.py index d7ee6ae5813f..09016253f09e 100644 --- a/vllm/entrypoints/pooling/embed/io_processor.py +++ b/vllm/entrypoints/pooling/embed/io_processor.py @@ -470,7 +470,7 @@ def _preprocess_cohere_text_completion( truncate_prompt_tokens=truncate_prompt_tokens, truncation_side=truncation_side, ) - return self._preprocess_completion_online( + return self._preprocess_cmpl_online( proxy, prompt_input=proxy.input, prompt_embeds=None ) @@ -579,7 +579,7 @@ def pre_process_online(self, ctx: PoolingServeContext): query=text_prompts[-1], docs=text_prompts[:-1] ) - engine_inputs = self._preprocess_completion_online( + engine_inputs = self._preprocess_cmpl_online( request, prompt_input=prompt_input, prompt_embeds=None, diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index 3abf0c7f3fd4..9389309efc77 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -53,15 +53,15 @@ def __init__(self, *args, **kwargs): def init_io_processor(self, *args, **kwargs) -> EmbedIOProcessor: return EmbedIOProcessor(*args, **kwargs) - async def _build_response( + def _build_response( self, ctx: PoolingServeContext, ) -> Response: if isinstance(ctx.request, CohereEmbedRequest): return self._build_cohere_response_from_ctx(ctx) - return await self._build_openai_response(ctx) + return self._build_openai_response(ctx) - async def _build_openai_response( + def _build_openai_response( self, ctx: EmbeddingServeContext, ) -> JSONResponse | StreamingResponse: diff --git a/vllm/entrypoints/pooling/pooling/serving.py b/vllm/entrypoints/pooling/pooling/serving.py index 3d16c1f2cc9d..ea8eff6eb30a 100644 --- a/vllm/entrypoints/pooling/pooling/serving.py +++ b/vllm/entrypoints/pooling/pooling/serving.py @@ -7,11 +7,11 @@ from functools import partial from typing import Literal, cast -from fastapi import Request from fastapi.responses import JSONResponse, Response, StreamingResponse from typing_extensions import assert_never from vllm.entrypoints.openai.engine.protocol import UsageInfo +from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor from vllm.entrypoints.pooling.base.serving import PoolingServingBase from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors from vllm.entrypoints.pooling.pooling.protocol import ( @@ -57,27 +57,10 @@ def __init__( ) self.json_response_cls = get_json_response_cls() - async def __call__( - self, - request: AnyPoolingRequest, - raw_request: Request | None = None, - ) -> Response: + def get_io_processor(self, request: AnyPoolingRequest) -> PoolingIOProcessor: assert isinstance(request, PoolingRequest) pooling_task = self._verify_pooling_task(request) - - io_processor = self.io_processors[pooling_task] - ctx = await self._init_ctx(request, raw_request) - - await io_processor.pre_process_online_async(ctx) - - if ctx.pooling_params is None: - ctx.pooling_params = io_processor.create_pooling_params(request) - - await self._prepare_generators(ctx) - await self._collect_batch(ctx) - - await io_processor.post_process_online_async(ctx) - return await self._build_response(ctx) + return self.io_processors[pooling_task] def _verify_pooling_task(self, request: PoolingRequest) -> str: if getattr(request, "dimensions", None) is not None: @@ -117,7 +100,7 @@ def _verify_pooling_task(self, request: PoolingRequest) -> str: return pooling_task - async def _build_response( + def _build_response( self, ctx: PoolingServeContext, ) -> Response: diff --git a/vllm/entrypoints/pooling/scoring/io_processor.py b/vllm/entrypoints/pooling/scoring/io_processor.py index 01c340d52525..549bae2775d5 100644 --- a/vllm/entrypoints/pooling/scoring/io_processor.py +++ b/vllm/entrypoints/pooling/scoring/io_processor.py @@ -220,7 +220,7 @@ def _pre_process( scoring_data.data_2, "document", self.model_config ) - return self._preprocess_completion_offline( + return self._preprocess_cmpl_offline( prompts=data_1 + data_2, tok_params=tok_params, prompt_extras=prompt_extras ) @@ -682,7 +682,7 @@ def _pre_process( for q, d in zip(queries, docs) ] - return self._preprocess_completion_offline( + return self._preprocess_cmpl_offline( prompts=prompts, tok_params=tok_params, prompt_extras=prompt_extras ) diff --git a/vllm/entrypoints/pooling/scoring/serving.py b/vllm/entrypoints/pooling/scoring/serving.py index aef1a1fd0f1f..df866efd56ea 100644 --- a/vllm/entrypoints/pooling/scoring/serving.py +++ b/vllm/entrypoints/pooling/scoring/serving.py @@ -65,7 +65,7 @@ async def __call__(self, *args, **kwargs) -> Response: return await self.flash_late_interaction(*args, **kwargs) - async def _build_response( + def _build_response( self, ctx: ScoringServeContext, ) -> JSONResponse: @@ -183,17 +183,15 @@ def _request_output_to_rerank_response( ### Can significantly improve late-interaction scoring performance. async def flash_late_interaction(self, *args, **kwargs) -> Response: - ctx = await self._init_ctx(*args, **kwargs) - ctx.pooling_params = self.io_processor.create_pooling_params(ctx.request) - await self.io_processor.pre_process_online_async(ctx) + ctx = await self._init_ctx(self.io_processor, *args, **kwargs) + await self._preprocessing_async(self.io_processor, ctx) # stage 1: encode queries and cache token embeddings on workers. await self._flash_late_interaction_encode_queries(ctx) # stage 2: encode docs and return scalar scores from workers. await self._flash_late_interaction_encode_docs(ctx) - await self.io_processor.post_process_online_async(ctx) - return await self._build_response(ctx) + return await self._postprocessing_async(self.io_processor, ctx) async def _flash_late_interaction_encode_queries(self, ctx: ScoringServeContext): assert ctx.n_queries is not None diff --git a/vllm/entrypoints/pooling/typing.py b/vllm/entrypoints/pooling/typing.py index e485dd8722d5..0ce9d5c8384c 100644 --- a/vllm/entrypoints/pooling/typing.py +++ b/vllm/entrypoints/pooling/typing.py @@ -69,9 +69,9 @@ class PoolingServeContext(Generic[PoolingRequestT]): raw_request: Request | None = None model_name: str request_id: str + pooling_params: PoolingParams | list[PoolingParams] created_time: int = field(default_factory=lambda: int(time.time())) lora_request: LoRARequest | None = None - pooling_params: PoolingParams | list[PoolingParams] | None = None engine_inputs: Sequence[EngineInput] | None = None prompt_request_ids: list[str] | None = None intermediates: Any | None = None