From 0829f20454559c74f2ce79b47070f43ac0b459d7 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Thu, 26 Feb 2026 17:36:30 +0000 Subject: [PATCH 1/6] fix maxsim cuda platform Signed-off-by: yewentao256 --- vllm/entrypoints/pooling/score/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/pooling/score/utils.py b/vllm/entrypoints/pooling/score/utils.py index 98c24856bf3c..191d6f66b66b 100644 --- a/vllm/entrypoints/pooling/score/utils.py +++ b/vllm/entrypoints/pooling/score/utils.py @@ -25,6 +25,7 @@ from vllm.model_executor.models.interfaces import supports_score_template from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict from vllm.outputs import PoolingRequestOutput +from vllm.platforms import current_platform from vllm.renderers.hf import safe_apply_chat_template from vllm.tokenizers import TokenizerLike @@ -73,7 +74,7 @@ def compute_maxsim_scores( if q_emb.shape[1] != d_emb.shape[1]: raise ValueError("Query and document embeddings must have same dim") - compute_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + compute_device = torch.device("cuda" if current_platform.is_cuda() else "cpu") scores: list[torch.Tensor] = [] start = 0 while start < num_pairs: From 7a8ef90f0aded17ccf5edc45c41e6d7f1f3b6862 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Thu, 26 Feb 2026 18:43:59 +0000 Subject: [PATCH 2/6] add env Signed-off-by: yewentao256 --- vllm/entrypoints/pooling/score/utils.py | 7 ++++++- vllm/envs.py | 8 ++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/pooling/score/utils.py b/vllm/entrypoints/pooling/score/utils.py index 191d6f66b66b..1c8ff43d6c5d 100644 --- a/vllm/entrypoints/pooling/score/utils.py +++ b/vllm/entrypoints/pooling/score/utils.py @@ -7,6 +7,7 @@ from torch.nn import CosineSimilarity from typing_extensions import Required, TypedDict +from vllm import envs from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import ( BaseMultiModalItemTracker, @@ -54,6 +55,10 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens return token_scores.amax(dim=-1).sum() +def _should_use_gpu_for_maxsim() -> bool: + return envs.VLLM_USE_GPU_FOR_POOLING_SCORE and current_platform.is_cuda() + + def compute_maxsim_scores( q_embs: Sequence[torch.Tensor], d_embs: Sequence[torch.Tensor], @@ -74,7 +79,7 @@ def compute_maxsim_scores( if q_emb.shape[1] != d_emb.shape[1]: raise ValueError("Query and document embeddings must have same dim") - compute_device = torch.device("cuda" if current_platform.is_cuda() else "cpu") + compute_device = torch.device("cuda" if _should_use_gpu_for_maxsim() else "cpu") scores: list[torch.Tensor] = [] start = 0 while start < num_pairs: diff --git a/vllm/envs.py b/vllm/envs.py index d560cfc7753c..6d5b844418be 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -26,6 +26,7 @@ VLLM_ENGINE_READY_TIMEOUT_S: int = 600 VLLM_API_KEY: str | None = None VLLM_DEBUG_LOG_API_SERVER_RESPONSE: bool = False + VLLM_USE_GPU_FOR_POOLING_SCORE: bool = False S3_ACCESS_KEY_ID: str | None = None S3_SECRET_ACCESS_KEY: str | None = None S3_ENDPOINT_URL: str | None = None @@ -638,6 +639,12 @@ def _get_or_set_default() -> str: "VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False" ).lower() == "true", + # If set, run pooling score MaxSim on GPU in the API server process. + # Huge performance improvement, https://github.com/vllm-project/vllm/pull/35330 + "VLLM_USE_GPU_FOR_POOLING_SCORE": lambda: ( + os.environ.get("VLLM_USE_GPU_FOR_POOLING_SCORE", "0").strip().lower() + in ("1", "true") + ), # S3 access information, used for tensorizer to load model from S3 "S3_ACCESS_KEY_ID": lambda: os.environ.get("S3_ACCESS_KEY_ID", None), "S3_SECRET_ACCESS_KEY": lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), @@ -1731,6 +1738,7 @@ def compile_factors() -> dict[str, object]: "VLLM_LOGGING_COLOR", "VLLM_LOG_STATS_INTERVAL", "VLLM_DEBUG_LOG_API_SERVER_RESPONSE", + "VLLM_USE_GPU_FOR_POOLING_SCORE", "VLLM_TUNED_CONFIG_FOLDER", "VLLM_ENGINE_ITERATION_TIMEOUT_S", "VLLM_HTTP_TIMEOUT_KEEP_ALIVE", From 9fbf62c43305141d80ab3467e6d5a553dddfe1fc Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Fri, 27 Feb 2026 10:46:42 -0500 Subject: [PATCH 3/6] Update vllm/entrypoints/pooling/score/utils.py Co-authored-by: Cyrus Leung Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> --- vllm/entrypoints/pooling/score/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/pooling/score/utils.py b/vllm/entrypoints/pooling/score/utils.py index 1c8ff43d6c5d..c93fda1045d5 100644 --- a/vllm/entrypoints/pooling/score/utils.py +++ b/vllm/entrypoints/pooling/score/utils.py @@ -56,7 +56,7 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens def _should_use_gpu_for_maxsim() -> bool: - return envs.VLLM_USE_GPU_FOR_POOLING_SCORE and current_platform.is_cuda() + return envs.VLLM_USE_GPU_FOR_POOLING_SCORE and not current_platform.is_cpu() def compute_maxsim_scores( From c8290282e1b59721e56abc443169b0bbbd8c8f63 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Fri, 27 Feb 2026 10:46:52 -0500 Subject: [PATCH 4/6] Update vllm/entrypoints/pooling/score/utils.py Co-authored-by: Cyrus Leung Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> --- vllm/entrypoints/pooling/score/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/pooling/score/utils.py b/vllm/entrypoints/pooling/score/utils.py index c93fda1045d5..3fd17c9d3e51 100644 --- a/vllm/entrypoints/pooling/score/utils.py +++ b/vllm/entrypoints/pooling/score/utils.py @@ -79,7 +79,7 @@ def compute_maxsim_scores( if q_emb.shape[1] != d_emb.shape[1]: raise ValueError("Query and document embeddings must have same dim") - compute_device = torch.device("cuda" if _should_use_gpu_for_maxsim() else "cpu") + compute_device = torch.device(current_platform.device_type if _should_use_gpu_for_maxsim() else "cpu") scores: list[torch.Tensor] = [] start = 0 while start < num_pairs: From f3d133efdd4d6a206d794c7efd0b9fef7dd4a5ed Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 27 Feb 2026 15:54:54 +0000 Subject: [PATCH 5/6] use cli args Signed-off-by: yewentao256 --- vllm/entrypoints/openai/cli_args.py | 4 ++++ vllm/entrypoints/pooling/__init__.py | 1 + vllm/entrypoints/pooling/score/serving.py | 3 +++ vllm/entrypoints/pooling/score/utils.py | 12 ++++++++---- vllm/envs.py | 8 -------- 5 files changed, 16 insertions(+), 12 deletions(-) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index eac581e5da9b..0cff8aaa71d9 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -277,6 +277,10 @@ class FrontendArgs(BaseFrontendArgs): Enable offline FastAPI documentation for air-gapped environments. Uses vendored static assets bundled with vLLM. """ + use_gpu_for_pooling_score: bool = False + """If set, run pooling score MaxSim on GPU in the API server process. + Can significantly improve late-interaction scoring performance. + https://github.com/vllm-project/vllm/pull/35330""" @classmethod def _customize_cli_kwargs( diff --git a/vllm/entrypoints/pooling/__init__.py b/vllm/entrypoints/pooling/__init__.py index 1108be175bc6..3ba131d5f831 100644 --- a/vllm/entrypoints/pooling/__init__.py +++ b/vllm/entrypoints/pooling/__init__.py @@ -115,6 +115,7 @@ def init_pooling_state( request_logger=request_logger, score_template=resolved_chat_template, log_error_stack=args.log_error_stack, + use_gpu_for_pooling_score=getattr(args, "use_gpu_for_pooling_score", False), ) if any(t in supported_tasks for t in ("embed", "score", "token_embed")) else None diff --git a/vllm/entrypoints/pooling/score/serving.py b/vllm/entrypoints/pooling/score/serving.py index aec6e909d161..60d6db6a7003 100644 --- a/vllm/entrypoints/pooling/score/serving.py +++ b/vllm/entrypoints/pooling/score/serving.py @@ -56,6 +56,7 @@ def __init__( request_logger: RequestLogger | None, score_template: str | None = None, log_error_stack: bool = False, + use_gpu_for_pooling_score: bool = False, ) -> None: super().__init__( engine_client=engine_client, @@ -64,6 +65,7 @@ def __init__( log_error_stack=log_error_stack, ) self.score_template = score_template + self.use_gpu_for_pooling_score = use_gpu_for_pooling_score self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) @@ -314,6 +316,7 @@ async def _late_interaction_score( maxsim_scores = compute_maxsim_scores( [emb.outputs.data for emb in emb_data_1], [emb.outputs.data for emb in emb_data_2], + use_gpu_for_pooling_score=self.use_gpu_for_pooling_score, ) scores: list[PoolingRequestOutput] = [] diff --git a/vllm/entrypoints/pooling/score/utils.py b/vllm/entrypoints/pooling/score/utils.py index 3fd17c9d3e51..65611dc3aa4f 100644 --- a/vllm/entrypoints/pooling/score/utils.py +++ b/vllm/entrypoints/pooling/score/utils.py @@ -7,7 +7,6 @@ from torch.nn import CosineSimilarity from typing_extensions import Required, TypedDict -from vllm import envs from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import ( BaseMultiModalItemTracker, @@ -55,8 +54,8 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens return token_scores.amax(dim=-1).sum() -def _should_use_gpu_for_maxsim() -> bool: - return envs.VLLM_USE_GPU_FOR_POOLING_SCORE and not current_platform.is_cpu() +def _should_use_gpu_for_maxsim(use_gpu_for_pooling_score: bool) -> bool: + return use_gpu_for_pooling_score and not current_platform.is_cpu() def compute_maxsim_scores( @@ -64,6 +63,7 @@ def compute_maxsim_scores( d_embs: Sequence[torch.Tensor], max_batch_size: int = 16, max_score_matrix_elements: int = 16_000_000, + use_gpu_for_pooling_score: bool = False, ) -> list[torch.Tensor]: """Compute ColBERT MaxSim scores in padded mini-batches.""" if len(q_embs) != len(d_embs): @@ -79,7 +79,11 @@ def compute_maxsim_scores( if q_emb.shape[1] != d_emb.shape[1]: raise ValueError("Query and document embeddings must have same dim") - compute_device = torch.device(current_platform.device_type if _should_use_gpu_for_maxsim() else "cpu") + compute_device = torch.device( + current_platform.device_type + if _should_use_gpu_for_maxsim(use_gpu_for_pooling_score) + else "cpu" + ) scores: list[torch.Tensor] = [] start = 0 while start < num_pairs: diff --git a/vllm/envs.py b/vllm/envs.py index 6d5b844418be..d560cfc7753c 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -26,7 +26,6 @@ VLLM_ENGINE_READY_TIMEOUT_S: int = 600 VLLM_API_KEY: str | None = None VLLM_DEBUG_LOG_API_SERVER_RESPONSE: bool = False - VLLM_USE_GPU_FOR_POOLING_SCORE: bool = False S3_ACCESS_KEY_ID: str | None = None S3_SECRET_ACCESS_KEY: str | None = None S3_ENDPOINT_URL: str | None = None @@ -639,12 +638,6 @@ def _get_or_set_default() -> str: "VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False" ).lower() == "true", - # If set, run pooling score MaxSim on GPU in the API server process. - # Huge performance improvement, https://github.com/vllm-project/vllm/pull/35330 - "VLLM_USE_GPU_FOR_POOLING_SCORE": lambda: ( - os.environ.get("VLLM_USE_GPU_FOR_POOLING_SCORE", "0").strip().lower() - in ("1", "true") - ), # S3 access information, used for tensorizer to load model from S3 "S3_ACCESS_KEY_ID": lambda: os.environ.get("S3_ACCESS_KEY_ID", None), "S3_SECRET_ACCESS_KEY": lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), @@ -1738,7 +1731,6 @@ def compile_factors() -> dict[str, object]: "VLLM_LOGGING_COLOR", "VLLM_LOG_STATS_INTERVAL", "VLLM_DEBUG_LOG_API_SERVER_RESPONSE", - "VLLM_USE_GPU_FOR_POOLING_SCORE", "VLLM_TUNED_CONFIG_FOLDER", "VLLM_ENGINE_ITERATION_TIMEOUT_S", "VLLM_HTTP_TIMEOUT_KEEP_ALIVE", From 4583e828a26dd3388768f7378525ec90b64c226b Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 27 Feb 2026 16:15:12 +0000 Subject: [PATCH 6/6] fix Signed-off-by: yewentao256 --- vllm/entrypoints/cli/serve.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index c12cc7ff2a0b..5d7bfe5ca136 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -220,6 +220,12 @@ def run_multi_api_server(args: argparse.Namespace): num_api_servers: int = args.api_server_count assert num_api_servers > 0 + if num_api_servers > 1 and getattr(args, "use_gpu_for_pooling_score", False): + # TODO(wentao): remove this once well tested + raise ValueError( + "--use-gpu-for-pooling-score cannot be used with api_server_count > 1 now" + ) + if num_api_servers > 1: setup_multiprocess_prometheus()