Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def __init__(
history: Optional[list[list[str]]],
tools: Optional[list[Dict]],
system: Optional[Union[str, list[str]]],
sampling_params: Optional[SamplingParams],
pooling_params: Optional[PoolingParams],
eos_token_ids: Optional[list[int]],
arrival_time: float,
sampling_params: Optional[SamplingParams] = None,
pooling_params: Optional[PoolingParams] = None,
preprocess_start_time: Optional[float] = None,
preprocess_end_time: Optional[float] = None,
multimodal_inputs: Optional[dict] = None,
Expand Down Expand Up @@ -538,6 +538,9 @@ def __repr__(self) -> str:
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and bool((self.data == other.data).all())

def to_dict(self):
return {"data": self.data}


_O = TypeVar("_O", default=PoolingOutput)

Expand All @@ -558,21 +561,30 @@ class PoolingRequestOutput(Generic[_O]):
outputs: _O
prompt_token_ids: list[int]
finished: bool
metrics: Optional[RequestMetrics] = (None,)
error_code: Optional[int] = (200,)
error_msg: Optional[str] = (None,)

def __repr__(self):
return (
f"{type(self).__name__}(request_id={self.request_id!r}, "
f"outputs={self.outputs!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"finished={self.finished})"
f"finished={self.finished}, "
f"metrics={self.metrics}, "
f"error_code={self.error_code}, "
f"error_msg={self.error_msg})"
)

def to_dict(self):
return {
"request_id": self.request_id,
"outputs": {"data": self.outputs.data},
"outputs": None if self.outputs is None else self.outputs.to_dict(),
"prompt_token_ids": self.prompt_token_ids,
"finished": self.finished,
"metrics": None if self.metrics is None else self.metrics.to_dict(),
"error_code": self.error_code,
"error_msg": self.error_msg,
}

@classmethod
Expand Down
24 changes: 19 additions & 5 deletions fastdeploy/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import base64
from collections.abc import AsyncGenerator
from typing import Literal, Union

import numpy as np
Expand Down Expand Up @@ -99,11 +100,13 @@ def _request_to_batch_dicts(self, ctx: ServeContext):

for idx, prompt in enumerate(request_prompts):
request_dict = self._request_to_dict(ctx)
request_dict["request_id"] = f"{ctx.request_id}-{idx}"
request_dict["request_id"] = f"{ctx.request_id}_{idx}"
request_dict["prompt"] = prompt
request_dicts.append(request_dict)
else:
request_dicts = [self._request_to_dict(ctx)]
request_dict = self._request_to_dict(ctx)
request_dict["request_id"] = f"{ctx.request_id}_0"
request_dicts = [request_dict]
return request_dicts

async def create_embedding(self, request: EmbeddingRequest):
Expand All @@ -118,9 +121,20 @@ async def create_embedding(self, request: EmbeddingRequest):
request_id=request_id,
)

generation = self.handle(ctx)
async for response in generation:
return response
idx = 0
response: EmbeddingResponse = None
generators: AsyncGenerator[EmbeddingResponse, None] = self.handle(ctx)
async for r in generators:
r.data[0].index = idx
idx += 1
if response is None:
response = r
else:
response.data.append(r.data[0])
response.usage.prompt_tokens += r.usage.prompt_tokens
response.usage.total_tokens += r.usage.total_tokens

return response

@override
def _build_response(self, ctx: ServeContext):
Expand Down
31 changes: 21 additions & 10 deletions fastdeploy/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from typing import Any, ClassVar, Dict, Generic, Optional, TypeVar, Union
from typing import Any, ClassVar, Generic, Optional, TypeVar, Union

from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import override
Expand All @@ -42,10 +42,11 @@ class ServeContext(
):
# Shared across all requests
request: RequestT
request_output: Optional[Union[RequestOutput, PoolingRequestOutput]] = None
model_name: str
request_id: str
created_time: int = Field(default_factory=lambda: int(time.time()))
preprocess_requests: Optional[list[dict]] = None
request_output: Optional[Union[RequestOutput, PoolingRequestOutput]] = None

# `protected_namespaces` resolves Pydantic v2's warning
# on conflict with protected namespace "model_"
Expand Down Expand Up @@ -136,7 +137,7 @@ def _validate_request(self, ctx: ServeContext):
pass

