Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 17 additions & 33 deletions fastdeploy/entrypoints/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.multimodal.registry import MultimodalRegistry
from fastdeploy.platforms import current_platform
from fastdeploy.utils import EngineError, StatefulSemaphore, api_server_logger
from fastdeploy.utils import (
EngineError,
ParameterError,
StatefulSemaphore,
api_server_logger,
)


class EngineClient:
Expand Down Expand Up @@ -218,42 +223,21 @@ async def add_requests(self, task):
def vaild_parameters(self, data):
"""
Validate stream options
超参数(top_p、seed、frequency_penalty、temperature、presence_penalty)的校验逻辑
前置到了ChatCompletionRequest/CompletionRequest中
"""

if data.get("n") is not None:
if data["n"] != 1:
raise ValueError("n only support 1.")
raise ParameterError("n", "n only support 1.")

if data.get("max_tokens") is not None:
if data["max_tokens"] < 1 or data["max_tokens"] >= self.max_model_len:
raise ValueError(f"max_tokens can be defined [1, {self.max_model_len}).")
raise ParameterError("max_tokens", f"max_tokens can be defined [1, {self.max_model_len}).")

if data.get("reasoning_max_tokens") is not None:
if data["reasoning_max_tokens"] > data["max_tokens"] or data["reasoning_max_tokens"] < 1:
raise ValueError("reasoning_max_tokens must be between max_tokens and 1")

if data.get("top_p") is not None:
if data["top_p"] > 1 or data["top_p"] < 0:
raise ValueError("top_p value can only be defined [0, 1].")

if data.get("frequency_penalty") is not None:
if not -2.0 <= data["frequency_penalty"] <= 2.0:
raise ValueError("frequency_penalty must be in [-2, 2]")

if data.get("temperature") is not None:
if data["temperature"] < 0:
raise ValueError("temperature must be non-negative")

if data.get("presence_penalty") is not None:
if not -2.0 <= data["presence_penalty"] <= 2.0:
raise ValueError("presence_penalty must be in [-2, 2]")

if data.get("seed") is not None:
if not 0 <= data["seed"] <= 922337203685477580:
raise ValueError("seed must be in [0, 922337203685477580]")

if data.get("stream_options") and not data.get("stream"):
raise ValueError("Stream options can only be defined when `stream=True`.")
raise ParameterError("reasoning_max_tokens", "reasoning_max_tokens must be between max_tokens and 1")

# logprobs
logprobs = data.get("logprobs")
Expand All @@ -263,35 +247,35 @@ def vaild_parameters(self, data):
if not self.enable_logprob:
err_msg = "Logprobs is disabled, please enable it in startup config."
api_server_logger.error(err_msg)
raise ValueError(err_msg)
raise ParameterError("logprobs", err_msg)
top_logprobs = data.get("top_logprobs")
elif isinstance(logprobs, int):
top_logprobs = logprobs
elif logprobs:
raise ValueError("Invalid type for 'logprobs'")
raise ParameterError("logprobs", "Invalid type for 'logprobs'")

# enable_logprob
if top_logprobs:
if not self.enable_logprob:
err_msg = "Logprobs is disabled, please enable it in startup config."
api_server_logger.error(err_msg)
raise ValueError(err_msg)
raise ParameterError("logprobs", err_msg)

if not isinstance(top_logprobs, int):
err_type = type(top_logprobs).__name__
err_msg = f"Invalid type for 'top_logprobs': expected int but got {err_type}."
api_server_logger.error(err_msg)
raise ValueError(err_msg)
raise ParameterError("top_logprobs", err_msg)

if top_logprobs < 0:
err_msg = f"Invalid 'top_logprobs': must be >= 0, got {top_logprobs}."
api_server_logger.error(err_msg)
raise ValueError(err_msg)
raise ParameterError("top_logprobs", err_msg)

if top_logprobs > 20:
err_msg = "Invalid value for 'top_logprobs': must be <= 20."
api_server_logger.error(err_msg)
raise ValueError(err_msg)
raise ParameterError("top_logprobs", err_msg)

