Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions tests/entrypoints/pooling/embed/test_io_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest

from vllm import PoolingParams
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
from vllm.entrypoints.pooling.embed.protocol import (
CohereEmbedContent,
Expand Down Expand Up @@ -218,6 +219,7 @@ class TestPreProcessCohereOnline:
def _make_context(**request_kwargs) -> PoolingServeContext[CohereEmbedRequest]:
return PoolingServeContext(
request=CohereEmbedRequest(model="test", **request_kwargs),
pooling_params=PoolingParams(),
model_name="test",
request_id="embd-test",
)
Expand All @@ -233,13 +235,13 @@ def test_text_only_without_task_prefix_uses_completion_path(self):
ctx = self._make_context(texts=["hello"])
calls: list[tuple[str, object]] = []

def preprocess_completion(request, prompt_input, prompt_embeds):
def preprocess_cmpl_online(request, prompt_input, prompt_embeds):
calls.append(("completion", prompt_input))
return ["completion"]

handler._get_task_instruction_prefix = lambda _input_type: None
handler._has_chat_template = lambda: False
handler._preprocess_completion_online = preprocess_completion
handler._preprocess_cmpl_online = preprocess_cmpl_online
handler._batch_render_chat = lambda *_args, **_kwargs: (
pytest.fail("text-only request should not require chat rendering")
)
Expand All @@ -254,7 +256,7 @@ def test_text_only_falls_back_to_prefixed_completion_without_template(self):
ctx = self._make_context(texts=["hello"], input_type="query")
calls: list[tuple[str, object]] = []

def preprocess_completion(request, prompt_input, prompt_embeds):
def preprocess_cmpl(request, prompt_input, prompt_embeds):
calls.append(("completion", prompt_input))
return ["fallback"]

Expand All @@ -263,7 +265,7 @@ def preprocess_completion(request, prompt_input, prompt_embeds):
handler._batch_render_chat = lambda *_args, **_kwargs: (
pytest.fail("chat rendering should be skipped without a template")
)
handler._preprocess_completion_online = preprocess_completion
handler._preprocess_cmpl_online = preprocess_cmpl

handler._pre_process_cohere_online(ctx)

Expand Down Expand Up @@ -297,7 +299,7 @@ def batch_render_chat(
handler._get_task_instruction_prefix = lambda _input_type: "query: "
handler._has_chat_template = lambda: True
handler._batch_render_chat = batch_render_chat
handler._preprocess_completion_online = lambda *_args, **_kwargs: (
handler._preprocess_cmpl_online = lambda *_args, **_kwargs: (
pytest.fail("completion path should be skipped when a template exists")
)

Expand Down
28 changes: 4 additions & 24 deletions vllm/entrypoints/pooling/base/io_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def pre_process_online(self, ctx: PoolingServeContext):
default_template_kwargs=None,
)
elif isinstance(request, PoolingCompletionLikeRequest):
engine_inputs = self._preprocess_completion_online(
engine_inputs = self._preprocess_cmpl_online(
request,
prompt_input=request.input,
prompt_embeds=None,
Expand All @@ -82,21 +82,12 @@ def pre_process_online(self, ctx: PoolingServeContext):

ctx.engine_inputs = engine_inputs

async def pre_process_online_async(self, ctx: PoolingServeContext):
self.pre_process_online(ctx)

def post_process_online(
self,
ctx: PoolingServeContext,
):
pass

async def post_process_online_async(
self,
ctx: PoolingServeContext,
):
self.post_process_online(ctx)

#######################################
# offline APIs

Expand All @@ -109,29 +100,18 @@ def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput
tok_params = self.renderer.default_cmpl_tok_params.with_kwargs(
**(ctx.tokenization_kwargs or {})
)
return self._preprocess_completion_offline(
prompts=prompts_seq, tok_params=tok_params
)

async def pre_process_offline_async(self, ctx: OfflineInputsContext):
return self.pre_process_offline(ctx)
return self._preprocess_cmpl_offline(prompts=prompts_seq, tok_params=tok_params)

def post_process_offline(
self,
ctx: OfflineOutputsContext,
) -> list[PoolingRequestOutput]:
return ctx.outputs

async def post_process_offline_async(
self,
ctx: OfflineOutputsContext,
) -> list[PoolingRequestOutput]:
return self.post_process_offline(ctx)

#######################################
# helpers

def _preprocess_completion_online(
def _preprocess_cmpl_online(
self,
request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
Expand Down Expand Up @@ -209,7 +189,7 @@ def _preprocess_chat_online(

return conversation, [engine_input]

def _preprocess_completion_offline(
def _preprocess_cmpl_offline(
self,
prompts: PromptType | Sequence[PromptType],
tok_params: TokenizeParams,
Expand Down
58 changes: 40 additions & 18 deletions vllm/entrypoints/pooling/base/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Mapping
from concurrent.futures import Executor
from http import HTTPStatus
from typing import ClassVar

import torch
from fastapi import Request
from fastapi.responses import Response
from starlette.datastructures import Headers
Expand All @@ -32,7 +34,7 @@
log_tracing_disabled_warning,
)
from vllm.utils import random_uuid
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.async_utils import make_async, merge_async_iterators

from .io_processor import PoolingIOProcessor

Expand Down Expand Up @@ -67,27 +69,60 @@ def __init__(
trust_request_chat_template=trust_request_chat_template,
)

@abstractmethod
# Shared thread pool executor for preprocessing and postprocessing.
self._executor: Executor = models.renderer._executor
self._preprocessing_async = make_async(
self._preprocessing, executor=self._executor
)
self._postprocessing_async = make_async(
self._postprocessing, executor=self._executor
)

async def __call__(
self,
request: AnyPoolingRequest,
raw_request: Request | None = None,
) -> Response:
io_processor = self.get_io_processor(request)
ctx = await self._init_ctx(io_processor, request, raw_request)
await self._preprocessing_async(io_processor, ctx)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
return await self._postprocessing_async(io_processor, ctx)

@abstractmethod
def get_io_processor(self, request: AnyPoolingRequest) -> PoolingIOProcessor:
raise NotImplementedError

@torch.inference_mode()
def _preprocessing(
self, io_processor: PoolingIOProcessor, ctx: PoolingServeContext
):
return io_processor.pre_process_online(ctx)

@torch.inference_mode()
def _postprocessing(
self, io_processor: PoolingIOProcessor, ctx: PoolingServeContext
):
io_processor.post_process_online(ctx)
return self._build_response(ctx)

async def _init_ctx(
self,
io_processor: PoolingIOProcessor,
request: AnyPoolingRequest,
raw_request: Request | None = None,
):
model_name = self.models.model_name()
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
await self._check_model(request)

pooling_params = io_processor.create_pooling_params(request)
ctx = PoolingServeContext(
request=request,
raw_request=raw_request,
model_name=model_name,
pooling_params=pooling_params,
request_id=request_id,
)

Expand Down Expand Up @@ -175,7 +210,7 @@ async def _collect_batch(
ctx.final_res_batch = [res for res in final_res_batch if res is not None]

@abstractmethod
async def _build_response(
def _build_response(
self,
ctx: PoolingServeContext,
) -> Response:
Expand Down Expand Up @@ -362,18 +397,5 @@ def init_io_processor(
) -> PoolingIOProcessor:
raise NotImplementedError

async def __call__(
self,
request: AnyPoolingRequest,
raw_request: Request | None = None,
) -> Response:
ctx = await self._init_ctx(request, raw_request)
await self.io_processor.pre_process_online_async(ctx)

if ctx.pooling_params is None:
ctx.pooling_params = self.io_processor.create_pooling_params(request)

await self._prepare_generators(ctx)
await self._collect_batch(ctx)
await self.io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)
def get_io_processor(self, request: AnyPoolingRequest) -> PoolingIOProcessor:
return self.io_processor
2 changes: 1 addition & 1 deletion vllm/entrypoints/pooling/classify/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ServingClassification(PoolingServing):
def init_io_processor(self, *args, **kwargs) -> ClassifyIOProcessor:
return ClassifyIOProcessor(*args, **kwargs)

async def _build_response(
def _build_response(
self,
ctx: ClassificationServeContext,
) -> JSONResponse:
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/pooling/embed/io_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def _preprocess_cohere_text_completion(
truncate_prompt_tokens=truncate_prompt_tokens,
truncation_side=truncation_side,
)
return self._preprocess_completion_online(
return self._preprocess_cmpl_online(
proxy, prompt_input=proxy.input, prompt_embeds=None
)

Expand Down Expand Up @@ -579,7 +579,7 @@ def pre_process_online(self, ctx: PoolingServeContext):
query=text_prompts[-1], docs=text_prompts[:-1]
)

engine_inputs = self._preprocess_completion_online(
engine_inputs = self._preprocess_cmpl_online(
request,
prompt_input=prompt_input,
prompt_embeds=None,
Expand Down
6 changes: 3 additions & 3 deletions vllm/entrypoints/pooling/embed/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ def __init__(self, *args, **kwargs):
def init_io_processor(self, *args, **kwargs) -> EmbedIOProcessor:
return EmbedIOProcessor(*args, **kwargs)

async def _build_response(
def _build_response(
self,
ctx: PoolingServeContext,
) -> Response:
if isinstance(ctx.request, CohereEmbedRequest):
return self._build_cohere_response_from_ctx(ctx)
return await self._build_openai_response(ctx)
return self._build_openai_response(ctx)

async def _build_openai_response(
def _build_openai_response(
self,
ctx: EmbeddingServeContext,
) -> JSONResponse | StreamingResponse:
Expand Down
25 changes: 4 additions & 21 deletions vllm/entrypoints/pooling/pooling/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from functools import partial
from typing import Literal, cast

from fastapi import Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from typing_extensions import assert_never

from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
from vllm.entrypoints.pooling.base.serving import PoolingServingBase
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
from vllm.entrypoints.pooling.pooling.protocol import (
Expand Down Expand Up @@ -57,27 +57,10 @@ def __init__(
)
self.json_response_cls = get_json_response_cls()

async def __call__(
self,
request: AnyPoolingRequest,
raw_request: Request | None = None,
) -> Response:
def get_io_processor(self, request: AnyPoolingRequest) -> PoolingIOProcessor:
assert isinstance(request, PoolingRequest)
pooling_task = self._verify_pooling_task(request)

io_processor = self.io_processors[pooling_task]
ctx = await self._init_ctx(request, raw_request)

await io_processor.pre_process_online_async(ctx)

if ctx.pooling_params is None:
ctx.pooling_params = io_processor.create_pooling_params(request)

await self._prepare_generators(ctx)
await self._collect_batch(ctx)

await io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)
return self.io_processors[pooling_task]

def _verify_pooling_task(self, request: PoolingRequest) -> str:
if getattr(request, "dimensions", None) is not None:
Expand Down Expand Up @@ -117,7 +100,7 @@ def _verify_pooling_task(self, request: PoolingRequest) -> str:

return pooling_task

async def _build_response(
def _build_response(
self,
ctx: PoolingServeContext,
) -> Response:
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/pooling/scoring/io_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _pre_process(
scoring_data.data_2, "document", self.model_config
)

return self._preprocess_completion_offline(
return self._preprocess_cmpl_offline(
prompts=data_1 + data_2, tok_params=tok_params, prompt_extras=prompt_extras
)

Expand Down Expand Up @@ -682,7 +682,7 @@ def _pre_process(
for q, d in zip(queries, docs)
]

return self._preprocess_completion_offline(
return self._preprocess_cmpl_offline(
prompts=prompts, tok_params=tok_params, prompt_extras=prompt_extras
)

Expand Down
10 changes: 4 additions & 6 deletions vllm/entrypoints/pooling/scoring/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def __call__(self, *args, **kwargs) -> Response:

return await self.flash_late_interaction(*args, **kwargs)

async def _build_response(
def _build_response(
self,
ctx: ScoringServeContext,
) -> JSONResponse:
Expand Down Expand Up @@ -183,17 +183,15 @@ def _request_output_to_rerank_response(
### Can significantly improve late-interaction scoring performance.

async def flash_late_interaction(self, *args, **kwargs) -> Response:
ctx = await self._init_ctx(*args, **kwargs)
ctx.pooling_params = self.io_processor.create_pooling_params(ctx.request)
await self.io_processor.pre_process_online_async(ctx)
ctx = await self._init_ctx(self.io_processor, *args, **kwargs)
await self._preprocessing_async(self.io_processor, ctx)

# stage 1: encode queries and cache token embeddings on workers.
await self._flash_late_interaction_encode_queries(ctx)
# stage 2: encode docs and return scalar scores from workers.
await self._flash_late_interaction_encode_docs(ctx)

await self.io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)
return await self._postprocessing_async(self.io_processor, ctx)

async def _flash_late_interaction_encode_queries(self, ctx: ScoringServeContext):
assert ctx.n_queries is not None
Expand Down
Loading
Loading