@abstractmethod
async def _preprocess(self, ctx: ServeContext) -> Dict:
async def _preprocess(self, ctx: ServeContext):
"""Preprocess the request into engine format"""
pass

Expand Down Expand Up @@ -239,9 +240,10 @@ def _request_to_batch_dicts(self, ctx: ServeContext):
return [self._request_to_dict(ctx)]

@override
async def _preprocess(self, ctx: ServeContext) -> Dict:
async def _preprocess(self, ctx: ServeContext):
"""Preprocess the request into engine format"""
request_dicts = self._request_to_batch_dicts(ctx)
ctx.preprocess_requests = request_dicts
for request_dict in request_dicts:
api_server_logger.info(f"batch add request_id: {request_dict['request_id']}, request: {request_dict}")
await self.engine_client.format_and_add_data(request_dict)
Expand All @@ -261,17 +263,26 @@ def _process_chat_template_kwargs(self, request_dict):
request_dict["chat_template_kwargs"] = chat_template_kwargs

@override
async def _prepare_generators(self, ctx: ServeContext) -> AsyncGenerator[RequestOutput]:
async def _prepare_generators(self, ctx: ServeContext) -> AsyncGenerator[dict]:
"""Prepare a generator of responses"""
request_id = ctx.request_id
try:
dealer, response_queue = await self.engine_client.connection_manager.get_connection(request_id)
dealer.write([b"", request_id.encode("utf-8")])
num_choices = len(ctx.preprocess_requests)
dealer, request_output_queue = await self.engine_client.connection_manager.get_connection(
request_id, num_choices
)
for pr in ctx.preprocess_requests:
dealer.write([b"", pr["request_id"].encode("utf-8")])
# if self.engine_client.check_model_weight_status():
# raise ValueError("Engine is clearing model weight")
responses = await asyncio.wait_for(response_queue.get(), timeout=60)
for response in responses:
yield response
while num_choices > 0:
request_output_dicts = await asyncio.wait_for(request_output_queue.get(), timeout=60)
for request_output_dict in request_output_dicts:
api_server_logger.debug(f"Received RequestOutput: {request_output_dict}")
if request_output_dict["finished"] is True:
num_choices -= 1
yield request_output_dict

except Exception as e:
raise ValueError(f"Error processing response: {str(e)}")
finally:
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/entrypoints/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def _listen_connection(self, dealer, conn_index):
raw_data = await dealer.read()
response = msgpack.unpackb(raw_data[-1])
request_id = response[-1]["request_id"]
if "cmpl" == request_id[:4]:
if request_id[:4] in ["cmpl", "embd"]:
request_id = request_id.rsplit("_", 1)[0]
elif "chatcmpl" == request_id[:8]:
request_id = request_id.rsplit("_", 1)[0]
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/output/stream_transfer_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,5 @@ class StreamTransferData:
logprobs: Optional[np.array] = None
accept_tokens: Optional[np.array] = None
accept_num: Optional[np.array] = None
# [num_reqs, hidden_size]
pooler_output: Optional[np.array] = None
65 changes: 42 additions & 23 deletions fastdeploy/output/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
import zmq