def check_health(self, time_interval_threashold=30):
"""
Expand Down
14 changes: 10 additions & 4 deletions fastdeploy/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import uvicorn
import zmq
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import CONTENT_TYPE_LATEST

Expand All @@ -40,6 +41,7 @@
CompletionRequest,
CompletionResponse,
ControlSchedulerRequest,
ErrorInfo,
ErrorResponse,
ModelList,
)
Expand All @@ -56,6 +58,7 @@
)
from fastdeploy.metrics.trace_util import fd_start_span, inject_to_metadata, instrument
from fastdeploy.utils import (
ExceptionHandler,
FlexibleArgumentParser,
StatefulSemaphore,
api_server_logger,
Expand Down Expand Up @@ -232,6 +235,8 @@ async def lifespan(app: FastAPI):


app = FastAPI(lifespan=lifespan)
app.add_exception_handler(RequestValidationError, ExceptionHandler.handle_request_validation_exception)
app.add_exception_handler(Exception, ExceptionHandler.handle_exception)
instrument(app)


Expand Down Expand Up @@ -336,7 +341,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
if isinstance(generator, ErrorResponse):
api_server_logger.debug(f"release: {connection_semaphore.status()}")
connection_semaphore.release()
return JSONResponse(content={"detail": generator.model_dump()}, status_code=generator.code)
return JSONResponse(content=generator.model_dump(), status_code=500)
elif isinstance(generator, ChatCompletionResponse):
api_server_logger.debug(f"release: {connection_semaphore.status()}")
connection_semaphore.release()
Expand Down Expand Up @@ -365,7 +370,7 @@ async def create_completion(request: CompletionRequest):
generator = await app.state.completion_handler.create_completion(request)
if isinstance(generator, ErrorResponse):
connection_semaphore.release()
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
return JSONResponse(content=generator.model_dump(), status_code=500)
elif isinstance(generator, CompletionResponse):
connection_semaphore.release()
return JSONResponse(content=generator.model_dump())
Expand All @@ -388,7 +393,7 @@ async def list_models() -> Response:

models = await app.state.model_handler.list_models()
if isinstance(models, ErrorResponse):
return JSONResponse(content=models.model_dump(), status_code=models.code)
return JSONResponse(content=models.model_dump())
elif isinstance(models, ModelList):
return JSONResponse(content=models.model_dump())

Expand Down Expand Up @@ -502,7 +507,8 @@ def control_scheduler(request: ControlSchedulerRequest):
"""
Control the scheduler behavior with the given parameters.
"""
content = ErrorResponse(object="", message="Scheduler updated successfully", code=0)

content = ErrorResponse(error=ErrorInfo(message="Scheduler updated successfully", code=0))

global llm_engine
if llm_engine is None:
Expand Down
31 changes: 18 additions & 13 deletions fastdeploy/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,14 @@ class ErrorResponse(BaseModel):
Error response from OpenAI API.
"""

object: str = "error"
error: ErrorInfo


class ErrorInfo(BaseModel):
message: str
code: int
type: Optional[str] = None
param: Optional[str] = None
code: Optional[str] = None


class PromptTokenUsageInfo(BaseModel):
Expand Down Expand Up @@ -403,21 +408,21 @@ class CompletionRequest(BaseModel):
prompt: Union[List[int], List[List[int]], str, List[str]]
best_of: Optional[int] = None
echo: Optional[bool] = False
frequency_penalty: Optional[float] = None
frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
logprobs: Optional[int] = None
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
max_tokens: Optional[int] = None
n: int = 1
presence_penalty: Optional[float] = None
seed: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
seed: Optional[int] = Field(default=None, ge=0, le=922337203685477580)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
suffix: Optional[dict] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
temperature: Optional[float] = Field(default=None, ge=0)
top_p: Optional[float] = Field(default=None, ge=0, le=1)
user: Optional[str] = None

