Skip to content
59 changes: 59 additions & 0 deletions tests/entrypoints/pooling/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import importlib
import importlib.util
import json
import warnings
from types import SimpleNamespace

import numpy as np
import pytest
import torch

from vllm.entrypoints.pooling.utils import encode_pooling_output_float_or_ndarray
Comment thread
noooop marked this conversation as resolved.


def _pooling_output(data):
return SimpleNamespace(outputs=SimpleNamespace(data=data))


def test_encode_pooling_output_float_or_ndarray_returns_numpy_array():
output = _pooling_output(torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32))

encoded = encode_pooling_output_float_or_ndarray(output)

assert isinstance(encoded, np.ndarray)
np.testing.assert_allclose(encoded, [1.0, 2.0, 3.0])


@pytest.mark.skipif(
importlib.util.find_spec("orjson") is None,
reason="orjson is not installed",
)
def test_orjson_serializes_numpy_array():
from fastapi.responses import ORJSONResponse

output = _pooling_output(torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32))
encoded = encode_pooling_output_float_or_ndarray(output)

with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
response = ORJSONResponse(content={"embedding": encoded})
assert json.loads(response.body)["embedding"] == pytest.approx([1.0, 2.0, 3.0])


def test_encode_pooling_output_float_or_ndarray_falls_back_to_list():
class DataWithUnsupportedNumpy:
def is_contiguous(self):
return True

def numpy(self):
raise TypeError("unsupported dtype")

def tolist(self):
return [1.0, 2.0, 3.0]

output = _pooling_output(DataWithUnsupportedNumpy())

assert encode_pooling_output_float_or_ndarray(output) == [1.0, 2.0, 3.0]
35 changes: 35 additions & 0 deletions vllm/entrypoints/pooling/embed/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
encode_pooling_bytes,
encode_pooling_output_base64,
encode_pooling_output_float,
encode_pooling_output_float_or_ndarray,
get_json_response_cls,
)
from .io_processor import EmbedIOProcessor
Expand Down Expand Up @@ -104,6 +105,40 @@ def _openai_json_response(
embed_dtype: EmbedDType,
endianness: Endianness,
) -> JSONResponse:
use_ndarray_response = (
encoding_format == "float"
and self.json_response_cls.__name__ == "ORJSONResponse"
)
if use_ndarray_response:
ndarray_items: list[dict[str, object]] = []
ndarray_num_tokens = 0

for idx, final_res in enumerate(final_res_batch):
item_dict = EmbeddingResponseData(
index=idx,
embedding=[],
).model_dump()
item_dict["embedding"] = encode_pooling_output_float_or_ndarray(
final_res
)
ndarray_items.append(item_dict)
ndarray_num_tokens += len(final_res.prompt_token_ids)

ndarray_usage = UsageInfo(
prompt_tokens=ndarray_num_tokens,
total_tokens=ndarray_num_tokens,
)
ndarray_response = EmbeddingResponse(
id=request_id,
created=created_time,
model=model_name,
data=[], # type: ignore[arg-type]
usage=ndarray_usage,
).model_dump()
ndarray_response["data"] = ndarray_items

return self.json_response_cls(content=ndarray_response)

encode_fn = cast(
Callable[[PoolingRequestOutput], list[float] | str],
(
Expand Down
11 changes: 11 additions & 0 deletions vllm/entrypoints/pooling/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ def encode_pooling_output_float(output: PoolingRequestOutput) -> list[float]:
return output.outputs.data.tolist()


def encode_pooling_output_float_or_ndarray(output: PoolingRequestOutput) -> Any:
"""Return an ndarray when the response renderer can serialize NumPy."""
try:
data = output.outputs.data
if not data.is_contiguous():
data = data.contiguous()
return data.numpy()
except (RuntimeError, TypeError):
return output.outputs.data.tolist()


def encode_pooling_output_base64(
output: PoolingRequestOutput,
embed_dtype: EmbedDType,
Expand Down
Loading