From b4882afedf4e8a21ea5e18aed0502ccc6c2f988e Mon Sep 17 00:00:00 2001 From: Shrinav Loka Date: Mon, 4 May 2026 16:34:34 -0700 Subject: [PATCH 1/5] [Perf] Use numpy zero-copy path for embedding float response serialization When encoding_format=float and ORJSONResponse is available, bypass the per-element .tolist() conversion and pass numpy arrays directly to ORJSON, which serializes them natively. This gives a 5-70x speedup in response construction depending on batch size and embedding dimension (zero-copy view vs O(n) Python iteration). Falls back to .tolist() for unsupported dtypes (bfloat16, float8) or CUDA tensors. Signed-off-by: Shrinav Loka --- tests/entrypoints/pooling/test_utils.py | 46 +++++++++++++++++++++++ vllm/entrypoints/pooling/embed/serving.py | 35 +++++++++++++++++ vllm/entrypoints/pooling/utils.py | 11 ++++++ 3 files changed, 92 insertions(+) create mode 100644 tests/entrypoints/pooling/test_utils.py diff --git a/tests/entrypoints/pooling/test_utils.py b/tests/entrypoints/pooling/test_utils.py new file mode 100644 index 000000000000..57bc9ec9d1a9 --- /dev/null +++ b/tests/entrypoints/pooling/test_utils.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import warnings +from types import SimpleNamespace + +import numpy as np +import pytest +import torch +from fastapi.responses import ORJSONResponse + +from vllm.entrypoints.pooling.utils import encode_pooling_output_float_or_ndarray + + +def _pooling_output(data): + return SimpleNamespace(outputs=SimpleNamespace(data=data)) + + +def test_encode_pooling_output_float_or_ndarray_returns_numpy_array(): + output = _pooling_output(torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)) + + encoded = encode_pooling_output_float_or_ndarray(output) + + assert isinstance(encoded, np.ndarray) + np.testing.assert_allclose(encoded, [1.0, 2.0, 3.0]) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + response = ORJSONResponse(content={"embedding": encoded}) + assert json.loads(response.body)["embedding"] == pytest.approx([1.0, 2.0, 3.0]) + + +def test_encode_pooling_output_float_or_ndarray_falls_back_to_list(): + class DataWithUnsupportedNumpy: + def is_contiguous(self): + return True + + def numpy(self): + raise TypeError("unsupported dtype") + + def tolist(self): + return [1.0, 2.0, 3.0] + + output = _pooling_output(DataWithUnsupportedNumpy()) + + assert encode_pooling_output_float_or_ndarray(output) == [1.0, 2.0, 3.0] diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index fd38598f7c79..06ba5596b4d7 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -19,6 +19,7 @@ encode_pooling_bytes, encode_pooling_output_base64, encode_pooling_output_float, + encode_pooling_output_float_or_ndarray, get_json_response_cls, ) from .io_processor import EmbedIOProcessor @@ -104,6 +105,40 @@ def _openai_json_response( embed_dtype: EmbedDType, endianness: Endianness, ) -> JSONResponse: + use_ndarray_response = ( + encoding_format == "float" + and self.json_response_cls.__name__ == "ORJSONResponse" + ) + if use_ndarray_response: + items: list[dict[str, object]] = [] + num_prompt_tokens = 0 + + for idx, final_res in enumerate(final_res_batch): + item = { + "index": idx, + "object": "embedding", + "embedding": encode_pooling_output_float_or_ndarray(final_res), + } + prompt_token_ids = final_res.prompt_token_ids + + items.append(item) + num_prompt_tokens += len(prompt_token_ids) + + response = { + "id": request_id, + "object": "list", + "created": created_time, + "model": model_name, + "data": items, + "usage": { + "prompt_tokens": num_prompt_tokens, + "total_tokens": num_prompt_tokens, + "completion_tokens": 0, + "prompt_tokens_details": None, + }, + } + return self.json_response_cls(content=response) + encode_fn = cast( Callable[[PoolingRequestOutput], list[float] | str], ( diff --git a/vllm/entrypoints/pooling/utils.py b/vllm/entrypoints/pooling/utils.py index 329a4d189692..2036ed7aae8e 100644 --- a/vllm/entrypoints/pooling/utils.py +++ b/vllm/entrypoints/pooling/utils.py @@ -62,6 +62,17 @@ def encode_pooling_output_float(output: PoolingRequestOutput) -> list[float]: return output.outputs.data.tolist() +def encode_pooling_output_float_or_ndarray(output: PoolingRequestOutput) -> Any: + """Return an ndarray when the response renderer can serialize NumPy.""" + try: + data = output.outputs.data + if not data.is_contiguous(): + data = data.contiguous() + return data.numpy() + except (RuntimeError, TypeError): + return output.outputs.data.tolist() + + def encode_pooling_output_base64( output: PoolingRequestOutput, embed_dtype: EmbedDType, From a64703870d5b1b6d95bbb5b13008b943d916f62f Mon Sep 17 00:00:00 2001 From: Shrinav Loka Date: Mon, 4 May 2026 16:41:47 -0700 Subject: [PATCH 2/5] Use Pydantic models for response structure in fast path Address review feedback: use EmbeddingResponseData, EmbeddingResponse, and UsageInfo models to build the response dict, then inject the numpy arrays. This keeps the fast path in sync with the API contract while preserving the zero-copy serialization benefit. Signed-off-by: Shrinav Loka --- vllm/entrypoints/pooling/embed/serving.py | 42 +++++++++++------------ 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index 06ba5596b4d7..90f9e56fe252 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -114,29 +114,27 @@ def _openai_json_response( num_prompt_tokens = 0 for idx, final_res in enumerate(final_res_batch): - item = { - "index": idx, - "object": "embedding", - "embedding": encode_pooling_output_float_or_ndarray(final_res), - } - prompt_token_ids = final_res.prompt_token_ids - + item = EmbeddingResponseData( + index=idx, embedding=[], + ).model_dump() + item["embedding"] = encode_pooling_output_float_or_ndarray( + final_res) items.append(item) - num_prompt_tokens += len(prompt_token_ids) - - response = { - "id": request_id, - "object": "list", - "created": created_time, - "model": model_name, - "data": items, - "usage": { - "prompt_tokens": num_prompt_tokens, - "total_tokens": num_prompt_tokens, - "completion_tokens": 0, - "prompt_tokens_details": None, - }, - } + num_prompt_tokens += len(final_res.prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + response = EmbeddingResponse( + id=request_id, + created=created_time, + model=model_name, + data=[], # type: ignore[arg-type] + usage=usage, + ).model_dump() + response["data"] = items + return self.json_response_cls(content=response) encode_fn = cast( From 4764a4f2e96e8599f0568a70c3ee99a46236f9d2 Mon Sep 17 00:00:00 2001 From: Shrinav Loka Date: Tue, 5 May 2026 12:38:35 -0700 Subject: [PATCH 3/5] Fix test to not require orjson at import time Split ORJSONResponse serialization test into a separate test with pytest.mark.skipif when orjson is not installed. The core function test (numpy array return) no longer depends on orjson. Addresses review feedback from @noooop. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Shrinav Loka --- tests/entrypoints/pooling/test_utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/entrypoints/pooling/test_utils.py b/tests/entrypoints/pooling/test_utils.py index 57bc9ec9d1a9..60c84ed28ff2 100644 --- a/tests/entrypoints/pooling/test_utils.py +++ b/tests/entrypoints/pooling/test_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib import json import warnings from types import SimpleNamespace @@ -8,7 +9,6 @@ import numpy as np import pytest import torch -from fastapi.responses import ORJSONResponse from vllm.entrypoints.pooling.utils import encode_pooling_output_float_or_ndarray @@ -24,6 +24,18 @@ def test_encode_pooling_output_float_or_ndarray_returns_numpy_array(): assert isinstance(encoded, np.ndarray) np.testing.assert_allclose(encoded, [1.0, 2.0, 3.0]) + + +@pytest.mark.skipif( + importlib.util.find_spec("orjson") is None, + reason="orjson is not installed", +) +def test_orjson_serializes_numpy_array(): + from fastapi.responses import ORJSONResponse + + output = _pooling_output(torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)) + encoded = encode_pooling_output_float_or_ndarray(output) + with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) response = ORJSONResponse(content={"embedding": encoded}) From 0a8e492ec50e2ab116be022b176019ab355fc309 Mon Sep 17 00:00:00 2001 From: Shrinav Loka Date: Wed, 6 May 2026 01:56:07 -0700 Subject: [PATCH 4/5] Fix ruff formatting in serving.py Co-Authored-By: Claude Opus 4.6 Signed-off-by: Shrinav Loka --- vllm/entrypoints/pooling/embed/serving.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index 90f9e56fe252..ffb0cf5184ea 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -115,10 +115,10 @@ def _openai_json_response( for idx, final_res in enumerate(final_res_batch): item = EmbeddingResponseData( - index=idx, embedding=[], + index=idx, + embedding=[], ).model_dump() - item["embedding"] = encode_pooling_output_float_or_ndarray( - final_res) + item["embedding"] = encode_pooling_output_float_or_ndarray(final_res) items.append(item) num_prompt_tokens += len(final_res.prompt_token_ids) From 94a50534e2388f83db0c3b644e8d69e35ddff5f4 Mon Sep 17 00:00:00 2001 From: Shrinav Loka Date: Thu, 7 May 2026 01:57:57 -0700 Subject: [PATCH 5/5] Fix mypy type errors in embedding serving and tests Rename variables in ndarray fast path to avoid no-redef errors, and add explicit importlib.util import for mypy --follow-imports skip. Signed-off-by: Shrinav Loka Co-Authored-By: Claude Opus 4.6 --- tests/entrypoints/pooling/test_utils.py | 1 + vllm/entrypoints/pooling/embed/serving.py | 28 ++++++++++++----------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/entrypoints/pooling/test_utils.py b/tests/entrypoints/pooling/test_utils.py index 60c84ed28ff2..13a89f2520ec 100644 --- a/tests/entrypoints/pooling/test_utils.py +++ b/tests/entrypoints/pooling/test_utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib +import importlib.util import json import warnings from types import SimpleNamespace diff --git a/vllm/entrypoints/pooling/embed/serving.py b/vllm/entrypoints/pooling/embed/serving.py index ffb0cf5184ea..d85d2372387f 100644 --- a/vllm/entrypoints/pooling/embed/serving.py +++ b/vllm/entrypoints/pooling/embed/serving.py @@ -110,32 +110,34 @@ def _openai_json_response( and self.json_response_cls.__name__ == "ORJSONResponse" ) if use_ndarray_response: - items: list[dict[str, object]] = [] - num_prompt_tokens = 0 + ndarray_items: list[dict[str, object]] = [] + ndarray_num_tokens = 0 for idx, final_res in enumerate(final_res_batch): - item = EmbeddingResponseData( + item_dict = EmbeddingResponseData( index=idx, embedding=[], ).model_dump() - item["embedding"] = encode_pooling_output_float_or_ndarray(final_res) - items.append(item) - num_prompt_tokens += len(final_res.prompt_token_ids) + item_dict["embedding"] = encode_pooling_output_float_or_ndarray( + final_res + ) + ndarray_items.append(item_dict) + ndarray_num_tokens += len(final_res.prompt_token_ids) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - total_tokens=num_prompt_tokens, + ndarray_usage = UsageInfo( + prompt_tokens=ndarray_num_tokens, + total_tokens=ndarray_num_tokens, ) - response = EmbeddingResponse( + ndarray_response = EmbeddingResponse( id=request_id, created=created_time, model=model_name, data=[], # type: ignore[arg-type] - usage=usage, + usage=ndarray_usage, ).model_dump() - response["data"] = items + ndarray_response["data"] = ndarray_items - return self.json_response_cls(content=response) + return self.json_response_cls(content=ndarray_response) encode_fn = cast( Callable[[PoolingRequestOutput], list[float] | str],