Skip to content

Commit

Permalink
[openai embedding] add base64 encoding in EmbeddingResponseData
Browse files Browse the repository at this point in the history
  • Loading branch information
llmpros committed Jun 30, 2024
1 parent 2be6955 commit dd0bbd9
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 17 deletions.
2 changes: 1 addition & 1 deletion examples/openai_embedding_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
model=model)

for data in responses.data:
print(data.embedding) # list of float of len 4096
print(data.embedding) # list of float of len 4096
33 changes: 33 additions & 0 deletions tests/entrypoints/openai/test_embedding.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import base64

import numpy as np
import openai
import pytest
import ray
Expand Down Expand Up @@ -109,3 +112,33 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 17
assert embeddings.usage.total_tokens == 17


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
model_name: str):
input_texts = [
"Hello my name is",
"The best thing about vLLM is that it supports many different models"
]

responses_float = embedding_client.embeddings.create(
input=input_texts, model=model_name, encoding_format="float")

responses_base64 = embedding_client.embeddings.create(
input=input_texts, model=model_name, encoding_format="base64")

decoded_responses_base64_data = []
for data in responses_base64.data:
decoded_responses_base64_data.append(
np.frombuffer(base64.b64decode(data.embedding),
dtype="float").tolist())

assert responses_float.data[0].embedding == decoded_responses_base64_data[
0]
assert responses_float.data[1].embedding == decoded_responses_base64_data[
1]
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ class CompletionStreamResponse(OpenAIBaseModel):
class EmbeddingResponseData(BaseModel):
index: int
object: str = "embedding"
embedding: List[float]
embedding: Union[List[float], str]


class EmbeddingResponse(BaseModel):
Expand Down
30 changes: 15 additions & 15 deletions vllm/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import base64
import time
from typing import AsyncIterator, List, Optional, Tuple

import numpy as np
from fastapi import Request

from vllm.config import ModelConfig
Expand All @@ -20,19 +22,19 @@


def request_output_to_embedding_response(
final_res_batch: List[EmbeddingRequestOutput],
request_id: str,
created_time: int,
model_name: str,
) -> EmbeddingResponse:
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
created_time: int, model_name: str,
encoding_format: str) -> EmbeddingResponse:
data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids

embedding_data = EmbeddingResponseData(
index=idx, embedding=final_res.outputs.embedding)
embedding = final_res.outputs.embedding
if encoding_format == "base64":
embedding = base64.b64encode(np.array(embedding))
embedding_data = EmbeddingResponseData(index=idx,
embedding=[embedding])
data.append(embedding_data)

num_prompt_tokens += len(prompt_token_ids)
Expand Down Expand Up @@ -72,10 +74,8 @@ async def create_embedding(self, request: EmbeddingRequest,
if error_check_ret is not None:
return error_check_ret

# Return error for unsupported features.
if request.encoding_format == "base64":
return self.create_error_response(
"base64 encoding is not currently supported")
encoding_format = (request.encoding_format
if request.encoding_format else "float")
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
Expand All @@ -89,7 +89,6 @@ async def create_embedding(self, request: EmbeddingRequest,
try:
prompt_is_tokens, prompts = parse_prompt_format(request.input)
pooling_params = request.to_pooling_params()

for i, prompt in enumerate(prompts):
if prompt_is_tokens:
prompt_formats = self._validate_prompt_and_tokenize(
Expand Down Expand Up @@ -129,7 +128,8 @@ async def create_embedding(self, request: EmbeddingRequest,
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
response = request_output_to_embedding_response(
final_res_batch, request_id, created_time, model_name)
final_res_batch, request_id, created_time, model_name,
encoding_format)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
Expand All @@ -141,4 +141,4 @@ def _check_embedding_mode(self, embedding_mode: bool):
logger.warning(
"embedding_mode is False. Embedding API will not work.")
else:
logger.info("Activating the server engine with embedding enabled.")
logger.info("Activating the server engine with embedding enabled.")

0 comments on commit dd0bbd9

Please sign in to comment.