diff --git a/tests/entrypoints/pooling/scoring/test_late_interaction_online.py b/tests/entrypoints/pooling/scoring/test_late_interaction_online.py index 9eedec6d2b98..7e4501fe8500 100644 --- a/tests/entrypoints/pooling/scoring/test_late_interaction_online.py +++ b/tests/entrypoints/pooling/scoring/test_late_interaction_online.py @@ -26,13 +26,18 @@ ] -@pytest.fixture(scope="module") -def server(): +@pytest.fixture(scope="module", params=[True, False]) +def server(request): args = [ "--max-model-len", str(MAX_MODEL_LEN), ] + # Test run pooling score MaxSim on worker side (GPU) + # aka flash-late-interaction + if not request.param: + args += ["--no-enable-flash-late-interaction"] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server diff --git a/tests/v1/worker/test_late_interaction_runner.py b/tests/v1/worker/test_late_interaction_runner.py index 5be3f6e6f10d..9719485cd542 100644 --- a/tests/v1/worker/test_late_interaction_runner.py +++ b/tests/v1/worker/test_late_interaction_runner.py @@ -4,12 +4,12 @@ import pytest import torch +from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score from vllm.pooling_params import LateInteractionParams, PoolingParams from vllm.v1.pool.late_interaction import ( LATE_INTERACTION_MODE_CACHE_QUERY, build_late_interaction_doc_params, build_late_interaction_query_params, - compute_maxsim_score, ) from vllm.v1.worker.gpu.pool.late_interaction_runner import LateInteractionRunner diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a576a3f28a8b..b9eea87451bb 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1449,14 +1449,14 @@ def score( pooling_task = io_processor.pooling_task scoring_data = io_processor.valid_inputs(data_1, data_2) - offset = len(scoring_data.data_1) + n_queries = len(scoring_data.data_1) ctx = OfflineInputsContext( prompts=scoring_data, pooling_params=pooling_params, tokenization_kwargs=tokenization_kwargs, chat_template=chat_template, - offset=offset, + n_queries=n_queries, ) processor_inputs = io_processor.pre_process_offline(ctx) @@ -1487,7 +1487,7 @@ def score( outputs = self._run_engine(use_tqdm=use_tqdm, output_type=PoolingRequestOutput) outputs = io_processor.post_process_offline( - ctx=OfflineOutputsContext(outputs=outputs, offset=offset), + ctx=OfflineOutputsContext(outputs=outputs, n_queries=n_queries), ) return [ScoringRequestOutput.from_base(item) for item in outputs] diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 7491c41c2713..898e62f7713d 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -278,6 +278,9 @@ class FrontendArgs(BaseFrontendArgs): Enable offline FastAPI documentation for air-gapped environments. Uses vendored static assets bundled with vLLM. """ + enable_flash_late_interaction: bool = True + """If set, run pooling score MaxSim on GPU in the API server process. + Can significantly improve late-interaction scoring performance.""" @classmethod def _customize_cli_kwargs( diff --git a/vllm/entrypoints/pooling/__init__.py b/vllm/entrypoints/pooling/__init__.py index b843c7913195..fb0c10e6f4f4 100644 --- a/vllm/entrypoints/pooling/__init__.py +++ b/vllm/entrypoints/pooling/__init__.py @@ -123,6 +123,9 @@ def init_pooling_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, trust_request_chat_template=args.trust_request_chat_template, + enable_flash_late_interaction=getattr( + args, "enable_flash_late_interaction", True + ), ) if enable_scoring_api(supported_tasks, model_config) else None diff --git a/vllm/entrypoints/pooling/base/serving.py b/vllm/entrypoints/pooling/base/serving.py index cf6e01742ce0..90554aa634b4 100644 --- a/vllm/entrypoints/pooling/base/serving.py +++ b/vllm/entrypoints/pooling/base/serving.py @@ -82,9 +82,20 @@ async def __call__( request: AnyPoolingRequest, raw_request: Request | None = None, ) -> Response: + ctx = await self._init_ctx(request, raw_request) + await self.io_processor.pre_process_online_async(ctx) + await self._prepare_generators(ctx) + await self._collect_batch(ctx) + await self.io_processor.post_process_online_async(ctx) + return await self._build_response(ctx) + + async def _init_ctx( + self, + request: AnyPoolingRequest, + raw_request: Request | None = None, + ): model_name = self.models.model_name() request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}" - await self._check_model(request) ctx = PoolingServeContext( @@ -96,11 +107,7 @@ async def __call__( self._validate_request(ctx) self._maybe_get_adapters(ctx) - await self.io_processor.pre_process_online_async(ctx) - await self._prepare_generators(ctx) - await self._collect_batch(ctx) - await self.io_processor.post_process_online_async(ctx) - return await self._build_response(ctx) + return ctx async def _prepare_generators( self, diff --git a/vllm/entrypoints/pooling/scoring/io_processor.py b/vllm/entrypoints/pooling/scoring/io_processor.py index 70fe1b221412..c520eb5ceb3d 100644 --- a/vllm/entrypoints/pooling/scoring/io_processor.py +++ b/vllm/entrypoints/pooling/scoring/io_processor.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time from collections.abc import Sequence -from typing import Any, TypeAlias, cast +from typing import Any, TypeAlias import torch.nn.functional as F @@ -16,7 +16,7 @@ from vllm.inputs import EngineInput from vllm.renderers import TokenizeParams from vllm.renderers.hf import safe_apply_chat_template -from vllm.tasks import PoolingTask, ScoreType +from vllm.tasks import PoolingTask from vllm.utils.mistral import is_mistral_tokenizer from ...chat_utils import ChatTemplateResolutionError @@ -34,7 +34,7 @@ class ScoringIOProcessor(PoolingIOProcessor): - name: ScoreType + name: str pooling_task: PoolingTask def __init__(self, *args, **kwargs): @@ -63,7 +63,7 @@ def valid_inputs( class BiEncoderIOProcessor(ScoringIOProcessor): - name: ScoreType = "bi-encoder" + name = "bi-encoder" pooling_task: PoolingTask = "embed" ####################################### @@ -94,20 +94,17 @@ def pre_process_online(self, ctx: ScoringServeContext): ) ctx.engine_inputs = engine_inputs - ctx.intermediates = len(scoring_data.data_1) + ctx.n_queries = len(scoring_data.data_1) def post_process_online( self, ctx: ScoringServeContext, ): - if ctx.final_res_batch is None: - raise ValueError("Final response batch not available") - - if ctx.intermediates is None: - raise ValueError("data_1 len not available") + assert ctx.final_res_batch is not None + assert isinstance(ctx.n_queries, int) ctx.final_res_batch = self._post_process( - outputs=ctx.final_res_batch, offset=cast(int, ctx.intermediates) + outputs=ctx.final_res_batch, n_queries=ctx.n_queries ) ####################################### @@ -124,8 +121,8 @@ def post_process_offline( self, ctx: OfflineOutputsContext, ) -> list[PoolingRequestOutput]: - assert ctx.offset is not None - return self._post_process(outputs=ctx.outputs, offset=ctx.offset) + assert ctx.n_queries is not None + return self._post_process(outputs=ctx.outputs, n_queries=ctx.n_queries) ####################################### # helpers @@ -145,9 +142,9 @@ def _pre_process( prompts=data_1 + data_2, tok_params=tok_params, prompt_extras=prompt_extras ) - def _post_process(self, outputs: list[PoolingRequestOutput], offset: int): - emb_data_1 = outputs[:offset] - emb_data_2 = outputs[offset:] + def _post_process(self, outputs: list[PoolingRequestOutput], n_queries: int): + emb_data_1 = outputs[:n_queries] + emb_data_2 = outputs[n_queries:] if len(emb_data_1) == 1: emb_data_1 = emb_data_1 * len(emb_data_2) @@ -177,13 +174,13 @@ def _post_process(self, outputs: list[PoolingRequestOutput], offset: int): class LateInteractionIOProcessor(BiEncoderIOProcessor): - name: ScoreType = "late-interaction" + name = "late-interaction" pooling_task: PoolingTask = "token_embed" - def _post_process(self, outputs: list[PoolingRequestOutput], offset: int): + def _post_process(self, outputs: list[PoolingRequestOutput], n_queries: int): # Split into query and document embeddings - emb_data_1 = outputs[:offset] - emb_data_2 = outputs[offset:] + emb_data_1 = outputs[:n_queries] + emb_data_2 = outputs[n_queries:] # Expand queries if 1:N scoring if len(emb_data_1) == 1: @@ -217,8 +214,15 @@ def _post_process(self, outputs: list[PoolingRequestOutput], offset: int): return final_res_batch +class FlashLateInteractionIOProcessor(LateInteractionIOProcessor): + name = "flash-late-interaction" + + def _post_process(self, outputs: list[PoolingRequestOutput], n_queries: int): + return outputs + + class CrossEncoderIOProcessor(ScoringIOProcessor): - name: ScoreType = "cross-encoder" + name = "cross-encoder" pooling_task: PoolingTask = "classify" def __init__(self, *args, **kwargs): @@ -412,8 +416,12 @@ def default_tokenizer_encode(): return full_prompt, engine_prompt -ScoringIOProcessors: dict[ScoreType, type[ScoringIOProcessor]] = { - "bi-encoder": BiEncoderIOProcessor, - "late-interaction": LateInteractionIOProcessor, - "cross-encoder": CrossEncoderIOProcessor, +ScoringIOProcessors: dict[str, type[ScoringIOProcessor]] = { + p.name: p + for p in [ + BiEncoderIOProcessor, + LateInteractionIOProcessor, + FlashLateInteractionIOProcessor, + CrossEncoderIOProcessor, + ] } diff --git a/vllm/entrypoints/pooling/scoring/serving.py b/vllm/entrypoints/pooling/scoring/serving.py index 57e5684e4e3e..de5b5797ce49 100644 --- a/vllm/entrypoints/pooling/scoring/serving.py +++ b/vllm/entrypoints/pooling/scoring/serving.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, Response +from vllm import PoolingParams from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ChatTemplateConfig from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor @@ -11,6 +13,10 @@ from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.renderers import BaseRenderer +from vllm.v1.pool.late_interaction import ( + build_late_interaction_doc_params, + build_late_interaction_query_params, +) from .io_processor import ScoringIOProcessors, ScoringServeContext from .protocol import ( @@ -31,13 +37,30 @@ class ServingScores(PoolingServing): request_id_prefix = "score" + def __init__( + self, + engine_client: EngineClient, + *args, + enable_flash_late_interaction: bool = True, + **kwargs, + ): + self.score_type = engine_client.model_config.score_type + self.enable_flash_late_interaction = ( + self.score_type == "late-interaction" and enable_flash_late_interaction + ) + + super().__init__(engine_client, *args, **kwargs) + def init_io_processor( self, model_config: ModelConfig, renderer: BaseRenderer, chat_template_config: ChatTemplateConfig, ) -> PoolingIOProcessor: - score_type = model_config.score_type + score_type: str = model_config.score_type + if self.enable_flash_late_interaction: + score_type = "flash-late-interaction" + assert score_type in ScoringIOProcessors processor_cls = ScoringIOProcessors[score_type] return processor_cls( @@ -46,6 +69,12 @@ def init_io_processor( chat_template_config=chat_template_config, ) + async def __call__(self, *args, **kwargs) -> Response: + if not self.enable_flash_late_interaction: + return await super().__call__(*args, **kwargs) + + return await self.flash_late_interaction(*args, **kwargs) + async def _build_response( self, ctx: ScoringServeContext, @@ -158,3 +187,106 @@ def _request_output_to_rerank_response( ) return JSONResponse(content=response.model_dump()) + + ################################################################################### + ### Run pooling score MaxSim on worker side (GPU) in the API server process + ### Can significantly improve late-interaction scoring performance. + + async def flash_late_interaction(self, *args, **kwargs) -> Response: + ctx = await self._init_ctx(*args, **kwargs) + ctx.pooling_params = self.io_processor.create_pooling_params(ctx.request) + await self.io_processor.pre_process_online_async(ctx) + + # stage 1: encode queries and cache token embeddings on workers. + await self._flash_late_interaction_encode_queries(ctx) + # stage 2: encode docs and return scalar scores from workers. + await self._flash_late_interaction_encode_docs(ctx) + + await self.io_processor.post_process_online_async(ctx) + return await self._build_response(ctx) + + async def _flash_late_interaction_encode_queries(self, ctx: ScoringServeContext): + assert ctx.n_queries is not None + assert ctx.engine_inputs is not None + assert isinstance(ctx.pooling_params, PoolingParams) + + n_queries = ctx.n_queries + n_docs = len(ctx.engine_inputs) - n_queries + query_engine_inputs = ctx.engine_inputs[:n_queries] + + query_keys = [f"{ctx.request_id}-query-{i}" for i in range(n_queries)] + query_uses = [n_docs if n_queries == 1 else 1] * n_queries + + query_pooling_params_list = [] + for i in range(n_queries): + pooling_params = ctx.pooling_params.clone() + pooling_params.late_interaction_params = ( + build_late_interaction_query_params( + query_key=query_keys[i], + query_uses=query_uses[i], + ) + ) + query_pooling_params_list.append(pooling_params) + + assert ( + n_queries + == len(query_pooling_params_list) + == len(query_engine_inputs) + == len(query_keys) + ) + + query_ctx = ScoringServeContext( + request=ctx.request, + raw_request=ctx.raw_request, + model_name=ctx.model_name, + request_id=ctx.request_id, + pooling_params=query_pooling_params_list, + prompt_request_ids=query_keys, + engine_inputs=query_engine_inputs, + ) + + await self._prepare_generators(query_ctx) + await self._collect_batch(query_ctx) + + async def _flash_late_interaction_encode_docs(self, ctx: ScoringServeContext): + assert ctx.n_queries is not None + assert ctx.engine_inputs is not None + assert isinstance(ctx.pooling_params, PoolingParams) + + n_queries = ctx.n_queries + n_docs = len(ctx.engine_inputs) - n_queries + doc_engine_inputs = ctx.engine_inputs[n_queries:] + + query_keys = [f"{ctx.request_id}-query-{i}" for i in range(n_queries)] + doc_keys = [f"{ctx.request_id}-doc-{i}" for i in range(n_docs)] + + doc_pooling_params_list = [] + for i in range(n_docs): + query_idx = 0 if n_queries == 1 else i + pooling_params = ctx.pooling_params.clone() + pooling_params.late_interaction_params = build_late_interaction_doc_params( + query_key=query_keys[query_idx] + ) + doc_pooling_params_list.append(pooling_params) + + assert ( + n_docs + == len(doc_pooling_params_list) + == len(doc_engine_inputs) + == len(doc_keys) + ) + + doc_ctx = ScoringServeContext( + request=ctx.request, + raw_request=ctx.raw_request, + model_name=ctx.model_name, + request_id=ctx.request_id, + pooling_params=doc_pooling_params_list, + prompt_request_ids=doc_keys, + engine_inputs=doc_engine_inputs, + ) + + await self._prepare_generators(doc_ctx) + await self._collect_batch(doc_ctx) + + ctx.final_res_batch = doc_ctx.final_res_batch diff --git a/vllm/entrypoints/pooling/scoring/utils.py b/vllm/entrypoints/pooling/scoring/utils.py index 812a75ab806f..01b8514eb7ba 100644 --- a/vllm/entrypoints/pooling/scoring/utils.py +++ b/vllm/entrypoints/pooling/scoring/utils.py @@ -36,8 +36,9 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens Returns: MaxSim score (sum over query tokens of max similarity to any doc token) """ + # compute in float32 for numerical stability # [query_len, doc_len] - token_scores = torch.matmul(q_emb, d_emb.T) + token_scores = torch.matmul(q_emb.float(), d_emb.float().T) # Max over document tokens, sum over query tokens return token_scores.amax(dim=-1).sum() diff --git a/vllm/entrypoints/pooling/typing.py b/vllm/entrypoints/pooling/typing.py index 8ccc5d49c2df..66dd9dd4d2b4 100644 --- a/vllm/entrypoints/pooling/typing.py +++ b/vllm/entrypoints/pooling/typing.py @@ -83,6 +83,9 @@ class PoolingServeContext(Generic[PoolingRequestT]): model_config = ConfigDict(arbitrary_types_allowed=True) + ## for bi-encoder & late-interaction + n_queries: int | None = None + @dataclass class OfflineInputsContext: @@ -92,7 +95,7 @@ class OfflineInputsContext: chat_template: str | None = None ## for bi-encoder & late-interaction - offset: int | None = None + n_queries: int | None = None @dataclass @@ -100,4 +103,4 @@ class OfflineOutputsContext: outputs: list[PoolingRequestOutput] ## for bi-encoder & late-interaction - offset: int | None = None + n_queries: int | None = None diff --git a/vllm/v1/pool/late_interaction.py b/vllm/v1/pool/late_interaction.py index 4a465bd2f7d3..554c5947c618 100644 --- a/vllm/v1/pool/late_interaction.py +++ b/vllm/v1/pool/late_interaction.py @@ -56,16 +56,7 @@ def build_late_interaction_doc_params( ) -def compute_maxsim_score( - q_emb: torch.Tensor, - d_emb: torch.Tensor, -) -> torch.Tensor: - # compute in float32 for numerical stability - token_scores = torch.matmul(q_emb.float(), d_emb.float().T) - return token_scores.amax(dim=-1).sum() - - -def compute_maxsim_scores( +def compute_maxsim_score_batched( q_embs: Sequence[torch.Tensor], d_embs: Sequence[torch.Tensor], max_batch_size: int = 64, diff --git a/vllm/v1/worker/gpu/pool/late_interaction_runner.py b/vllm/v1/worker/gpu/pool/late_interaction_runner.py index 221dee558699..da87c8f05d6b 100644 --- a/vllm/v1/worker/gpu/pool/late_interaction_runner.py +++ b/vllm/v1/worker/gpu/pool/late_interaction_runner.py @@ -9,7 +9,7 @@ from vllm.v1.pool.late_interaction import ( LATE_INTERACTION_MODE_CACHE_QUERY, LATE_INTERACTION_MODE_SCORE_DOC, - compute_maxsim_scores, + compute_maxsim_score_batched, ) @@ -116,7 +116,7 @@ def postprocess_pooler_output( raise ValueError(f"Unsupported late-interaction mode: {mode!r}") if score_indices: - score_values = compute_maxsim_scores(score_queries, score_docs) + score_values = compute_maxsim_score_batched(score_queries, score_docs) for i, req_id, query_key, score in zip( score_indices, score_req_ids, score_query_keys, score_values ):