# doc: begin-completion-sampling-params
Expand Down Expand Up @@ -537,7 +542,7 @@ class ChatCompletionRequest(BaseModel):
messages: Union[List[Any], List[int]]
tools: Optional[List[ChatCompletionToolsParam]] = None
model: Optional[str] = "default"
frequency_penalty: Optional[float] = None
frequency_penalty: Optional[float] = Field(None, le=2, ge=-2)
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0

Expand All @@ -552,13 +557,13 @@ class ChatCompletionRequest(BaseModel):
)
max_completion_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = None
seed: Optional[int] = None
presence_penalty: Optional[float] = Field(None, le=2, ge=-2)
seed: Optional[int] = Field(default=None, ge=0, le=922337203685477580)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
temperature: Optional[float] = Field(None, ge=0)
top_p: Optional[float] = Field(None, le=1, ge=0)
user: Optional[str] = None
metadata: Optional[dict] = None
response_format: Optional[AnyResponseFormat] = None
Expand Down
28 changes: 18 additions & 10 deletions fastdeploy/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
ErrorInfo,
ErrorResponse,
LogProbEntry,
LogProbs,
Expand All @@ -38,7 +39,7 @@
)
from fastdeploy.entrypoints.openai.response_processors import ChatResponseProcessor
from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.utils import api_server_logger
from fastdeploy.utils import ErrorCode, ErrorType, ParameterError, api_server_logger
from fastdeploy.worker.output import LogprobsLists


Expand Down Expand Up @@ -86,14 +87,16 @@ async def create_chat_completion(self, request: ChatCompletionRequest):
f"Only master node can accept completion request, please send request to master node: {self.master_ip}"
)
api_server_logger.error(err_msg)
return ErrorResponse(message=err_msg, code=400)
return ErrorResponse(error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR))

if self.models:
is_supported, request.model = self.models.is_supported_model(request.model)
if not is_supported:
err_msg = f"Unsupported model: [{request.model}], support [{', '.join([x.name for x in self.models.model_paths])}] or default"
api_server_logger.error(err_msg)
return ErrorResponse(message=err_msg, code=400)
return ErrorResponse(
error=ErrorInfo(message=err_msg, type=ErrorType.SERVER_ERROR, code=ErrorCode.MODEL_NOT_SUPPORT)
)

try:
if self.max_waiting_time < 0:
Expand All @@ -117,11 +120,17 @@ async def create_chat_completion(self, request: ChatCompletionRequest):
text_after_process = current_req_dict.get("text_after_process")
if isinstance(prompt_token_ids, np.ndarray):
prompt_token_ids = prompt_token_ids.tolist()
except ParameterError as e:
api_server_logger.error(e.message)
self.engine_client.semaphore.release()
return ErrorResponse(
error=ErrorInfo(message=str(e.message), type=ErrorType.INVALID_REQUEST_ERROR, param=e.param)
)
except Exception as e:
error_msg = f"request[{request_id}] generator error: {str(e)}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg)
self.engine_client.semaphore.release()
return ErrorResponse(code=400, message=error_msg)
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INVALID_REQUEST_ERROR))
del current_req_dict

if request.stream:
Expand All @@ -136,21 +145,20 @@ async def create_chat_completion(self, request: ChatCompletionRequest):
except Exception as e:
error_msg = f"request[{request_id}]full generator error: {str(e)}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg)
return ErrorResponse(code=408, message=error_msg)
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.SERVER_ERROR))
except Exception as e:
error_msg = (
f"request[{request_id}] waiting error: {str(e)}, {str(traceback.format_exc())}, "
f"max waiting time: {self.max_waiting_time}"
)
api_server_logger.error(error_msg)
return ErrorResponse(code=408, message=error_msg)
return ErrorResponse(
error=ErrorInfo(message=error_msg, type=ErrorType.TIMEOUT_ERROR, code=ErrorCode.TIMEOUT)
)

def _create_streaming_error_response(self, message: str) -> str:
api_server_logger.error(message)
error_response = ErrorResponse(
code=400,
message=message,
)
error_response = ErrorResponse(error=ErrorInfo(message=message, type=ErrorType.SERVER_ERROR))
return error_response.model_dump_json()

async def chat_completion_stream_generator(
Expand Down
Loading
Loading