from fastdeploy import envs
from fastdeploy.engine.request import CompletionOutput, RequestMetrics, RequestOutput
from fastdeploy.engine.request import (
CompletionOutput,
PoolingOutput,
PoolingRequestOutput,
Request,
RequestMetrics,
RequestOutput,
)
from fastdeploy.inter_communicator import IPCSignal, ZmqIpcServer
from fastdeploy.metrics.metrics import main_process_metrics
from fastdeploy.platforms import current_platform
Expand All @@ -49,7 +56,6 @@ class TokenProcessor:
"""

def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_connector):

paddle.device.set_device("cpu")
self.cfg = cfg
self.cached_generated_tokens = cached_generated_tokens
Expand Down Expand Up @@ -231,7 +237,7 @@ def _process_batch_output_use_zmq(self, receive_datas):
if self.resource_manager.stop_flags[i]:
continue

task = self.resource_manager.tasks_list[i]
task: Request = self.resource_manager.tasks_list[i]

task_id = task.request_id
token_ids = stream_data.tokens # numpy.array
Expand All @@ -254,27 +260,40 @@ def _process_batch_output_use_zmq(self, receive_datas):
request_start_time=task.arrival_time,
)

result = RequestOutput(
request_id=task_id,
outputs=CompletionOutput(
index=i,
send_idx=self.tokens_counter[task_id],
token_ids=[],
draft_token_ids=[],
),
finished=False,
metrics=metrics,
)

if self.tokens_counter[task_id] == 0:
if task.messages is not None:
result.prompt = task.messages
result.num_cached_tokens = task.num_cached_tokens

is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
result = self._process_per_token(task, i, token_ids, result, is_prefill)
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
if task.pooling_params is not None:
pooler_output = stream_data.pooler_output
if isinstance(pooler_output, np.ndarray):
pooler_output = pooler_output.tolist()
result = PoolingRequestOutput(
request_id=task_id,
finished=True,
metrics=metrics,
prompt_token_ids=task.prompt_token_ids,
outputs=PoolingOutput(data=pooler_output),
)
self._recycle_resources(task_id, i, task, result, False)
batch_result.append(result)
else:
result = RequestOutput(
request_id=task_id,
outputs=CompletionOutput(
index=i,
send_idx=self.tokens_counter[task_id],
token_ids=[],
draft_token_ids=[],
),
finished=False,
metrics=metrics,
)
if self.tokens_counter[task_id] == 0:
if task.messages is not None:
result.prompt = task.messages
result.num_cached_tokens = task.num_cached_tokens

is_prefill = task.disaggregate_info is not None and task.disaggregate_info["role"] == "prefill"
result = self._process_per_token(task, i, token_ids, result, is_prefill)
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
batch_result.append(result)

return batch_result

Expand Down
18 changes: 12 additions & 6 deletions tests/entrypoints/openai/test_serving_embedding.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import time
import unittest
from unittest.mock import AsyncMock, MagicMock

from fastdeploy.engine.request import PoolingOutput, PoolingRequestOutput
from fastdeploy.engine.request import (
PoolingOutput,
PoolingRequestOutput,
RequestMetrics,
)
from fastdeploy.entrypoints.openai.protocol import (
EmbeddingChatRequest,
EmbeddingCompletionRequest,
Expand All @@ -27,6 +32,7 @@ def setUp(self):
prompt_token_ids=[1, 2, 3],
finished=True,
outputs=PoolingOutput(data=[0.1, 0.2, 0.3]),
metrics=RequestMetrics(arrival_time=time.time()),
)
mock_response_queue.get = AsyncMock(
return_value=[
Expand Down Expand Up @@ -69,14 +75,14 @@ async def test_create_embedding_success(self):

def test_request_to_batch_dicts(self):
test_cases = [
("string input", EmbeddingCompletionRequest(input="hello"), ["hello"], ["req-1-0"]),
("list of ints", EmbeddingCompletionRequest(input=[1, 2, 3]), [[1, 2, 3]], ["req-1-0"]),
("list of strings", EmbeddingCompletionRequest(input=["a", "b"]), ["a", "b"], ["req-1-0", "req-1-1"]),
("string input", EmbeddingCompletionRequest(input="hello"), ["hello"], ["req-1_0"]),
("list of ints", EmbeddingCompletionRequest(input=[1, 2, 3]), [[1, 2, 3]], ["req-1_0"]),
("list of strings", EmbeddingCompletionRequest(input=["a", "b"]), ["a", "b"], ["req-1_0", "req-1_1"]),
(
"list of list of ints",
EmbeddingCompletionRequest(input=[[1, 2], [3, 4]]),
[[1, 2], [3, 4]],
["req-1-0", "req-1-1"],
["req-1_0", "req-1_1"],
),
]

Expand All @@ -90,7 +96,7 @@ def test_request_to_batch_dicts(self):
result = self.embedding_service._request_to_batch_dicts(ctx)
self.assertEqual(len(result), len(expected_prompts))
for r, prompt, rid in zip(result, expected_prompts, expected_ids):
print(f"assertEqual r:{r} prompt:{prompt} rid:{rid}")
# print(f"assertEqual r:{r} prompt:{prompt} rid:{rid}")
self.assertEqual(r["prompt"], prompt)
self.assertEqual(r["request_id"], rid)

Expand Down
Loading