Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,18 @@
]


@pytest.fixture(scope="module")
def server():
@pytest.fixture(scope="module", params=[True, False])
def server(request):
args = [
"--max-model-len",
str(MAX_MODEL_LEN),
]

# Test run pooling score MaxSim on worker side (GPU)
# aka flash-late-interaction
if not request.param:
args += ["--no-enable-flash-late-interaction"]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server

Expand Down
2 changes: 1 addition & 1 deletion tests/v1/worker/test_late_interaction_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import pytest
import torch

from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
from vllm.pooling_params import LateInteractionParams, PoolingParams
from vllm.v1.pool.late_interaction import (
LATE_INTERACTION_MODE_CACHE_QUERY,
build_late_interaction_doc_params,
build_late_interaction_query_params,
compute_maxsim_score,
)
from vllm.v1.worker.gpu.pool.late_interaction_runner import LateInteractionRunner

Expand Down
6 changes: 3 additions & 3 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,14 +1449,14 @@ def score(

pooling_task = io_processor.pooling_task
scoring_data = io_processor.valid_inputs(data_1, data_2)
offset = len(scoring_data.data_1)
n_queries = len(scoring_data.data_1)

ctx = OfflineInputsContext(
prompts=scoring_data,
pooling_params=pooling_params,
tokenization_kwargs=tokenization_kwargs,
chat_template=chat_template,
offset=offset,
n_queries=n_queries,
)

processor_inputs = io_processor.pre_process_offline(ctx)
Expand Down Expand Up @@ -1487,7 +1487,7 @@ def score(

outputs = self._run_engine(use_tqdm=use_tqdm, output_type=PoolingRequestOutput)
outputs = io_processor.post_process_offline(
ctx=OfflineOutputsContext(outputs=outputs, offset=offset),
ctx=OfflineOutputsContext(outputs=outputs, n_queries=n_queries),
)

return [ScoringRequestOutput.from_base(item) for item in outputs]
Expand Down
3 changes: 3 additions & 0 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ class FrontendArgs(BaseFrontendArgs):
Enable offline FastAPI documentation for air-gapped environments.
Uses vendored static assets bundled with vLLM.
"""
enable_flash_late_interaction: bool = True
"""If set, run pooling score MaxSim on GPU in the API server process.
Can significantly improve late-interaction scoring performance."""

Comment thread
noooop marked this conversation as resolved.
@classmethod
def _customize_cli_kwargs(
Expand Down
3 changes: 3 additions & 0 deletions vllm/entrypoints/pooling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def init_pooling_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
enable_flash_late_interaction=getattr(
args, "enable_flash_late_interaction", True
),
)
if enable_scoring_api(supported_tasks, model_config)
else None
Expand Down
19 changes: 13 additions & 6 deletions vllm/entrypoints/pooling/base/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,20 @@ async def __call__(
request: AnyPoolingRequest,
raw_request: Request | None = None,
) -> Response:
ctx = await self._init_ctx(request, raw_request)
await self.io_processor.pre_process_online_async(ctx)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
await self.io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)

async def _init_ctx(
self,
request: AnyPoolingRequest,
raw_request: Request | None = None,
):
model_name = self.models.model_name()
request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"

await self._check_model(request)

ctx = PoolingServeContext(
Expand All @@ -96,11 +107,7 @@ async def __call__(

self._validate_request(ctx)
self._maybe_get_adapters(ctx)
await self.io_processor.pre_process_online_async(ctx)
await self._prepare_generators(ctx)
await self._collect_batch(ctx)
await self.io_processor.post_process_online_async(ctx)
return await self._build_response(ctx)
return ctx

async def _prepare_generators(
self,
Expand Down
58 changes: 33 additions & 25 deletions vllm/entrypoints/pooling/scoring/io_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import Sequence
from typing import Any, TypeAlias, cast
from typing import Any, TypeAlias

import torch.nn.functional as F

Expand All @@ -16,7 +16,7 @@
from vllm.inputs import EngineInput
from vllm.renderers import TokenizeParams
from vllm.renderers.hf import safe_apply_chat_template
from vllm.tasks import PoolingTask, ScoreType
from vllm.tasks import PoolingTask
from vllm.utils.mistral import is_mistral_tokenizer

from ...chat_utils import ChatTemplateResolutionError
Expand All @@ -34,7 +34,7 @@


class ScoringIOProcessor(PoolingIOProcessor):
name: ScoreType
name: str
pooling_task: PoolingTask

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -63,7 +63,7 @@ def valid_inputs(


class BiEncoderIOProcessor(ScoringIOProcessor):
name: ScoreType = "bi-encoder"
name = "bi-encoder"
pooling_task: PoolingTask = "embed"

#######################################
Expand Down Expand Up @@ -94,20 +94,17 @@ def pre_process_online(self, ctx: ScoringServeContext):
)

ctx.engine_inputs = engine_inputs
ctx.intermediates = len(scoring_data.data_1)
ctx.n_queries = len(scoring_data.data_1)

def post_process_online(
self,
ctx: ScoringServeContext,
):
if ctx.final_res_batch is None:
raise ValueError("Final response batch not available")

if ctx.intermediates is None:
raise ValueError("data_1 len not available")
assert ctx.final_res_batch is not None
assert isinstance(ctx.n_queries, int)

ctx.final_res_batch = self._post_process(
outputs=ctx.final_res_batch, offset=cast(int, ctx.intermediates)
outputs=ctx.final_res_batch, n_queries=ctx.n_queries
)

#######################################
Expand All @@ -124,8 +121,8 @@ def post_process_offline(
self,
ctx: OfflineOutputsContext,
) -> list[PoolingRequestOutput]:
assert ctx.offset is not None
return self._post_process(outputs=ctx.outputs, offset=ctx.offset)
assert ctx.n_queries is not None
return self._post_process(outputs=ctx.outputs, n_queries=ctx.n_queries)

#######################################
# helpers
Expand All @@ -145,9 +142,9 @@ def _pre_process(
prompts=data_1 + data_2, tok_params=tok_params, prompt_extras=prompt_extras
)

def _post_process(self, outputs: list[PoolingRequestOutput], offset: int):
emb_data_1 = outputs[:offset]
emb_data_2 = outputs[offset:]
def _post_process(self, outputs: list[PoolingRequestOutput], n_queries: int):
emb_data_1 = outputs[:n_queries]
emb_data_2 = outputs[n_queries:]

if len(emb_data_1) == 1:
emb_data_1 = emb_data_1 * len(emb_data_2)
Expand Down Expand Up @@ -177,13 +174,13 @@ def _post_process(self, outputs: list[PoolingRequestOutput], offset: int):


class LateInteractionIOProcessor(BiEncoderIOProcessor):
name: ScoreType = "late-interaction"
name = "late-interaction"
pooling_task: PoolingTask = "token_embed"

def _post_process(self, outputs: list[PoolingRequestOutput], offset: int):
def _post_process(self, outputs: list[PoolingRequestOutput], n_queries: int):
# Split into query and document embeddings
emb_data_1 = outputs[:offset]
emb_data_2 = outputs[offset:]
emb_data_1 = outputs[:n_queries]
emb_data_2 = outputs[n_queries:]

# Expand queries if 1:N scoring
if len(emb_data_1) == 1:
Expand Down Expand Up @@ -217,8 +214,15 @@ def _post_process(self, outputs: list[PoolingRequestOutput], offset: int):
return final_res_batch


class FlashLateInteractionIOProcessor(LateInteractionIOProcessor):
name = "flash-late-interaction"

def _post_process(self, outputs: list[PoolingRequestOutput], n_queries: int):
return outputs


class CrossEncoderIOProcessor(ScoringIOProcessor):
name: ScoreType = "cross-encoder"
name = "cross-encoder"
pooling_task: PoolingTask = "classify"

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -412,8 +416,12 @@ def default_tokenizer_encode():
return full_prompt, engine_prompt


ScoringIOProcessors: dict[ScoreType, type[ScoringIOProcessor]] = {
"bi-encoder": BiEncoderIOProcessor,
"late-interaction": LateInteractionIOProcessor,
"cross-encoder": CrossEncoderIOProcessor,
ScoringIOProcessors: dict[str, type[ScoringIOProcessor]] = {
p.name: p
for p in [
BiEncoderIOProcessor,
LateInteractionIOProcessor,
FlashLateInteractionIOProcessor,
CrossEncoderIOProcessor,
]
}
Loading
Loading