From 1d17465eb3bc1e3416220a1f72a56ea937d7c757 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sat, 14 Jun 2025 01:10:29 +0000 Subject: [PATCH 01/33] Add refactored OpenAI API server modules implementation - Introduced new modules for handling OpenAI-compatible API requests, including chat and completion serving logic. - Implemented request validation rules for chat and completion requests. - Added utility functions for processing and formatting requests and responses. - Included Pydantic models for defining request and response structures. This commit lays the groundwork for integrating OpenAI API functionalities into the SGLang framework. Signed-off-by: Xinyuan Tong --- .../sglang/srt/entrypoints/openai/__init__.py | 23 + .../sglang/srt/entrypoints/openai/protocol.py | 544 +++++++++ .../srt/entrypoints/openai/serving_chat.py | 917 ++++++++++++++ .../entrypoints/openai/serving_completions.py | 484 ++++++++ .../srt/entrypoints/openai/serving_engine.py | 110 ++ python/sglang/srt/entrypoints/openai/utils.py | 506 ++++++++ .../srt/entrypoints/openai/validation.py | 344 ++++++ test/pytest.ini | 2 + test/srt/openai/__init__.py | 14 + test/srt/openai/test_protocol.py | 683 +++++++++++ test/srt/openai/test_serving_chat.py | 845 +++++++++++++ test/srt/openai/test_serving_completions.py | 1055 +++++++++++++++++ 12 files changed, 5527 insertions(+) create mode 100644 python/sglang/srt/entrypoints/openai/__init__.py create mode 100644 python/sglang/srt/entrypoints/openai/protocol.py create mode 100644 python/sglang/srt/entrypoints/openai/serving_chat.py create mode 100644 python/sglang/srt/entrypoints/openai/serving_completions.py create mode 100644 python/sglang/srt/entrypoints/openai/serving_engine.py create mode 100644 python/sglang/srt/entrypoints/openai/utils.py create mode 100644 python/sglang/srt/entrypoints/openai/validation.py create mode 100644 test/pytest.ini create mode 100644 test/srt/openai/__init__.py create mode 100644 test/srt/openai/test_protocol.py create mode 100644 test/srt/openai/test_serving_chat.py create mode 100644 test/srt/openai/test_serving_completions.py diff --git a/python/sglang/srt/entrypoints/openai/__init__.py b/python/sglang/srt/entrypoints/openai/__init__.py new file mode 100644 index 00000000000..17f1eacfb28 --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""OpenAI-compatible API server module""" + +from .protocol import * +from .serving_engine import OpenAIServingBase, RequestContext +from .utils import * + +__all__ = [ + "OpenAIServingBase", + "RequestContext", +] diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py new file mode 100644 index 00000000000..a11fc71f93a --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -0,0 +1,544 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pydantic models for OpenAI API protocol""" + +import time +from typing import Dict, List, Optional, Union + +from pydantic import ( + BaseModel, + Field, + field_validator, + model_validator, + root_validator, + validator, +) +from pydantic_core import ValidationError +from typing_extensions import Literal + + +class ModelCard(BaseModel): + """Model cards.""" + + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "sglang" + root: Optional[str] = None + max_model_len: Optional[int] = None + + +class ModelList(BaseModel): + """Model list consists of model cards.""" + + object: str = "list" + data: List[ModelCard] = Field(default_factory=list) + + +class ErrorResponse(BaseModel): + object: str = "error" + message: str + type: str + param: Optional[str] = None + code: int + + +class LogProbs(BaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, float]]] = Field(default_factory=list) + + +class TopLogprob(BaseModel): + token: str + bytes: List[int] + logprob: float + + +class ChatCompletionTokenLogprob(BaseModel): + token: str + bytes: List[int] + logprob: float + top_logprobs: List[TopLogprob] + + +class ChoiceLogprobs(BaseModel): + # build for v1/chat/completions response + content: List[ChatCompletionTokenLogprob] + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + # only used to return cached tokens when --enable-cache-report is set + prompt_tokens_details: Optional[Dict[str, int]] = None + + +class StreamOptions(BaseModel): + include_usage: Optional[bool] = False + + +class JsonSchemaResponseFormat(BaseModel): + name: str + description: Optional[str] = None + # use alias to workaround pydantic conflict + schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None) + strict: Optional[bool] = False + + +class FileRequest(BaseModel): + # https://platform.openai.com/docs/api-reference/files/create + file: bytes # The File object (not file name) to be uploaded + purpose: str = ( + "batch" # The intended purpose of the uploaded file, default is "batch" + ) + + +class FileResponse(BaseModel): + id: str + object: str = "file" + bytes: int + created_at: int + filename: str + purpose: str + + +class FileDeleteResponse(BaseModel): + id: str + object: str = "file" + deleted: bool + + +class BatchRequest(BaseModel): + input_file_id: ( + str # The ID of an uploaded file that contains requests for the new batch + ) + endpoint: str # The endpoint to be used for all requests in the batch + completion_window: str # The time frame within which the batch should be processed + metadata: Optional[dict] = None # Optional custom metadata for the batch + + +class BatchResponse(BaseModel): + id: str + object: str = "batch" + endpoint: str + errors: Optional[dict] = None + input_file_id: str + completion_window: str + status: str = "validating" + output_file_id: Optional[str] = None + error_file_id: Optional[str] = None + created_at: int + in_progress_at: Optional[int] = None + expires_at: Optional[int] = None + finalizing_at: Optional[int] = None + completed_at: Optional[int] = None + failed_at: Optional[int] = None + expired_at: Optional[int] = None + cancelling_at: Optional[int] = None + cancelled_at: Optional[int] = None + request_counts: Optional[dict] = None + metadata: Optional[dict] = None + + +class CompletionRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/completions/create + model: str + prompt: Union[List[int], List[List[int]], str, List[str]] + best_of: Optional[int] = None + echo: bool = False + frequency_penalty: float = 0.0 + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[int] = None + max_tokens: int = 16 + n: int = 1 + presence_penalty: float = 0.0 + seed: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: bool = False + stream_options: Optional[StreamOptions] = None + suffix: Optional[str] = None + temperature: float = 1.0 + top_p: float = 1.0 + user: Optional[str] = None + + # Extra parameters for SRT backend only and will be ignored by OpenAI models. + top_k: int = -1 + min_p: float = 0.0 + min_tokens: int = 0 + json_schema: Optional[str] = None + regex: Optional[str] = None + ebnf: Optional[str] = None + repetition_penalty: float = 1.0 + stop_token_ids: Optional[List[int]] = None + no_stop_trim: bool = False + ignore_eos: bool = False + skip_special_tokens: bool = True + lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + session_params: Optional[Dict] = None + + # For PD disaggregation + bootstrap_host: Optional[str] = None + bootstrap_port: Optional[int] = None + bootstrap_room: Optional[int] = None + + @field_validator("max_tokens") + @classmethod + def validate_max_tokens_positive(cls, v): + if v is not None and v < 0: + raise ValueError("max_tokens must be non-negative") + return v + + +class CompletionResponseChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Literal["stop", "length", "content_filter", "abort"] + matched_stop: Union[None, int, str] = None + + +class CompletionResponse(BaseModel): + id: str + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None + matched_stop: Union[None, int, str] = None + + +class CompletionStreamResponse(BaseModel): + id: str + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] + usage: Optional[UsageInfo] = None + + +class ChatCompletionMessageContentTextPart(BaseModel): + type: Literal["text"] + text: str + + +class ChatCompletionMessageContentImageURL(BaseModel): + url: str + detail: Optional[Literal["auto", "low", "high"]] = "auto" + + +class ChatCompletionMessageContentAudioURL(BaseModel): + url: str + + +class ChatCompletionMessageContentImagePart(BaseModel): + type: Literal["image_url"] + image_url: ChatCompletionMessageContentImageURL + modalities: Optional[Literal["image", "multi-images", "video"]] = "image" + + +class ChatCompletionMessageContentAudioPart(BaseModel): + type: Literal["audio_url"] + audio_url: ChatCompletionMessageContentAudioURL + + +ChatCompletionMessageContentPart = Union[ + ChatCompletionMessageContentTextPart, + ChatCompletionMessageContentImagePart, + ChatCompletionMessageContentAudioPart, +] + + +class FunctionResponse(BaseModel): + """Function response.""" + + name: Optional[str] = None + arguments: Optional[str] = None + + +class ToolCall(BaseModel): + """Tool call response.""" + + id: Optional[str] = None + index: Optional[int] = None + type: Literal["function"] = "function" + function: FunctionResponse + + +class ChatCompletionMessageGenericParam(BaseModel): + role: Literal["system", "assistant", "tool"] + content: Union[str, List[ChatCompletionMessageContentTextPart], None] + tool_call_id: Optional[str] = None + name: Optional[str] = None + reasoning_content: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) + + +class ChatCompletionMessageUserParam(BaseModel): + role: Literal["user"] + content: Union[str, List[ChatCompletionMessageContentPart]] + + +ChatCompletionMessageParam = Union[ + ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam +] + + +class ResponseFormat(BaseModel): + type: Literal["text", "json_object", "json_schema"] + json_schema: Optional[JsonSchemaResponseFormat] = None + + +class StructuresResponseFormat(BaseModel): + begin: str + schema_: Optional[Dict[str, object]] = Field(alias="schema", default=None) + end: str + + +class StructuralTagResponseFormat(BaseModel): + type: Literal["structural_tag"] + structures: List[StructuresResponseFormat] + triggers: List[str] + + +class Function(BaseModel): + """Function descriptions.""" + + description: Optional[str] = Field(default=None, examples=[None]) + name: Optional[str] = None + parameters: Optional[object] = None + strict: bool = False + + +class Tool(BaseModel): + """Function wrapper.""" + + type: str = Field(default="function", examples=["function"]) + function: Function + + +class ToolChoiceFuncName(BaseModel): + """The name of tool choice function.""" + + name: Optional[str] = None + + +class ToolChoice(BaseModel): + """The tool choice definition.""" + + function: ToolChoiceFuncName + type: Literal["function"] = Field(default="function", examples=["function"]) + + +class ChatCompletionRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/chat/create + messages: List[ChatCompletionMessageParam] + model: str + frequency_penalty: float = 0.0 + logit_bias: Optional[Dict[str, float]] = None + logprobs: bool = False + top_logprobs: Optional[int] = None + max_tokens: Optional[int] = Field( + default=None, + deprecated="max_tokens is deprecated in favor of the max_completion_tokens field", + description="The maximum number of tokens that can be generated in the chat completion. ", + ) + max_completion_tokens: Optional[int] = Field( + default=None, + description="The maximum number of completion tokens for a chat completion request, " + "including visible output tokens and reasoning tokens. Input tokens are not included. ", + ) + n: int = 1 + presence_penalty: float = 0.0 + response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None + seed: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: bool = False + stream_options: Optional[StreamOptions] = None + temperature: float = 0.7 + top_p: float = 1.0 + user: Optional[str] = None + tools: Optional[List[Tool]] = Field(default=None, examples=[None]) + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field( + default="auto", examples=["none"] + ) # noqa + + @model_validator(mode="before") + @classmethod + def set_tool_choice_default(cls, values): + if isinstance(values, dict): + if values.get("tool_choice") is None: + if values.get("tools") is None: + values["tool_choice"] = "none" + else: + values["tool_choice"] = "auto" + return values + + @field_validator("messages") + @classmethod + def validate_messages_not_empty(cls, v): + if not v: + raise ValueError("Messages cannot be empty") + return v + + # Extra parameters for SRT backend only and will be ignored by OpenAI models. + top_k: int = -1 + min_p: float = 0.0 + min_tokens: int = 0 + regex: Optional[str] = None + ebnf: Optional[str] = None + repetition_penalty: float = 1.0 + stop_token_ids: Optional[List[int]] = None + no_stop_trim: bool = False + ignore_eos: bool = False + continue_final_message: bool = False + skip_special_tokens: bool = True + lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + session_params: Optional[Dict] = None + separate_reasoning: bool = True + stream_reasoning: bool = True + chat_template_kwargs: Optional[Dict] = None + + # The request id. + rid: Optional[str] = None + + # For PD disaggregation + bootstrap_host: Optional[str] = None + bootstrap_port: Optional[int] = None + bootstrap_room: Optional[int] = None + + +class ChatMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + reasoning_content: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None + finish_reason: Literal[ + "stop", "length", "tool_calls", "content_filter", "function_call", "abort" + ] + matched_stop: Union[None, int, str] = None + + +class ChatCompletionResponse(BaseModel): + id: str + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + reasoning_content: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None + finish_reason: Optional[ + Literal["stop", "length", "tool_calls", "content_filter", "function_call"] + ] = None + matched_stop: Union[None, int, str] = None + + +class ChatCompletionStreamResponse(BaseModel): + id: str + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + usage: Optional[UsageInfo] = None + + +class MultimodalEmbeddingInput(BaseModel): + text: Optional[str] = None + image: Optional[str] = None + + +class EmbeddingRequest(BaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/embeddings/create + input: Union[ + List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput] + ] + model: str + encoding_format: str = "float" + dimensions: int = None + user: Optional[str] = None + + # The request id. + rid: Optional[str] = None + + +class EmbeddingObject(BaseModel): + embedding: List[float] + index: int + object: str = "embedding" + + +class EmbeddingResponse(BaseModel): + data: List[EmbeddingObject] + model: str + object: str = "list" + usage: Optional[UsageInfo] = None + + +class ScoringRequest(BaseModel): + query: Optional[Union[str, List[int]]] = ( + None # Query text or pre-tokenized token IDs + ) + items: Optional[Union[str, List[str], List[List[int]]]] = ( + None # Item text(s) or pre-tokenized token IDs + ) + label_token_ids: Optional[List[int]] = ( + None # Token IDs to compute probabilities for + ) + apply_softmax: bool = False + item_first: bool = False + model: str + + +class ScoringResponse(BaseModel): + scores: List[ + List[float] + ] # List of lists of probabilities, each in the order of label_token_ids + model: str + usage: Optional[UsageInfo] = None + object: str = "scoring" + + +OpenAIServingRequest = Union[ + ChatCompletionRequest, CompletionRequest, EmbeddingRequest, ScoringRequest +] diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py new file mode 100644 index 00000000000..36652544130 --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -0,0 +1,917 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Chat completions serving logic for OpenAI API""" + +import base64 +import json +import logging +import time +import uuid +from typing import Any, Dict, List, Optional, Union + +from fastapi import Request +from fastapi.responses import StreamingResponse + +from sglang.srt.conversation import generate_chat_conv +from sglang.srt.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatCompletionTokenLogprob, + ChatMessage, + ChoiceLogprobs, + DeltaMessage, + ErrorResponse, + FunctionResponse, + ToolCall, + TopLogprob, + UsageInfo, +) +from sglang.srt.entrypoints.openai.serving_engine import ( + OpenAIServingBase, + RequestContext, +) +from sglang.srt.entrypoints.openai.utils import ( + _get_enable_thinking_from_request, + aggregate_token_usage, + build_base_sampling_params, + create_error_response, + create_stream_done, + create_streaming_chunk_data, + create_streaming_error_response, + detect_template_content_format, + process_content_for_template_format, + to_openai_style_logprobs, +) +from sglang.srt.function_call.function_call_parser import FunctionCallParser +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.reasoning_parser import ReasoningParser +from sglang.utils import convert_json_schema_to_str + +logger = logging.getLogger(__name__) + +# Global cache for template content format detection +_cached_chat_template = None +_cached_template_format = None + + +class ChatCompletionHandler(OpenAIServingBase): + """Handler for chat completion requests""" + + def __init__(self, tokenizer_manager: TokenizerManager): + super().__init__(tokenizer_manager) + + async def handle_request( + self, request: ChatCompletionRequest, raw_request: Request + ) -> Union[ChatCompletionResponse, StreamingResponse, ErrorResponse]: + """Handle a chat completion request""" + try: + # Validate request + error = self._validate_request(request) + if error: + return error + + # Create request context + ctx = RequestContext( + raw_request=raw_request, + openai_request=request, + request_id=request.rid or f"chatcmpl-{uuid.uuid4()}", + ) + + # Convert to internal format + adapted_request, processed_request = self._convert_to_internal_request( + [request], [ctx.request_id] + ) + + if request.stream: + return await self._handle_streaming_request( + adapted_request, processed_request, ctx + ) + else: + return await self._handle_non_streaming_request( + adapted_request, processed_request, ctx + ) + + except Exception as e: + logger.error(f"Error in chat completion: {e}") + return create_error_response( + message=f"Internal server error: {str(e)}", + err_type="InternalServerError", + status_code=500, + ) + + def _convert_to_internal_request( + self, + all_requests: List[ChatCompletionRequest], + request_ids: List[str], + ) -> tuple[ + GenerateReqInput, Union[ChatCompletionRequest, List[ChatCompletionRequest]] + ]: + """Convert OpenAI chat completion request to internal format""" + input_ids = [] + prompts = [] + sampling_params_list = [] + image_data_list = [] + audio_data_list = [] + return_logprobs = [] + logprob_start_lens = [] + top_logprobs_nums = [] + modalities_list = [] + lora_paths = [] + + is_multimodal = self.tokenizer_manager.model_config.is_multimodal + + for request in all_requests: + # Process messages and apply chat template + prompt, prompt_ids, image_data, audio_data, modalities, stop = ( + self._process_messages(request, is_multimodal) + ) + + input_ids.append(prompt_ids) + prompts.append(prompt) + return_logprobs.append(request.logprobs) + logprob_start_lens.append(-1) + top_logprobs_nums.append(request.top_logprobs or 0) + lora_paths.append(request.lora_path) + + # Build sampling parameters + sampling_params = self._build_sampling_params(request, stop) + sampling_params_list.append(sampling_params) + + image_data_list.append(image_data) + audio_data_list.append(audio_data) + modalities_list.append(modalities) + + # Handle single vs multiple requests + if len(all_requests) == 1: + if is_multimodal: + prompt_kwargs = {"text": prompts[0]} + else: + if isinstance(input_ids[0], str): + prompt_kwargs = {"text": input_ids[0]} + else: + prompt_kwargs = {"input_ids": input_ids[0]} + + sampling_params_list = sampling_params_list[0] + image_data_list = image_data_list[0] + audio_data_list = audio_data_list[0] + return_logprobs = return_logprobs[0] + logprob_start_lens = logprob_start_lens[0] + top_logprobs_nums = top_logprobs_nums[0] + modalities_list = modalities_list[0] + lora_paths = lora_paths[0] + request_ids = request_ids[0] + else: + if is_multimodal: + prompt_kwargs = {"text": prompts} + else: + if isinstance(input_ids[0], str): + prompt_kwargs = {"text": input_ids} + else: + prompt_kwargs = {"input_ids": input_ids} + + adapted_request = GenerateReqInput( + **prompt_kwargs, + image_data=image_data_list, + audio_data=audio_data_list, + sampling_params=sampling_params_list, + return_logprob=return_logprobs, + logprob_start_len=logprob_start_lens, + top_logprobs_num=top_logprobs_nums, + stream=all_requests[0].stream, + return_text_in_logprobs=True, + rid=request_ids, + modalities=modalities_list, + lora_path=lora_paths, + bootstrap_host=all_requests[0].bootstrap_host, + bootstrap_port=all_requests[0].bootstrap_port, + bootstrap_room=all_requests[0].bootstrap_room, + ) + + return adapted_request, ( + all_requests if len(all_requests) > 1 else all_requests[0] + ) + + def _process_messages( + self, request: ChatCompletionRequest, is_multimodal: bool + ) -> tuple[ + str, Union[str, List[int]], Optional[Any], Optional[Any], List[str], List[str] + ]: + """Process chat messages and apply chat template""" + tool_call_constraint = None + prompt = "" + prompt_ids = [] + + if not isinstance(request.messages, str): + # Apply chat template and its stop strings + tools = None + if request.tools and request.tool_choice != "none": + request.skip_special_tokens = False + if not isinstance(request.tool_choice, str): + tools = [ + item.function.model_dump() + for item in request.tools + if item.function.name == request.tool_choice.function.name + ] + else: + tools = [item.function.model_dump() for item in request.tools] + + tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser + parser = FunctionCallParser(request.tools, tool_call_parser) + tool_call_constraint = parser.get_structure_constraint( + request.tool_choice + ) + + # Use chat template + if ( + hasattr(self.tokenizer_manager, "chat_template_name") + and self.tokenizer_manager.chat_template_name is None + ): + prompt, prompt_ids, image_data, audio_data, modalities, stop = ( + self._apply_jinja_template(request, tools, is_multimodal) + ) + else: + prompt, image_data, audio_data, modalities, stop = ( + self._apply_conversation_template(request) + ) + if not is_multimodal: + prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt) + else: + # Use raw prompt + prompt_ids = request.messages + stop = request.stop or [] + image_data = None + audio_data = None + modalities = [] + prompt = request.messages + + return prompt, prompt_ids, image_data, audio_data, modalities, stop + + def _apply_jinja_template( + self, + request: ChatCompletionRequest, + tools: Optional[List[Dict]], + is_multimodal: bool, + ) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]: + """Apply Jinja chat template""" + global _cached_chat_template, _cached_template_format + + openai_compatible_messages = [] + image_data = [] + audio_data = [] + modalities = [] + + # Detect template content format + current_template = self.tokenizer_manager.tokenizer.chat_template + if current_template != _cached_chat_template: + _cached_chat_template = current_template + _cached_template_format = detect_template_content_format(current_template) + logger.info( + f"Detected chat template content format: {_cached_template_format}" + ) + + template_content_format = _cached_template_format + + for message in request.messages: + if message.content is None: + message.content = "" + msg_dict = message.model_dump() + + # Process content based on detected template format + processed_msg = process_content_for_template_format( + msg_dict, + template_content_format, + image_data, + audio_data, + modalities, + ) + openai_compatible_messages.append(processed_msg) + + # Handle assistant prefix for continue_final_message + assistant_prefix = None + if ( + openai_compatible_messages + and openai_compatible_messages[-1]["role"] == "assistant" + ): + if request.continue_final_message: + assistant_prefix = openai_compatible_messages[-1]["content"] + openai_compatible_messages = openai_compatible_messages[:-1] + + try: + prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template( + openai_compatible_messages, + tokenize=True, + add_generation_prompt=True, + tools=tools, + **( + request.chat_template_kwargs if request.chat_template_kwargs else {} + ), + ) + except Exception: + # Handle different tools input format (e.g., Mistral) + tools = ( + [t if "function" in t else {"function": t} for t in tools] + if tools + else None + ) + prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template( + openai_compatible_messages, + tokenize=True, + add_generation_prompt=True, + tools=tools, + **( + request.chat_template_kwargs if request.chat_template_kwargs else {} + ), + ) + + if assistant_prefix: + encoded = self.tokenizer_manager.tokenizer.encode(assistant_prefix) + if encoded and encoded[0] == self.tokenizer_manager.tokenizer.bos_token_id: + encoded = encoded[1:] + prompt_ids += encoded + + if is_multimodal: + prompt = self.tokenizer_manager.tokenizer.decode(prompt_ids) + + stop = request.stop or [] + return prompt, prompt_ids, image_data, audio_data, modalities, stop + + def _apply_conversation_template( + self, request: ChatCompletionRequest + ) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str]]: + """Apply conversation template""" + conv = generate_chat_conv(request, self.tokenizer_manager.chat_template_name) + + # Handle continue_final_message + if ( + request.continue_final_message + and request.messages + and request.messages[-1].role == "assistant" + ): + if conv.messages and conv.messages[-1][1] is None: + conv.messages.pop() + prompt = conv.get_prompt() + # Strip trailing stop tokens + if isinstance(conv.stop_str, list): + for stop_token in conv.stop_str: + if prompt.endswith(stop_token): + prompt = prompt[: -len(stop_token)] + elif isinstance(conv.stop_str, str) and prompt.endswith(conv.stop_str): + prompt = prompt[: -len(conv.stop_str)] + if conv.sep and prompt.endswith(conv.sep): + prompt = prompt[: -len(conv.sep)] + if getattr(conv, "sep2", None) and prompt.endswith(conv.sep2): + prompt = prompt[: -len(conv.sep2)] + else: + prompt = conv.get_prompt() + + image_data = conv.image_data + audio_data = conv.audio_data + modalities = conv.modalities + stop = conv.stop_str or [] if not request.ignore_eos else [] + + if request.stop: + if isinstance(request.stop, str): + stop.append(request.stop) + else: + stop.extend(request.stop) + + return prompt, image_data, audio_data, modalities, stop + + def _build_sampling_params( + self, request: ChatCompletionRequest, stop: List[str] + ) -> Dict[str, Any]: + """Build sampling parameters for the request""" + # Start with common parameters + sampling_params = build_base_sampling_params(request) + + # Override stop with processed stop sequences + sampling_params["stop"] = stop + + # Handle response format + if request.response_format and request.response_format.type == "json_schema": + sampling_params["json_schema"] = convert_json_schema_to_str( + request.response_format.json_schema.schema_ + ) + elif request.response_format and request.response_format.type == "json_object": + sampling_params["json_schema"] = '{"type": "object"}' + elif ( + request.response_format and request.response_format.type == "structural_tag" + ): + sampling_params["structural_tag"] = convert_json_schema_to_str( + request.response_format.model_dump(by_alias=True) + ) + + # Handle tool call constraints + if hasattr(self, "_tool_call_constraint") and self._tool_call_constraint: + constraint_type, constraint_value = self._tool_call_constraint + if constraint_type == "structural_tag": + sampling_params[constraint_type] = convert_json_schema_to_str( + constraint_value.model_dump(by_alias=True) + ) + else: + sampling_params[constraint_type] = constraint_value + + return sampling_params + + async def _handle_streaming_request( + self, + adapted_request: GenerateReqInput, + request: ChatCompletionRequest, + ctx: RequestContext, + ) -> StreamingResponse: + """Handle streaming chat completion request""" + + async def generate_stream_resp(): + parser_dict = {} + reasoning_parser_dict = {} + tool_call_first = True + is_firsts = {} + stream_buffers = {} + n_prev_tokens = {} + prompt_tokens = {} + completion_tokens = {} + cached_tokens = {} + + try: + async for content in self.tokenizer_manager.generate_request( + adapted_request, ctx.raw_request + ): + index = content.get("index", 0) + text = content["text"] + + is_first = is_firsts.get(index, True) + stream_buffer = stream_buffers.get(index, "") + n_prev_token = n_prev_tokens.get(index, 0) + + prompt_tokens[index] = content["meta_info"]["prompt_tokens"] + completion_tokens[index] = content["meta_info"]["completion_tokens"] + cached_tokens[index] = content["meta_info"].get("cached_tokens", 0) + + # Handle logprobs + choice_logprobs = None + if request.logprobs: + choice_logprobs = self._process_streaming_logprobs( + content, n_prev_token + ) + n_prev_token = len( + content["meta_info"]["output_token_logprobs"] + ) + + finish_reason = content["meta_info"]["finish_reason"] + finish_reason_type = ( + finish_reason["type"] if finish_reason else None + ) + + # First chunk with role + if is_first: + is_first = False + delta = DeltaMessage(role="assistant") + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=delta, + finish_reason=finish_reason_type, + matched_stop=( + finish_reason["matched"] + if finish_reason and "matched" in finish_reason + else None + ), + logprobs=choice_logprobs, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=int(time.time()), + choices=[choice_data], + model=request.model, + ) + yield create_streaming_chunk_data(chunk.model_dump_json()) + + # Process content delta + delta = text[len(stream_buffer) :] + new_stream_buffer = stream_buffer + delta + + # Handle reasoning content + enable_thinking = _get_enable_thinking_from_request(request) + if ( + self.tokenizer_manager.server_args.reasoning_parser + and request.separate_reasoning + and enable_thinking + ): + reasoning_text, delta = self._process_reasoning_stream( + index, delta, reasoning_parser_dict, content, request + ) + if reasoning_text: + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(reasoning_content=reasoning_text), + finish_reason=finish_reason_type, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=int(time.time()), + choices=[choice_data], + model=request.model, + ) + yield create_streaming_chunk_data(chunk.model_dump_json()) + + if not delta or len(delta) == 0: + stream_buffers[index] = new_stream_buffer + is_firsts[index] = is_first + n_prev_tokens[index] = n_prev_token + continue + + # Handle tool calls + if request.tool_choice != "none" and request.tools: + async for chunk in self._process_tool_call_stream( + index, + delta, + parser_dict, + content, + request, + finish_reason_type, + ): + yield chunk + else: + # Regular content + if delta or not ( + request.stream_options + and request.stream_options.include_usage + ): + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(content=delta if delta else None), + finish_reason=( + None + if request.stream_options + and request.stream_options.include_usage + else finish_reason_type + ), + matched_stop=( + finish_reason["matched"] + if finish_reason and "matched" in finish_reason + else None + ), + logprobs=choice_logprobs, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=int(time.time()), + choices=[choice_data], + model=request.model, + ) + yield create_streaming_chunk_data(chunk.model_dump_json()) + + stream_buffers[index] = new_stream_buffer + is_firsts[index] = is_first + n_prev_tokens[index] = n_prev_token + + # Final chunk with usage + if request.stream_options and request.stream_options.include_usage: + usage = self._calculate_streaming_usage_base( + prompt_tokens, completion_tokens, cached_tokens, request + ) + else: + usage = None + + final_chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=int(time.time()), + choices=[ + ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(), + finish_reason=finish_reason_type, + ) + ], + model=request.model, + usage=usage, + ) + yield create_streaming_chunk_data(final_chunk.model_dump_json()) + + except Exception as e: + error = create_streaming_error_response(str(e)) + yield create_streaming_chunk_data(error) + + yield create_stream_done() + + return StreamingResponse( + generate_stream_resp(), + media_type="text/event-stream", + background=self.tokenizer_manager.create_abort_task(adapted_request), + ) + + async def _handle_non_streaming_request( + self, + adapted_request: GenerateReqInput, + request: ChatCompletionRequest, + ctx: RequestContext, + ) -> Union[ChatCompletionResponse, ErrorResponse]: + """Handle non-streaming chat completion request""" + try: + ret = await self.tokenizer_manager.generate_request( + adapted_request, ctx.raw_request + ).__anext__() + except ValueError as e: + return create_error_response(str(e)) + + if not isinstance(ret, list): + ret = [ret] + + response = self._build_chat_response( + request, + ret, + int(time.time()), + cache_report=self.tokenizer_manager.server_args.enable_cache_report, + tool_call_parser=self.tokenizer_manager.server_args.tool_call_parser, + reasoning_parser=self.tokenizer_manager.server_args.reasoning_parser, + ) + + return response + + def _build_chat_response( + self, + request: ChatCompletionRequest, + ret: List[Dict[str, Any]], + created: int, + cache_report: bool = False, + tool_call_parser: Optional[str] = None, + reasoning_parser: Optional[str] = None, + ) -> ChatCompletionResponse: + """Build chat completion response from generation results""" + choices = [] + + for idx, ret_item in enumerate(ret): + # Process logprobs + choice_logprobs = None + if request.logprobs: + choice_logprobs = self._process_response_logprobs(ret_item) + + finish_reason = ret_item["meta_info"]["finish_reason"] + text = ret_item["text"] + + # Handle reasoning content + reasoning_text = None + enable_thinking = _get_enable_thinking_from_request(request) + if reasoning_parser and request.separate_reasoning and enable_thinking: + try: + parser = ReasoningParser( + model_type=reasoning_parser, stream_reasoning=False + ) + reasoning_text, text = parser.parse_non_stream(text) + except Exception as e: + logger.error(f"Reasoning parsing error: {e}") + return create_error_response( + "Failed to parse reasoning content", + err_type="InternalServerError", + status_code=500, + ) + + # Handle tool calls + tool_calls = None + if request.tool_choice != "none" and request.tools: + tool_calls, text, finish_reason = self._process_tool_calls( + text, request.tools, tool_call_parser, finish_reason + ) + + choice_data = ChatCompletionResponseChoice( + index=idx, + message=ChatMessage( + role="assistant", + content=text if text else None, + tool_calls=tool_calls, + reasoning_content=reasoning_text if reasoning_text else None, + ), + logprobs=choice_logprobs, + finish_reason=finish_reason["type"] if finish_reason else None, + matched_stop=( + finish_reason["matched"] + if finish_reason and "matched" in finish_reason + else None + ), + ) + choices.append(choice_data) + + # Calculate usage + usage = aggregate_token_usage(ret, request.n, cache_report) + + return ChatCompletionResponse( + id=ret[0]["meta_info"]["id"], + created=created, + model=request.model, + choices=choices, + usage=usage, + ) + + def _process_response_logprobs(self, ret_item: Dict[str, Any]) -> ChoiceLogprobs: + """Process logprobs for non-streaming response""" + logprobs = to_openai_style_logprobs( + output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], + output_top_logprobs=ret_item["meta_info"].get("output_top_logprobs", None), + ) + + token_logprobs = [] + for token_idx, (token, logprob) in enumerate( + zip(logprobs.tokens, logprobs.token_logprobs) + ): + token_bytes = list(token.encode("utf-8")) + top_logprobs = [] + if logprobs.top_logprobs: + for top_token, top_logprob in logprobs.top_logprobs[token_idx].items(): + top_token_bytes = list(top_token.encode("utf-8")) + top_logprobs.append( + TopLogprob( + token=top_token, + bytes=top_token_bytes, + logprob=top_logprob, + ) + ) + token_logprobs.append( + ChatCompletionTokenLogprob( + token=token, + bytes=token_bytes, + logprob=logprob, + top_logprobs=top_logprobs, + ) + ) + + return ChoiceLogprobs(content=token_logprobs) + + def _process_tool_calls( + self, + text: str, + tools: List[Any], + tool_call_parser: Optional[str], + finish_reason: Dict[str, Any], + ) -> tuple[Optional[List[ToolCall]], str, Dict[str, Any]]: + """Process tool calls in the response""" + parser = FunctionCallParser(tools, tool_call_parser) + if parser.has_tool_call(text): + if finish_reason["type"] == "stop": + finish_reason["type"] = "tool_calls" + finish_reason["matched"] = None + try: + text, call_info_list = parser.parse_non_stream(text) + tool_calls = [ + ToolCall( + id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}", + function=FunctionResponse( + name=call_info.name, arguments=call_info.parameters + ), + ) + for call_info in call_info_list + ] + return tool_calls, text, finish_reason + except Exception as e: + logger.error(f"Tool call parsing error: {e}") + # Return error but don't fail the whole request + return None, text, finish_reason + + return None, text, finish_reason + + def _process_streaming_logprobs( + self, content: Dict[str, Any], n_prev_token: int + ) -> ChoiceLogprobs: + """Process logprobs for streaming response""" + logprobs = to_openai_style_logprobs( + output_token_logprobs=content["meta_info"]["output_token_logprobs"][ + n_prev_token: + ], + output_top_logprobs=content["meta_info"].get("output_top_logprobs", [])[ + n_prev_token: + ], + ) + + token_logprobs = [] + for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs): + token_bytes = list(token.encode("utf-8")) + top_logprobs = [] + if logprobs.top_logprobs: + for top_token, top_logprob in logprobs.top_logprobs[0].items(): + top_token_bytes = list(top_token.encode("utf-8")) + top_logprobs.append( + TopLogprob( + token=top_token, + bytes=top_token_bytes, + logprob=top_logprob, + ) + ) + token_logprobs.append( + ChatCompletionTokenLogprob( + token=token, + bytes=token_bytes, + logprob=logprob, + top_logprobs=top_logprobs, + ) + ) + + return ChoiceLogprobs(content=token_logprobs) + + def _process_reasoning_stream( + self, + index: int, + delta: str, + reasoning_parser_dict: Dict[int, ReasoningParser], + content: Dict[str, Any], + request: ChatCompletionRequest, + ) -> tuple[Optional[str], str]: + """Process reasoning content in streaming response""" + if index not in reasoning_parser_dict: + reasoning_parser_dict[index] = ReasoningParser( + self.tokenizer_manager.server_args.reasoning_parser, + request.stream_reasoning, + ) + reasoning_parser = reasoning_parser_dict[index] + return reasoning_parser.parse_stream_chunk(delta) + + async def _process_tool_call_stream( + self, + index: int, + delta: str, + parser_dict: Dict[int, FunctionCallParser], + content: Dict[str, Any], + request: ChatCompletionRequest, + finish_reason_type: Optional[str], + ): + """Process tool calls in streaming response""" + if index not in parser_dict: + parser_dict[index] = FunctionCallParser( + tools=request.tools, + tool_call_parser=self.tokenizer_manager.server_args.tool_call_parser, + ) + parser = parser_dict[index] + + normal_text, calls = parser.parse_stream_chunk(delta) + + # Yield normal text + if normal_text: + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(content=normal_text), + finish_reason=finish_reason_type, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=int(time.time()), + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + + # Yield tool calls + for call_item in calls: + if finish_reason_type == "stop": + # Handle remaining arguments + latest_delta_len = 0 + if isinstance(call_item.parameters, str): + latest_delta_len = len(call_item.parameters) + + expected_call = json.dumps( + parser.detector.prev_tool_call_arr[index].get("arguments", {}), + ensure_ascii=False, + ) + actual_call = parser.detector.streamed_args_for_tool[index] + if latest_delta_len > 0: + actual_call = actual_call[:-latest_delta_len] + remaining_call = expected_call.replace(actual_call, "", 1) + call_item.parameters = remaining_call + finish_reason_type = "tool_calls" + + tool_call = ToolCall( + id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}", + index=call_item.tool_index, + function=FunctionResponse( + name=call_item.name, + arguments=call_item.parameters, + ), + ) + + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(tool_calls=[tool_call]), + finish_reason=( + None + if request.stream_options and request.stream_options.include_usage + else finish_reason_type + ), + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + created=int(time.time()), + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py new file mode 100644 index 00000000000..a8816cce032 --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -0,0 +1,484 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Completion serving logic for OpenAI API""" + +import json +import logging +import time +import uuid +from typing import Any, Dict, List, Union + +from fastapi import Request +from fastapi.responses import StreamingResponse + +from sglang.srt.code_completion_parser import ( + completion_template_name, + generate_completion_prompt, + is_completion_template_defined, +) +from sglang.srt.entrypoints.openai.protocol import ( + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + ErrorResponse, +) +from sglang.srt.entrypoints.openai.serving_engine import ( + OpenAIServingBase, + RequestContext, +) +from sglang.srt.entrypoints.openai.utils import ( + aggregate_token_usage, + build_base_sampling_params, + create_error_response, + create_stream_done, + create_streaming_chunk_data, + create_streaming_error_response, + to_openai_style_logprobs, +) +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.tokenizer_manager import TokenizerManager + +logger = logging.getLogger(__name__) + + +class CompletionHandler(OpenAIServingBase): + """Handler for completion requests""" + + def __init__(self, tokenizer_manager: TokenizerManager): + super().__init__(tokenizer_manager) + + async def handle_request( + self, request: CompletionRequest, raw_request: Request + ) -> Union[CompletionResponse, StreamingResponse, ErrorResponse]: + """Handle a completion request""" + try: + # Echo + logprobs warning + if request.echo and request.logprobs: + logger.warning( + "Echo is not compatible with logprobs. " + "To compute logprobs of input prompt, please use the native /generate API." + ) + + # Validate request + error = self._validate_request(request) + if error: + return error + + # Create request context + ctx = RequestContext( + raw_request=raw_request, + openai_request=request, + request_id=f"cmpl-{uuid.uuid4()}", + ) + + # Convert to internal format + adapted_request, processed_request = self._convert_to_internal_request( + [request], [ctx.request_id] + ) + + if request.stream: + return await self._handle_streaming_request( + adapted_request, processed_request, ctx + ) + else: + return await self._handle_non_streaming_request( + adapted_request, processed_request, ctx + ) + + except Exception as e: + logger.error(f"Error in completion: {e}") + return create_error_response( + message=f"Internal server error: {str(e)}", + err_type="InternalServerError", + status_code=500, + ) + + def _convert_to_internal_request( + self, + all_requests: List[CompletionRequest], + request_ids: List[str], + ) -> tuple[GenerateReqInput, Union[CompletionRequest, List[CompletionRequest]]]: + """Convert OpenAI completion request to internal format""" + # Validate batch requests + if len(all_requests) > 1: + first_prompt_type = type(all_requests[0].prompt) + for request in all_requests: + assert ( + type(request.prompt) is first_prompt_type + ), "All prompts must be of the same type in file input settings" + if request.n > 1: + raise ValueError( + "Parallel sampling is not supported for completions from files" + ) + + prompts = [] + sampling_params_list = [] + return_logprobs = [] + logprob_start_lens = [] + top_logprobs_nums = [] + lora_paths = [] + + for request in all_requests: + # Process prompt + prompt = request.prompt + if is_completion_template_defined(): + if request.suffix: + prompt = generate_completion_prompt( + str(request.prompt), request.suffix, completion_template_name + ) + prompts.append(prompt) + + lora_paths.append(request.lora_path) + + # Set logprob start length based on echo and logprobs + if request.echo and request.logprobs: + current_logprob_start_len = 0 + else: + current_logprob_start_len = -1 + + # Build sampling parameters + sampling_params = self._build_sampling_params(request) + sampling_params_list.append(sampling_params) + + return_logprobs.append(request.logprobs is not None) + logprob_start_lens.append(current_logprob_start_len) + top_logprobs_nums.append( + request.logprobs if request.logprobs is not None else 0 + ) + + # Handle single vs multiple requests + if len(all_requests) == 1: + if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): + prompt_kwargs = {"text": prompts[0]} + else: + prompt_kwargs = {"input_ids": prompts[0]} + sampling_params_list = sampling_params_list[0] + return_logprobs = return_logprobs[0] + logprob_start_lens = logprob_start_lens[0] + top_logprobs_nums = top_logprobs_nums[0] + lora_paths = lora_paths[0] + request_ids = request_ids[0] + else: + if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): + prompt_kwargs = {"text": prompts} + else: + prompt_kwargs = {"input_ids": prompts} + + adapted_request = GenerateReqInput( + **prompt_kwargs, + sampling_params=sampling_params_list, + return_logprob=return_logprobs, + top_logprobs_num=top_logprobs_nums, + logprob_start_len=logprob_start_lens, + return_text_in_logprobs=True, + stream=all_requests[0].stream, + rid=request_ids, + lora_path=lora_paths, + bootstrap_host=all_requests[0].bootstrap_host, + bootstrap_port=all_requests[0].bootstrap_port, + bootstrap_room=all_requests[0].bootstrap_room, + ) + + return adapted_request, ( + all_requests if len(all_requests) > 1 else all_requests[0] + ) + + def _build_sampling_params(self, request: CompletionRequest) -> Dict[str, Any]: + """Build sampling parameters for the request""" + # Start with common parameters + sampling_params = build_base_sampling_params(request) + + # No additional completion-specific parameters needed currently + # (json_schema is already handled in base method) + + return sampling_params + + async def _handle_streaming_request( + self, + adapted_request: GenerateReqInput, + request: CompletionRequest, + ctx: RequestContext, + ) -> StreamingResponse: + """Handle streaming completion request""" + created = int(time.time()) + + async def generate_stream_resp(): + stream_buffers = {} + n_prev_tokens = {} + prompt_tokens = {} + completion_tokens = {} + cached_tokens = {} + + try: + async for content in self.tokenizer_manager.generate_request( + adapted_request, ctx.raw_request + ): + index = content.get("index", 0) + + stream_buffer = stream_buffers.get(index, "") + n_prev_token = n_prev_tokens.get(index, 0) + + text = content["text"] + prompt_tokens[index] = content["meta_info"]["prompt_tokens"] + completion_tokens[index] = content["meta_info"]["completion_tokens"] + cached_tokens[index] = content["meta_info"].get("cached_tokens", 0) + + # Handle echo for first chunk + if not stream_buffer: # The first chunk + if request.echo: + echo_text = self._get_echo_text(request, index) + text = echo_text + text + + # Handle logprobs + logprobs = None + if request.logprobs is not None: + # The first chunk and echo is enabled. + if not stream_buffer and request.echo: + input_token_logprobs = content["meta_info"][ + "input_token_logprobs" + ] + input_top_logprobs = content["meta_info"][ + "input_top_logprobs" + ] + else: + input_token_logprobs = None + input_top_logprobs = None + + logprobs = to_openai_style_logprobs( + input_token_logprobs=input_token_logprobs, + input_top_logprobs=input_top_logprobs, + output_token_logprobs=content["meta_info"][ + "output_token_logprobs" + ][n_prev_token:], + output_top_logprobs=content["meta_info"][ + "output_top_logprobs" + ][n_prev_token:], + ) + n_prev_token = len( + content["meta_info"]["output_token_logprobs"] + ) + + # Generate delta + delta = text[len(stream_buffer) :] + stream_buffer = stream_buffer + delta + finish_reason = content["meta_info"]["finish_reason"] + + choice_data = CompletionResponseStreamChoice( + index=index, + text=delta, + logprobs=logprobs, + finish_reason=finish_reason["type"] if finish_reason else None, + matched_stop=( + finish_reason["matched"] + if finish_reason and "matched" in finish_reason + else None + ), + ) + chunk = CompletionStreamResponse( + id=content["meta_info"]["id"], + created=created, + object="text_completion", + choices=[choice_data], + model=request.model, + ) + + stream_buffers[index] = stream_buffer + n_prev_tokens[index] = n_prev_token + + yield create_streaming_chunk_data(chunk.model_dump_json()) + + # Handle final usage chunk + if request.stream_options and request.stream_options.include_usage: + usage = self._calculate_streaming_usage_base( + prompt_tokens, completion_tokens, cached_tokens, request + ) + final_usage_chunk = CompletionStreamResponse( + id=content["meta_info"]["id"], + created=created, + choices=[], + model=request.model, + usage=usage, + ) + final_usage_data = final_usage_chunk.model_dump_json( + exclude_none=True + ) + yield create_streaming_chunk_data(final_usage_data) + + except Exception as e: + error = create_streaming_error_response(str(e)) + yield f"data: {error}\n\n" + + yield create_stream_done() + + return StreamingResponse( + generate_stream_resp(), + media_type="text/event-stream", + background=self.tokenizer_manager.create_abort_task(adapted_request), + ) + + async def _handle_non_streaming_request( + self, + adapted_request: GenerateReqInput, + request: CompletionRequest, + ctx: RequestContext, + ) -> Union[CompletionResponse, ErrorResponse]: + """Handle non-streaming completion request""" + try: + generator = self.tokenizer_manager.generate_request( + adapted_request, ctx.raw_request + ) + ret = await generator.__anext__() + except ValueError as e: + return create_error_response(str(e)) + + if not isinstance(ret, list): + ret = [ret] + + response = self._build_completion_response( + request, + ret, + int(time.time()), + cache_report=self.tokenizer_manager.server_args.enable_cache_report, + ) + + return response + + def _build_completion_response( + self, + request: CompletionRequest, + ret: List[Dict[str, Any]], + created: int, + cache_report: bool = False, + ) -> CompletionResponse: + """Build completion response from generation results""" + choices = [] + echo = False + + # Prepare echo prompts if needed + echo_prompts = [] + if (not isinstance(request, list)) and request.echo: + echo_prompts = self._prepare_echo_prompts(request) + echo = True + + for idx, ret_item in enumerate(ret): + text = ret_item["text"] + + # Handle echo + if isinstance(request, list) and request[idx].echo: + echo = True + text = request[idx].prompt + text + elif echo and not isinstance(request, list): + prompt_index = idx // request.n + text = echo_prompts[prompt_index] + text + + # Handle logprobs + logprobs = None + if isinstance(request, list) and request[idx].logprobs is not None: + logprobs = True + elif (not isinstance(request, list)) and request.logprobs is not None: + logprobs = True + + if logprobs: + if echo: + input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"] + input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"] + else: + input_token_logprobs = None + input_top_logprobs = None + + logprobs = to_openai_style_logprobs( + input_token_logprobs=input_token_logprobs, + input_top_logprobs=input_top_logprobs, + output_token_logprobs=ret_item["meta_info"][ + "output_token_logprobs" + ], + output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"], + ) + + finish_reason = ret_item["meta_info"]["finish_reason"] + + choice_data = CompletionResponseChoice( + index=idx, + text=text, + logprobs=logprobs, + finish_reason=finish_reason["type"] if finish_reason else None, + matched_stop=( + finish_reason["matched"] + if finish_reason and "matched" in finish_reason + else None + ), + ) + choices.append(choice_data) + + # Calculate usage + usage = aggregate_token_usage(ret, request.n, cache_report) + + return CompletionResponse( + id=ret[0]["meta_info"]["id"], + model=request.model, + created=created, + choices=choices, + usage=usage, + ) + + def _get_echo_text(self, request: CompletionRequest, index: int) -> str: + """Get echo text for streaming response""" + if isinstance(request.prompt, str): + # for the case of single str prompts + return request.prompt + elif isinstance(request.prompt, list): + if isinstance(request.prompt[0], str): + # for the case of multiple str prompts + return request.prompt[index // request.n] + elif isinstance(request.prompt[0], int): + # for the case of single token ids prompt + return self.tokenizer_manager.tokenizer.decode( + request.prompt, skip_special_tokens=True + ) + elif isinstance(request.prompt[0], list) and isinstance( + request.prompt[0][0], int + ): + # for the case of multiple token ids prompts + return self.tokenizer_manager.tokenizer.decode( + request.prompt[index // request.n], + skip_special_tokens=True, + ) + return "" + + def _prepare_echo_prompts(self, request: CompletionRequest) -> List[str]: + """Prepare echo prompts for non-streaming response""" + # TODO: handle the case prompt is token ids + if isinstance(request.prompt, list) and isinstance(request.prompt[0], str): + # for the case of multiple str prompts + return request.prompt + elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list): + # for the case of multiple token ids prompts + return [ + self.tokenizer_manager.tokenizer.decode( + prompt, skip_special_tokens=True + ) + for prompt in request.prompt + ] + elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int): + # for the case of single token ids prompt + return [ + self.tokenizer_manager.tokenizer.decode( + request.prompt, skip_special_tokens=True + ) + ] + else: + # for the case of single str prompt + return [request.prompt] diff --git a/python/sglang/srt/entrypoints/openai/serving_engine.py b/python/sglang/srt/entrypoints/openai/serving_engine.py new file mode 100644 index 00000000000..54de8250f01 --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/serving_engine.py @@ -0,0 +1,110 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import time +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +from fastapi import Request + +from sglang.srt.entrypoints.openai.protocol import ( + ErrorResponse, + OpenAIServingRequest, + UsageInfo, +) +from sglang.srt.entrypoints.openai.utils import create_error_response +from sglang.srt.entrypoints.openai.validation import get_validation_rules +from sglang.srt.managers.tokenizer_manager import TokenizerManager + + +class RequestContext: + """Context object for tracking request state throughout the pipeline""" + + def __init__( + self, + raw_request: Request, + openai_request: OpenAIServingRequest, + request_id: str, + ): + self.raw_request = raw_request + self.openai_request = openai_request + self.request_id = request_id + self.start_time = time.time() + self.metadata: Dict[str, Any] = {} + + def elapsed_time(self) -> float: + """Get elapsed time since request started""" + return time.time() - self.start_time + + def add_metadata(self, key: str, value: Any) -> None: + """Add metadata to the request context""" + self.metadata[key] = value + + def get_metadata(self, key: str, default: Any = None) -> Any: + """Get metadata from the request context""" + return self.metadata.get(key, default) + + +# Base class for specific endpoint handlers +class OpenAIServingBase(ABC): + """Abstract base class for OpenAI endpoint handlers""" + + def __init__(self, tokenizer_manager: TokenizerManager): + self.tokenizer_manager = tokenizer_manager + + @abstractmethod + async def handle_request( + self, request: OpenAIServingRequest, raw_request: Request + ) -> Any: + """Handle the specific request type""" + pass + + def _validate_request( + self, request: OpenAIServingRequest + ) -> Optional[ErrorResponse]: + """Validate request""" + validation_rules = get_validation_rules(request) + for rule in validation_rules: + param_value = rule.param_getter(request) + error_msg = rule.validator_func(param_value) + if error_msg: + return create_error_response(error_msg, param=rule.param_name) + return None + + def _calculate_streaming_usage_base( + self, + prompt_tokens: Dict[int, int], + completion_tokens: Dict[int, int], + cached_tokens: Dict[int, int], + n_choices: int, + ) -> UsageInfo: + """Calculate usage information for streaming responses (common logic)""" + total_prompt_tokens = sum( + tokens for i, tokens in prompt_tokens.items() if i % n_choices == 0 + ) + total_completion_tokens = sum(tokens for tokens in completion_tokens.values()) + + cache_report = self.tokenizer_manager.server_args.enable_cache_report + prompt_tokens_details = None + if cache_report: + cached_tokens_sum = sum(tokens for tokens in cached_tokens.values()) + if cached_tokens_sum > 0: + prompt_tokens_details = {"cached_tokens": cached_tokens_sum} + + return UsageInfo( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens, + prompt_tokens_details=prompt_tokens_details, + ) diff --git a/python/sglang/srt/entrypoints/openai/utils.py b/python/sglang/srt/entrypoints/openai/utils.py new file mode 100644 index 00000000000..14c27220fcb --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/utils.py @@ -0,0 +1,506 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for OpenAI API server""" + +import json +import logging +import re +from typing import Any, Dict, List, Optional, Union + +import jinja2.nodes +import transformers.utils.chat_template_utils as hf_chat_utils + +from sglang.srt.entrypoints.openai.protocol import ( + ChatCompletionMessageParam, + ChatCompletionRequest, + CompletionRequest, + ErrorResponse, + LogProbs, + OpenAIServingRequest, + UsageInfo, +) +from sglang.srt.entrypoints.openai.validation import ValidationRule + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# JINJA TEMPLATE CONTENT FORMAT DETECTION +# ============================================================================ +# +# This adapts vLLM's approach for detecting chat template content format: +# https://github.com/vllm-project/vllm/blob/02f0c7b220422792f5e53de2a7d51d2d3ff2df28/vllm/entrypoints/chat_utils.py#L296-L313 +# - Analyzes Jinja template AST to detect content iteration patterns +# - 'openai' format: templates with {%- for content in message['content'] -%} loops +# - 'string' format: templates that expect simple string content +# - Processes content accordingly to match template expectations + + +def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: + """Check if node is a variable access like {{ varname }}""" + if isinstance(node, jinja2.nodes.Name): + return node.ctx == "load" and node.name == varname + return False + + +def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: + """Check if node is an attribute access like {{ varname['key'] }} or {{ varname.key }}""" + if isinstance(node, jinja2.nodes.Getitem): + return ( + _is_var_access(node.node, varname) + and isinstance(node.arg, jinja2.nodes.Const) + and node.arg.value == key + ) + + if isinstance(node, jinja2.nodes.Getattr): + return _is_var_access(node.node, varname) and node.attr == key + + return False + + +def _is_var_or_elems_access( + node: jinja2.nodes.Node, + varname: str, + key: str = None, +) -> bool: + """Check if node accesses varname or varname[key] with filters/tests""" + if isinstance(node, jinja2.nodes.Filter): + return node.node is not None and _is_var_or_elems_access( + node.node, varname, key + ) + if isinstance(node, jinja2.nodes.Test): + return _is_var_or_elems_access(node.node, varname, key) + + if isinstance(node, jinja2.nodes.Getitem) and isinstance( + node.arg, jinja2.nodes.Slice + ): + return _is_var_or_elems_access(node.node, varname, key) + + return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname) + + +def _try_extract_ast(chat_template: str): + """Try to parse the Jinja template into an AST""" + try: + jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) + return jinja_compiled.environment.parse(chat_template) + except Exception as e: + logger.debug(f"Error when compiling Jinja template: {e}") + return None + + +def detect_template_content_format(chat_template: str) -> str: + """ + Detect whether a chat template expects 'string' or 'openai' content format. + + - 'string': content is a simple string (like DeepSeek templates) + - 'openai': content is a list of structured dicts (like Llama4 templates) + + Detection logic: + - If template has loops like {%- for content in message['content'] -%} → 'openai' + - Otherwise → 'string' + """ + jinja_ast = _try_extract_ast(chat_template) + if jinja_ast is None: + return "string" + + try: + # Look for patterns like: {%- for content in message['content'] -%} + for loop_ast in jinja_ast.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + + # Check if iterating over message['content'] or similar + if _is_var_or_elems_access(loop_iter, "message", "content"): + return "openai" # Found content iteration → openai format + + return "string" # No content loops found → string format + except Exception as e: + logger.debug(f"Error when parsing AST of Jinja template: {e}") + return "string" + + +def process_content_for_template_format( + msg_dict: dict, + content_format: str, + image_data: list, + audio_data: list, + modalities: list, +) -> dict: + """ + Process message content based on detected template format. + + Args: + msg_dict: Message dictionary with content + content_format: 'string' or 'openai' (detected via AST analysis) + image_data: List to append extracted image URLs + audio_data: List to append extracted audio URLs + modalities: List to append modalities + + Returns: + Processed message dictionary + """ + if not isinstance(msg_dict.get("content"), list): + # Already a string or None, no processing needed + return {k: v for k, v in msg_dict.items() if v is not None} + + if content_format == "openai": + # OpenAI format: preserve structured content list, normalize types + processed_content_parts = [] + for chunk in msg_dict["content"]: + if isinstance(chunk, dict): + chunk_type = chunk.get("type") + + if chunk_type == "image_url": + image_data.append(chunk["image_url"]["url"]) + if chunk.get("modalities"): + modalities.append(chunk.get("modalities")) + # Normalize to simple 'image' type for template compatibility + processed_content_parts.append({"type": "image"}) + elif chunk_type == "audio_url": + audio_data.append(chunk["audio_url"]["url"]) + # Normalize to simple 'audio' type + processed_content_parts.append({"type": "audio"}) + else: + # Keep other content as-is (text, etc.) + processed_content_parts.append(chunk) + + new_msg = { + k: v for k, v in msg_dict.items() if v is not None and k != "content" + } + new_msg["content"] = processed_content_parts + return new_msg + + else: # content_format == "string" + # String format: flatten to text only (for templates like DeepSeek) + text_parts = [] + for chunk in msg_dict["content"]: + if isinstance(chunk, dict) and chunk.get("type") == "text": + text_parts.append(chunk["text"]) + # Note: For string format, we ignore images/audio since the template + # doesn't expect structured content - multimodal placeholders would + # need to be inserted differently + + new_msg = msg_dict.copy() + new_msg["content"] = " ".join(text_parts) if text_parts else "" + new_msg = {k: v for k, v in new_msg.items() if v is not None} + return new_msg + + +def calculate_token_usage( + prompt_tokens: int, + completion_tokens: int, + cached_tokens: Optional[Dict[str, int]] = None, +) -> UsageInfo: + """Calculate token usage information""" + return UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + prompt_tokens_details=cached_tokens, + ) + + +def aggregate_token_usage( + responses: List[Dict[str, Any]], + n_choices: int = 1, + enable_cache_report: bool = False, +) -> UsageInfo: + """Aggregate token usage from multiple responses + + Args: + responses: List of response dictionaries with meta_info + n_choices: Number of choices per request (for prompt token counting) + enable_cache_report: Whether to include cached token details + + Returns: + Aggregated UsageInfo + """ + # Sum completion tokens from all responses + completion_tokens = sum( + response["meta_info"]["completion_tokens"] for response in responses + ) + + # For prompt tokens, only count every n_choices-th response to avoid double counting + prompt_tokens = sum( + responses[i]["meta_info"]["prompt_tokens"] + for i in range(0, len(responses), n_choices) + ) + + # Handle cached tokens if cache reporting is enabled + cached_tokens_details = None + if enable_cache_report: + cached_tokens_sum = sum( + response["meta_info"].get("cached_tokens", 0) for response in responses + ) + if cached_tokens_sum > 0: + cached_tokens_details = {"cached_tokens": cached_tokens_sum} + + return calculate_token_usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + cached_tokens=cached_tokens_details, + ) + + +def create_error_response( + message: str, + err_type: str = "BadRequestError", + status_code: int = 400, + param: Optional[str] = None, +) -> ErrorResponse: + """Create an error response""" + return ErrorResponse( + object="error", + message=message, + type=err_type, + param=param, + code=status_code, + ) + + +def create_streaming_error_response( + message: str, + err_type: str = "BadRequestError", + status_code: int = 400, +) -> str: + """Create a streaming error response""" + error = create_error_response(message, err_type, status_code) + return json.dumps({"error": error.model_dump()}) + + +def build_base_sampling_params(request: OpenAIServingRequest) -> Dict[str, Any]: + """Build common sampling parameters shared by both chat and completion requests""" + params = {} + + # Define parameter mappings (request_attr -> param_name) + direct_mappings = { + "temperature": "temperature", + "max_tokens": "max_new_tokens", + "min_tokens": "min_new_tokens", + "stop": "stop", + "stop_token_ids": "stop_token_ids", + "top_p": "top_p", + "top_k": "top_k", + "min_p": "min_p", + "presence_penalty": "presence_penalty", + "frequency_penalty": "frequency_penalty", + "repetition_penalty": "repetition_penalty", + "regex": "regex", + "ebnf": "ebnf", + "n": "n", + "no_stop_trim": "no_stop_trim", + "ignore_eos": "ignore_eos", + "logit_bias": "logit_bias", + "skip_special_tokens": "skip_special_tokens", + "json_schema": "json_schema", + } + + # Apply direct mappings + for request_attr, param_name in direct_mappings.items(): + if hasattr(request, request_attr): + params[param_name] = getattr(request, request_attr) + + # Handle special cases + # max_completion_tokens overrides max_tokens for chat requests + if isinstance(request, ChatCompletionRequest) and request.max_completion_tokens: + params["max_new_tokens"] = request.max_completion_tokens + + return params + + +def sanitize_model_name(model: str) -> str: + """Sanitize model name for safe usage + + Args: + model: Model name to sanitize + + Returns: + Sanitized model name + """ + # Remove potentially dangerous characters + sanitized = re.sub(r'[<>:"|?*]', "", model) + + # Limit length + if len(sanitized) > 256: + sanitized = sanitized[:256] + + return sanitized.strip() + + +def extract_error_message(exception: Exception) -> str: + """Extract a clean error message from an exception + + Args: + exception: Exception to extract message from + + Returns: + Clean error message string + """ + error_msg = str(exception) + + # Remove common prefixes that aren't user-friendly + prefixes_to_remove = [ + "ValidationError: ", + "ValueError: ", + "TypeError: ", + "KeyError: ", + ] + + for prefix in prefixes_to_remove: + if error_msg.startswith(prefix): + error_msg = error_msg[len(prefix) :] + break + + # Limit length for safety + if len(error_msg) > 500: + error_msg = error_msg[:500] + "..." + + return error_msg + + +def format_validation_errors(errors: List[Dict[str, Any]]) -> str: + """Format Pydantic validation errors into a user-friendly message + + Args: + errors: List of validation error dictionaries + + Returns: + Formatted error message + """ + if not errors: + return "Unknown validation error" + + messages = [] + for error in errors[:5]: # Limit to first 5 errors + loc = " -> ".join(str(x) for x in error.get("loc", [])) + msg = error.get("msg", "Unknown error") + if loc: + messages.append(f"{loc}: {msg}") + else: + messages.append(msg) + + result = "; ".join(messages) + + if len(errors) > 5: + result += f" (and {len(errors) - 5} more errors)" + + return result + + +def is_multimodal_content(content: Any) -> bool: + """Check if content contains multimodal elements + + Args: + content: Content to check + + Returns: + True if content is multimodal, False otherwise + """ + if isinstance(content, list): + return any( + isinstance(item, dict) and item.get("type") in ["image_url", "audio_url"] + for item in content + ) + return False + + +def count_message_tokens_estimate(messages: List[ChatCompletionMessageParam]) -> int: + """Rough estimate of token count for messages (for validation purposes) + + Args: + messages: List of chat messages + + Returns: + Estimated token count + """ + total_chars = 0 + + for msg in messages: + if isinstance(msg.content, str): + total_chars += len(msg.content) + elif isinstance(msg.content, list): + for item in msg.content: + if isinstance(item, dict) and item.get("type") == "text": + total_chars += len(item.get("text", "")) + + # Add some tokens for role and structure + total_chars += 10 + + # Rough estimate: 1 token ≈ 4 characters for English text + return total_chars // 4 + + +def to_openai_style_logprobs( + input_token_logprobs=None, + output_token_logprobs=None, + input_top_logprobs=None, + output_top_logprobs=None, +): + ret_logprobs = LogProbs() + + def append_token_logprobs(token_logprobs): + for logprob, _, token_text in token_logprobs: + ret_logprobs.tokens.append(token_text) + ret_logprobs.token_logprobs.append(logprob) + + # Not supported yet + ret_logprobs.text_offset.append(-1) + + def append_top_logprobs(top_logprobs): + for tokens in top_logprobs: + if tokens is not None: + ret_logprobs.top_logprobs.append( + {token[2]: token[0] for token in tokens} + ) + else: + ret_logprobs.top_logprobs.append(None) + + if input_token_logprobs is not None: + append_token_logprobs(input_token_logprobs) + if output_token_logprobs is not None: + append_token_logprobs(output_token_logprobs) + if input_top_logprobs is not None: + append_top_logprobs(input_top_logprobs) + if output_top_logprobs is not None: + append_top_logprobs(output_top_logprobs) + + return ret_logprobs + + +def _get_enable_thinking_from_request(request_obj): + """Extracts the 'enable_thinking' flag from request chat_template_kwargs. + + Args: + request_obj: The request object (or an item from a list of requests). + + Returns: + The boolean value of 'enable_thinking' if found and not True, otherwise True. + """ + if ( + hasattr(request_obj, "chat_template_kwargs") + and request_obj.chat_template_kwargs + and request_obj.chat_template_kwargs.get("enable_thinking") is not None + ): + return request_obj.chat_template_kwargs.get("enable_thinking") + return True + + +def create_streaming_chunk_data(chunk_data: str) -> str: + """Create a streaming response chunk in the proper format""" + return f"data: {chunk_data}\n\n" + + +def create_stream_done() -> str: + """Create the final [DONE] message for streaming responses""" + return "data: [DONE]\n\n" diff --git a/python/sglang/srt/entrypoints/openai/validation.py b/python/sglang/srt/entrypoints/openai/validation.py new file mode 100644 index 00000000000..68e33fb90e1 --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/validation.py @@ -0,0 +1,344 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pre-built validation rules for OpenAI API parameters""" + +import re +from typing import Any, Callable, List, Optional, Union + +from sglang.srt.entrypoints.openai.protocol import ( + ChatCompletionMessageParam, + ChatCompletionRequest, + CompletionRequest, + OpenAIServingRequest, +) + + +class ValidationRule: + """Represents a validation rule for request parameters""" + + def __init__( + self, + param_name: str, + validator_func: Callable[[Any], Optional[str]], + param_getter: Callable[[OpenAIServingRequest], Any], + ): + self.param_name = param_name + self.validator_func = validator_func + self.param_getter = param_getter + + +def validate_chat_messages(messages: List[ChatCompletionMessageParam]) -> Optional[str]: + """Validate chat messages format and content + + Args: + messages: List of chat messages + + Returns: + Error message if validation fails, None if valid + """ + if not messages: + return "Messages cannot be empty" + + # Check for alternating user/assistant pattern (optional validation) + roles = [msg.role for msg in messages] + + # First message should typically be from user or system + if roles[0] not in ["user", "system"]: + return "First message should be from 'user' or 'system'" + + # Check for consecutive assistant messages (which might indicate an error) + for i in range(1, len(roles)): + if roles[i] == "assistant" and roles[i - 1] == "assistant": + # This is actually allowed in some cases, so just warn + pass + + # Validate message content + for i, msg in enumerate(messages): + if msg.role == "user": + if not msg.content: + return f"User message at index {i} has no content" + elif msg.role == "assistant": + # Assistant messages can have no content if they have tool_calls + if not msg.content and not getattr(msg, "tool_calls", None): + return f"Assistant message at index {i} has no content or tool calls" + + return None + + +def validate_completion_prompt( + prompt: Union[str, List[str], List[int], List[List[int]]] +) -> Optional[str]: + """Validate completion prompt format and content + + Args: + prompt: The prompt to validate + + Returns: + Error message if validation fails, None if valid + """ + if prompt is None: + return "Prompt cannot be None" + + if isinstance(prompt, str): + if not prompt.strip(): + return "Prompt cannot be empty or whitespace only" + elif isinstance(prompt, list): + if not prompt: + return "Prompt list cannot be empty" + + # Check if it's a list of strings + if all(isinstance(item, str) for item in prompt): + for i, item in enumerate(prompt): + if not item.strip(): + return f"Prompt at index {i} cannot be empty or whitespace only" + + # Check if it's a list of token IDs (integers) + elif all(isinstance(item, int) for item in prompt): + if any(item < 0 for item in prompt): + return "Token IDs must be non-negative" + + # Check if it's a list of lists (multiple token sequences) + elif all(isinstance(item, list) for item in prompt): + for i, item in enumerate(prompt): + if not item: + return f"Token sequence at index {i} cannot be empty" + if not all(isinstance(token, int) for token in item): + return f"Token sequence at index {i} must contain only integers" + if any(token < 0 for token in item): + return f"Token sequence at index {i} contains negative token IDs" + else: + return "Prompt must be string, list of strings, list of integers, or list of integer lists" + else: + return "Prompt must be string or list" + + return None + + +def validate_model_name(model: str) -> Optional[str]: + """Validate model name format + + Args: + model: Model name to validate + + Returns: + Error message if validation fails, None if valid + """ + if not model: + return "Model name cannot be empty" + + if not isinstance(model, str): + return "Model name must be a string" + + # Basic validation - model names should be reasonable + if len(model) > 256: + return "Model name too long (maximum 256 characters)" + + # Check for invalid characters (basic validation) + if re.search(r'[<>:"|?*]', model): + return "Model name contains invalid characters" + + return None + + +def validate_temperature(temperature: float) -> Optional[str]: + """Validate temperature parameter + + Args: + temperature: Temperature value to validate + + Returns: + Error message if validation fails, None if valid + """ + if not isinstance(temperature, (int, float)): + return "Temperature must be a number" + + if temperature < 0: + return "Temperature must be non-negative" + + # OpenAI allows up to 2.0, but some models may support higher + if temperature > 2.0: + return "Temperature should typically be between 0 and 2" + + return None + + +def validate_max_tokens(max_tokens: Optional[int]) -> Optional[str]: + """Validate max_tokens parameter + + Args: + max_tokens: Maximum tokens value to validate + + Returns: + Error message if validation fails, None if valid + """ + if max_tokens is None: + return None + + if not isinstance(max_tokens, int): + return "max_tokens must be an integer" + + if max_tokens <= 0: + return "max_tokens must be positive" + + # Reasonable upper limit (can be adjusted based on model capabilities) + if max_tokens > 100000: + return "max_tokens is too large (maximum 100000)" + + return None + + +def validate_stop_sequences(stop: Optional[Union[str, List[str]]]) -> Optional[str]: + """Validate stop sequences + + Args: + stop: Stop sequences to validate + + Returns: + Error message if validation fails, None if valid + """ + if stop is None: + return None + + if isinstance(stop, str): + if len(stop) > 100: + return "Stop sequence too long (maximum 100 characters)" + return None + + if isinstance(stop, list): + if len(stop) > 4: # OpenAI limit + return "Too many stop sequences (maximum 4)" + + for i, seq in enumerate(stop): + if not isinstance(seq, str): + return f"Stop sequence at index {i} must be a string" + if len(seq) > 100: + return f"Stop sequence at index {i} too long (maximum 100 characters)" + + return None + + return "Stop sequences must be string or list of strings" + + +def validate_top_p(top_p: float) -> Optional[str]: + """Validate top_p parameter + + Args: + top_p: Top-p value to validate + + Returns: + Error message if validation fails, None if valid + """ + if not isinstance(top_p, (int, float)): + return "top_p must be a number" + + if top_p <= 0 or top_p > 1: + return "top_p must be between 0 and 1" + + return None + + +def validate_frequency_penalty(frequency_penalty: float) -> Optional[str]: + """Validate frequency_penalty parameter + + Args: + frequency_penalty: Frequency penalty value to validate + + Returns: + Error message if validation fails, None if valid + """ + if not isinstance(frequency_penalty, (int, float)): + return "frequency_penalty must be a number" + + if frequency_penalty < -2.0 or frequency_penalty > 2.0: + return "frequency_penalty must be between -2.0 and 2.0" + + return None + + +def validate_presence_penalty(presence_penalty: float) -> Optional[str]: + """Validate presence_penalty parameter + + Args: + presence_penalty: Presence penalty value to validate + + Returns: + Error message if validation fails, None if valid + """ + if not isinstance(presence_penalty, (int, float)): + return "presence_penalty must be a number" + + if presence_penalty < -2.0 or presence_penalty > 2.0: + return "presence_penalty must be between -2.0 and 2.0" + + return None + + +def get_common_validation_rules() -> List[ValidationRule]: + """Get validation rules common to both chat and completion requests""" + return [ + ValidationRule( + param_name="model", + validator_func=validate_model_name, + param_getter=lambda request: request.model, + ), + ValidationRule( + param_name="temperature", + validator_func=validate_temperature, + param_getter=lambda request: request.temperature, + ), + ValidationRule( + param_name="max_tokens", + validator_func=validate_max_tokens, + param_getter=lambda request: request.max_tokens, + ), + ValidationRule( + param_name="stop", + validator_func=validate_stop_sequences, + param_getter=lambda request: request.stop, + ), + ] + + +def get_chat_specific_validation_rules() -> List[ValidationRule]: + """Get validation rules specific to chat completion requests""" + return [ + ValidationRule( + param_name="messages", + validator_func=validate_chat_messages, + param_getter=lambda request: request.messages, + ), + ] + + +def get_completion_specific_validation_rules() -> List[ValidationRule]: + """Get validation rules specific to completion requests""" + return [ + ValidationRule( + param_name="prompt", + validator_func=validate_completion_prompt, + param_getter=lambda request: request.prompt, + ), + ] + + +def get_validation_rules(request: OpenAIServingRequest) -> List[ValidationRule]: + """Get all validation rules for the request""" + if isinstance(request, ChatCompletionRequest): + return get_common_validation_rules() + get_chat_specific_validation_rules() + elif isinstance(request, CompletionRequest): + return ( + get_common_validation_rules() + get_completion_specific_validation_rules() + ) + else: + raise ValueError(f"Unsupported request type: {type(request)}") diff --git a/test/pytest.ini b/test/pytest.ini new file mode 100644 index 00000000000..2f4c80e3075 --- /dev/null +++ b/test/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = auto diff --git a/test/srt/openai/__init__.py b/test/srt/openai/__init__.py new file mode 100644 index 00000000000..3379038e77f --- /dev/null +++ b/test/srt/openai/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for OpenAI-compatible API server refactor""" diff --git a/test/srt/openai/test_protocol.py b/test/srt/openai/test_protocol.py new file mode 100644 index 00000000000..a14b3d7179f --- /dev/null +++ b/test/srt/openai/test_protocol.py @@ -0,0 +1,683 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for OpenAI API protocol models""" + +import json +import time +from typing import Dict, List, Optional + +import pytest +from pydantic import ValidationError + +from sglang.srt.entrypoints.openai.protocol import ( + BatchRequest, + BatchResponse, + ChatCompletionMessageContentImagePart, + ChatCompletionMessageContentTextPart, + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatCompletionTokenLogprob, + ChatMessage, + ChoiceLogprobs, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + DeltaMessage, + EmbeddingObject, + EmbeddingRequest, + EmbeddingResponse, + ErrorResponse, + FileDeleteResponse, + FileRequest, + FileResponse, + Function, + FunctionResponse, + JsonSchemaResponseFormat, + LogProbs, + ModelCard, + ModelList, + MultimodalEmbeddingInput, + ResponseFormat, + ScoringRequest, + ScoringResponse, + StreamOptions, + StructuralTagResponseFormat, + Tool, + ToolCall, + ToolChoice, + TopLogprob, + UsageInfo, +) + + +class TestModelCard: + """Test ModelCard protocol model""" + + def test_basic_model_card_creation(self): + """Test basic model card creation with required fields""" + card = ModelCard(id="test-model") + assert card.id == "test-model" + assert card.object == "model" + assert card.owned_by == "sglang" + assert isinstance(card.created, int) + assert card.root is None + assert card.max_model_len is None + + def test_model_card_with_optional_fields(self): + """Test model card with optional fields""" + card = ModelCard( + id="test-model", + root="/path/to/model", + max_model_len=2048, + created=1234567890, + ) + assert card.id == "test-model" + assert card.root == "/path/to/model" + assert card.max_model_len == 2048 + assert card.created == 1234567890 + + def test_model_card_serialization(self): + """Test model card JSON serialization""" + card = ModelCard(id="test-model", max_model_len=4096) + data = card.model_dump() + assert data["id"] == "test-model" + assert data["object"] == "model" + assert data["max_model_len"] == 4096 + + +class TestModelList: + """Test ModelList protocol model""" + + def test_empty_model_list(self): + """Test empty model list creation""" + model_list = ModelList() + assert model_list.object == "list" + assert len(model_list.data) == 0 + + def test_model_list_with_cards(self): + """Test model list with model cards""" + cards = [ + ModelCard(id="model-1"), + ModelCard(id="model-2", max_model_len=2048), + ] + model_list = ModelList(data=cards) + assert len(model_list.data) == 2 + assert model_list.data[0].id == "model-1" + assert model_list.data[1].id == "model-2" + + +class TestErrorResponse: + """Test ErrorResponse protocol model""" + + def test_basic_error_response(self): + """Test basic error response creation""" + error = ErrorResponse( + message="Invalid request", type="BadRequestError", code=400 + ) + assert error.object == "error" + assert error.message == "Invalid request" + assert error.type == "BadRequestError" + assert error.code == 400 + assert error.param is None + + def test_error_response_with_param(self): + """Test error response with parameter""" + error = ErrorResponse( + message="Invalid temperature", + type="ValidationError", + code=422, + param="temperature", + ) + assert error.param == "temperature" + + +class TestUsageInfo: + """Test UsageInfo protocol model""" + + def test_basic_usage_info(self): + """Test basic usage info creation""" + usage = UsageInfo(prompt_tokens=10, completion_tokens=20, total_tokens=30) + assert usage.prompt_tokens == 10 + assert usage.completion_tokens == 20 + assert usage.total_tokens == 30 + assert usage.prompt_tokens_details is None + + def test_usage_info_with_cache_details(self): + """Test usage info with cache details""" + usage = UsageInfo( + prompt_tokens=10, + completion_tokens=20, + total_tokens=30, + prompt_tokens_details={"cached_tokens": 5}, + ) + assert usage.prompt_tokens_details == {"cached_tokens": 5} + + +class TestCompletionRequest: + """Test CompletionRequest protocol model""" + + def test_basic_completion_request(self): + """Test basic completion request""" + request = CompletionRequest(model="test-model", prompt="Hello world") + assert request.model == "test-model" + assert request.prompt == "Hello world" + assert request.max_tokens == 16 # default + assert request.temperature == 1.0 # default + assert request.n == 1 # default + assert not request.stream # default + assert not request.echo # default + + def test_completion_request_with_options(self): + """Test completion request with various options""" + request = CompletionRequest( + model="test-model", + prompt=["Hello", "world"], + max_tokens=100, + temperature=0.7, + top_p=0.9, + n=2, + stream=True, + echo=True, + stop=[".", "!"], + logprobs=5, + ) + assert request.prompt == ["Hello", "world"] + assert request.max_tokens == 100 + assert request.temperature == 0.7 + assert request.top_p == 0.9 + assert request.n == 2 + assert request.stream + assert request.echo + assert request.stop == [".", "!"] + assert request.logprobs == 5 + + def test_completion_request_sglang_extensions(self): + """Test completion request with SGLang-specific extensions""" + request = CompletionRequest( + model="test-model", + prompt="Hello", + top_k=50, + min_p=0.1, + repetition_penalty=1.1, + regex=r"\d+", + json_schema='{"type": "object"}', + lora_path="/path/to/lora", + ) + assert request.top_k == 50 + assert request.min_p == 0.1 + assert request.repetition_penalty == 1.1 + assert request.regex == r"\d+" + assert request.json_schema == '{"type": "object"}' + assert request.lora_path == "/path/to/lora" + + def test_completion_request_validation_errors(self): + """Test completion request validation errors""" + with pytest.raises(ValidationError): + CompletionRequest() # missing required fields + + with pytest.raises(ValidationError): + CompletionRequest(model="test-model") # missing prompt + + +class TestCompletionResponse: + """Test CompletionResponse protocol model""" + + def test_basic_completion_response(self): + """Test basic completion response""" + choice = CompletionResponseChoice( + index=0, text="Hello world!", finish_reason="stop" + ) + usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5) + response = CompletionResponse( + id="test-id", model="test-model", choices=[choice], usage=usage + ) + assert response.id == "test-id" + assert response.object == "text_completion" + assert response.model == "test-model" + assert len(response.choices) == 1 + assert response.choices[0].text == "Hello world!" + assert response.usage.total_tokens == 5 + + +class TestChatCompletionRequest: + """Test ChatCompletionRequest protocol model""" + + def test_basic_chat_completion_request(self): + """Test basic chat completion request""" + messages = [{"role": "user", "content": "Hello"}] + request = ChatCompletionRequest(model="test-model", messages=messages) + assert request.model == "test-model" + assert len(request.messages) == 1 + assert request.messages[0].role == "user" + assert request.messages[0].content == "Hello" + assert request.temperature == 0.7 # default + assert not request.stream # default + assert request.tool_choice == "none" # default when no tools + + def test_chat_completion_with_multimodal_content(self): + """Test chat completion with multimodal content""" + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,/9j/4AAQ..."}, + }, + ], + } + ] + request = ChatCompletionRequest(model="test-model", messages=messages) + assert len(request.messages[0].content) == 2 + assert request.messages[0].content[0].type == "text" + assert request.messages[0].content[1].type == "image_url" + + def test_chat_completion_with_tools(self): + """Test chat completion with tools""" + messages = [{"role": "user", "content": "What's the weather?"}] + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + } + ] + request = ChatCompletionRequest( + model="test-model", messages=messages, tools=tools + ) + assert len(request.tools) == 1 + assert request.tools[0].function.name == "get_weather" + assert request.tool_choice == "auto" # default when tools present + + def test_chat_completion_tool_choice_validation(self): + """Test tool choice validation logic""" + messages = [{"role": "user", "content": "Hello"}] + + # No tools, tool_choice should default to "none" + request1 = ChatCompletionRequest(model="test-model", messages=messages) + assert request1.tool_choice == "none" + + # With tools, tool_choice should default to "auto" + tools = [ + { + "type": "function", + "function": {"name": "test_func", "description": "Test function"}, + } + ] + request2 = ChatCompletionRequest( + model="test-model", messages=messages, tools=tools + ) + assert request2.tool_choice == "auto" + + def test_chat_completion_sglang_extensions(self): + """Test chat completion with SGLang extensions""" + messages = [{"role": "user", "content": "Hello"}] + request = ChatCompletionRequest( + model="test-model", + messages=messages, + top_k=40, + min_p=0.05, + separate_reasoning=False, + stream_reasoning=False, + chat_template_kwargs={"custom_param": "value"}, + ) + assert request.top_k == 40 + assert request.min_p == 0.05 + assert not request.separate_reasoning + assert not request.stream_reasoning + assert request.chat_template_kwargs == {"custom_param": "value"} + + +class TestChatCompletionResponse: + """Test ChatCompletionResponse protocol model""" + + def test_basic_chat_completion_response(self): + """Test basic chat completion response""" + message = ChatMessage(role="assistant", content="Hello there!") + choice = ChatCompletionResponseChoice( + index=0, message=message, finish_reason="stop" + ) + usage = UsageInfo(prompt_tokens=2, completion_tokens=3, total_tokens=5) + response = ChatCompletionResponse( + id="test-id", model="test-model", choices=[choice], usage=usage + ) + assert response.id == "test-id" + assert response.object == "chat.completion" + assert response.model == "test-model" + assert len(response.choices) == 1 + assert response.choices[0].message.content == "Hello there!" + + def test_chat_completion_response_with_tool_calls(self): + """Test chat completion response with tool calls""" + tool_call = ToolCall( + id="call_123", + function=FunctionResponse( + name="get_weather", arguments='{"location": "San Francisco"}' + ), + ) + message = ChatMessage(role="assistant", content=None, tool_calls=[tool_call]) + choice = ChatCompletionResponseChoice( + index=0, message=message, finish_reason="tool_calls" + ) + usage = UsageInfo(prompt_tokens=10, completion_tokens=5, total_tokens=15) + response = ChatCompletionResponse( + id="test-id", model="test-model", choices=[choice], usage=usage + ) + assert response.choices[0].message.tool_calls[0].function.name == "get_weather" + assert response.choices[0].finish_reason == "tool_calls" + + +class TestEmbeddingRequest: + """Test EmbeddingRequest protocol model""" + + def test_basic_embedding_request(self): + """Test basic embedding request""" + request = EmbeddingRequest(model="test-model", input="Hello world") + assert request.model == "test-model" + assert request.input == "Hello world" + assert request.encoding_format == "float" # default + assert request.dimensions is None # default + + def test_embedding_request_with_list_input(self): + """Test embedding request with list input""" + request = EmbeddingRequest( + model="test-model", input=["Hello", "world"], dimensions=512 + ) + assert request.input == ["Hello", "world"] + assert request.dimensions == 512 + + def test_multimodal_embedding_request(self): + """Test multimodal embedding request""" + multimodal_input = [ + MultimodalEmbeddingInput(text="Hello", image="base64_image_data"), + MultimodalEmbeddingInput(text="World", image=None), + ] + request = EmbeddingRequest(model="test-model", input=multimodal_input) + assert len(request.input) == 2 + assert request.input[0].text == "Hello" + assert request.input[0].image == "base64_image_data" + assert request.input[1].text == "World" + assert request.input[1].image is None + + +class TestEmbeddingResponse: + """Test EmbeddingResponse protocol model""" + + def test_basic_embedding_response(self): + """Test basic embedding response""" + embedding_obj = EmbeddingObject(embedding=[0.1, 0.2, 0.3], index=0) + usage = UsageInfo(prompt_tokens=3, total_tokens=3) + response = EmbeddingResponse( + data=[embedding_obj], model="test-model", usage=usage + ) + assert response.object == "list" + assert len(response.data) == 1 + assert response.data[0].embedding == [0.1, 0.2, 0.3] + assert response.data[0].index == 0 + assert response.usage.prompt_tokens == 3 + + +class TestScoringRequest: + """Test ScoringRequest protocol model""" + + def test_basic_scoring_request(self): + """Test basic scoring request""" + request = ScoringRequest( + model="test-model", query="Hello", items=["World", "Earth"] + ) + assert request.model == "test-model" + assert request.query == "Hello" + assert request.items == ["World", "Earth"] + assert not request.apply_softmax # default + assert not request.item_first # default + + def test_scoring_request_with_token_ids(self): + """Test scoring request with token IDs""" + request = ScoringRequest( + model="test-model", + query=[1, 2, 3], + items=[[4, 5], [6, 7]], + label_token_ids=[8, 9], + apply_softmax=True, + item_first=True, + ) + assert request.query == [1, 2, 3] + assert request.items == [[4, 5], [6, 7]] + assert request.label_token_ids == [8, 9] + assert request.apply_softmax + assert request.item_first + + +class TestScoringResponse: + """Test ScoringResponse protocol model""" + + def test_basic_scoring_response(self): + """Test basic scoring response""" + response = ScoringResponse(scores=[[0.1, 0.9], [0.3, 0.7]], model="test-model") + assert response.object == "scoring" + assert response.scores == [[0.1, 0.9], [0.3, 0.7]] + assert response.model == "test-model" + assert response.usage is None # default + + +class TestFileOperations: + """Test file operation protocol models""" + + def test_file_request(self): + """Test file request model""" + file_data = b"test file content" + request = FileRequest(file=file_data, purpose="batch") + assert request.file == file_data + assert request.purpose == "batch" + + def test_file_response(self): + """Test file response model""" + response = FileResponse( + id="file-123", + bytes=1024, + created_at=1234567890, + filename="test.jsonl", + purpose="batch", + ) + assert response.id == "file-123" + assert response.object == "file" + assert response.bytes == 1024 + assert response.filename == "test.jsonl" + + def test_file_delete_response(self): + """Test file delete response model""" + response = FileDeleteResponse(id="file-123", deleted=True) + assert response.id == "file-123" + assert response.object == "file" + assert response.deleted + + +class TestBatchOperations: + """Test batch operation protocol models""" + + def test_batch_request(self): + """Test batch request model""" + request = BatchRequest( + input_file_id="file-123", + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={"custom": "value"}, + ) + assert request.input_file_id == "file-123" + assert request.endpoint == "/v1/chat/completions" + assert request.completion_window == "24h" + assert request.metadata == {"custom": "value"} + + def test_batch_response(self): + """Test batch response model""" + response = BatchResponse( + id="batch-123", + endpoint="/v1/chat/completions", + input_file_id="file-123", + completion_window="24h", + created_at=1234567890, + ) + assert response.id == "batch-123" + assert response.object == "batch" + assert response.status == "validating" # default + assert response.endpoint == "/v1/chat/completions" + + +class TestResponseFormats: + """Test response format protocol models""" + + def test_basic_response_format(self): + """Test basic response format""" + format_obj = ResponseFormat(type="json_object") + assert format_obj.type == "json_object" + assert format_obj.json_schema is None + + def test_json_schema_response_format(self): + """Test JSON schema response format""" + schema = {"type": "object", "properties": {"name": {"type": "string"}}} + json_schema = JsonSchemaResponseFormat( + name="person_schema", description="Person schema", schema=schema + ) + format_obj = ResponseFormat(type="json_schema", json_schema=json_schema) + assert format_obj.type == "json_schema" + assert format_obj.json_schema.name == "person_schema" + assert format_obj.json_schema.schema_ == schema + + def test_structural_tag_response_format(self): + """Test structural tag response format""" + structures = [ + { + "begin": "", + "schema_": {"type": "string"}, + "end": "", + } + ] + format_obj = StructuralTagResponseFormat( + type="structural_tag", structures=structures, triggers=["think"] + ) + assert format_obj.type == "structural_tag" + assert len(format_obj.structures) == 1 + assert format_obj.triggers == ["think"] + + +class TestLogProbs: + """Test LogProbs protocol models""" + + def test_basic_logprobs(self): + """Test basic LogProbs model""" + logprobs = LogProbs( + text_offset=[0, 5, 11], + token_logprobs=[-0.1, -0.2, -0.3], + tokens=["Hello", " ", "world"], + top_logprobs=[{"Hello": -0.1}, {" ": -0.2}, {"world": -0.3}], + ) + assert len(logprobs.tokens) == 3 + assert logprobs.tokens == ["Hello", " ", "world"] + assert logprobs.token_logprobs == [-0.1, -0.2, -0.3] + + def test_choice_logprobs(self): + """Test ChoiceLogprobs model""" + token_logprob = ChatCompletionTokenLogprob( + token="Hello", + bytes=[72, 101, 108, 108, 111], + logprob=-0.1, + top_logprobs=[ + TopLogprob(token="Hello", bytes=[72, 101, 108, 108, 111], logprob=-0.1) + ], + ) + choice_logprobs = ChoiceLogprobs(content=[token_logprob]) + assert len(choice_logprobs.content) == 1 + assert choice_logprobs.content[0].token == "Hello" + + +class TestStreamingModels: + """Test streaming response models""" + + def test_stream_options(self): + """Test StreamOptions model""" + options = StreamOptions(include_usage=True) + assert options.include_usage + + def test_chat_completion_stream_response(self): + """Test ChatCompletionStreamResponse model""" + delta = DeltaMessage(role="assistant", content="Hello") + choice = ChatCompletionResponseStreamChoice(index=0, delta=delta) + response = ChatCompletionStreamResponse( + id="test-id", model="test-model", choices=[choice] + ) + assert response.object == "chat.completion.chunk" + assert response.choices[0].delta.content == "Hello" + + +class TestValidationEdgeCases: + """Test edge cases and validation scenarios""" + + def test_empty_messages_validation(self): + """Test validation with empty messages""" + with pytest.raises(ValidationError): + ChatCompletionRequest(model="test-model", messages=[]) + + def test_invalid_tool_choice_type(self): + """Test invalid tool choice type""" + messages = [{"role": "user", "content": "Hello"}] + with pytest.raises(ValidationError): + ChatCompletionRequest( + model="test-model", messages=messages, tool_choice=123 + ) + + def test_negative_token_limits(self): + """Test negative token limits""" + with pytest.raises(ValidationError): + CompletionRequest(model="test-model", prompt="Hello", max_tokens=-1) + + def test_invalid_temperature_range(self): + """Test invalid temperature values""" + # Note: The current protocol doesn't enforce temperature range, + # but this test documents expected behavior + request = CompletionRequest(model="test-model", prompt="Hello", temperature=5.0) + assert request.temperature == 5.0 # Currently allowed + + def test_model_serialization_roundtrip(self): + """Test that models can be serialized and deserialized""" + original_request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + temperature=0.7, + max_tokens=100, + ) + + # Serialize to dict + data = original_request.model_dump() + + # Deserialize back + restored_request = ChatCompletionRequest(**data) + + assert restored_request.model == original_request.model + assert restored_request.temperature == original_request.temperature + assert restored_request.max_tokens == original_request.max_tokens + assert len(restored_request.messages) == len(original_request.messages) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/srt/openai/test_serving_chat.py b/test/srt/openai/test_serving_chat.py new file mode 100644 index 00000000000..69c67118eae --- /dev/null +++ b/test/srt/openai/test_serving_chat.py @@ -0,0 +1,845 @@ +""" +Unit tests for the ChatCompletionHandler class from serving_chat.py. + +These tests ensure that the refactored implementation maintains compatibility +with the original adapter.py functionality. +""" + +import asyncio +import json +import time +import uuid +from typing import Any, Dict, List +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from fastapi import Request +from fastapi.responses import StreamingResponse +from pydantic_core import ValidationError + +from sglang.srt.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionStreamResponse, + ChatMessage, + DeltaMessage, + ErrorResponse, + FunctionResponse, + ToolCall, + UsageInfo, +) +from sglang.srt.entrypoints.openai.serving_chat import ChatCompletionHandler +from sglang.srt.entrypoints.openai.serving_engine import RequestContext +from sglang.srt.entrypoints.openai.utils import ( + build_base_sampling_params, + create_error_response, +) +from sglang.srt.managers.io_struct import GenerateReqInput + + +# Mock TokenizerManager since it may not be directly importable in tests +class MockTokenizerManager: + def __init__(self): + self.model_config = Mock() + self.model_config.is_multimodal = False + self.server_args = Mock() + self.server_args.enable_cache_report = False + self.server_args.tool_call_parser = "hermes" + self.server_args.reasoning_parser = None + self.chat_template_name = "llama-3" + + # Mock tokenizer + self.tokenizer = Mock() + self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5]) + self.tokenizer.decode = Mock(return_value="Test response") + self.tokenizer.chat_template = None + self.tokenizer.bos_token_id = 1 + + # Mock generate_request method + async def mock_generate(): + yield { + "text": "Test response", + "meta_info": { + "id": f"chatcmpl-{uuid.uuid4()}", + "prompt_tokens": 10, + "completion_tokens": 5, + "cached_tokens": 0, + "finish_reason": {"type": "stop", "matched": None}, + "output_token_logprobs": [(0.1, 1, "Test"), (0.2, 2, "response")], + "output_top_logprobs": None, + }, + "index": 0, + } + + self.generate_request = Mock(return_value=mock_generate()) + self.create_abort_task = Mock(return_value=None) + + +@pytest.fixture +def mock_tokenizer_manager(): + """Create a mock tokenizer manager for testing.""" + return MockTokenizerManager() + + +@pytest.fixture +def chat_handler(mock_tokenizer_manager): + """Create a ChatCompletionHandler instance for testing.""" + return ChatCompletionHandler(mock_tokenizer_manager) + + +@pytest.fixture +def mock_request(): + """Create a mock FastAPI request.""" + request = Mock(spec=Request) + request.headers = {} + return request + + +@pytest.fixture +def basic_chat_request(): + """Create a basic chat completion request.""" + return ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=0.7, + max_tokens=100, + stream=False, + ) + + +@pytest.fixture +def streaming_chat_request(): + """Create a streaming chat completion request.""" + return ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello, how are you?"}], + temperature=0.7, + max_tokens=100, + stream=True, + ) + + +class TestChatCompletionHandlerValidation: + """Test validation methods of ChatCompletionHandler.""" + + def test_validate_chat_request_valid(self, chat_handler, basic_chat_request): + """Test validation with a valid request.""" + # Use utility function directly instead of handler method + error = chat_handler._validate_request(basic_chat_request) + assert error is None + + def test_validate_chat_request_empty_messages(self, chat_handler): + """Test validation fails with empty messages.""" + # Since we now have Pydantic validation that prevents creating the request, + # we expect a ValidationError to be raised during object creation + with pytest.raises(ValidationError) as exc_info: + request = ChatCompletionRequest( + model="test-model", + messages=[], + temperature=0.7, + ) + # Check that the error is about empty messages + assert "empty" in str(exc_info.value).lower() + + def test_validate_chat_request_invalid_temperature(self, chat_handler): + """Test validation fails with invalid temperature.""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + temperature=-0.5, # Invalid negative temperature + ) + error = chat_handler._validate_request(request) + assert error is not None + + def test_validate_chat_request_invalid_max_tokens(self, chat_handler): + """Test validation fails with invalid max_tokens.""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=-10, # Invalid negative max_tokens + ) + error = chat_handler._validate_request(request) + assert error is not None + + +class TestChatCompletionHandlerConversion: + """Test request conversion methods.""" + + def test_convert_to_internal_request_single( + self, chat_handler, basic_chat_request, mock_tokenizer_manager + ): + """Test converting single request to internal format.""" + with patch( + "sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv" + ) as mock_conv: + mock_conv_instance = Mock() + mock_conv_instance.get_prompt.return_value = "Test prompt" + mock_conv_instance.image_data = None + mock_conv_instance.audio_data = None + mock_conv_instance.modalities = [] + mock_conv_instance.stop_str = [""] + mock_conv.return_value = mock_conv_instance + + # Mock the _process_messages method to return expected values + with patch.object(chat_handler, "_process_messages") as mock_process: + mock_process.return_value = ( + "Test prompt", + [1, 2, 3], + None, + None, + [], + [""], + ) + + adapted_request, processed_request = ( + chat_handler._convert_to_internal_request( + [basic_chat_request], ["test-id"] + ) + ) + + assert isinstance(adapted_request, GenerateReqInput) + assert adapted_request.stream == basic_chat_request.stream + assert processed_request == basic_chat_request + + +class TestToolCalls: + """Test tool call functionality from adapter.py""" + + def test_tool_call_request_conversion(self, chat_handler): + """Test request with tool calls""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "What's the weather?"}], + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + } + ], + tool_choice="auto", + ) + + with patch.object(chat_handler, "_process_messages") as mock_process: + mock_process.return_value = ( + "Test prompt", + [1, 2, 3], + None, + None, + [], + [""], + ) + + adapted_request, _ = chat_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + assert adapted_request.rid == "test-id" + # Tool call constraint should be processed + assert request.tools is not None + + def test_tool_choice_none(self, chat_handler): + """Test tool_choice=none disables tool calls""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + tools=[{"type": "function", "function": {"name": "test_func"}}], + tool_choice="none", + ) + + with patch.object(chat_handler, "_process_messages") as mock_process: + mock_process.return_value = ( + "Test prompt", + [1, 2, 3], + None, + None, + [], + [""], + ) + + adapted_request, _ = chat_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + # Tools should not be processed when tool_choice is "none" + assert adapted_request.rid == "test-id" + + def test_tool_call_response_processing(self, chat_handler): + """Test processing tool calls in response""" + mock_ret_item = { + "text": '{"name": "get_weather", "parameters": {"location": "Paris"}}', + "meta_info": { + "output_token_logprobs": [], + "output_top_logprobs": None, + }, + } + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + } + ] + + finish_reason = {"type": "stop", "matched": None} + + # Mock FunctionCallParser + with patch( + "sglang.srt.entrypoints.openai.serving_chat.FunctionCallParser" + ) as mock_parser_class: + mock_parser = Mock() + mock_parser.has_tool_call.return_value = True + + # Create proper mock tool call object + mock_tool_call = Mock() + mock_tool_call.name = "get_weather" + mock_tool_call.parameters = '{"location": "Paris"}' + + mock_parser.parse_non_stream.return_value = ("", [mock_tool_call]) + mock_parser_class.return_value = mock_parser + + tool_calls, text, updated_finish_reason = chat_handler._process_tool_calls( + mock_ret_item["text"], tools, "hermes", finish_reason + ) + + assert tool_calls is not None + assert len(tool_calls) == 1 + assert updated_finish_reason["type"] == "tool_calls" + + +class TestMultimodalContent: + """Test multimodal content handling from adapter.py""" + + def test_multimodal_request_with_images(self, chat_handler): + """Test request with image content""" + request = ChatCompletionRequest( + model="test-model", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,..."}, + }, + ], + } + ], + ) + + # Set multimodal mode + chat_handler.tokenizer_manager.model_config.is_multimodal = True + + with patch.object(chat_handler, "_apply_jinja_template") as mock_apply: + mock_apply.return_value = ( + "prompt", + [1, 2, 3], + ["image_data"], + None, + [], + [], + ) + + with patch.object( + chat_handler, "_apply_conversation_template" + ) as mock_conv: + mock_conv.return_value = ("prompt", ["image_data"], None, [], []) + + prompt, prompt_ids, image_data, audio_data, modalities, stop = ( + chat_handler._process_messages(request, True) + ) + + assert image_data == ["image_data"] + assert prompt == "prompt" + + def test_multimodal_request_with_audio(self, chat_handler): + """Test request with audio content""" + request = ChatCompletionRequest( + model="test-model", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Transcribe this audio"}, + { + "type": "audio_url", + "audio_url": {"url": "data:audio/wav;base64,UklGR..."}, + }, + ], + } + ], + ) + + chat_handler.tokenizer_manager.model_config.is_multimodal = True + + with patch.object(chat_handler, "_apply_jinja_template") as mock_apply: + mock_apply.return_value = ( + "prompt", + [1, 2, 3], + None, + ["audio_data"], + ["audio"], + [], + ) + + with patch.object( + chat_handler, "_apply_conversation_template" + ) as mock_conv: + mock_conv.return_value = ("prompt", None, ["audio_data"], ["audio"], []) + + prompt, prompt_ids, image_data, audio_data, modalities, stop = ( + chat_handler._process_messages(request, True) + ) + + assert audio_data == ["audio_data"] + assert modalities == ["audio"] + + +class TestTemplateHandling: + """Test chat template handling from adapter.py""" + + def test_jinja_template_processing(self, chat_handler): + """Test Jinja template processing""" + request = ChatCompletionRequest( + model="test-model", messages=[{"role": "user", "content": "Hello"}] + ) + + # Mock the template attribute directly + chat_handler.tokenizer_manager.chat_template_name = None + chat_handler.tokenizer_manager.tokenizer.chat_template = "" + + with patch.object(chat_handler, "_apply_jinja_template") as mock_apply: + mock_apply.return_value = ( + "processed_prompt", + [1, 2, 3], + None, + None, + [], + [""], + ) + + # Mock hasattr to simulate the None check + with patch("builtins.hasattr") as mock_hasattr: + mock_hasattr.return_value = True + + prompt, prompt_ids, image_data, audio_data, modalities, stop = ( + chat_handler._process_messages(request, False) + ) + + assert prompt == "processed_prompt" + assert prompt_ids == [1, 2, 3] + + def test_conversation_template_processing(self, chat_handler): + """Test conversation template processing""" + request = ChatCompletionRequest( + model="test-model", messages=[{"role": "user", "content": "Hello"}] + ) + + chat_handler.tokenizer_manager.chat_template_name = "llama-3" + + with patch.object(chat_handler, "_apply_conversation_template") as mock_apply: + mock_apply.return_value = ("conv_prompt", None, None, [], [""]) + + prompt, prompt_ids, image_data, audio_data, modalities, stop = ( + chat_handler._process_messages(request, False) + ) + + assert prompt == "conv_prompt" + assert stop == [""] + + def test_continue_final_message(self, chat_handler): + """Test continue_final_message functionality""" + request = ChatCompletionRequest( + model="test-model", + messages=[ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + ], + continue_final_message=True, + ) + + with patch.object(chat_handler, "_apply_conversation_template") as mock_apply: + mock_apply.return_value = ("Hi there", None, None, [], [""]) + + prompt, prompt_ids, image_data, audio_data, modalities, stop = ( + chat_handler._process_messages(request, False) + ) + + # Should handle continue_final_message properly + assert prompt == "Hi there" + + +class TestReasoningContent: + """Test reasoning content separation from adapter.py""" + + def test_reasoning_content_request(self, chat_handler): + """Test request with reasoning content separation""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Solve this math problem"}], + separate_reasoning=True, + stream_reasoning=False, + ) + + with patch.object(chat_handler, "_process_messages") as mock_process: + mock_process.return_value = ( + "Test prompt", + [1, 2, 3], + None, + None, + [], + [""], + ) + + adapted_request, _ = chat_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + assert adapted_request.rid == "test-id" + assert request.separate_reasoning == True + + def test_reasoning_content_response(self, chat_handler): + """Test reasoning content in response""" + mock_ret_item = { + "text": "This is reasoningAnswer: 42", + "meta_info": { + "output_token_logprobs": [], + "output_top_logprobs": None, + }, + } + + # Mock ReasoningParser + with patch( + "sglang.srt.entrypoints.openai.serving_chat.ReasoningParser" + ) as mock_parser_class: + mock_parser = Mock() + mock_parser.parse_non_stream.return_value = ( + "This is reasoning", + "Answer: 42", + ) + mock_parser_class.return_value = mock_parser + + choice_logprobs = None + reasoning_text = None + text = mock_ret_item["text"] + + # Simulate reasoning processing + enable_thinking = True + if enable_thinking: + parser = mock_parser_class(model_type="test", stream_reasoning=False) + reasoning_text, text = parser.parse_non_stream(text) + + assert reasoning_text == "This is reasoning" + assert text == "Answer: 42" + + +class TestSamplingParams: + """Test sampling parameter handling from adapter.py""" + + def test_all_sampling_parameters(self, chat_handler): + """Test all sampling parameters are properly handled""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + temperature=0.8, + max_tokens=150, + max_completion_tokens=200, # Should override max_tokens + min_tokens=5, + top_p=0.9, + top_k=50, + min_p=0.1, + presence_penalty=0.1, + frequency_penalty=0.2, + repetition_penalty=1.1, + stop=["<|endoftext|>"], + stop_token_ids=[13, 14], + regex=r"\d+", + ebnf=" ::= ", + n=2, + no_stop_trim=True, + ignore_eos=True, + skip_special_tokens=False, + logit_bias={"1": 0.5, "2": -0.3}, + ) + + with patch.object(chat_handler, "_process_messages") as mock_process: + mock_process.return_value = ( + "Test prompt", + [1, 2, 3], + None, + None, + [], + [""], + ) + + sampling_params = chat_handler._build_sampling_params(request, [""]) + + # Verify all parameters + assert sampling_params["temperature"] == 0.8 + assert ( + sampling_params["max_new_tokens"] == 200 + ) # max_completion_tokens overrides + assert sampling_params["min_new_tokens"] == 5 + assert sampling_params["top_p"] == 0.9 + assert sampling_params["top_k"] == 50 + assert sampling_params["min_p"] == 0.1 + assert sampling_params["presence_penalty"] == 0.1 + assert sampling_params["frequency_penalty"] == 0.2 + assert sampling_params["repetition_penalty"] == 1.1 + assert sampling_params["stop"] == [ + "" + ] # Should be overridden with processed stop + assert sampling_params["logit_bias"] == {"1": 0.5, "2": -0.3} + + def test_response_format_json_schema(self, chat_handler): + """Test response format with JSON schema""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Generate JSON"}], + response_format={ + "type": "json_schema", + "json_schema": { + "name": "response", + "schema": { + "type": "object", + "properties": {"answer": {"type": "string"}}, + }, + }, + }, + ) + + with patch.object(chat_handler, "_process_messages") as mock_process: + mock_process.return_value = ( + "Test prompt", + [1, 2, 3], + None, + None, + [], + [""], + ) + + sampling_params = chat_handler._build_sampling_params(request, [""]) + + assert "json_schema" in sampling_params + assert '"type": "object"' in sampling_params["json_schema"] + + def test_response_format_json_object(self, chat_handler): + """Test response format with JSON object""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Generate JSON"}], + response_format={"type": "json_object"}, + ) + + with patch.object(chat_handler, "_process_messages") as mock_process: + mock_process.return_value = ( + "Test prompt", + [1, 2, 3], + None, + None, + [], + [""], + ) + + sampling_params = chat_handler._build_sampling_params(request, [""]) + + assert sampling_params["json_schema"] == '{"type": "object"}' + + +class TestUtilityFunctions: + """Test utility functions that were moved from OpenAIServingBase.""" + + def test_build_base_sampling_params_functionality(self): + """Test that build_base_sampling_params works correctly.""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + temperature=0.8, + max_tokens=150, + top_p=0.9, + top_k=50, + presence_penalty=0.1, + frequency_penalty=0.2, + stop=["<|endoftext|>"], + ) + + sampling_params = build_base_sampling_params(request) + + # Test that parameters are correctly mapped + assert sampling_params["temperature"] == request.temperature + assert sampling_params["max_new_tokens"] == request.max_tokens + assert sampling_params["top_p"] == request.top_p + assert sampling_params["top_k"] == request.top_k + assert sampling_params["presence_penalty"] == request.presence_penalty + assert sampling_params["frequency_penalty"] == request.frequency_penalty + assert sampling_params["stop"] == request.stop + + def test_build_base_sampling_params_max_completion_tokens_override(self): + """Test that max_completion_tokens overrides max_tokens.""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + max_tokens=100, + max_completion_tokens=200, + ) + + sampling_params = build_base_sampling_params(request) + + # max_completion_tokens should override max_tokens + assert sampling_params["max_new_tokens"] == 200 + + def test_create_error_response_functionality(self): + """Test that create_error_response works correctly.""" + error = create_error_response("Test error message") + assert isinstance(error, ErrorResponse) + assert error.message == "Test error message" + assert error.type == "BadRequestError" + assert error.code == 400 + + +class TestChatCompletionHandlerCompatibility: + """Test compatibility with adapter.py functionality.""" + + def test_compatibility_sampling_params(self): + """Test that sampling parameters are built the same way as adapter.py.""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + temperature=0.8, + max_tokens=150, + top_p=0.9, + top_k=50, + presence_penalty=0.1, + frequency_penalty=0.2, + stop=["<|endoftext|>"], + ) + + # Test the utility function directly + sampling_params = build_base_sampling_params(request) + + # These should match the structure used in adapter.py's v1_chat_generate_request + expected_keys = [ + "temperature", + "max_new_tokens", + "top_p", + "top_k", + "min_p", + "presence_penalty", + "frequency_penalty", + "repetition_penalty", + "stop", + "regex", + "ebnf", + "n", + ] + + for key in expected_keys: + assert key in sampling_params + + assert sampling_params["temperature"] == request.temperature + assert sampling_params["max_new_tokens"] == request.max_tokens + assert sampling_params["top_p"] == request.top_p + assert sampling_params["top_k"] == request.top_k + + def test_compatibility_request_structure(self): + """Test that the request structure matches what adapter.py expects.""" + # Test with all the parameters that adapter.py supports + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + temperature=0.8, + max_tokens=150, + top_p=0.9, + top_k=50, + presence_penalty=0.1, + frequency_penalty=0.2, + repetition_penalty=1.1, + stop=["<|endoftext|>"], + stream=False, + logprobs=True, + top_logprobs=5, + n=1, + continue_final_message=False, + separate_reasoning=True, + stream_reasoning=False, + ) + + # Verify that the request can be created without errors + assert request.model == "test-model" + assert request.temperature == 0.8 + assert request.max_tokens == 150 + assert request.top_p == 0.9 + assert request.top_k == 50 + assert request.presence_penalty == 0.1 + assert request.frequency_penalty == 0.2 + assert request.repetition_penalty == 1.1 + assert request.stop == ["<|endoftext|>"] + assert request.stream == False + assert request.logprobs == True + assert request.top_logprobs == 5 + assert request.n == 1 + assert request.continue_final_message == False + assert request.separate_reasoning == True + assert request.stream_reasoning == False + + def test_compatibility_bootstrap_params(self, chat_handler): + """Test that bootstrap parameters are properly supported.""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + bootstrap_host="localhost", + bootstrap_port=8998, + bootstrap_room=12345, + ) + + assert request.bootstrap_host == "localhost" + assert request.bootstrap_port == 8998 + assert request.bootstrap_room == 12345 + + # Mock the _process_messages method to return expected values + with patch.object(chat_handler, "_process_messages") as mock_process: + mock_process.return_value = ( + "Test prompt", + [1, 2, 3], + None, + None, + [], + [""], + ) + + adapted_request, _ = chat_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + assert adapted_request.bootstrap_host == "localhost" + assert adapted_request.bootstrap_port == 8998 + assert adapted_request.bootstrap_room == 12345 + + def test_compatibility_logit_bias(self): + """Test that logit_bias parameter is properly handled.""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + logit_bias={"1": 0.5, "2": -0.3}, + ) + + sampling_params = build_base_sampling_params(request) + assert sampling_params["logit_bias"] == {"1": 0.5, "2": -0.3} + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/srt/openai/test_serving_completions.py b/test/srt/openai/test_serving_completions.py new file mode 100644 index 00000000000..39ba51c9c3f --- /dev/null +++ b/test/srt/openai/test_serving_completions.py @@ -0,0 +1,1055 @@ +""" +Tests for the refactored completions serving handler +""" + +from unittest.mock import AsyncMock, Mock, patch + +import pytest + +from sglang.srt.entrypoints.openai.protocol import ( + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionStreamResponse, + ErrorResponse, +) +from sglang.srt.entrypoints.openai.serving_completions import CompletionHandler +from sglang.srt.entrypoints.openai.utils import ( + build_base_sampling_params, + create_error_response, + create_streaming_error_response, +) +from sglang.srt.managers.io_struct import GenerateReqInput +from sglang.srt.managers.tokenizer_manager import TokenizerManager + + +@pytest.fixture +def mock_tokenizer_manager(): + """Create a mock tokenizer manager""" + manager = Mock(spec=TokenizerManager) + + # Mock tokenizer + manager.tokenizer = Mock() + manager.tokenizer.encode = Mock(return_value=[1, 2, 3, 4]) + manager.tokenizer.decode = Mock(return_value="decoded text") + manager.tokenizer.bos_token_id = 1 + + # Mock model config + manager.model_config = Mock() + manager.model_config.is_multimodal = False + + # Mock server args + manager.server_args = Mock() + manager.server_args.enable_cache_report = False + + # Mock generation + manager.generate_request = AsyncMock() + manager.create_abort_task = Mock(return_value=None) + + return manager + + +@pytest.fixture +def completion_handler(mock_tokenizer_manager): + """Create a completion handler instance""" + return CompletionHandler(mock_tokenizer_manager) + + +class TestUtilityFunctions: + """Test utility functions that were moved from OpenAIServingBase.""" + + def test_build_base_sampling_params_functionality(self): + """Test that build_base_sampling_params works correctly.""" + request = CompletionRequest( + model="test-model", + prompt="Hello world", + temperature=0.8, + max_tokens=150, + top_p=0.9, + top_k=50, + presence_penalty=0.1, + frequency_penalty=0.2, + stop=["<|endoftext|>"], + ) + + sampling_params = build_base_sampling_params(request) + + # Test that parameters are correctly mapped + assert sampling_params["temperature"] == request.temperature + assert sampling_params["max_new_tokens"] == request.max_tokens + assert sampling_params["top_p"] == request.top_p + assert sampling_params["top_k"] == request.top_k + assert sampling_params["presence_penalty"] == request.presence_penalty + assert sampling_params["frequency_penalty"] == request.frequency_penalty + assert sampling_params["stop"] == request.stop + + def test_build_base_sampling_params_logit_bias(self): + """Test that logit_bias parameter is properly handled.""" + request = CompletionRequest( + model="test-model", + prompt="Hello world", + logit_bias={"1": 0.5, "2": -0.3}, + ) + + sampling_params = build_base_sampling_params(request) + assert sampling_params["logit_bias"] == {"1": 0.5, "2": -0.3} + + def test_build_base_sampling_params_all_parameters(self): + """Test that all sampling parameters from adapter.py are handled.""" + request = CompletionRequest( + model="test-model", + prompt="Hello world", + temperature=0.8, + max_tokens=150, + min_tokens=5, + top_p=0.9, + top_k=50, + min_p=0.1, + presence_penalty=0.1, + frequency_penalty=0.2, + repetition_penalty=1.1, + stop=["<|endoftext|>"], + stop_token_ids=[13, 14], + regex=r"\d+", + json_schema='{"type": "object"}', + ebnf=" ::= ", + n=2, + no_stop_trim=True, + ignore_eos=True, + skip_special_tokens=False, + logit_bias={"1": 0.5}, + ) + + sampling_params = build_base_sampling_params(request) + + # Verify all parameters are present + expected_keys = { + "temperature", + "max_new_tokens", + "min_new_tokens", + "stop", + "stop_token_ids", + "top_p", + "top_k", + "min_p", + "presence_penalty", + "frequency_penalty", + "repetition_penalty", + "regex", + "json_schema", + "ebnf", + "n", + "no_stop_trim", + "ignore_eos", + "skip_special_tokens", + "logit_bias", + } + + for key in expected_keys: + assert key in sampling_params, f"Missing parameter: {key}" + + # Verify values + assert sampling_params["temperature"] == 0.8 + assert sampling_params["max_new_tokens"] == 150 + assert sampling_params["min_new_tokens"] == 5 + assert sampling_params["logit_bias"] == {"1": 0.5} + + def test_validate_request_functionality(self, completion_handler): + """Test that validate_request works correctly.""" + request = CompletionRequest( + model="test-model", + prompt="Hello world", + temperature=0.7, + max_tokens=100, + ) + + # Test with completion validation rules + error = completion_handler._validate_request(request) + assert error is None + + # Test with invalid request + invalid_request = CompletionRequest( + model="", # Invalid empty model + prompt="Hello world", + max_tokens=100, + ) + error = completion_handler._validate_request(invalid_request) + assert error is not None + + def test_create_error_response_functionality(self): + """Test that create_error_response works correctly.""" + error = create_error_response("Test error message") + assert isinstance(error, ErrorResponse) + assert error.message == "Test error message" + assert error.type == "BadRequestError" + assert error.code == 400 + + def test_create_streaming_error_response_functionality(self): + """Test that create_streaming_error_response works correctly.""" + error_json = create_streaming_error_response("Test streaming error") + # Should return JSON string with error structure + import json + + error_data = json.loads(error_json) + assert "error" in error_data + assert error_data["error"]["message"] == "Test streaming error" + + +class TestCompletionValidation: + """Test validation methods""" + + def test_validate_completion_request_valid(self, completion_handler): + """Test validation of valid completion request""" + request = CompletionRequest( + model="test-model", + prompt="Hello world", + max_tokens=100, + temperature=0.7, + stream=False, + ) + + # Use utility function directly instead of handler method + error = completion_handler._validate_request(request) + assert error is None + + def test_validate_completion_request_empty_prompt_string(self, completion_handler): + """Test validation fails for empty string prompt""" + request = CompletionRequest(model="test-model", prompt="", max_tokens=100) + + error = completion_handler._validate_request(request) + assert error is not None + assert "prompt" in error.model_dump()["param"] + + def test_validate_completion_request_whitespace_prompt(self, completion_handler): + """Test validation fails for whitespace-only prompt""" + request = CompletionRequest( + model="test-model", prompt=" \n\t ", max_tokens=100 + ) + + error = completion_handler._validate_request(request) + assert error is not None + assert "prompt" in error.model_dump()["param"] + + def test_validate_completion_request_empty_list_prompt(self, completion_handler): + """Test validation fails for empty list prompt""" + request = CompletionRequest(model="test-model", prompt=[], max_tokens=100) + + error = completion_handler._validate_request(request) + assert error is not None + assert "prompt" in error.model_dump()["param"] + + def test_validate_completion_request_invalid_model(self, completion_handler): + """Test validation fails for invalid model""" + request = CompletionRequest(model="", prompt="Hello world", max_tokens=100) + + error = completion_handler._validate_request(request) + assert error is not None + assert "model" in error.model_dump()["param"] + + def test_validate_completion_request_invalid_temperature(self, completion_handler): + """Test validation fails for invalid temperature""" + request = CompletionRequest( + model="test-model", + prompt="Hello world", + max_tokens=100, + temperature=-1.0, # Invalid + ) + + error = completion_handler._validate_request(request) + assert error is not None + assert "temperature" in error.model_dump()["param"] + + def test_validate_completion_request_invalid_max_tokens(self, completion_handler): + """Test validation fails for invalid max_tokens""" + request = CompletionRequest( + model="test-model", prompt="Hello world", max_tokens=0 # Invalid + ) + + error = completion_handler._validate_request(request) + assert error is not None + assert "max_tokens" in error.model_dump()["param"] + + +class TestPromptHandling: + """Test different prompt types and formats from adapter.py""" + + def test_single_string_prompt(self, completion_handler): + """Test handling single string prompt""" + request = CompletionRequest( + model="test-model", prompt="Hello world", max_tokens=100 + ) + + adapted_request, _ = completion_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + assert adapted_request.text == "Hello world" + + def test_single_token_ids_prompt(self, completion_handler): + """Test handling single token IDs prompt""" + request = CompletionRequest( + model="test-model", prompt=[1, 2, 3, 4], max_tokens=100 + ) + + adapted_request, _ = completion_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + assert adapted_request.input_ids == [1, 2, 3, 4] + + def test_multiple_string_prompts(self, completion_handler): + """Test handling multiple string prompts""" + requests = [ + CompletionRequest(model="test-model", prompt="Hello", max_tokens=50), + CompletionRequest(model="test-model", prompt="World", max_tokens=50), + ] + + adapted_request, _ = completion_handler._convert_to_internal_request( + requests, ["id1", "id2"] + ) + + assert adapted_request.text == ["Hello", "World"] + assert adapted_request.rid == ["id1", "id2"] + + def test_multiple_token_ids_prompts(self, completion_handler): + """Test handling multiple token IDs prompts""" + requests = [ + CompletionRequest(model="test-model", prompt=[1, 2], max_tokens=50), + CompletionRequest(model="test-model", prompt=[3, 4], max_tokens=50), + ] + + adapted_request, _ = completion_handler._convert_to_internal_request( + requests, ["id1", "id2"] + ) + + assert adapted_request.input_ids == [[1, 2], [3, 4]] + + def test_list_of_strings_prompt(self, completion_handler): + """Test handling list of strings as prompt""" + request = CompletionRequest( + model="test-model", prompt=["Hello", "world"], max_tokens=100 + ) + + adapted_request, _ = completion_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + assert adapted_request.text == ["Hello", "world"] + + def test_completion_template_handling(self, completion_handler): + """Test completion template processing""" + request = CompletionRequest( + model="test-model", + prompt="def hello():", + suffix="return 'world'", + max_tokens=100, + ) + + with patch( + "sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined", + return_value=True, + ): + with patch( + "sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt", + return_value="processed_prompt", + ): + adapted_request, _ = completion_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + assert adapted_request.text == "processed_prompt" + + +class TestEchoHandling: + """Test echo functionality from adapter.py""" + + def test_echo_with_string_prompt_streaming(self, completion_handler): + """Test echo handling with string prompt in streaming""" + request = CompletionRequest( + model="test-model", prompt="Hello", max_tokens=100, echo=True + ) + + # Test _get_echo_text method + echo_text = completion_handler._get_echo_text(request, 0) + assert echo_text == "Hello" + + def test_echo_with_list_of_strings_streaming(self, completion_handler): + """Test echo handling with list of strings in streaming""" + request = CompletionRequest( + model="test-model", + prompt=["Hello", "World"], + max_tokens=100, + echo=True, + n=1, + ) + + echo_text = completion_handler._get_echo_text(request, 0) + assert echo_text == "Hello" + + echo_text = completion_handler._get_echo_text(request, 1) + assert echo_text == "World" + + def test_echo_with_token_ids_streaming(self, completion_handler): + """Test echo handling with token IDs in streaming""" + request = CompletionRequest( + model="test-model", prompt=[1, 2, 3], max_tokens=100, echo=True + ) + + completion_handler.tokenizer_manager.tokenizer.decode.return_value = ( + "decoded_prompt" + ) + echo_text = completion_handler._get_echo_text(request, 0) + assert echo_text == "decoded_prompt" + + def test_echo_with_multiple_token_ids_streaming(self, completion_handler): + """Test echo handling with multiple token ID prompts in streaming""" + request = CompletionRequest( + model="test-model", prompt=[[1, 2], [3, 4]], max_tokens=100, echo=True, n=1 + ) + + completion_handler.tokenizer_manager.tokenizer.decode.return_value = "decoded" + echo_text = completion_handler._get_echo_text(request, 0) + assert echo_text == "decoded" + + def test_prepare_echo_prompts_non_streaming(self, completion_handler): + """Test prepare echo prompts for non-streaming response""" + # Test with single string + request = CompletionRequest(model="test-model", prompt="Hello", echo=True) + + echo_prompts = completion_handler._prepare_echo_prompts(request) + assert echo_prompts == ["Hello"] + + # Test with list of strings + request = CompletionRequest( + model="test-model", prompt=["Hello", "World"], echo=True + ) + + echo_prompts = completion_handler._prepare_echo_prompts(request) + assert echo_prompts == ["Hello", "World"] + + # Test with token IDs + request = CompletionRequest(model="test-model", prompt=[1, 2, 3], echo=True) + + completion_handler.tokenizer_manager.tokenizer.decode.return_value = "decoded" + echo_prompts = completion_handler._prepare_echo_prompts(request) + assert echo_prompts == ["decoded"] + + +class TestCompletionRequestConversion: + """Test request conversion to internal format""" + + def test_convert_simple_string_prompt(self, completion_handler): + """Test conversion of simple string prompt""" + request = CompletionRequest( + model="test-model", prompt="Hello world", max_tokens=100, temperature=0.7 + ) + + adapted_request, processed_request = ( + completion_handler._convert_to_internal_request([request], ["test-id"]) + ) + + assert isinstance(adapted_request, GenerateReqInput) + assert adapted_request.text == "Hello world" + assert adapted_request.sampling_params["temperature"] == 0.7 + assert adapted_request.sampling_params["max_new_tokens"] == 100 + assert adapted_request.rid == "test-id" + assert processed_request == request + + def test_convert_token_ids_prompt(self, completion_handler): + """Test conversion of token IDs prompt""" + request = CompletionRequest( + model="test-model", prompt=[1, 2, 3, 4], max_tokens=100 + ) + + adapted_request, processed_request = ( + completion_handler._convert_to_internal_request([request], ["test-id"]) + ) + + assert isinstance(adapted_request, GenerateReqInput) + assert adapted_request.input_ids == [1, 2, 3, 4] + assert adapted_request.sampling_params["max_new_tokens"] == 100 + + def test_convert_logprob_start_len_with_echo_and_logprobs(self, completion_handler): + """Test logprob_start_len setting with echo and logprobs""" + request = CompletionRequest( + model="test-model", + prompt="Hello world", + max_tokens=100, + echo=True, + logprobs=5, + ) + + adapted_request, _ = completion_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + # When echo=True and logprobs is set, should be 0 + assert adapted_request.logprob_start_len == 0 + assert adapted_request.return_logprob == True + assert adapted_request.top_logprobs_num == 5 + + def test_convert_logprob_start_len_without_echo(self, completion_handler): + """Test logprob_start_len setting without echo""" + request = CompletionRequest( + model="test-model", + prompt="Hello world", + max_tokens=100, + echo=False, + logprobs=3, + ) + + adapted_request, _ = completion_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + # When echo=False, should be -1 + assert adapted_request.logprob_start_len == -1 + assert adapted_request.return_logprob == True + assert adapted_request.top_logprobs_num == 3 + + +class TestCompatibilityWithAdapter: + """Test compatibility with adapter.py functionality""" + + def test_sampling_params_structure_matches_adapter(self): + """Test that sampling params structure matches adapter.py exactly""" + request = CompletionRequest( + model="test-model", + prompt="Hello world", + max_tokens=100, + temperature=0.7, + top_p=0.9, + top_k=50, + min_p=0.1, + presence_penalty=0.5, + frequency_penalty=0.3, + repetition_penalty=1.1, + stop=["STOP"], + stop_token_ids=[13], + n=2, + ignore_eos=True, + skip_special_tokens=False, + ) + + # Test the utility function directly + sampling_params = build_base_sampling_params(request) + + # Check all parameters from adapter.py v1_generate_request + expected_params = { + "temperature", + "max_new_tokens", + "min_new_tokens", + "stop", + "stop_token_ids", + "top_p", + "top_k", + "min_p", + "presence_penalty", + "frequency_penalty", + "repetition_penalty", + "regex", + "json_schema", + "ebnf", + "n", + "no_stop_trim", + "ignore_eos", + "skip_special_tokens", + } + + actual_params = set(sampling_params.keys()) + assert expected_params.issubset(actual_params) + + def test_bootstrap_parameters_support(self, completion_handler): + """Test that bootstrap parameters are supported""" + request = CompletionRequest( + model="test-model", + prompt="Hello world", + max_tokens=100, + bootstrap_host="localhost", + bootstrap_port=8080, + bootstrap_room=123, + ) + + adapted_request, _ = completion_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + assert adapted_request.bootstrap_host == "localhost" + assert adapted_request.bootstrap_port == 8080 + assert adapted_request.bootstrap_room == 123 + + def test_lora_path_support(self, completion_handler): + """Test that LoRA path is supported""" + request = CompletionRequest( + model="test-model", + prompt="Hello world", + max_tokens=100, + lora_path="/path/to/lora", + ) + + adapted_request, _ = completion_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + assert adapted_request.lora_path == "/path/to/lora" + + def test_echo_and_logprobs_compatibility(self, completion_handler): + """Test echo and logprobs handling matches adapter behavior""" + request = CompletionRequest( + model="test-model", + prompt="Hello world", + max_tokens=100, + echo=True, + logprobs=5, + ) + + adapted_request, _ = completion_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + # When echo=True and logprobs is set, logprob_start_len should be 0 + assert adapted_request.logprob_start_len == 0 + assert adapted_request.return_logprob == True + assert adapted_request.top_logprobs_num == 5 + + def test_no_echo_logprobs_compatibility(self, completion_handler): + """Test no echo but logprobs handling""" + request = CompletionRequest( + model="test-model", + prompt="Hello world", + max_tokens=100, + echo=False, + logprobs=3, + ) + + adapted_request, _ = completion_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + # When echo=False, logprob_start_len should be -1 + assert adapted_request.logprob_start_len == -1 + assert adapted_request.return_logprob == True + assert adapted_request.top_logprobs_num == 3 + + def test_return_text_in_logprobs_setting(self, completion_handler): + """Test that return_text_in_logprobs is properly set""" + request = CompletionRequest( + model="test-model", prompt="Hello world", max_tokens=100 + ) + + adapted_request, _ = completion_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + assert adapted_request.return_text_in_logprobs == True + + def test_multiple_requests_batch_handling(self, completion_handler): + """Test handling of multiple requests in batch mode""" + requests = [ + CompletionRequest( + model="test-model", prompt="Hello", max_tokens=50, lora_path="/path1" + ), + CompletionRequest( + model="test-model", prompt="World", max_tokens=50, lora_path="/path2" + ), + ] + + adapted_request, processed_requests = ( + completion_handler._convert_to_internal_request(requests, ["id1", "id2"]) + ) + + assert adapted_request.text == ["Hello", "World"] + assert adapted_request.lora_path == ["/path1", "/path2"] + assert adapted_request.rid == ["id1", "id2"] + assert ( + processed_requests == requests + ) # Should return list for multiple requests + + +class TestResponseBuilding: + """Test response building functionality""" + + def test_build_simple_response(self, completion_handler): + """Test building simple completion response""" + request = CompletionRequest(model="test-model", prompt="Hello", max_tokens=100) + + mock_ret = [ + { + "text": " world!", + "meta_info": { + "id": "test-id", + "prompt_tokens": 5, + "completion_tokens": 10, + "finish_reason": {"type": "stop"}, + }, + } + ] + + response = completion_handler._build_completion_response( + request, mock_ret, 1234567890 + ) + + assert isinstance(response, CompletionResponse) + assert response.id == "test-id" + assert response.model == "test-model" + assert response.created == 1234567890 + assert len(response.choices) == 1 + assert response.choices[0].text == " world!" + assert response.choices[0].finish_reason == "stop" + assert response.usage.prompt_tokens == 5 + assert response.usage.completion_tokens == 10 + assert response.usage.total_tokens == 15 + + def test_build_response_with_echo(self, completion_handler): + """Test building response with echo enabled""" + request = CompletionRequest( + model="test-model", prompt="Hello", max_tokens=100, echo=True + ) + + # Mock echo prompts preparation + completion_handler._prepare_echo_prompts = Mock(return_value=["Hello"]) + + mock_ret = [ + { + "text": " world!", + "meta_info": { + "id": "test-id", + "prompt_tokens": 5, + "completion_tokens": 10, + "finish_reason": {"type": "stop"}, + }, + } + ] + + response = completion_handler._build_completion_response( + request, mock_ret, 1234567890 + ) + + # With echo=True, text should include the prompt + assert response.choices[0].text == "Hello world!" + + def test_build_response_with_logprobs(self, completion_handler): + """Test building response with logprobs""" + request = CompletionRequest( + model="test-model", prompt="Hello", max_tokens=100, logprobs=3 + ) + + mock_ret = [ + { + "text": " world!", + "meta_info": { + "id": "test-id", + "prompt_tokens": 5, + "completion_tokens": 10, + "finish_reason": {"type": "stop"}, + "output_token_logprobs": [(-0.1, 1, " world"), (-0.2, 2, "!")], + "output_top_logprobs": [ + [(-0.1, 1, " world"), (-0.3, 3, " earth")], + [(-0.2, 2, "!"), (-0.4, 4, ".")], + ], + }, + } + ] + + response = completion_handler._build_completion_response( + request, mock_ret, 1234567890 + ) + + assert response.choices[0].logprobs is not None + assert len(response.choices[0].logprobs.tokens) == 2 + assert response.choices[0].logprobs.tokens[0] == " world" + assert response.choices[0].logprobs.tokens[1] == "!" + + def test_build_response_with_echo_and_logprobs(self, completion_handler): + """Test building response with both echo and logprobs""" + request = CompletionRequest( + model="test-model", prompt="Hello", max_tokens=100, echo=True, logprobs=2 + ) + + completion_handler._prepare_echo_prompts = Mock(return_value=["Hello"]) + + mock_ret = [ + { + "text": " world!", + "meta_info": { + "id": "test-id", + "prompt_tokens": 5, + "completion_tokens": 10, + "finish_reason": {"type": "stop"}, + "input_token_logprobs": [(-0.05, 0, "Hello")], + "input_top_logprobs": [[(-0.05, 0, "Hello"), (-0.1, 1, "Hi")]], + "output_token_logprobs": [(-0.1, 1, " world"), (-0.2, 2, "!")], + "output_top_logprobs": [ + [(-0.1, 1, " world"), (-0.3, 3, " earth")], + [(-0.2, 2, "!"), (-0.4, 4, ".")], + ], + }, + } + ] + + response = completion_handler._build_completion_response( + request, mock_ret, 1234567890 + ) + + assert response.choices[0].text == "Hello world!" + assert response.choices[0].logprobs is not None + # Should include both input and output logprobs + assert len(response.choices[0].logprobs.tokens) == 3 # Hello + world + ! + + def test_build_response_with_matched_stop(self, completion_handler): + """Test building response with matched stop token""" + request = CompletionRequest(model="test-model", prompt="Hello", max_tokens=100) + + mock_ret = [ + { + "text": " world!", + "meta_info": { + "id": "test-id", + "prompt_tokens": 5, + "completion_tokens": 10, + "finish_reason": {"type": "stop", "matched": ""}, + }, + } + ] + + response = completion_handler._build_completion_response( + request, mock_ret, 1234567890 + ) + + assert response.choices[0].finish_reason == "stop" + assert response.choices[0].matched_stop == "" + + def test_build_response_with_cache_report(self, completion_handler): + """Test building response with cache reporting enabled""" + request = CompletionRequest(model="test-model", prompt="Hello", max_tokens=100) + + mock_ret = [ + { + "text": " world!", + "meta_info": { + "id": "test-id", + "prompt_tokens": 5, + "completion_tokens": 10, + "cached_tokens": 3, + "finish_reason": {"type": "stop"}, + }, + } + ] + + response = completion_handler._build_completion_response( + request, mock_ret, 1234567890, cache_report=True + ) + + assert response.usage.prompt_tokens_details is not None + assert response.usage.prompt_tokens_details["cached_tokens"] == 3 + + def test_build_response_multiple_choices(self, completion_handler): + """Test building response with multiple choices (n > 1)""" + request = CompletionRequest( + model="test-model", prompt="Hello", max_tokens=100, n=2 + ) + + completion_handler._prepare_echo_prompts = Mock(return_value=["Hello"]) + + mock_ret = [ + { + "text": " world!", + "meta_info": { + "id": "test-id", + "prompt_tokens": 5, + "completion_tokens": 10, + "finish_reason": {"type": "stop"}, + }, + }, + { + "text": " there!", + "meta_info": { + "id": "test-id", + "prompt_tokens": 5, + "completion_tokens": 8, + "finish_reason": {"type": "stop"}, + }, + }, + ] + + response = completion_handler._build_completion_response( + request, mock_ret, 1234567890 + ) + + assert len(response.choices) == 2 + assert response.choices[0].text == " world!" + assert response.choices[1].text == " there!" + assert response.choices[0].index == 0 + assert response.choices[1].index == 1 + # Total tokens should be: prompt_tokens + both completion_tokens + assert response.usage.total_tokens == 5 + 10 + 8 + + +@pytest.mark.asyncio +class TestAsyncMethods: + """Test async handler methods""" + + async def test_handle_request_validation_error(self, completion_handler): + """Test handling request with validation error""" + mock_request = Mock() + request = CompletionRequest( + model="", prompt="Hello world", max_tokens=100 # Invalid model + ) + + response = await completion_handler.handle_request(request, mock_request) + + # Should return error response + assert hasattr(response, "model_dump") + error_data = response.model_dump() + assert error_data["object"] == "error" + assert "model" in error_data["param"] + + async def test_handle_request_non_streaming(self, completion_handler): + """Test handling non-streaming request - simplified test for async flow""" + mock_request = Mock() + request = CompletionRequest( + model="test-model", prompt="Hello world", max_tokens=100, stream=False + ) + + # For now, just test that we can call the method and get some response + # The detailed functionality is tested in the sync tests above + response = await completion_handler.handle_request(request, mock_request) + + # Should return some response (either error or success, depending on mock setup) + assert response is not None + assert hasattr(response, "model_dump") + + async def test_handle_request_streaming(self, completion_handler): + """Test handling streaming request""" + mock_request = Mock() + request = CompletionRequest( + model="test-model", prompt="Hello world", max_tokens=100, stream=True + ) + + response = await completion_handler.handle_request(request, mock_request) + + # Should return StreamingResponse + from fastapi.responses import StreamingResponse + + assert isinstance(response, StreamingResponse) + + async def test_handle_streaming_with_usage(self, completion_handler): + """Test streaming with usage reporting""" + mock_request = Mock() + request = CompletionRequest( + model="test-model", + prompt="Hello world", + max_tokens=100, + stream=True, + stream_options={"include_usage": True}, + ) + + response = await completion_handler.handle_request(request, mock_request) + + from fastapi.responses import StreamingResponse + + assert isinstance(response, StreamingResponse) + + +class TestEdgeCases: + """Test edge cases and error conditions""" + + def test_multiple_requests_different_prompt_types_error(self, completion_handler): + """Test error when multiple requests have different prompt types""" + requests = [ + CompletionRequest(model="test-model", prompt="Hello", max_tokens=50), + CompletionRequest(model="test-model", prompt=[1, 2, 3], max_tokens=50), + ] + + with pytest.raises(AssertionError): + completion_handler._convert_to_internal_request(requests, ["id1", "id2"]) + + def test_multiple_requests_with_n_greater_than_1_error(self, completion_handler): + """Test error when multiple requests have n > 1""" + requests = [ + CompletionRequest(model="test-model", prompt="Hello", max_tokens=50, n=2), + CompletionRequest(model="test-model", prompt="World", max_tokens=50, n=1), + ] + + with pytest.raises(ValueError, match="Parallel sampling is not supported"): + completion_handler._convert_to_internal_request(requests, ["id1", "id2"]) + + def test_empty_prompt_list_validation(self, completion_handler): + """Test validation of empty prompt list""" + request = CompletionRequest(model="test-model", prompt=[], max_tokens=100) + + error = completion_handler._validate_request(request) + assert error is not None + assert "prompt" in error.model_dump()["param"] + + def test_nested_empty_prompt_list_validation(self, completion_handler): + """Test validation of nested empty prompt list""" + request = CompletionRequest(model="test-model", prompt=[[]], max_tokens=100) + + error = completion_handler._validate_request(request) + assert error is not None + assert "prompt" in error.model_dump()["param"] + + @pytest.mark.asyncio + async def test_echo_warning_with_logprobs(self, completion_handler): + """Test warning when echo is used with logprobs""" + request = CompletionRequest( + model="test-model", + prompt="Hello world", + max_tokens=100, + echo=True, + logprobs=5, + ) + + mock_raw_request = Mock() + + with patch( + "sglang.srt.entrypoints.openai.serving_completions.logger" + ) as mock_logger: + # Call handle_request which contains the warning logic + await completion_handler.handle_request(request, mock_raw_request) + # Should log warning about echo + logprobs incompatibility + mock_logger.warning.assert_called_once() + assert "Echo is not compatible with logprobs" in str( + mock_logger.warning.call_args + ) + + def test_suffix_without_completion_template(self, completion_handler): + """Test that suffix is ignored when completion template is not defined""" + request = CompletionRequest( + model="test-model", + prompt="def hello():", + suffix="return 'world'", + max_tokens=100, + ) + + with patch( + "sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined", + return_value=False, + ): + adapted_request, _ = completion_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + # Should use original prompt, not processed with suffix + assert adapted_request.text == "def hello():" + + def test_zero_max_tokens_handling(self, completion_handler): + """Test handling of zero max_tokens""" + request = CompletionRequest( + model="test-model", prompt="Hello world", max_tokens=0 + ) + + error = completion_handler._validate_request(request) + assert error is not None + assert "max_tokens" in error.model_dump()["param"] + + def test_negative_temperature_handling(self, completion_handler): + """Test handling of negative temperature""" + request = CompletionRequest( + model="test-model", prompt="Hello world", max_tokens=100, temperature=-0.5 + ) + + error = completion_handler._validate_request(request) + assert error is not None + assert "temperature" in error.model_dump()["param"] From d9ceddd56468f64f3a9b7c07e4aecde44c400538 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sat, 14 Jun 2025 02:06:37 +0000 Subject: [PATCH 02/33] feat: add serving_embedding Signed-off-by: Xinyuan Tong --- .../sglang/srt/entrypoints/openai/protocol.py | 9 +- .../entrypoints/openai/serving_embedding.py | 236 ++++++++ .../srt/entrypoints/openai/validation.py | 64 +++ test/srt/openai/test_serving_embedding.py | 532 ++++++++++++++++++ 4 files changed, 838 insertions(+), 3 deletions(-) create mode 100644 python/sglang/srt/entrypoints/openai/serving_embedding.py create mode 100644 test/srt/openai/test_serving_embedding.py diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index a11fc71f93a..f083f61fa6c 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -487,12 +487,15 @@ class MultimodalEmbeddingInput(BaseModel): image: Optional[str] = None +EmbeddingInput = Union[ + List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput] +] + + class EmbeddingRequest(BaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/embeddings/create - input: Union[ - List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput] - ] + input: EmbeddingInput model: str encoding_format: str = "float" dimensions: int = None diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py new file mode 100644 index 00000000000..e7314613ea4 --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py @@ -0,0 +1,236 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Embedding serving logic for OpenAI API""" + +import logging +import uuid +from typing import Any, Dict, List, Optional, Union + +from fastapi import Request + +from sglang.srt.conversation import generate_embedding_convs +from sglang.srt.entrypoints.openai.protocol import ( + EmbeddingObject, + EmbeddingRequest, + EmbeddingResponse, + ErrorResponse, + MultimodalEmbeddingInput, + UsageInfo, +) +from sglang.srt.entrypoints.openai.serving_engine import ( + OpenAIServingBase, + RequestContext, +) +from sglang.srt.entrypoints.openai.utils import create_error_response +from sglang.srt.managers.io_struct import EmbeddingReqInput +from sglang.srt.managers.tokenizer_manager import TokenizerManager + +logger = logging.getLogger(__name__) + + +class EmbeddingHandler(OpenAIServingBase): + """Handler for embedding requests""" + + def __init__(self, tokenizer_manager: TokenizerManager): + super().__init__(tokenizer_manager) + + async def handle_request( + self, request: EmbeddingRequest, raw_request: Request + ) -> Union[EmbeddingResponse, ErrorResponse]: + """Handle an embedding request""" + try: + # Validate request + error = self._validate_request(request) + if error: + return error + + # Create request context + ctx = RequestContext( + raw_request=raw_request, + openai_request=request, + request_id=request.rid or f"embd-{uuid.uuid4()}", + ) + + # Convert to internal format + adapted_request, processed_request = self._convert_to_internal_request( + [request], [ctx.request_id] + ) + + # Handle the request + return await self._handle_request(adapted_request, processed_request, ctx) + + except Exception as e: + logger.error(f"Error in embedding: {e}") + return create_error_response( + message=f"Internal server error: {str(e)}", + err_type="InternalServerError", + status_code=500, + ) + + def _convert_to_internal_request( + self, + all_requests: List[EmbeddingRequest], + request_ids: List[str], + ) -> tuple[EmbeddingReqInput, Union[EmbeddingRequest, List[EmbeddingRequest]]]: + """Convert OpenAI embedding request to internal format""" + prompts = [request.input for request in all_requests] + + # Handle single vs multiple requests + if len(all_requests) == 1: + prompt = prompts[0] + if isinstance(prompt, str): + # Single string input + prompt_kwargs = {"text": prompt} + elif isinstance(prompt, list): + if len(prompt) > 0 and isinstance(prompt[0], str): + # List of strings + prompt_kwargs = {"text": prompt} + elif len(prompt) > 0 and isinstance( + prompt[0], MultimodalEmbeddingInput + ): + # Handle multimodal embedding inputs + texts = [] + images = [] + for item in prompt: + # Use padding for text if None - this could be improved + texts.append(item.text if item.text is not None else "padding") + images.append(item.image if item.image is not None else None) + + generate_prompts = [] + # Check if we have a chat template for multimodal embeddings + # This would need to be passed in from the server configuration + chat_template_name = getattr( + self.tokenizer_manager, "chat_template_name", None + ) + if chat_template_name is not None: + convs = generate_embedding_convs( + texts, images, chat_template_name + ) + for conv in convs: + generate_prompts.append(conv.get_prompt()) + else: + generate_prompts = texts + + if len(generate_prompts) == 1: + prompt_kwargs = { + "text": generate_prompts[0], + "image_data": images[0], + } + else: + prompt_kwargs = { + "text": generate_prompts, + "image_data": images, + } + else: + # List of integers (token IDs) or empty list + prompt_kwargs = {"input_ids": prompt} + else: + # Other types (should not happen but handle gracefully) + prompt_kwargs = {"input_ids": prompt} + # Use the passed request_ids for single request + final_request_id = request_ids[0] if len(all_requests) == 1 else request_ids + else: + # Handle batch requests + if len(prompts) > 0: + # Validate that all prompts have the same type + first_prompt = prompts[0] + first_type = type(first_prompt) + for i, prompt in enumerate(prompts[1:], 1): + if type(prompt) != first_type: + raise AssertionError( + f"All prompts in batch must have the same type, but prompt at index {i} has different type" + ) + + if isinstance(first_prompt, str): + # Batch of strings + prompt_kwargs = {"text": prompts} + elif isinstance(first_prompt, list): + if len(first_prompt) > 0 and isinstance(first_prompt[0], str): + # Batch of lists of strings + prompt_kwargs = {"text": prompts} + elif len(first_prompt) > 0 and isinstance( + first_prompt[0], MultimodalEmbeddingInput + ): + # Handle multimodal batch requests + raise NotImplementedError( + "Multiple requests with multimodal inputs are not supported yet" + ) + else: + # Batch of token ID lists + prompt_kwargs = {"input_ids": prompts} + else: + # Other types + prompt_kwargs = {"input_ids": prompts} + else: + prompt_kwargs = {"input_ids": prompts} + # Use the passed request_ids for batch requests + final_request_id = request_ids + + adapted_request = EmbeddingReqInput( + rid=final_request_id, + **prompt_kwargs, + ) + + return adapted_request, ( + all_requests[0] if len(all_requests) == 1 else all_requests + ) + + async def _handle_request( + self, + adapted_request: EmbeddingReqInput, + request: EmbeddingRequest, + ctx: RequestContext, + ) -> Union[EmbeddingResponse, ErrorResponse]: + """Handle the embedding request""" + try: + ret = await self.tokenizer_manager.generate_request( + adapted_request, ctx.raw_request + ).__anext__() + except ValueError as e: + return create_error_response(str(e)) + + if not isinstance(ret, list): + ret = [ret] + + response = self._build_embedding_response( + ret, self.tokenizer_manager.model_path + ) + return response + + def _build_embedding_response( + self, ret: List[Dict[str, Any]], model_path: str + ) -> EmbeddingResponse: + """Build the embedding response""" + embedding_objects = [] + prompt_tokens = 0 + + for idx, ret_item in enumerate(ret): + embedding_objects.append( + EmbeddingObject( + embedding=ret_item["embedding"], + index=idx, + ) + ) + # Handle missing prompt_tokens gracefully + meta_info = ret_item.get("meta_info", {}) + prompt_tokens += meta_info.get("prompt_tokens", 0) + + return EmbeddingResponse( + data=embedding_objects, + model=model_path, + usage=UsageInfo( + prompt_tokens=prompt_tokens, + total_tokens=prompt_tokens, + ), + ) diff --git a/python/sglang/srt/entrypoints/openai/validation.py b/python/sglang/srt/entrypoints/openai/validation.py index 68e33fb90e1..e1f0c3f3cff 100644 --- a/python/sglang/srt/entrypoints/openai/validation.py +++ b/python/sglang/srt/entrypoints/openai/validation.py @@ -20,6 +20,8 @@ ChatCompletionMessageParam, ChatCompletionRequest, CompletionRequest, + EmbeddingInput, + EmbeddingRequest, OpenAIServingRequest, ) @@ -284,6 +286,55 @@ def validate_presence_penalty(presence_penalty: float) -> Optional[str]: return None +def validate_embedding_input(input: EmbeddingInput) -> Optional[str]: + """Validate that the input is not empty or whitespace only.""" + if not input: + return "Input cannot be empty" + + # Handle single string + if isinstance(input, str): + if not input.strip(): + return "Input cannot be empty or whitespace only" + return None + + # Handle list inputs + if isinstance(input, list): + if len(input) == 0: + return "Input cannot be empty" + + # Check first element to determine type + first_item = input[0] + + if isinstance(first_item, str): + # List of strings + for i, item in enumerate(input): + if not isinstance(item, str): + return f"All items in input list must be strings" + if not item.strip(): + return f"Input at index {i} cannot be empty or whitespace only" + elif isinstance(first_item, int): + # List of integers (token IDs) + for i, item in enumerate(input): + if not isinstance(item, int): + return f"All items in input list must be integers" + if item < 0: + return f"Token ID at index {i} must be non-negative" + elif isinstance(first_item, list): + # List of lists (multiple token sequences) + for i, item in enumerate(input): + if not isinstance(item, list): + return f"Input at index {i} must be a list" + if not item: + return f"Input at index {i} cannot be empty" + if not all(isinstance(token, int) for token in item): + return f"Input at index {i} must contain only integers" + if any(token < 0 for token in item): + return f"Input at index {i} contains negative token IDs" + # Note: MultimodalEmbeddingInput validation would be handled by Pydantic + + return None + + def get_common_validation_rules() -> List[ValidationRule]: """Get validation rules common to both chat and completion requests""" return [ @@ -332,6 +383,17 @@ def get_completion_specific_validation_rules() -> List[ValidationRule]: ] +def get_embedding_specific_validation_rules() -> List[ValidationRule]: + """Get validation rules specific to embedding requests""" + return [ + ValidationRule( + param_name="input", + validator_func=validate_embedding_input, + param_getter=lambda request: request.input, + ), + ] + + def get_validation_rules(request: OpenAIServingRequest) -> List[ValidationRule]: """Get all validation rules for the request""" if isinstance(request, ChatCompletionRequest): @@ -340,5 +402,7 @@ def get_validation_rules(request: OpenAIServingRequest) -> List[ValidationRule]: return ( get_common_validation_rules() + get_completion_specific_validation_rules() ) + elif isinstance(request, EmbeddingRequest): + return get_embedding_specific_validation_rules() else: raise ValueError(f"Unsupported request type: {type(request)}") diff --git a/test/srt/openai/test_serving_embedding.py b/test/srt/openai/test_serving_embedding.py new file mode 100644 index 00000000000..087bd35a40c --- /dev/null +++ b/test/srt/openai/test_serving_embedding.py @@ -0,0 +1,532 @@ +""" +Unit tests for the EmbeddingHandler class from serving_embedding.py. + +These tests ensure that the embedding serving implementation maintains compatibility +with the original adapter.py functionality and follows OpenAI API specifications. +""" + +import asyncio +import json +import time +import uuid +from typing import Any, Dict, List +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from fastapi import Request +from pydantic_core import ValidationError + +from sglang.srt.entrypoints.openai.protocol import ( + EmbeddingObject, + EmbeddingRequest, + EmbeddingResponse, + ErrorResponse, + MultimodalEmbeddingInput, + UsageInfo, +) +from sglang.srt.entrypoints.openai.serving_embedding import EmbeddingHandler +from sglang.srt.entrypoints.openai.serving_engine import RequestContext +from sglang.srt.managers.io_struct import EmbeddingReqInput + + +# Mock TokenizerManager for embedding tests +class MockTokenizerManager: + def __init__(self): + self.model_config = Mock() + self.model_config.is_multimodal = False + self.server_args = Mock() + self.server_args.enable_cache_report = False + self.model_path = "test-model" + + # Mock tokenizer + self.tokenizer = Mock() + self.tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5]) + self.tokenizer.decode = Mock(return_value="Test embedding input") + self.tokenizer.chat_template = None + self.tokenizer.bos_token_id = 1 + + # Mock generate_request method for embeddings + async def mock_generate_embedding(): + yield { + "embedding": [0.1, 0.2, 0.3, 0.4, 0.5] * 20, # 100-dim embedding + "meta_info": { + "id": f"embd-{uuid.uuid4()}", + "prompt_tokens": 5, + }, + } + + self.generate_request = Mock(return_value=mock_generate_embedding()) + + +@pytest.fixture +def mock_tokenizer_manager(): + """Create a mock tokenizer manager for testing.""" + return MockTokenizerManager() + + +@pytest.fixture +def embedding_handler(mock_tokenizer_manager): + """Create an EmbeddingHandler instance for testing.""" + return EmbeddingHandler(mock_tokenizer_manager) + + +@pytest.fixture +def mock_request(): + """Create a mock FastAPI request.""" + request = Mock(spec=Request) + request.headers = {} + return request + + +@pytest.fixture +def basic_embedding_request(): + """Create a basic embedding request.""" + return EmbeddingRequest( + model="test-model", + input="Hello, how are you?", + encoding_format="float", + ) + + +@pytest.fixture +def list_embedding_request(): + """Create an embedding request with list input.""" + return EmbeddingRequest( + model="test-model", + input=["Hello, how are you?", "I am fine, thank you!"], + encoding_format="float", + ) + + +@pytest.fixture +def multimodal_embedding_request(): + """Create a multimodal embedding request.""" + return EmbeddingRequest( + model="test-model", + input=[ + MultimodalEmbeddingInput(text="Hello", image="base64_image_data"), + MultimodalEmbeddingInput(text="World", image=None), + ], + encoding_format="float", + ) + + +@pytest.fixture +def token_ids_embedding_request(): + """Create an embedding request with token IDs.""" + return EmbeddingRequest( + model="test-model", + input=[1, 2, 3, 4, 5], + encoding_format="float", + ) + + +class TestEmbeddingHandlerValidation: + """Test validation methods of EmbeddingHandler.""" + + def test_validate_embedding_request_valid( + self, embedding_handler, basic_embedding_request + ): + """Test validation with a valid request.""" + error = embedding_handler._validate_request(basic_embedding_request) + assert error is None + + def test_validate_embedding_request_empty_string(self, embedding_handler): + """Test validation fails with empty string input.""" + request = EmbeddingRequest(model="test-model", input="") + error = embedding_handler._validate_request(request) + assert error is not None + assert "empty" in error.message.lower() + + def test_validate_embedding_request_whitespace_only(self, embedding_handler): + """Test validation fails with whitespace-only input.""" + request = EmbeddingRequest(model="test-model", input=" \n\t ") + error = embedding_handler._validate_request(request) + assert error is not None + assert "whitespace" in error.message.lower() + + def test_validate_embedding_request_empty_list(self, embedding_handler): + """Test validation fails with empty list input.""" + request = EmbeddingRequest(model="test-model", input=[]) + error = embedding_handler._validate_request(request) + assert error is not None + assert "empty" in error.message.lower() + + def test_validate_embedding_request_empty_string_in_list(self, embedding_handler): + """Test validation fails with empty string in list.""" + request = EmbeddingRequest(model="test-model", input=[""]) + error = embedding_handler._validate_request(request) + assert error is not None + assert "empty" in error.message.lower() + + def test_validate_embedding_request_valid_list( + self, embedding_handler, list_embedding_request + ): + """Test validation passes with valid list input.""" + error = embedding_handler._validate_request(list_embedding_request) + assert error is None + + +class TestEmbeddingHandlerConversion: + """Test request conversion methods.""" + + def test_convert_single_string_request( + self, embedding_handler, basic_embedding_request + ): + """Test converting single string request to internal format.""" + adapted_request, processed_request = ( + embedding_handler._convert_to_internal_request( + [basic_embedding_request], ["test-id"] + ) + ) + + assert isinstance(adapted_request, EmbeddingReqInput) + assert adapted_request.text == "Hello, how are you?" + assert adapted_request.rid == "test-id" + assert processed_request == basic_embedding_request + + def test_convert_list_string_request( + self, embedding_handler, list_embedding_request + ): + """Test converting list of strings request to internal format.""" + adapted_request, processed_request = ( + embedding_handler._convert_to_internal_request( + [list_embedding_request], ["test-id"] + ) + ) + + assert isinstance(adapted_request, EmbeddingReqInput) + assert adapted_request.text == ["Hello, how are you?", "I am fine, thank you!"] + assert adapted_request.rid == "test-id" + assert processed_request == list_embedding_request + + def test_convert_token_ids_request( + self, embedding_handler, token_ids_embedding_request + ): + """Test converting token IDs request to internal format.""" + adapted_request, processed_request = ( + embedding_handler._convert_to_internal_request( + [token_ids_embedding_request], ["test-id"] + ) + ) + + assert isinstance(adapted_request, EmbeddingReqInput) + assert adapted_request.input_ids == [1, 2, 3, 4, 5] + assert adapted_request.rid == "test-id" + assert processed_request == token_ids_embedding_request + + def test_convert_multimodal_request( + self, embedding_handler, multimodal_embedding_request + ): + """Test converting multimodal request to internal format.""" + adapted_request, processed_request = ( + embedding_handler._convert_to_internal_request( + [multimodal_embedding_request], ["test-id"] + ) + ) + + assert isinstance(adapted_request, EmbeddingReqInput) + # Should extract text and images separately + assert len(adapted_request.text) == 2 + assert "Hello" in adapted_request.text + assert "World" in adapted_request.text + assert adapted_request.image_data[0] == "base64_image_data" + assert adapted_request.image_data[1] is None + assert adapted_request.rid == "test-id" + + def test_convert_batch_requests(self, embedding_handler): + """Test converting multiple requests (batch) to internal format.""" + request1 = EmbeddingRequest(model="test-model", input="First text") + request2 = EmbeddingRequest(model="test-model", input="Second text") + + adapted_request, processed_requests = ( + embedding_handler._convert_to_internal_request( + [request1, request2], ["id1", "id2"] + ) + ) + + assert isinstance(adapted_request, EmbeddingReqInput) + assert adapted_request.text == ["First text", "Second text"] + assert adapted_request.rid == ["id1", "id2"] + assert processed_requests == [request1, request2] + + def test_convert_batch_requests_type_mismatch_error(self, embedding_handler): + """Test that batch requests with different input types raise error.""" + request1 = EmbeddingRequest(model="test-model", input="String input") + request2 = EmbeddingRequest(model="test-model", input=[1, 2, 3]) # Token IDs + + with pytest.raises(AssertionError, match="same type"): + embedding_handler._convert_to_internal_request( + [request1, request2], ["id1", "id2"] + ) + + +class TestEmbeddingResponseBuilding: + """Test response building methods.""" + + def test_build_single_embedding_response(self, embedding_handler): + """Test building response for single embedding.""" + ret_data = [ + { + "embedding": [0.1, 0.2, 0.3, 0.4, 0.5], + "meta_info": {"prompt_tokens": 5}, + } + ] + + response = embedding_handler._build_embedding_response(ret_data, "test-model") + + assert isinstance(response, EmbeddingResponse) + assert response.model == "test-model" + assert len(response.data) == 1 + assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5] + assert response.data[0].index == 0 + assert response.data[0].object == "embedding" + assert response.usage.prompt_tokens == 5 + assert response.usage.total_tokens == 5 + assert response.usage.completion_tokens == 0 + + def test_build_multiple_embedding_response(self, embedding_handler): + """Test building response for multiple embeddings.""" + ret_data = [ + { + "embedding": [0.1, 0.2, 0.3], + "meta_info": {"prompt_tokens": 3}, + }, + { + "embedding": [0.4, 0.5, 0.6], + "meta_info": {"prompt_tokens": 4}, + }, + ] + + response = embedding_handler._build_embedding_response(ret_data, "test-model") + + assert isinstance(response, EmbeddingResponse) + assert len(response.data) == 2 + assert response.data[0].embedding == [0.1, 0.2, 0.3] + assert response.data[0].index == 0 + assert response.data[1].embedding == [0.4, 0.5, 0.6] + assert response.data[1].index == 1 + assert response.usage.prompt_tokens == 7 # 3 + 4 + assert response.usage.total_tokens == 7 + + +@pytest.mark.asyncio +class TestEmbeddingHandlerAsyncMethods: + """Test async methods of EmbeddingHandler.""" + + async def test_handle_request_success( + self, embedding_handler, basic_embedding_request, mock_request + ): + """Test successful embedding request handling.""" + + # Mock the generate_request to return expected data + async def mock_generate(): + yield { + "embedding": [0.1, 0.2, 0.3, 0.4, 0.5], + "meta_info": {"prompt_tokens": 5}, + } + + embedding_handler.tokenizer_manager.generate_request = Mock( + return_value=mock_generate() + ) + + response = await embedding_handler.handle_request( + basic_embedding_request, mock_request + ) + + assert isinstance(response, EmbeddingResponse) + assert len(response.data) == 1 + assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5] + + async def test_handle_request_validation_error( + self, embedding_handler, mock_request + ): + """Test handling request with validation error.""" + invalid_request = EmbeddingRequest(model="test-model", input="") + + response = await embedding_handler.handle_request(invalid_request, mock_request) + + assert isinstance(response, ErrorResponse) + assert "empty" in response.message.lower() + + async def test_handle_request_generation_error( + self, embedding_handler, basic_embedding_request, mock_request + ): + """Test handling request with generation error.""" + + # Mock generate_request to raise an error + async def mock_generate_error(): + raise ValueError("Generation failed") + yield # This won't be reached but needed for async generator + + embedding_handler.tokenizer_manager.generate_request = Mock( + return_value=mock_generate_error() + ) + + response = await embedding_handler.handle_request( + basic_embedding_request, mock_request + ) + + assert isinstance(response, ErrorResponse) + assert "Generation failed" in response.message + + async def test_handle_request_internal_error( + self, embedding_handler, basic_embedding_request, mock_request + ): + """Test handling request with internal server error.""" + # Mock _convert_to_internal_request to raise an exception + with patch.object( + embedding_handler, + "_convert_to_internal_request", + side_effect=Exception("Internal error"), + ): + response = await embedding_handler.handle_request( + basic_embedding_request, mock_request + ) + + assert isinstance(response, ErrorResponse) + assert "Internal server error" in response.message + assert response.code == 500 + + +class TestCompatibilityWithAdapter: + """Test compatibility with original adapter.py implementation.""" + + def test_embedding_request_structure_matches_adapter(self, embedding_handler): + """Test that EmbeddingReqInput structure matches adapter expectations.""" + request = EmbeddingRequest(model="test-model", input="Test text") + + adapted_request, _ = embedding_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + # Check that adapted_request has expected fields from adapter.py + assert hasattr(adapted_request, "rid") + assert hasattr(adapted_request, "text") or hasattr(adapted_request, "input_ids") + assert adapted_request.rid == "test-id" + + def test_multimodal_embedding_processing_compatibility(self, embedding_handler): + """Test multimodal processing matches adapter patterns.""" + multimodal_input = [ + MultimodalEmbeddingInput(text="Hello", image="image_data"), + MultimodalEmbeddingInput(text="World", image=None), + ] + request = EmbeddingRequest(model="test-model", input=multimodal_input) + + adapted_request, _ = embedding_handler._convert_to_internal_request( + [request], ["test-id"] + ) + + # Should have text and image_data fields like adapter + assert hasattr(adapted_request, "text") + assert hasattr(adapted_request, "image_data") + assert len(adapted_request.text) == 2 + assert len(adapted_request.image_data) == 2 + + def test_response_format_matches_adapter(self, embedding_handler): + """Test response format matches adapter.py output.""" + ret_data = [ + { + "embedding": [0.1, 0.2, 0.3], + "meta_info": {"prompt_tokens": 3}, + } + ] + + response = embedding_handler._build_embedding_response(ret_data, "test-model") + + # Check response structure matches adapter output + assert response.object == "list" + assert isinstance(response.data, list) + assert len(response.data) == 1 + assert response.data[0].object == "embedding" + assert isinstance(response.data[0].embedding, list) + assert isinstance(response.data[0].index, int) + assert isinstance(response.usage, UsageInfo) + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_request_id_generation( + self, embedding_handler, basic_embedding_request, mock_request + ): + """Test that request IDs are properly generated when not provided.""" + # Request without rid + request_without_id = EmbeddingRequest(model="test-model", input="Test") + assert request_without_id.rid is None + + # Should generate ID during handling + with patch.object(embedding_handler, "_handle_request") as mock_handle: + mock_handle.return_value = EmbeddingResponse( + data=[], + model="test-model", + usage=UsageInfo(prompt_tokens=0, total_tokens=0), + ) + + asyncio.run( + embedding_handler.handle_request(request_without_id, mock_request) + ) + + # Check that context was created with generated ID + args, kwargs = mock_handle.call_args + ctx = args[2] # Third argument is context + assert ctx.request_id.startswith("embd-") + + def test_request_id_preservation(self, embedding_handler, mock_request): + """Test that provided request IDs are preserved.""" + request_with_id = EmbeddingRequest( + model="test-model", input="Test", rid="custom-id" + ) + + with patch.object(embedding_handler, "_handle_request") as mock_handle: + mock_handle.return_value = EmbeddingResponse( + data=[], + model="test-model", + usage=UsageInfo(prompt_tokens=0, total_tokens=0), + ) + + asyncio.run(embedding_handler.handle_request(request_with_id, mock_request)) + + # Check that custom ID was preserved + args, kwargs = mock_handle.call_args + ctx = args[2] # Third argument is context + assert ctx.request_id == "custom-id" + + def test_multimodal_batch_not_implemented(self, embedding_handler): + """Test that multimodal batch requests raise NotImplementedError.""" + request1 = EmbeddingRequest( + model="test-model", + input=[MultimodalEmbeddingInput(text="Hello", image="img1")], + ) + request2 = EmbeddingRequest( + model="test-model", + input=[MultimodalEmbeddingInput(text="World", image="img2")], + ) + + with pytest.raises(NotImplementedError, match="multimodal.*not supported"): + embedding_handler._convert_to_internal_request( + [request1, request2], ["id1", "id2"] + ) + + def test_empty_return_data_handling(self, embedding_handler): + """Test handling of empty return data from generation.""" + # Test with empty list + response = embedding_handler._build_embedding_response([], "test-model") + assert len(response.data) == 0 + assert response.usage.prompt_tokens == 0 + assert response.usage.total_tokens == 0 + + def test_missing_meta_info_handling(self, embedding_handler): + """Test handling of missing meta_info in return data.""" + ret_data = [ + { + "embedding": [0.1, 0.2, 0.3], + "meta_info": {}, # Missing prompt_tokens + } + ] + + # Should handle missing prompt_tokens gracefully + response = embedding_handler._build_embedding_response(ret_data, "test-model") + assert len(response.data) == 1 + # Should default to 0 for missing prompt_tokens + assert response.usage.prompt_tokens == 0 From f8d604b3517d161922266a0d84b65e7360f99b58 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sat, 14 Jun 2025 02:30:06 +0000 Subject: [PATCH 03/33] Refactors request handling in OpenAI endpoints Consolidates request handling logic into the base class to reduce code duplication. Moves the common request validation, context creation, and request dispatch logic to the OpenAIServingBase class. This change streamlines the structure of the handlers for chat completions, completions, and embeddings. The individual handler classes now only need to implement the conversion to internal format and the specific streaming and non-streaming handling logic. Signed-off-by: Xinyuan Tong --- .../srt/entrypoints/openai/serving_chat.py | 42 -------- .../entrypoints/openai/serving_completions.py | 56 ----------- .../entrypoints/openai/serving_embedding.py | 56 +++-------- .../srt/entrypoints/openai/serving_engine.py | 96 ++++++++++++++++++- .../srt/entrypoints/openai/validation.py | 9 ++ test/srt/openai/test_serving_completions.py | 4 +- test/srt/openai/test_serving_embedding.py | 8 +- 7 files changed, 121 insertions(+), 150 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 36652544130..d01a2f15e10 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -72,48 +72,6 @@ class ChatCompletionHandler(OpenAIServingBase): """Handler for chat completion requests""" - def __init__(self, tokenizer_manager: TokenizerManager): - super().__init__(tokenizer_manager) - - async def handle_request( - self, request: ChatCompletionRequest, raw_request: Request - ) -> Union[ChatCompletionResponse, StreamingResponse, ErrorResponse]: - """Handle a chat completion request""" - try: - # Validate request - error = self._validate_request(request) - if error: - return error - - # Create request context - ctx = RequestContext( - raw_request=raw_request, - openai_request=request, - request_id=request.rid or f"chatcmpl-{uuid.uuid4()}", - ) - - # Convert to internal format - adapted_request, processed_request = self._convert_to_internal_request( - [request], [ctx.request_id] - ) - - if request.stream: - return await self._handle_streaming_request( - adapted_request, processed_request, ctx - ) - else: - return await self._handle_non_streaming_request( - adapted_request, processed_request, ctx - ) - - except Exception as e: - logger.error(f"Error in chat completion: {e}") - return create_error_response( - message=f"Internal server error: {str(e)}", - err_type="InternalServerError", - status_code=500, - ) - def _convert_to_internal_request( self, all_requests: List[ChatCompletionRequest], diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index a8816cce032..f492e27a394 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -13,13 +13,9 @@ # ============================================================================== """Completion serving logic for OpenAI API""" -import json -import logging import time -import uuid from typing import Any, Dict, List, Union -from fastapi import Request from fastapi.responses import StreamingResponse from sglang.srt.code_completion_parser import ( @@ -49,63 +45,11 @@ to_openai_style_logprobs, ) from sglang.srt.managers.io_struct import GenerateReqInput -from sglang.srt.managers.tokenizer_manager import TokenizerManager - -logger = logging.getLogger(__name__) class CompletionHandler(OpenAIServingBase): """Handler for completion requests""" - def __init__(self, tokenizer_manager: TokenizerManager): - super().__init__(tokenizer_manager) - - async def handle_request( - self, request: CompletionRequest, raw_request: Request - ) -> Union[CompletionResponse, StreamingResponse, ErrorResponse]: - """Handle a completion request""" - try: - # Echo + logprobs warning - if request.echo and request.logprobs: - logger.warning( - "Echo is not compatible with logprobs. " - "To compute logprobs of input prompt, please use the native /generate API." - ) - - # Validate request - error = self._validate_request(request) - if error: - return error - - # Create request context - ctx = RequestContext( - raw_request=raw_request, - openai_request=request, - request_id=f"cmpl-{uuid.uuid4()}", - ) - - # Convert to internal format - adapted_request, processed_request = self._convert_to_internal_request( - [request], [ctx.request_id] - ) - - if request.stream: - return await self._handle_streaming_request( - adapted_request, processed_request, ctx - ) - else: - return await self._handle_non_streaming_request( - adapted_request, processed_request, ctx - ) - - except Exception as e: - logger.error(f"Error in completion: {e}") - return create_error_response( - message=f"Internal server error: {str(e)}", - err_type="InternalServerError", - status_code=500, - ) - def _convert_to_internal_request( self, all_requests: List[CompletionRequest], diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py index e7314613ea4..2136d3f220a 100644 --- a/python/sglang/srt/entrypoints/openai/serving_embedding.py +++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py @@ -14,10 +14,9 @@ """Embedding serving logic for OpenAI API""" import logging -import uuid -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Union -from fastapi import Request +from fastapi.responses import StreamingResponse from sglang.srt.conversation import generate_embedding_convs from sglang.srt.entrypoints.openai.protocol import ( @@ -34,50 +33,11 @@ ) from sglang.srt.entrypoints.openai.utils import create_error_response from sglang.srt.managers.io_struct import EmbeddingReqInput -from sglang.srt.managers.tokenizer_manager import TokenizerManager - -logger = logging.getLogger(__name__) class EmbeddingHandler(OpenAIServingBase): """Handler for embedding requests""" - def __init__(self, tokenizer_manager: TokenizerManager): - super().__init__(tokenizer_manager) - - async def handle_request( - self, request: EmbeddingRequest, raw_request: Request - ) -> Union[EmbeddingResponse, ErrorResponse]: - """Handle an embedding request""" - try: - # Validate request - error = self._validate_request(request) - if error: - return error - - # Create request context - ctx = RequestContext( - raw_request=raw_request, - openai_request=request, - request_id=request.rid or f"embd-{uuid.uuid4()}", - ) - - # Convert to internal format - adapted_request, processed_request = self._convert_to_internal_request( - [request], [ctx.request_id] - ) - - # Handle the request - return await self._handle_request(adapted_request, processed_request, ctx) - - except Exception as e: - logger.error(f"Error in embedding: {e}") - return create_error_response( - message=f"Internal server error: {str(e)}", - err_type="InternalServerError", - status_code=500, - ) - def _convert_to_internal_request( self, all_requests: List[EmbeddingRequest], @@ -186,7 +146,17 @@ def _convert_to_internal_request( all_requests[0] if len(all_requests) == 1 else all_requests ) - async def _handle_request( + async def _handle_streaming_request( + self, + adapted_request: EmbeddingReqInput, + request: EmbeddingRequest, + ctx: RequestContext, + ) -> StreamingResponse: + """Handle streaming embedding request (not supported)""" + # Embeddings don't support streaming + raise NotImplementedError("Embedding requests do not support streaming") + + async def _handle_non_streaming_request( self, adapted_request: EmbeddingReqInput, request: EmbeddingRequest, diff --git a/python/sglang/srt/entrypoints/openai/serving_engine.py b/python/sglang/srt/entrypoints/openai/serving_engine.py index 54de8250f01..aa4f6a7fa67 100644 --- a/python/sglang/srt/entrypoints/openai/serving_engine.py +++ b/python/sglang/srt/entrypoints/openai/serving_engine.py @@ -12,21 +12,30 @@ # limitations under the License. # ============================================================================== +import logging import time +import uuid from abc import ABC, abstractmethod -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Union from fastapi import Request +from fastapi.responses import StreamingResponse from sglang.srt.entrypoints.openai.protocol import ( + ChatCompletionRequest, + CompletionRequest, + EmbeddingRequest, ErrorResponse, OpenAIServingRequest, UsageInfo, ) from sglang.srt.entrypoints.openai.utils import create_error_response from sglang.srt.entrypoints.openai.validation import get_validation_rules +from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager +logger = logging.getLogger(__name__) + class RequestContext: """Context object for tracking request state throughout the pipeline""" @@ -63,11 +72,90 @@ class OpenAIServingBase(ABC): def __init__(self, tokenizer_manager: TokenizerManager): self.tokenizer_manager = tokenizer_manager - @abstractmethod async def handle_request( self, request: OpenAIServingRequest, raw_request: Request - ) -> Any: - """Handle the specific request type""" + ) -> Union[Any, StreamingResponse, ErrorResponse]: + """Handle the specific request type with common pattern""" + try: + # Validate request + error = self._validate_request(request) + if error: + return error + + # Create request context + ctx = RequestContext( + raw_request=raw_request, + openai_request=request, + request_id=self._generate_request_id(request), + ) + + # Convert to internal format + adapted_request, processed_request = self._convert_to_internal_request( + [request], [ctx.request_id] + ) + + # Check if this handler supports streaming + if hasattr(request, "stream") and request.stream: + return await self._handle_streaming_request( + adapted_request, processed_request, ctx + ) + else: + return await self._handle_non_streaming_request( + adapted_request, processed_request, ctx + ) + + except Exception as e: + logger.error(f"Error in request: {e}") + return create_error_response( + message=f"Internal server error: {str(e)}", + err_type="InternalServerError", + status_code=500, + ) + + def _generate_request_id(self, request: OpenAIServingRequest) -> str: + """Generate request ID based on request type""" + # Default implementation - can be overridden + if rid := getattr(request, "rid", None): + return rid + + # Determine prefix based on request type + prefix_mapping = { + ChatCompletionRequest: "chatcmpl", + CompletionRequest: "cmpl", + EmbeddingRequest: "embd", + } + prefix = prefix_mapping.get(type(request), "req") + return f"{prefix}-{uuid.uuid4()}" + + @abstractmethod + def _convert_to_internal_request( + self, + all_requests: List[OpenAIServingRequest], + request_ids: List[str], + ) -> tuple[ + GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]] + ]: + """Convert OpenAI request to internal format""" + pass + + @abstractmethod + async def _handle_streaming_request( + self, + adapted_request: GenerateReqInput, + request: OpenAIServingRequest, + ctx: RequestContext, + ) -> StreamingResponse: + """Handle streaming request""" + pass + + @abstractmethod + async def _handle_non_streaming_request( + self, + adapted_request: GenerateReqInput, + request: OpenAIServingRequest, + ctx: RequestContext, + ) -> Union[Any, ErrorResponse]: + """Handle non-streaming request""" pass def _validate_request( diff --git a/python/sglang/srt/entrypoints/openai/validation.py b/python/sglang/srt/entrypoints/openai/validation.py index e1f0c3f3cff..1e65f28410d 100644 --- a/python/sglang/srt/entrypoints/openai/validation.py +++ b/python/sglang/srt/entrypoints/openai/validation.py @@ -13,6 +13,7 @@ # ============================================================================== """Pre-built validation rules for OpenAI API parameters""" +import logging import re from typing import Any, Callable, List, Optional, Union @@ -25,6 +26,8 @@ OpenAIServingRequest, ) +logger = logging.getLogger(__name__) + class ValidationRule: """Represents a validation rule for request parameters""" @@ -399,6 +402,12 @@ def get_validation_rules(request: OpenAIServingRequest) -> List[ValidationRule]: if isinstance(request, ChatCompletionRequest): return get_common_validation_rules() + get_chat_specific_validation_rules() elif isinstance(request, CompletionRequest): + # Echo + logprobs warning + if request.echo and request.logprobs: + logger.warning( + "Echo is not compatible with logprobs. " + "To compute logprobs of input prompt, please use the native /generate API." + ) return ( get_common_validation_rules() + get_completion_specific_validation_rules() ) diff --git a/test/srt/openai/test_serving_completions.py b/test/srt/openai/test_serving_completions.py index 39ba51c9c3f..0abe3a30a8e 100644 --- a/test/srt/openai/test_serving_completions.py +++ b/test/srt/openai/test_serving_completions.py @@ -1003,9 +1003,7 @@ async def test_echo_warning_with_logprobs(self, completion_handler): mock_raw_request = Mock() - with patch( - "sglang.srt.entrypoints.openai.serving_completions.logger" - ) as mock_logger: + with patch("sglang.srt.entrypoints.openai.validation.logger") as mock_logger: # Call handle_request which contains the warning logic await completion_handler.handle_request(request, mock_raw_request) # Should log warning about echo + logprobs incompatibility diff --git a/test/srt/openai/test_serving_embedding.py b/test/srt/openai/test_serving_embedding.py index 087bd35a40c..62c2c0e5459 100644 --- a/test/srt/openai/test_serving_embedding.py +++ b/test/srt/openai/test_serving_embedding.py @@ -456,7 +456,9 @@ def test_request_id_generation( assert request_without_id.rid is None # Should generate ID during handling - with patch.object(embedding_handler, "_handle_request") as mock_handle: + with patch.object( + embedding_handler, "_handle_non_streaming_request" + ) as mock_handle: mock_handle.return_value = EmbeddingResponse( data=[], model="test-model", @@ -478,7 +480,9 @@ def test_request_id_preservation(self, embedding_handler, mock_request): model="test-model", input="Test", rid="custom-id" ) - with patch.object(embedding_handler, "_handle_request") as mock_handle: + with patch.object( + embedding_handler, "_handle_non_streaming_request" + ) as mock_handle: mock_handle.return_value = EmbeddingResponse( data=[], model="test-model", From a86bf2791547db53f4820e5c364d9141a48678c9 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sat, 14 Jun 2025 02:47:08 +0000 Subject: [PATCH 04/33] Adds documentation to OpenAI API endpoints Adds comprehensive docstrings to the OpenAI API endpoint modules, including descriptions of key features, processing pipelines, and architecture. This improves code maintainability and provides better understanding of the purpose and functionality of each module. Signed-off-by: Xinyuan Tong --- .../sglang/srt/entrypoints/openai/__init__.py | 41 +++++++++++++++---- .../sglang/srt/entrypoints/openai/protocol.py | 10 +---- .../srt/entrypoints/openai/serving_chat.py | 29 ++++++++++++- .../entrypoints/openai/serving_completions.py | 29 ++++++++++++- .../entrypoints/openai/serving_embedding.py | 34 ++++++++++++++- .../srt/entrypoints/openai/serving_engine.py | 23 +++++++++++ python/sglang/srt/entrypoints/openai/utils.py | 21 +++++++++- .../srt/entrypoints/openai/validation.py | 28 ++++++++++++- 8 files changed, 193 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/__init__.py b/python/sglang/srt/entrypoints/openai/__init__.py index 17f1eacfb28..2df1cce2270 100644 --- a/python/sglang/srt/entrypoints/openai/__init__.py +++ b/python/sglang/srt/entrypoints/openai/__init__.py @@ -11,13 +11,38 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""OpenAI-compatible API server module""" +""" +OpenAI-compatible API server module for SGLang. -from .protocol import * -from .serving_engine import OpenAIServingBase, RequestContext -from .utils import * +This module provides OpenAI-compatible API endpoints that allow existing OpenAI client +applications to seamlessly work with SGLang models. The implementation includes: -__all__ = [ - "OpenAIServingBase", - "RequestContext", -] +Key Features: +- Full OpenAI API compatibility for chat completions, text completions, and embeddings +- Streaming support for real-time response generation +- Batch processing capabilities for multiple requests +- Function calling and tool use support +- Multimodal input support (text, images, audio) +- Advanced reasoning capabilities with separate reasoning content +- Custom sampling parameters and constraints (regex, JSON schema, EBNF) +- LoRA adapter support for fine-tuned models +- Cache reporting and token usage tracking + +Supported Endpoints: +- /v1/chat/completions - Chat-based completions with conversation history +- /v1/completions - Text completions for single prompts +- /v1/embeddings - Text/multimodal embeddings generation +- /v1/models - Model listing and information + +The module is structured with separate handlers for each endpoint type, all inheriting +from a common base class that provides shared functionality like request validation, +error handling, and response formatting. + +Architecture: +- OpenAIServingBase: Abstract base class for all endpoint handlers +- ChatCompletionHandler: Handles chat completion requests +- CompletionHandler: Handles text completion requests +- EmbeddingHandler: Handles embedding requests +- Protocol classes: Pydantic models for request/response validation +- Utility functions: Shared helpers for formatting and validation +""" diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index f083f61fa6c..e60dfe385be 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -16,15 +16,7 @@ import time from typing import Dict, List, Optional, Union -from pydantic import ( - BaseModel, - Field, - field_validator, - model_validator, - root_validator, - validator, -) -from pydantic_core import ValidationError +from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Literal diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index d01a2f15e10..f580c4bc948 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -11,7 +11,34 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Chat completions serving logic for OpenAI API""" +""" +Chat completions serving logic for OpenAI API. + +This module implements the /v1/chat/completions endpoint, providing full OpenAI +compatibility for chat-based interactions with conversation history support. + +Key Features: +- Full OpenAI chat completions API compatibility +- Streaming and non-streaming response modes +- Multimodal support (text, images, audio) +- Function calling and tool use +- Advanced reasoning with separate reasoning content +- Chat template processing for different model types +- Custom sampling parameters and output constraints +- LoRA adapter support + +Processing Pipeline: +1. Request validation and preprocessing +2. Message processing and chat template application +3. Tool/function call setup if applicable +4. Internal request generation with SGLang extensions +5. Model inference with streaming or batch processing +6. Response formatting and postprocessing +7. Tool call parsing and reasoning extraction + +The implementation handles both string-based and structured content formats, +automatically detecting the appropriate format based on the model's chat template. +""" import base64 import json diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index f492e27a394..26dbef324fa 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -11,7 +11,34 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Completion serving logic for OpenAI API""" +""" +Text completion serving logic for OpenAI API. + +This module implements the /v1/completions endpoint, providing OpenAI-compatible +text completion functionality for single prompts without conversation context. + +Key Features: +- Full OpenAI text completions API compatibility +- Streaming and non-streaming response modes +- Echo support to include the prompt in the response +- Custom completion templates for specialized use cases +- Advanced sampling parameters and output constraints +- Batch processing support for multiple prompts +- LoRA adapter support for fine-tuned models +- Token-level logprobs with configurable detail levels + +Processing Pipeline: +1. Request validation and prompt preprocessing +2. Completion template application if configured +3. Sampling parameter configuration +4. Internal request generation with SGLang extensions +5. Model inference with optional streaming +6. Response formatting with echo handling +7. Logprobs processing and token usage calculation + +The implementation supports various prompt formats including strings, token IDs, +and batched inputs, with automatic type detection and validation. +""" import time from typing import Any, Dict, List, Union diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py index 2136d3f220a..1c5ae7057d7 100644 --- a/python/sglang/srt/entrypoints/openai/serving_embedding.py +++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py @@ -11,7 +11,39 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Embedding serving logic for OpenAI API""" +""" +Embedding serving logic for OpenAI API. + +This module implements the /v1/embeddings endpoint, providing OpenAI-compatible +text and multimodal embedding generation capabilities. + +Key Features: +- Full OpenAI embeddings API compatibility +- Text embedding generation for single and batch inputs +- Multimodal embedding support (text + image combinations) +- Chat template integration for embedding-specific formatting +- Batch processing for multiple inputs +- Comprehensive input validation and error handling +- Token usage tracking and reporting + +Supported Input Types: +- Single string: Direct text input +- List of strings: Batch text embedding +- List of MultimodalEmbeddingInput: Text+image combinations +- Token IDs: Pre-tokenized input sequences + +Processing Pipeline: +1. Input validation and type detection +2. Multimodal content processing (if applicable) +3. Chat template application for embedding context +4. Internal request generation +5. Model inference for embedding generation +6. Response formatting with usage statistics + +The implementation handles various input formats gracefully and provides +detailed error messages for invalid inputs. Multimodal embeddings use +padding for missing text content when needed. +""" import logging from typing import Any, Dict, List, Union diff --git a/python/sglang/srt/entrypoints/openai/serving_engine.py b/python/sglang/srt/entrypoints/openai/serving_engine.py index aa4f6a7fa67..abe09bba3df 100644 --- a/python/sglang/srt/entrypoints/openai/serving_engine.py +++ b/python/sglang/srt/entrypoints/openai/serving_engine.py @@ -11,6 +11,29 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +""" +Base serving engine for OpenAI API endpoints. + +This module provides the foundational classes and request handling patterns +used by all OpenAI API endpoint implementations. It establishes a common +architecture for request processing, validation, and response generation. + +Key Components: +- RequestContext: Tracks request state and metadata throughout processing +- OpenAIServingBase: Abstract base class for all endpoint handlers +- Common request handling patterns with proper error handling +- Validation integration for request parameters +- Streaming and non-streaming response support + +Architecture Pattern: +All endpoint handlers inherit from OpenAIServingBase and implement: +1. _convert_to_internal_request: Transform OpenAI request to SGLang format +2. _handle_streaming_request: Process streaming requests +3. _handle_non_streaming_request: Process non-streaming requests + +This ensures consistent behavior across all endpoints while allowing +endpoint-specific customization. +""" import logging import time diff --git a/python/sglang/srt/entrypoints/openai/utils.py b/python/sglang/srt/entrypoints/openai/utils.py index 14c27220fcb..8bfd4899abe 100644 --- a/python/sglang/srt/entrypoints/openai/utils.py +++ b/python/sglang/srt/entrypoints/openai/utils.py @@ -11,7 +11,26 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utility functions for OpenAI API server""" +""" +Utility functions for OpenAI API server. + +This module provides shared utility functions used across the OpenAI API implementation, +including template processing, validation, error handling, and response formatting. + +Key Components: +- Template Format Detection: Analyzes Jinja templates to determine content format +- Content Processing: Handles multimodal content based on template requirements +- Token Usage Calculation: Aggregates token usage across requests and responses +- Error Response Generation: Creates standardized error responses +- Logprobs Formatting: Converts internal logprobs to OpenAI format +- Validation Helpers: Common validation functions for requests +- Streaming Utilities: Helpers for streaming response formatting + +Template Format Detection: +The module includes sophisticated logic to detect whether a chat template expects +'string' or 'openai' content format by analyzing the Jinja template AST. This enables +proper content processing for different model types (e.g., DeepSeek vs Llama). +""" import json import logging diff --git a/python/sglang/srt/entrypoints/openai/validation.py b/python/sglang/srt/entrypoints/openai/validation.py index 1e65f28410d..523de26b79b 100644 --- a/python/sglang/srt/entrypoints/openai/validation.py +++ b/python/sglang/srt/entrypoints/openai/validation.py @@ -11,7 +11,33 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Pre-built validation rules for OpenAI API parameters""" +""" +Pre-built validation rules for OpenAI API parameters. + +This module provides comprehensive validation for all OpenAI API request parameters, +ensuring requests meet both OpenAI standards and SGLang-specific requirements. + +Key Components: +- ValidationRule: Encapsulates parameter validation logic +- Parameter Validators: Specific validation functions for different parameter types +- Request Type Handlers: Validation rule sets for different endpoint types +- Common Validators: Shared validation logic across endpoints + +Validation Categories: +- Basic Types: String, number, boolean validation +- Ranges: Min/max validation for numeric parameters +- Formats: Pattern matching for structured data +- Content: Message and prompt content validation +- Constraints: Cross-parameter dependency validation + +The validation system is designed to provide clear, actionable error messages +that help users understand and fix request issues quickly. + +Usage: +Validation rules are automatically applied based on request type. Each rule +specifies the parameter name, validation function, and parameter accessor, +allowing for flexible and comprehensive validation coverage. +""" import logging import re From 5ddc8fc11615c170e4df840890f9960b6773dd6e Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sat, 14 Jun 2025 02:52:09 +0000 Subject: [PATCH 05/33] Simplifies getting enable_thinking value Signed-off-by: Xinyuan Tong --- python/sglang/srt/entrypoints/openai/utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/utils.py b/python/sglang/srt/entrypoints/openai/utils.py index 8bfd4899abe..5ab2e8804fd 100644 --- a/python/sglang/srt/entrypoints/openai/utils.py +++ b/python/sglang/srt/entrypoints/openai/utils.py @@ -506,13 +506,7 @@ def _get_enable_thinking_from_request(request_obj): Returns: The boolean value of 'enable_thinking' if found and not True, otherwise True. """ - if ( - hasattr(request_obj, "chat_template_kwargs") - and request_obj.chat_template_kwargs - and request_obj.chat_template_kwargs.get("enable_thinking") is not None - ): - return request_obj.chat_template_kwargs.get("enable_thinking") - return True + return getattr(request_obj, "chat_template_kwargs", {}).get("enable_thinking", True) def create_streaming_chunk_data(chunk_data: str) -> str: From 2ddbb40c6663012a149d381690c9184041cefce6 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sat, 14 Jun 2025 02:59:41 +0000 Subject: [PATCH 06/33] rename serving_engine to serving_base Signed-off-by: Xinyuan Tong --- .../openai/{serving_engine.py => serving_base.py} | 0 python/sglang/srt/entrypoints/openai/serving_chat.py | 5 +---- python/sglang/srt/entrypoints/openai/serving_completions.py | 5 +---- python/sglang/srt/entrypoints/openai/serving_embedding.py | 5 +---- test/srt/openai/test_serving_chat.py | 1 - test/srt/openai/test_serving_embedding.py | 1 - 6 files changed, 3 insertions(+), 14 deletions(-) rename python/sglang/srt/entrypoints/openai/{serving_engine.py => serving_base.py} (100%) diff --git a/python/sglang/srt/entrypoints/openai/serving_engine.py b/python/sglang/srt/entrypoints/openai/serving_base.py similarity index 100% rename from python/sglang/srt/entrypoints/openai/serving_engine.py rename to python/sglang/srt/entrypoints/openai/serving_base.py diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index f580c4bc948..c75230f3981 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -67,10 +67,7 @@ TopLogprob, UsageInfo, ) -from sglang.srt.entrypoints.openai.serving_engine import ( - OpenAIServingBase, - RequestContext, -) +from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase, RequestContext from sglang.srt.entrypoints.openai.utils import ( _get_enable_thinking_from_request, aggregate_token_usage, diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 26dbef324fa..65da1a60ddd 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -58,10 +58,7 @@ CompletionStreamResponse, ErrorResponse, ) -from sglang.srt.entrypoints.openai.serving_engine import ( - OpenAIServingBase, - RequestContext, -) +from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase, RequestContext from sglang.srt.entrypoints.openai.utils import ( aggregate_token_usage, build_base_sampling_params, diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py index 1c5ae7057d7..ec0e9f38dfc 100644 --- a/python/sglang/srt/entrypoints/openai/serving_embedding.py +++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py @@ -59,10 +59,7 @@ MultimodalEmbeddingInput, UsageInfo, ) -from sglang.srt.entrypoints.openai.serving_engine import ( - OpenAIServingBase, - RequestContext, -) +from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase, RequestContext from sglang.srt.entrypoints.openai.utils import create_error_response from sglang.srt.managers.io_struct import EmbeddingReqInput diff --git a/test/srt/openai/test_serving_chat.py b/test/srt/openai/test_serving_chat.py index 69c67118eae..11ce7d63aa5 100644 --- a/test/srt/openai/test_serving_chat.py +++ b/test/srt/openai/test_serving_chat.py @@ -30,7 +30,6 @@ UsageInfo, ) from sglang.srt.entrypoints.openai.serving_chat import ChatCompletionHandler -from sglang.srt.entrypoints.openai.serving_engine import RequestContext from sglang.srt.entrypoints.openai.utils import ( build_base_sampling_params, create_error_response, diff --git a/test/srt/openai/test_serving_embedding.py b/test/srt/openai/test_serving_embedding.py index 62c2c0e5459..1aac7370cca 100644 --- a/test/srt/openai/test_serving_embedding.py +++ b/test/srt/openai/test_serving_embedding.py @@ -25,7 +25,6 @@ UsageInfo, ) from sglang.srt.entrypoints.openai.serving_embedding import EmbeddingHandler -from sglang.srt.entrypoints.openai.serving_engine import RequestContext from sglang.srt.managers.io_struct import EmbeddingReqInput From 4596b520330dd70be0f178ac561111bff7599698 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sat, 14 Jun 2025 04:53:14 +0000 Subject: [PATCH 07/33] Makes chat template caching instance-specific Improves thread safety by making the chat template caching mechanism instance-specific, moving it from a global scope to the ChatCompletionHandler class. This ensures that each handler instance maintains its own cache, preventing potential conflicts when multiple instances are used concurrently. Signed-off-by: Xinyuan Tong --- .../srt/entrypoints/openai/serving_chat.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index c75230f3981..8f786a76f24 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -88,14 +88,16 @@ logger = logging.getLogger(__name__) -# Global cache for template content format detection -_cached_chat_template = None -_cached_template_format = None - class ChatCompletionHandler(OpenAIServingBase): """Handler for chat completion requests""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Instance-specific cache for template content format detection + self._cached_chat_template = None + self._cached_template_format = None + def _convert_to_internal_request( self, all_requests: List[ChatCompletionRequest], @@ -250,8 +252,6 @@ def _apply_jinja_template( is_multimodal: bool, ) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]: """Apply Jinja chat template""" - global _cached_chat_template, _cached_template_format - openai_compatible_messages = [] image_data = [] audio_data = [] @@ -259,14 +259,16 @@ def _apply_jinja_template( # Detect template content format current_template = self.tokenizer_manager.tokenizer.chat_template - if current_template != _cached_chat_template: - _cached_chat_template = current_template - _cached_template_format = detect_template_content_format(current_template) + if current_template != self._cached_chat_template: + self._cached_chat_template = current_template + self._cached_template_format = detect_template_content_format( + current_template + ) logger.info( - f"Detected chat template content format: {_cached_template_format}" + f"Detected chat template content format: {self._cached_template_format}" ) - template_content_format = _cached_template_format + template_content_format = self._cached_template_format for message in request.messages: if message.content is None: From 47d54dc8500fab15f1f0097c942bbd431f7f8709 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sat, 14 Jun 2025 05:29:54 +0000 Subject: [PATCH 08/33] Refactors logprobs processing Refactors the logprobs processing logic into a common helper function to avoid duplication between streaming and non-streaming responses. This change improves code maintainability and reduces the risk of inconsistencies in logprobs handling. Signed-off-by: Xinyuan Tong --- .../srt/entrypoints/openai/serving_chat.py | 57 +++++++++---------- 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 8f786a76f24..29ea85ebdd8 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -63,6 +63,7 @@ DeltaMessage, ErrorResponse, FunctionResponse, + LogProbs, ToolCall, TopLogprob, UsageInfo, @@ -700,21 +701,29 @@ def _build_chat_response( usage=usage, ) - def _process_response_logprobs(self, ret_item: Dict[str, Any]) -> ChoiceLogprobs: - """Process logprobs for non-streaming response""" - logprobs = to_openai_style_logprobs( - output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], - output_top_logprobs=ret_item["meta_info"].get("output_top_logprobs", None), - ) + def _process_logprobs_tokens( + self, logprobs: LogProbs, use_token_index: bool = False + ) -> List[ChatCompletionTokenLogprob]: + """Common helper to process logprobs tokens for both streaming and non-streaming + Args: + logprobs: LogProbs data from model + use_token_index: True for non-streaming (use token_idx), False for streaming (use index 0) + """ token_logprobs = [] + for token_idx, (token, logprob) in enumerate( zip(logprobs.tokens, logprobs.token_logprobs) ): token_bytes = list(token.encode("utf-8")) top_logprobs = [] if logprobs.top_logprobs: - for top_token, top_logprob in logprobs.top_logprobs[token_idx].items(): + # - Non-streaming (use_token_index=True): uses token_idx for full data + # - Streaming (use_token_index=False): uses index 0 for pre-sliced data + top_logprobs_idx = token_idx if use_token_index else 0 + for top_token, top_logprob in logprobs.top_logprobs[ + top_logprobs_idx + ].items(): top_token_bytes = list(top_token.encode("utf-8")) top_logprobs.append( TopLogprob( @@ -732,6 +741,16 @@ def _process_response_logprobs(self, ret_item: Dict[str, Any]) -> ChoiceLogprobs ) ) + return token_logprobs + + def _process_response_logprobs(self, ret_item: Dict[str, Any]) -> ChoiceLogprobs: + """Process logprobs for non-streaming response""" + logprobs = to_openai_style_logprobs( + output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], + output_top_logprobs=ret_item["meta_info"].get("output_top_logprobs", None), + ) + + token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=True) return ChoiceLogprobs(content=token_logprobs) def _process_tool_calls( @@ -779,29 +798,7 @@ def _process_streaming_logprobs( ], ) - token_logprobs = [] - for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs): - token_bytes = list(token.encode("utf-8")) - top_logprobs = [] - if logprobs.top_logprobs: - for top_token, top_logprob in logprobs.top_logprobs[0].items(): - top_token_bytes = list(top_token.encode("utf-8")) - top_logprobs.append( - TopLogprob( - token=top_token, - bytes=top_token_bytes, - logprob=top_logprob, - ) - ) - token_logprobs.append( - ChatCompletionTokenLogprob( - token=token, - bytes=token_bytes, - logprob=logprob, - top_logprobs=top_logprobs, - ) - ) - + token_logprobs = self._process_logprobs_tokens(logprobs, use_token_index=False) return ChoiceLogprobs(content=token_logprobs) def _process_reasoning_stream( From 8ac4349bce2f7559cde7dc74a20bd3e7f7937f42 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Date: Sat, 14 Jun 2025 13:42:30 +0800 Subject: [PATCH 09/33] Update python/sglang/srt/entrypoints/openai/protocol.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/sglang/srt/entrypoints/openai/protocol.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index e60dfe385be..c7423ed1b43 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -191,8 +191,8 @@ class CompletionRequest(BaseModel): @field_validator("max_tokens") @classmethod def validate_max_tokens_positive(cls, v): - if v is not None and v < 0: - raise ValueError("max_tokens must be non-negative") + if v is not None and v <= 0: + raise ValueError("max_tokens must be positive") return v From 00b202ce82a8cf1845447a02cf97d91730df59e3 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 13 Jun 2025 22:25:13 -0700 Subject: [PATCH 10/33] Improve test cases for eagle infer (#7173) --- test/srt/run_suite.py | 4 +- test/srt/test_eagle_infer_a.py | 2 +- test/srt/test_eagle_infer_b.py | 74 ++++++++++++++++++++-------------- 3 files changed, 46 insertions(+), 34 deletions(-) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 6eb175cb142..8cdde63e27f 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -31,8 +31,8 @@ class TestFile: TestFile("test_block_int8.py", 22), TestFile("test_create_kvindices.py", 2), TestFile("test_chunked_prefill.py", 313), - TestFile("test_eagle_infer_a.py", 300), - TestFile("test_eagle_infer_b.py", 300), + TestFile("test_eagle_infer_a.py", 370), + TestFile("test_eagle_infer_b.py", 270), TestFile("test_ebnf_constrained.py", 108), TestFile("test_enable_thinking.py", 70), TestFile("test_embedding_openai_server.py", 141), diff --git a/test/srt/test_eagle_infer_a.py b/test/srt/test_eagle_infer_a.py index 298f1073e1b..c19f0c22f08 100644 --- a/test/srt/test_eagle_infer_a.py +++ b/test/srt/test_eagle_infer_a.py @@ -129,7 +129,7 @@ def _test_acc_length(self, engine): output["meta_info"]["completion_tokens"] / output["meta_info"]["e2e_latency"] ) - print(f"{acc_length=}") + print(f"{acc_length=:.4f}, {speed=}") if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST: self.assertGreater(acc_length, 3.6) diff --git a/test/srt/test_eagle_infer_b.py b/test/srt/test_eagle_infer_b.py index 72a69864fe4..f71feb15a77 100644 --- a/test/srt/test_eagle_infer_b.py +++ b/test/srt/test_eagle_infer_b.py @@ -10,7 +10,6 @@ import numpy as np import requests -import torch from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval @@ -24,10 +23,6 @@ run_logprob_check, ) -torch_dtype = torch.float16 -prefill_tolerance = 5e-2 -decode_tolerance: float = 5e-2 - class TestEAGLEServer(CustomTestCase): PROMPTS = [ @@ -202,7 +197,11 @@ def test_logprob_match(self): """Test the output logprobs are close to the input logprobs if we run a prefill again.""" def run_generate( - prompt, return_logprob=False, max_new_tokens=512, logprob_start_len=-1 + prompt, + return_logprob=False, + max_new_tokens=512, + logprob_start_len=-1, + temperature=1.0, ): if isinstance(prompt, str): @@ -215,45 +214,58 @@ def run_generate( json={ **prompt_kwargs, "sampling_params": { - "temperature": 1.0, + "temperature": temperature, "max_new_tokens": max_new_tokens, "ignore_eos": True, }, "return_logprob": return_logprob, "return_text_in_logprobs": True, "logprob_start_len": logprob_start_len, + "temp_scaled_logprobs": True, }, ) return response.json() prompt = "I have a very good idea on how to" - gen = run_generate(prompt, return_logprob=True, logprob_start_len=0) - output_logprobs = np.array( - [x[0] for x in gen["meta_info"]["output_token_logprobs"]] - ) - num_prompts_tokens = gen["meta_info"]["prompt_tokens"] - - input_tokens = [x[1] for x in gen["meta_info"]["input_token_logprobs"]] - output_tokens = [x[1] for x in gen["meta_info"]["output_token_logprobs"]] - - new_prompt = input_tokens + output_tokens - score = run_generate( - new_prompt, return_logprob=True, logprob_start_len=0, max_new_tokens=0 - ) - output_logprobs_score = np.array( - [ - x[0] - for x in score["meta_info"]["input_token_logprobs"][num_prompts_tokens:] - ] - ) + for temperature in [1.0]: + gen = run_generate( + prompt, + return_logprob=True, + logprob_start_len=0, + temperature=temperature, + ) + output_logprobs = np.array( + [x[0] for x in gen["meta_info"]["output_token_logprobs"]] + ) + num_prompts_tokens = gen["meta_info"]["prompt_tokens"] + + input_tokens = [x[1] for x in gen["meta_info"]["input_token_logprobs"]] + output_tokens = [x[1] for x in gen["meta_info"]["output_token_logprobs"]] + + new_prompt = input_tokens + output_tokens + score = run_generate( + new_prompt, + return_logprob=True, + logprob_start_len=0, + max_new_tokens=0, + temperature=temperature, + ) + output_logprobs_score = np.array( + [ + x[0] + for x in score["meta_info"]["input_token_logprobs"][ + num_prompts_tokens: + ] + ] + ) - print(f"{output_logprobs[-10:]=}") - print(f"{output_logprobs_score[-10:]=}") + print(f"{output_logprobs[-10:]=}") + print(f"{output_logprobs_score[-10:]=}") - diff = np.abs(output_logprobs - output_logprobs_score) - max_diff = np.max(diff) - self.assertLess(max_diff, 0.25) + diff = np.abs(output_logprobs - output_logprobs_score) + max_diff = np.max(diff) + self.assertLess(max_diff, 0.255) def test_logprob_mixed(self): args = [] From fb4ae0589f4909f5d94a5a545240b2d2b9409204 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sat, 14 Jun 2025 06:09:32 +0000 Subject: [PATCH 11/33] fix CI Signed-off-by: Xinyuan Tong --- test/srt/openai/__init__.py | 14 -------------- 1 file changed, 14 deletions(-) delete mode 100644 test/srt/openai/__init__.py diff --git a/test/srt/openai/__init__.py b/test/srt/openai/__init__.py deleted file mode 100644 index 3379038e77f..00000000000 --- a/test/srt/openai/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for OpenAI-compatible API server refactor""" From 3b28fdb4c1754a37e3bfb23f68b1ac6841381808 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sat, 14 Jun 2025 23:33:35 +0000 Subject: [PATCH 12/33] Removes unused utility functions Signed-off-by: Xinyuan Tong --- python/sglang/srt/entrypoints/openai/utils.py | 103 ------------------ 1 file changed, 103 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/utils.py b/python/sglang/srt/entrypoints/openai/utils.py index 5ab2e8804fd..e6f0813ca79 100644 --- a/python/sglang/srt/entrypoints/openai/utils.py +++ b/python/sglang/srt/entrypoints/openai/utils.py @@ -357,109 +357,6 @@ def sanitize_model_name(model: str) -> str: return sanitized.strip() -def extract_error_message(exception: Exception) -> str: - """Extract a clean error message from an exception - - Args: - exception: Exception to extract message from - - Returns: - Clean error message string - """ - error_msg = str(exception) - - # Remove common prefixes that aren't user-friendly - prefixes_to_remove = [ - "ValidationError: ", - "ValueError: ", - "TypeError: ", - "KeyError: ", - ] - - for prefix in prefixes_to_remove: - if error_msg.startswith(prefix): - error_msg = error_msg[len(prefix) :] - break - - # Limit length for safety - if len(error_msg) > 500: - error_msg = error_msg[:500] + "..." - - return error_msg - - -def format_validation_errors(errors: List[Dict[str, Any]]) -> str: - """Format Pydantic validation errors into a user-friendly message - - Args: - errors: List of validation error dictionaries - - Returns: - Formatted error message - """ - if not errors: - return "Unknown validation error" - - messages = [] - for error in errors[:5]: # Limit to first 5 errors - loc = " -> ".join(str(x) for x in error.get("loc", [])) - msg = error.get("msg", "Unknown error") - if loc: - messages.append(f"{loc}: {msg}") - else: - messages.append(msg) - - result = "; ".join(messages) - - if len(errors) > 5: - result += f" (and {len(errors) - 5} more errors)" - - return result - - -def is_multimodal_content(content: Any) -> bool: - """Check if content contains multimodal elements - - Args: - content: Content to check - - Returns: - True if content is multimodal, False otherwise - """ - if isinstance(content, list): - return any( - isinstance(item, dict) and item.get("type") in ["image_url", "audio_url"] - for item in content - ) - return False - - -def count_message_tokens_estimate(messages: List[ChatCompletionMessageParam]) -> int: - """Rough estimate of token count for messages (for validation purposes) - - Args: - messages: List of chat messages - - Returns: - Estimated token count - """ - total_chars = 0 - - for msg in messages: - if isinstance(msg.content, str): - total_chars += len(msg.content) - elif isinstance(msg.content, list): - for item in msg.content: - if isinstance(item, dict) and item.get("type") == "text": - total_chars += len(item.get("text", "")) - - # Add some tokens for role and structure - total_chars += 10 - - # Rough estimate: 1 token ≈ 4 characters for English text - return total_chars // 4 - - def to_openai_style_logprobs( input_token_logprobs=None, output_token_logprobs=None, From 012bcb5f987e8b1469b9dd4e471874ff0ab0444a Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sun, 15 Jun 2025 00:06:30 +0000 Subject: [PATCH 13/33] Refactors request validation for OpenAI endpoints This commit refactors the request validation process for OpenAI-compatible endpoints. It removes the centralized validation logic and instead implements request-specific validation directly within each handler. The previous validation approach relied on a generic set of rules and validators, making it difficult to customize validation for specific request types. By moving the validation logic into the individual handlers (ChatCompletion, Completions, and Embeddings), this commit improves code organization, simplifies the validation process, and allows for more precise error messages. Signed-off-by: Xinyuan Tong --- .../srt/entrypoints/openai/serving_base.py | 20 +- .../srt/entrypoints/openai/serving_chat.py | 33 +- .../entrypoints/openai/serving_completions.py | 43 +- .../entrypoints/openai/serving_embedding.py | 50 +- python/sglang/srt/entrypoints/openai/utils.py | 5 +- .../srt/entrypoints/openai/validation.py | 443 ------------------ test/srt/openai/test_serving_chat.py | 43 -- test/srt/openai/test_serving_completions.py | 170 ------- test/srt/openai/test_serving_embedding.py | 46 -- 9 files changed, 130 insertions(+), 723 deletions(-) delete mode 100644 python/sglang/srt/entrypoints/openai/validation.py diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index abe09bba3df..ad582d9e2ae 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -53,7 +53,6 @@ UsageInfo, ) from sglang.srt.entrypoints.openai.utils import create_error_response -from sglang.srt.entrypoints.openai.validation import get_validation_rules from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -101,9 +100,9 @@ async def handle_request( """Handle the specific request type with common pattern""" try: # Validate request - error = self._validate_request(request) - if error: - return error + error_msg = self._validate_request(request) + if error_msg: + return create_error_response(error_msg) # Create request context ctx = RequestContext( @@ -181,17 +180,10 @@ async def _handle_non_streaming_request( """Handle non-streaming request""" pass - def _validate_request( - self, request: OpenAIServingRequest - ) -> Optional[ErrorResponse]: + @abstractmethod + def _validate_request(self, request: OpenAIServingRequest) -> Optional[str]: """Validate request""" - validation_rules = get_validation_rules(request) - for rule in validation_rules: - param_value = rule.param_getter(request) - error_msg = rule.validator_func(param_value) - if error_msg: - return create_error_response(error_msg, param=rule.param_name) - return None + pass def _calculate_streaming_usage_base( self, diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 29ea85ebdd8..8ddc504ae94 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -83,7 +83,6 @@ ) from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.managers.io_struct import GenerateReqInput -from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.reasoning_parser import ReasoningParser from sglang.utils import convert_json_schema_to_str @@ -99,6 +98,38 @@ def __init__(self, *args, **kwargs): self._cached_chat_template = None self._cached_template_format = None + def _validate_request(self, request: ChatCompletionRequest) -> Optional[str]: + """Validate chat messages format and content""" + if not (messages := request.messages): + return "Messages cannot be empty" + + # Check for alternating user/assistant pattern (optional validation) + roles = [msg.role for msg in messages] + + # First message should typically be from user or system + if roles[0] not in ["user", "system"]: + return "First message should be from 'user' or 'system'" + + # Check for consecutive assistant messages (which might indicate an error) + for i in range(1, len(roles)): + if roles[i] == "assistant" and roles[i - 1] == "assistant": + # This is actually allowed in some cases, so just warn + pass + + # Validate message content + for i, msg in enumerate(messages): + if msg.role == "user": + if not msg.content: + return f"User message at index {i} has no content" + elif msg.role == "assistant": + # Assistant messages can have no content if they have tool_calls + if not msg.content and not getattr(msg, "tool_calls", None): + return ( + f"Assistant message at index {i} has no content or tool calls" + ) + + return None + def _convert_to_internal_request( self, all_requests: List[ChatCompletionRequest], diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 65da1a60ddd..c791d024e79 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -41,7 +41,7 @@ """ import time -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from fastapi.responses import StreamingResponse @@ -74,6 +74,47 @@ class CompletionHandler(OpenAIServingBase): """Handler for completion requests""" + def _validate_request(self, request: CompletionRequest) -> Optional[str]: + """Validate completion prompt format and content""" + if not (prompt := request.prompt): + return "Prompt cannot be None" + + if isinstance(prompt, str): + if not prompt.strip(): + return "Prompt cannot be empty or whitespace only" + elif isinstance(prompt, list): + if not prompt: + return "Prompt list cannot be empty" + + # Check if it's a list of strings + if all(isinstance(item, str) for item in prompt): + for i, item in enumerate(prompt): + if not item.strip(): + return f"Prompt at index {i} cannot be empty or whitespace only" + + # Check if it's a list of token IDs (integers) + elif all(isinstance(item, int) for item in prompt): + if any(item < 0 for item in prompt): + return "Token IDs must be non-negative" + + # Check if it's a list of lists (multiple token sequences) + elif all(isinstance(item, list) for item in prompt): + for i, item in enumerate(prompt): + if not item: + return f"Token sequence at index {i} cannot be empty" + if not all(isinstance(token, int) for token in item): + return f"Token sequence at index {i} must contain only integers" + if any(token < 0 for token in item): + return ( + f"Token sequence at index {i} contains negative token IDs" + ) + else: + return "Prompt must be string, list of strings, list of integers, or list of integer lists" + else: + return "Prompt must be string or list" + + return None + def _convert_to_internal_request( self, all_requests: List[CompletionRequest], diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py index ec0e9f38dfc..ebb49697c39 100644 --- a/python/sglang/srt/entrypoints/openai/serving_embedding.py +++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py @@ -46,7 +46,7 @@ """ import logging -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from fastapi.responses import StreamingResponse @@ -67,6 +67,54 @@ class EmbeddingHandler(OpenAIServingBase): """Handler for embedding requests""" + def _validate_request(self, request: EmbeddingRequest) -> Optional[str]: + """Validate that the input is not empty or whitespace only.""" + if not (input := request.input): + return "Input cannot be empty" + + # Handle single string + if isinstance(input, str): + if not input.strip(): + return "Input cannot be empty or whitespace only" + return None + + # Handle list inputs + if isinstance(input, list): + if len(input) == 0: + return "Input cannot be empty" + + # Check first element to determine type + first_item = input[0] + + if isinstance(first_item, str): + # List of strings + for i, item in enumerate(input): + if not isinstance(item, str): + return f"All items in input list must be strings" + if not item.strip(): + return f"Input at index {i} cannot be empty or whitespace only" + elif isinstance(first_item, int): + # List of integers (token IDs) + for i, item in enumerate(input): + if not isinstance(item, int): + return f"All items in input list must be integers" + if item < 0: + return f"Token ID at index {i} must be non-negative" + elif isinstance(first_item, list): + # List of lists (multiple token sequences) + for i, item in enumerate(input): + if not isinstance(item, list): + return f"Input at index {i} must be a list" + if not item: + return f"Input at index {i} cannot be empty" + if not all(isinstance(token, int) for token in item): + return f"Input at index {i} must contain only integers" + if any(token < 0 for token in item): + return f"Input at index {i} contains negative token IDs" + # Note: MultimodalEmbeddingInput validation would be handled by Pydantic + + return None + def _convert_to_internal_request( self, all_requests: List[EmbeddingRequest], diff --git a/python/sglang/srt/entrypoints/openai/utils.py b/python/sglang/srt/entrypoints/openai/utils.py index e6f0813ca79..4f4482bf1bd 100644 --- a/python/sglang/srt/entrypoints/openai/utils.py +++ b/python/sglang/srt/entrypoints/openai/utils.py @@ -35,21 +35,18 @@ import json import logging import re -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import jinja2.nodes import transformers.utils.chat_template_utils as hf_chat_utils from sglang.srt.entrypoints.openai.protocol import ( - ChatCompletionMessageParam, ChatCompletionRequest, - CompletionRequest, ErrorResponse, LogProbs, OpenAIServingRequest, UsageInfo, ) -from sglang.srt.entrypoints.openai.validation import ValidationRule logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/entrypoints/openai/validation.py b/python/sglang/srt/entrypoints/openai/validation.py deleted file mode 100644 index 523de26b79b..00000000000 --- a/python/sglang/srt/entrypoints/openai/validation.py +++ /dev/null @@ -1,443 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Pre-built validation rules for OpenAI API parameters. - -This module provides comprehensive validation for all OpenAI API request parameters, -ensuring requests meet both OpenAI standards and SGLang-specific requirements. - -Key Components: -- ValidationRule: Encapsulates parameter validation logic -- Parameter Validators: Specific validation functions for different parameter types -- Request Type Handlers: Validation rule sets for different endpoint types -- Common Validators: Shared validation logic across endpoints - -Validation Categories: -- Basic Types: String, number, boolean validation -- Ranges: Min/max validation for numeric parameters -- Formats: Pattern matching for structured data -- Content: Message and prompt content validation -- Constraints: Cross-parameter dependency validation - -The validation system is designed to provide clear, actionable error messages -that help users understand and fix request issues quickly. - -Usage: -Validation rules are automatically applied based on request type. Each rule -specifies the parameter name, validation function, and parameter accessor, -allowing for flexible and comprehensive validation coverage. -""" - -import logging -import re -from typing import Any, Callable, List, Optional, Union - -from sglang.srt.entrypoints.openai.protocol import ( - ChatCompletionMessageParam, - ChatCompletionRequest, - CompletionRequest, - EmbeddingInput, - EmbeddingRequest, - OpenAIServingRequest, -) - -logger = logging.getLogger(__name__) - - -class ValidationRule: - """Represents a validation rule for request parameters""" - - def __init__( - self, - param_name: str, - validator_func: Callable[[Any], Optional[str]], - param_getter: Callable[[OpenAIServingRequest], Any], - ): - self.param_name = param_name - self.validator_func = validator_func - self.param_getter = param_getter - - -def validate_chat_messages(messages: List[ChatCompletionMessageParam]) -> Optional[str]: - """Validate chat messages format and content - - Args: - messages: List of chat messages - - Returns: - Error message if validation fails, None if valid - """ - if not messages: - return "Messages cannot be empty" - - # Check for alternating user/assistant pattern (optional validation) - roles = [msg.role for msg in messages] - - # First message should typically be from user or system - if roles[0] not in ["user", "system"]: - return "First message should be from 'user' or 'system'" - - # Check for consecutive assistant messages (which might indicate an error) - for i in range(1, len(roles)): - if roles[i] == "assistant" and roles[i - 1] == "assistant": - # This is actually allowed in some cases, so just warn - pass - - # Validate message content - for i, msg in enumerate(messages): - if msg.role == "user": - if not msg.content: - return f"User message at index {i} has no content" - elif msg.role == "assistant": - # Assistant messages can have no content if they have tool_calls - if not msg.content and not getattr(msg, "tool_calls", None): - return f"Assistant message at index {i} has no content or tool calls" - - return None - - -def validate_completion_prompt( - prompt: Union[str, List[str], List[int], List[List[int]]] -) -> Optional[str]: - """Validate completion prompt format and content - - Args: - prompt: The prompt to validate - - Returns: - Error message if validation fails, None if valid - """ - if prompt is None: - return "Prompt cannot be None" - - if isinstance(prompt, str): - if not prompt.strip(): - return "Prompt cannot be empty or whitespace only" - elif isinstance(prompt, list): - if not prompt: - return "Prompt list cannot be empty" - - # Check if it's a list of strings - if all(isinstance(item, str) for item in prompt): - for i, item in enumerate(prompt): - if not item.strip(): - return f"Prompt at index {i} cannot be empty or whitespace only" - - # Check if it's a list of token IDs (integers) - elif all(isinstance(item, int) for item in prompt): - if any(item < 0 for item in prompt): - return "Token IDs must be non-negative" - - # Check if it's a list of lists (multiple token sequences) - elif all(isinstance(item, list) for item in prompt): - for i, item in enumerate(prompt): - if not item: - return f"Token sequence at index {i} cannot be empty" - if not all(isinstance(token, int) for token in item): - return f"Token sequence at index {i} must contain only integers" - if any(token < 0 for token in item): - return f"Token sequence at index {i} contains negative token IDs" - else: - return "Prompt must be string, list of strings, list of integers, or list of integer lists" - else: - return "Prompt must be string or list" - - return None - - -def validate_model_name(model: str) -> Optional[str]: - """Validate model name format - - Args: - model: Model name to validate - - Returns: - Error message if validation fails, None if valid - """ - if not model: - return "Model name cannot be empty" - - if not isinstance(model, str): - return "Model name must be a string" - - # Basic validation - model names should be reasonable - if len(model) > 256: - return "Model name too long (maximum 256 characters)" - - # Check for invalid characters (basic validation) - if re.search(r'[<>:"|?*]', model): - return "Model name contains invalid characters" - - return None - - -def validate_temperature(temperature: float) -> Optional[str]: - """Validate temperature parameter - - Args: - temperature: Temperature value to validate - - Returns: - Error message if validation fails, None if valid - """ - if not isinstance(temperature, (int, float)): - return "Temperature must be a number" - - if temperature < 0: - return "Temperature must be non-negative" - - # OpenAI allows up to 2.0, but some models may support higher - if temperature > 2.0: - return "Temperature should typically be between 0 and 2" - - return None - - -def validate_max_tokens(max_tokens: Optional[int]) -> Optional[str]: - """Validate max_tokens parameter - - Args: - max_tokens: Maximum tokens value to validate - - Returns: - Error message if validation fails, None if valid - """ - if max_tokens is None: - return None - - if not isinstance(max_tokens, int): - return "max_tokens must be an integer" - - if max_tokens <= 0: - return "max_tokens must be positive" - - # Reasonable upper limit (can be adjusted based on model capabilities) - if max_tokens > 100000: - return "max_tokens is too large (maximum 100000)" - - return None - - -def validate_stop_sequences(stop: Optional[Union[str, List[str]]]) -> Optional[str]: - """Validate stop sequences - - Args: - stop: Stop sequences to validate - - Returns: - Error message if validation fails, None if valid - """ - if stop is None: - return None - - if isinstance(stop, str): - if len(stop) > 100: - return "Stop sequence too long (maximum 100 characters)" - return None - - if isinstance(stop, list): - if len(stop) > 4: # OpenAI limit - return "Too many stop sequences (maximum 4)" - - for i, seq in enumerate(stop): - if not isinstance(seq, str): - return f"Stop sequence at index {i} must be a string" - if len(seq) > 100: - return f"Stop sequence at index {i} too long (maximum 100 characters)" - - return None - - return "Stop sequences must be string or list of strings" - - -def validate_top_p(top_p: float) -> Optional[str]: - """Validate top_p parameter - - Args: - top_p: Top-p value to validate - - Returns: - Error message if validation fails, None if valid - """ - if not isinstance(top_p, (int, float)): - return "top_p must be a number" - - if top_p <= 0 or top_p > 1: - return "top_p must be between 0 and 1" - - return None - - -def validate_frequency_penalty(frequency_penalty: float) -> Optional[str]: - """Validate frequency_penalty parameter - - Args: - frequency_penalty: Frequency penalty value to validate - - Returns: - Error message if validation fails, None if valid - """ - if not isinstance(frequency_penalty, (int, float)): - return "frequency_penalty must be a number" - - if frequency_penalty < -2.0 or frequency_penalty > 2.0: - return "frequency_penalty must be between -2.0 and 2.0" - - return None - - -def validate_presence_penalty(presence_penalty: float) -> Optional[str]: - """Validate presence_penalty parameter - - Args: - presence_penalty: Presence penalty value to validate - - Returns: - Error message if validation fails, None if valid - """ - if not isinstance(presence_penalty, (int, float)): - return "presence_penalty must be a number" - - if presence_penalty < -2.0 or presence_penalty > 2.0: - return "presence_penalty must be between -2.0 and 2.0" - - return None - - -def validate_embedding_input(input: EmbeddingInput) -> Optional[str]: - """Validate that the input is not empty or whitespace only.""" - if not input: - return "Input cannot be empty" - - # Handle single string - if isinstance(input, str): - if not input.strip(): - return "Input cannot be empty or whitespace only" - return None - - # Handle list inputs - if isinstance(input, list): - if len(input) == 0: - return "Input cannot be empty" - - # Check first element to determine type - first_item = input[0] - - if isinstance(first_item, str): - # List of strings - for i, item in enumerate(input): - if not isinstance(item, str): - return f"All items in input list must be strings" - if not item.strip(): - return f"Input at index {i} cannot be empty or whitespace only" - elif isinstance(first_item, int): - # List of integers (token IDs) - for i, item in enumerate(input): - if not isinstance(item, int): - return f"All items in input list must be integers" - if item < 0: - return f"Token ID at index {i} must be non-negative" - elif isinstance(first_item, list): - # List of lists (multiple token sequences) - for i, item in enumerate(input): - if not isinstance(item, list): - return f"Input at index {i} must be a list" - if not item: - return f"Input at index {i} cannot be empty" - if not all(isinstance(token, int) for token in item): - return f"Input at index {i} must contain only integers" - if any(token < 0 for token in item): - return f"Input at index {i} contains negative token IDs" - # Note: MultimodalEmbeddingInput validation would be handled by Pydantic - - return None - - -def get_common_validation_rules() -> List[ValidationRule]: - """Get validation rules common to both chat and completion requests""" - return [ - ValidationRule( - param_name="model", - validator_func=validate_model_name, - param_getter=lambda request: request.model, - ), - ValidationRule( - param_name="temperature", - validator_func=validate_temperature, - param_getter=lambda request: request.temperature, - ), - ValidationRule( - param_name="max_tokens", - validator_func=validate_max_tokens, - param_getter=lambda request: request.max_tokens, - ), - ValidationRule( - param_name="stop", - validator_func=validate_stop_sequences, - param_getter=lambda request: request.stop, - ), - ] - - -def get_chat_specific_validation_rules() -> List[ValidationRule]: - """Get validation rules specific to chat completion requests""" - return [ - ValidationRule( - param_name="messages", - validator_func=validate_chat_messages, - param_getter=lambda request: request.messages, - ), - ] - - -def get_completion_specific_validation_rules() -> List[ValidationRule]: - """Get validation rules specific to completion requests""" - return [ - ValidationRule( - param_name="prompt", - validator_func=validate_completion_prompt, - param_getter=lambda request: request.prompt, - ), - ] - - -def get_embedding_specific_validation_rules() -> List[ValidationRule]: - """Get validation rules specific to embedding requests""" - return [ - ValidationRule( - param_name="input", - validator_func=validate_embedding_input, - param_getter=lambda request: request.input, - ), - ] - - -def get_validation_rules(request: OpenAIServingRequest) -> List[ValidationRule]: - """Get all validation rules for the request""" - if isinstance(request, ChatCompletionRequest): - return get_common_validation_rules() + get_chat_specific_validation_rules() - elif isinstance(request, CompletionRequest): - # Echo + logprobs warning - if request.echo and request.logprobs: - logger.warning( - "Echo is not compatible with logprobs. " - "To compute logprobs of input prompt, please use the native /generate API." - ) - return ( - get_common_validation_rules() + get_completion_specific_validation_rules() - ) - elif isinstance(request, EmbeddingRequest): - return get_embedding_specific_validation_rules() - else: - raise ValueError(f"Unsupported request type: {type(request)}") diff --git a/test/srt/openai/test_serving_chat.py b/test/srt/openai/test_serving_chat.py index 11ce7d63aa5..b965b071c9c 100644 --- a/test/srt/openai/test_serving_chat.py +++ b/test/srt/openai/test_serving_chat.py @@ -119,49 +119,6 @@ def streaming_chat_request(): ) -class TestChatCompletionHandlerValidation: - """Test validation methods of ChatCompletionHandler.""" - - def test_validate_chat_request_valid(self, chat_handler, basic_chat_request): - """Test validation with a valid request.""" - # Use utility function directly instead of handler method - error = chat_handler._validate_request(basic_chat_request) - assert error is None - - def test_validate_chat_request_empty_messages(self, chat_handler): - """Test validation fails with empty messages.""" - # Since we now have Pydantic validation that prevents creating the request, - # we expect a ValidationError to be raised during object creation - with pytest.raises(ValidationError) as exc_info: - request = ChatCompletionRequest( - model="test-model", - messages=[], - temperature=0.7, - ) - # Check that the error is about empty messages - assert "empty" in str(exc_info.value).lower() - - def test_validate_chat_request_invalid_temperature(self, chat_handler): - """Test validation fails with invalid temperature.""" - request = ChatCompletionRequest( - model="test-model", - messages=[{"role": "user", "content": "Hello"}], - temperature=-0.5, # Invalid negative temperature - ) - error = chat_handler._validate_request(request) - assert error is not None - - def test_validate_chat_request_invalid_max_tokens(self, chat_handler): - """Test validation fails with invalid max_tokens.""" - request = ChatCompletionRequest( - model="test-model", - messages=[{"role": "user", "content": "Hello"}], - max_tokens=-10, # Invalid negative max_tokens - ) - error = chat_handler._validate_request(request) - assert error is not None - - class TestChatCompletionHandlerConversion: """Test request conversion methods.""" diff --git a/test/srt/openai/test_serving_completions.py b/test/srt/openai/test_serving_completions.py index 0abe3a30a8e..81be68c7099 100644 --- a/test/srt/openai/test_serving_completions.py +++ b/test/srt/openai/test_serving_completions.py @@ -154,28 +154,6 @@ def test_build_base_sampling_params_all_parameters(self): assert sampling_params["min_new_tokens"] == 5 assert sampling_params["logit_bias"] == {"1": 0.5} - def test_validate_request_functionality(self, completion_handler): - """Test that validate_request works correctly.""" - request = CompletionRequest( - model="test-model", - prompt="Hello world", - temperature=0.7, - max_tokens=100, - ) - - # Test with completion validation rules - error = completion_handler._validate_request(request) - assert error is None - - # Test with invalid request - invalid_request = CompletionRequest( - model="", # Invalid empty model - prompt="Hello world", - max_tokens=100, - ) - error = completion_handler._validate_request(invalid_request) - assert error is not None - def test_create_error_response_functionality(self): """Test that create_error_response works correctly.""" error = create_error_response("Test error message") @@ -195,81 +173,6 @@ def test_create_streaming_error_response_functionality(self): assert error_data["error"]["message"] == "Test streaming error" -class TestCompletionValidation: - """Test validation methods""" - - def test_validate_completion_request_valid(self, completion_handler): - """Test validation of valid completion request""" - request = CompletionRequest( - model="test-model", - prompt="Hello world", - max_tokens=100, - temperature=0.7, - stream=False, - ) - - # Use utility function directly instead of handler method - error = completion_handler._validate_request(request) - assert error is None - - def test_validate_completion_request_empty_prompt_string(self, completion_handler): - """Test validation fails for empty string prompt""" - request = CompletionRequest(model="test-model", prompt="", max_tokens=100) - - error = completion_handler._validate_request(request) - assert error is not None - assert "prompt" in error.model_dump()["param"] - - def test_validate_completion_request_whitespace_prompt(self, completion_handler): - """Test validation fails for whitespace-only prompt""" - request = CompletionRequest( - model="test-model", prompt=" \n\t ", max_tokens=100 - ) - - error = completion_handler._validate_request(request) - assert error is not None - assert "prompt" in error.model_dump()["param"] - - def test_validate_completion_request_empty_list_prompt(self, completion_handler): - """Test validation fails for empty list prompt""" - request = CompletionRequest(model="test-model", prompt=[], max_tokens=100) - - error = completion_handler._validate_request(request) - assert error is not None - assert "prompt" in error.model_dump()["param"] - - def test_validate_completion_request_invalid_model(self, completion_handler): - """Test validation fails for invalid model""" - request = CompletionRequest(model="", prompt="Hello world", max_tokens=100) - - error = completion_handler._validate_request(request) - assert error is not None - assert "model" in error.model_dump()["param"] - - def test_validate_completion_request_invalid_temperature(self, completion_handler): - """Test validation fails for invalid temperature""" - request = CompletionRequest( - model="test-model", - prompt="Hello world", - max_tokens=100, - temperature=-1.0, # Invalid - ) - - error = completion_handler._validate_request(request) - assert error is not None - assert "temperature" in error.model_dump()["param"] - - def test_validate_completion_request_invalid_max_tokens(self, completion_handler): - """Test validation fails for invalid max_tokens""" - request = CompletionRequest( - model="test-model", prompt="Hello world", max_tokens=0 # Invalid - ) - - error = completion_handler._validate_request(request) - assert error is not None - assert "max_tokens" in error.model_dump()["param"] - - class TestPromptHandling: """Test different prompt types and formats from adapter.py""" @@ -889,21 +792,6 @@ def test_build_response_multiple_choices(self, completion_handler): class TestAsyncMethods: """Test async handler methods""" - async def test_handle_request_validation_error(self, completion_handler): - """Test handling request with validation error""" - mock_request = Mock() - request = CompletionRequest( - model="", prompt="Hello world", max_tokens=100 # Invalid model - ) - - response = await completion_handler.handle_request(request, mock_request) - - # Should return error response - assert hasattr(response, "model_dump") - error_data = response.model_dump() - assert error_data["object"] == "error" - assert "model" in error_data["param"] - async def test_handle_request_non_streaming(self, completion_handler): """Test handling non-streaming request - simplified test for async flow""" mock_request = Mock() @@ -974,44 +862,6 @@ def test_multiple_requests_with_n_greater_than_1_error(self, completion_handler) with pytest.raises(ValueError, match="Parallel sampling is not supported"): completion_handler._convert_to_internal_request(requests, ["id1", "id2"]) - def test_empty_prompt_list_validation(self, completion_handler): - """Test validation of empty prompt list""" - request = CompletionRequest(model="test-model", prompt=[], max_tokens=100) - - error = completion_handler._validate_request(request) - assert error is not None - assert "prompt" in error.model_dump()["param"] - - def test_nested_empty_prompt_list_validation(self, completion_handler): - """Test validation of nested empty prompt list""" - request = CompletionRequest(model="test-model", prompt=[[]], max_tokens=100) - - error = completion_handler._validate_request(request) - assert error is not None - assert "prompt" in error.model_dump()["param"] - - @pytest.mark.asyncio - async def test_echo_warning_with_logprobs(self, completion_handler): - """Test warning when echo is used with logprobs""" - request = CompletionRequest( - model="test-model", - prompt="Hello world", - max_tokens=100, - echo=True, - logprobs=5, - ) - - mock_raw_request = Mock() - - with patch("sglang.srt.entrypoints.openai.validation.logger") as mock_logger: - # Call handle_request which contains the warning logic - await completion_handler.handle_request(request, mock_raw_request) - # Should log warning about echo + logprobs incompatibility - mock_logger.warning.assert_called_once() - assert "Echo is not compatible with logprobs" in str( - mock_logger.warning.call_args - ) - def test_suffix_without_completion_template(self, completion_handler): """Test that suffix is ignored when completion template is not defined""" request = CompletionRequest( @@ -1031,23 +881,3 @@ def test_suffix_without_completion_template(self, completion_handler): # Should use original prompt, not processed with suffix assert adapted_request.text == "def hello():" - - def test_zero_max_tokens_handling(self, completion_handler): - """Test handling of zero max_tokens""" - request = CompletionRequest( - model="test-model", prompt="Hello world", max_tokens=0 - ) - - error = completion_handler._validate_request(request) - assert error is not None - assert "max_tokens" in error.model_dump()["param"] - - def test_negative_temperature_handling(self, completion_handler): - """Test handling of negative temperature""" - request = CompletionRequest( - model="test-model", prompt="Hello world", max_tokens=100, temperature=-0.5 - ) - - error = completion_handler._validate_request(request) - assert error is not None - assert "temperature" in error.model_dump()["param"] diff --git a/test/srt/openai/test_serving_embedding.py b/test/srt/openai/test_serving_embedding.py index 1aac7370cca..b03bdac8f00 100644 --- a/test/srt/openai/test_serving_embedding.py +++ b/test/srt/openai/test_serving_embedding.py @@ -120,52 +120,6 @@ def token_ids_embedding_request(): ) -class TestEmbeddingHandlerValidation: - """Test validation methods of EmbeddingHandler.""" - - def test_validate_embedding_request_valid( - self, embedding_handler, basic_embedding_request - ): - """Test validation with a valid request.""" - error = embedding_handler._validate_request(basic_embedding_request) - assert error is None - - def test_validate_embedding_request_empty_string(self, embedding_handler): - """Test validation fails with empty string input.""" - request = EmbeddingRequest(model="test-model", input="") - error = embedding_handler._validate_request(request) - assert error is not None - assert "empty" in error.message.lower() - - def test_validate_embedding_request_whitespace_only(self, embedding_handler): - """Test validation fails with whitespace-only input.""" - request = EmbeddingRequest(model="test-model", input=" \n\t ") - error = embedding_handler._validate_request(request) - assert error is not None - assert "whitespace" in error.message.lower() - - def test_validate_embedding_request_empty_list(self, embedding_handler): - """Test validation fails with empty list input.""" - request = EmbeddingRequest(model="test-model", input=[]) - error = embedding_handler._validate_request(request) - assert error is not None - assert "empty" in error.message.lower() - - def test_validate_embedding_request_empty_string_in_list(self, embedding_handler): - """Test validation fails with empty string in list.""" - request = EmbeddingRequest(model="test-model", input=[""]) - error = embedding_handler._validate_request(request) - assert error is not None - assert "empty" in error.message.lower() - - def test_validate_embedding_request_valid_list( - self, embedding_handler, list_embedding_request - ): - """Test validation passes with valid list input.""" - error = embedding_handler._validate_request(list_embedding_request) - assert error is None - - class TestEmbeddingHandlerConversion: """Test request conversion methods.""" From 27341aeb921cf020b362ebe5397f3060f1f972fb Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sun, 15 Jun 2025 00:23:53 +0000 Subject: [PATCH 14/33] Improves OpenAI serving base class logic Updates the OpenAI serving base class to provide default implementations for streaming and non-streaming request handlers. This change simplifies the implementation of derived classes by providing a default behavior (returning a "NotImplementedError") if streaming or non-streaming requests are not supported. It also removes the `abstractmethod` decorator from the `_handle_streaming_request` and `_handle_non_streaming_request` methods, making them optional to override. Additionally, this change removes the unnecessary implementation of `_handle_streaming_request` method from `EmbeddingHandler` since the base class now provides a default implementation. Signed-off-by: Xinyuan Tong --- .../srt/entrypoints/openai/serving_base.py | 25 +++++++++++++------ .../entrypoints/openai/serving_embedding.py | 13 ---------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index ad582d9e2ae..e2b0c28c751 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -160,27 +160,38 @@ def _convert_to_internal_request( """Convert OpenAI request to internal format""" pass - @abstractmethod async def _handle_streaming_request( self, adapted_request: GenerateReqInput, request: OpenAIServingRequest, ctx: RequestContext, ) -> StreamingResponse: - """Handle streaming request""" - pass + """Handle streaming request + + Override this method in child classes that support streaming requests. + """ + return create_error_response( + message=f"{self.__class__.__name__} does not support streaming requests", + err_type="NotImplementedError", + status_code=501, + ) - @abstractmethod async def _handle_non_streaming_request( self, adapted_request: GenerateReqInput, request: OpenAIServingRequest, ctx: RequestContext, ) -> Union[Any, ErrorResponse]: - """Handle non-streaming request""" - pass + """Handle non-streaming request + + Override this method in child classes that support non-streaming requests. + """ + return create_error_response( + message=f"{self.__class__.__name__} does not support non-streaming requests", + err_type="NotImplementedError", + status_code=501, + ) - @abstractmethod def _validate_request(self, request: OpenAIServingRequest) -> Optional[str]: """Validate request""" pass diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py index ebb49697c39..640a0752cd3 100644 --- a/python/sglang/srt/entrypoints/openai/serving_embedding.py +++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py @@ -45,11 +45,8 @@ padding for missing text content when needed. """ -import logging from typing import Any, Dict, List, Optional, Union -from fastapi.responses import StreamingResponse - from sglang.srt.conversation import generate_embedding_convs from sglang.srt.entrypoints.openai.protocol import ( EmbeddingObject, @@ -223,16 +220,6 @@ def _convert_to_internal_request( all_requests[0] if len(all_requests) == 1 else all_requests ) - async def _handle_streaming_request( - self, - adapted_request: EmbeddingReqInput, - request: EmbeddingRequest, - ctx: RequestContext, - ) -> StreamingResponse: - """Handle streaming embedding request (not supported)""" - # Embeddings don't support streaming - raise NotImplementedError("Embedding requests do not support streaming") - async def _handle_non_streaming_request( self, adapted_request: EmbeddingReqInput, From 286751a1187d75a65efc49028ae275b59d2fb5f0 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sun, 15 Jun 2025 02:39:21 +0000 Subject: [PATCH 15/33] Refactors error handling for OpenAI endpoints Consolidates error response creation logic into the base class. This change refactors the error handling mechanism for OpenAI-compatible endpoints. It removes duplicated error response creation functions from `utils.py` and the handlers, and instead, implements centralized `create_error_response` and `create_streaming_error_response` methods in the `OpenAIServingBase` class. The handlers are updated to use these base class methods. This promotes code reuse and ensures consistency in error response formatting across all OpenAI endpoints. Also updates the streaming response to properly format the data. Signed-off-by: Xinyuan Tong --- .../srt/entrypoints/openai/serving_base.py | 36 +++++++++++++++--- .../srt/entrypoints/openai/serving_chat.py | 22 +++++------ .../entrypoints/openai/serving_completions.py | 14 +++---- .../entrypoints/openai/serving_embedding.py | 3 +- python/sglang/srt/entrypoints/openai/utils.py | 38 ------------------- test/srt/openai/test_serving_chat.py | 9 ++--- test/srt/openai/test_serving_completions.py | 16 ++++---- 7 files changed, 56 insertions(+), 82 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index e2b0c28c751..8eec23746b4 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -35,6 +35,7 @@ endpoint-specific customization. """ +import json import logging import time import uuid @@ -52,7 +53,6 @@ OpenAIServingRequest, UsageInfo, ) -from sglang.srt.entrypoints.openai.utils import create_error_response from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -102,7 +102,7 @@ async def handle_request( # Validate request error_msg = self._validate_request(request) if error_msg: - return create_error_response(error_msg) + return self.create_error_response(error_msg) # Create request context ctx = RequestContext( @@ -128,7 +128,7 @@ async def handle_request( except Exception as e: logger.error(f"Error in request: {e}") - return create_error_response( + return self.create_error_response( message=f"Internal server error: {str(e)}", err_type="InternalServerError", status_code=500, @@ -170,7 +170,7 @@ async def _handle_streaming_request( Override this method in child classes that support streaming requests. """ - return create_error_response( + return self.create_error_response( message=f"{self.__class__.__name__} does not support streaming requests", err_type="NotImplementedError", status_code=501, @@ -186,7 +186,7 @@ async def _handle_non_streaming_request( Override this method in child classes that support non-streaming requests. """ - return create_error_response( + return self.create_error_response( message=f"{self.__class__.__name__} does not support non-streaming requests", err_type="NotImplementedError", status_code=501, @@ -222,3 +222,29 @@ def _calculate_streaming_usage_base( total_tokens=total_prompt_tokens + total_completion_tokens, prompt_tokens_details=prompt_tokens_details, ) + + def create_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: int = 400, + param: Optional[str] = None, + ) -> ErrorResponse: + """Create an error response""" + return ErrorResponse( + object="error", + message=message, + type=err_type, + param=param, + code=status_code, + ) + + def create_streaming_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: int = 400, + ) -> str: + """Create a streaming error response""" + error = self.create_error_response(message, err_type, status_code) + return json.dumps({"error": error.model_dump()}) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 8ddc504ae94..dc5ae1a1e99 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -73,10 +73,6 @@ _get_enable_thinking_from_request, aggregate_token_usage, build_base_sampling_params, - create_error_response, - create_stream_done, - create_streaming_chunk_data, - create_streaming_error_response, detect_template_content_format, process_content_for_template_format, to_openai_style_logprobs, @@ -514,7 +510,7 @@ async def generate_stream_resp(): choices=[choice_data], model=request.model, ) - yield create_streaming_chunk_data(chunk.model_dump_json()) + yield f"data: {chunk.model_dump_json()}\n\n" # Process content delta delta = text[len(stream_buffer) :] @@ -542,7 +538,7 @@ async def generate_stream_resp(): choices=[choice_data], model=request.model, ) - yield create_streaming_chunk_data(chunk.model_dump_json()) + yield f"data: {chunk.model_dump_json()}\n\n" if not delta or len(delta) == 0: stream_buffers[index] = new_stream_buffer @@ -589,7 +585,7 @@ async def generate_stream_resp(): choices=[choice_data], model=request.model, ) - yield create_streaming_chunk_data(chunk.model_dump_json()) + yield f"data: {chunk.model_dump_json()}\n\n" stream_buffers[index] = new_stream_buffer is_firsts[index] = is_first @@ -616,13 +612,13 @@ async def generate_stream_resp(): model=request.model, usage=usage, ) - yield create_streaming_chunk_data(final_chunk.model_dump_json()) + yield f"data: {final_chunk.model_dump_json()}\n\n" except Exception as e: - error = create_streaming_error_response(str(e)) - yield create_streaming_chunk_data(error) + error = self.create_streaming_error_response(str(e)) + yield f"data: {error}\n\n" - yield create_stream_done() + yield "data: [DONE]\n\n" return StreamingResponse( generate_stream_resp(), @@ -642,7 +638,7 @@ async def _handle_non_streaming_request( adapted_request, ctx.raw_request ).__anext__() except ValueError as e: - return create_error_response(str(e)) + return self.create_error_response(str(e)) if not isinstance(ret, list): ret = [ret] @@ -690,7 +686,7 @@ def _build_chat_response( reasoning_text, text = parser.parse_non_stream(text) except Exception as e: logger.error(f"Reasoning parsing error: {e}") - return create_error_response( + return self.create_error_response( "Failed to parse reasoning content", err_type="InternalServerError", status_code=500, diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index c791d024e79..4d3fb322cb4 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -62,10 +62,6 @@ from sglang.srt.entrypoints.openai.utils import ( aggregate_token_usage, build_base_sampling_params, - create_error_response, - create_stream_done, - create_streaming_chunk_data, - create_streaming_error_response, to_openai_style_logprobs, ) from sglang.srt.managers.io_struct import GenerateReqInput @@ -307,7 +303,7 @@ async def generate_stream_resp(): stream_buffers[index] = stream_buffer n_prev_tokens[index] = n_prev_token - yield create_streaming_chunk_data(chunk.model_dump_json()) + yield f"data: {chunk.model_dump_json()}\n\n" # Handle final usage chunk if request.stream_options and request.stream_options.include_usage: @@ -324,13 +320,13 @@ async def generate_stream_resp(): final_usage_data = final_usage_chunk.model_dump_json( exclude_none=True ) - yield create_streaming_chunk_data(final_usage_data) + yield f"data: {final_usage_data}\n\n" except Exception as e: - error = create_streaming_error_response(str(e)) + error = self.create_streaming_error_response(str(e)) yield f"data: {error}\n\n" - yield create_stream_done() + yield "data: [DONE]\n\n" return StreamingResponse( generate_stream_resp(), @@ -351,7 +347,7 @@ async def _handle_non_streaming_request( ) ret = await generator.__anext__() except ValueError as e: - return create_error_response(str(e)) + return self.create_error_response(str(e)) if not isinstance(ret, list): ret = [ret] diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py index 640a0752cd3..c9ddc4f7370 100644 --- a/python/sglang/srt/entrypoints/openai/serving_embedding.py +++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py @@ -57,7 +57,6 @@ UsageInfo, ) from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase, RequestContext -from sglang.srt.entrypoints.openai.utils import create_error_response from sglang.srt.managers.io_struct import EmbeddingReqInput @@ -232,7 +231,7 @@ async def _handle_non_streaming_request( adapted_request, ctx.raw_request ).__anext__() except ValueError as e: - return create_error_response(str(e)) + return self.create_error_response(str(e)) if not isinstance(ret, list): ret = [ret] diff --git a/python/sglang/srt/entrypoints/openai/utils.py b/python/sglang/srt/entrypoints/openai/utils.py index 4f4482bf1bd..6526ce4be1b 100644 --- a/python/sglang/srt/entrypoints/openai/utils.py +++ b/python/sglang/srt/entrypoints/openai/utils.py @@ -32,7 +32,6 @@ proper content processing for different model types (e.g., DeepSeek vs Llama). """ -import json import logging import re from typing import Any, Dict, List, Optional @@ -42,7 +41,6 @@ from sglang.srt.entrypoints.openai.protocol import ( ChatCompletionRequest, - ErrorResponse, LogProbs, OpenAIServingRequest, UsageInfo, @@ -269,32 +267,6 @@ def aggregate_token_usage( ) -def create_error_response( - message: str, - err_type: str = "BadRequestError", - status_code: int = 400, - param: Optional[str] = None, -) -> ErrorResponse: - """Create an error response""" - return ErrorResponse( - object="error", - message=message, - type=err_type, - param=param, - code=status_code, - ) - - -def create_streaming_error_response( - message: str, - err_type: str = "BadRequestError", - status_code: int = 400, -) -> str: - """Create a streaming error response""" - error = create_error_response(message, err_type, status_code) - return json.dumps({"error": error.model_dump()}) - - def build_base_sampling_params(request: OpenAIServingRequest) -> Dict[str, Any]: """Build common sampling parameters shared by both chat and completion requests""" params = {} @@ -401,13 +373,3 @@ def _get_enable_thinking_from_request(request_obj): The boolean value of 'enable_thinking' if found and not True, otherwise True. """ return getattr(request_obj, "chat_template_kwargs", {}).get("enable_thinking", True) - - -def create_streaming_chunk_data(chunk_data: str) -> str: - """Create a streaming response chunk in the proper format""" - return f"data: {chunk_data}\n\n" - - -def create_stream_done() -> str: - """Create the final [DONE] message for streaming responses""" - return "data: [DONE]\n\n" diff --git a/test/srt/openai/test_serving_chat.py b/test/srt/openai/test_serving_chat.py index b965b071c9c..8982046918b 100644 --- a/test/srt/openai/test_serving_chat.py +++ b/test/srt/openai/test_serving_chat.py @@ -30,10 +30,7 @@ UsageInfo, ) from sglang.srt.entrypoints.openai.serving_chat import ChatCompletionHandler -from sglang.srt.entrypoints.openai.utils import ( - build_base_sampling_params, - create_error_response, -) +from sglang.srt.entrypoints.openai.utils import build_base_sampling_params from sglang.srt.managers.io_struct import GenerateReqInput @@ -658,9 +655,9 @@ def test_build_base_sampling_params_max_completion_tokens_override(self): # max_completion_tokens should override max_tokens assert sampling_params["max_new_tokens"] == 200 - def test_create_error_response_functionality(self): + def test_create_error_response_functionality(self, chat_handler): """Test that create_error_response works correctly.""" - error = create_error_response("Test error message") + error = chat_handler.create_error_response("Test error message") assert isinstance(error, ErrorResponse) assert error.message == "Test error message" assert error.type == "BadRequestError" diff --git a/test/srt/openai/test_serving_completions.py b/test/srt/openai/test_serving_completions.py index 81be68c7099..064b6efe953 100644 --- a/test/srt/openai/test_serving_completions.py +++ b/test/srt/openai/test_serving_completions.py @@ -14,11 +14,7 @@ ErrorResponse, ) from sglang.srt.entrypoints.openai.serving_completions import CompletionHandler -from sglang.srt.entrypoints.openai.utils import ( - build_base_sampling_params, - create_error_response, - create_streaming_error_response, -) +from sglang.srt.entrypoints.openai.utils import build_base_sampling_params from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -154,17 +150,19 @@ def test_build_base_sampling_params_all_parameters(self): assert sampling_params["min_new_tokens"] == 5 assert sampling_params["logit_bias"] == {"1": 0.5} - def test_create_error_response_functionality(self): + def test_create_error_response_functionality(self, completion_handler): """Test that create_error_response works correctly.""" - error = create_error_response("Test error message") + error = completion_handler.create_error_response("Test error message") assert isinstance(error, ErrorResponse) assert error.message == "Test error message" assert error.type == "BadRequestError" assert error.code == 400 - def test_create_streaming_error_response_functionality(self): + def test_create_streaming_error_response_functionality(self, completion_handler): """Test that create_streaming_error_response works correctly.""" - error_json = create_streaming_error_response("Test streaming error") + error_json = completion_handler.create_streaming_error_response( + "Test streaming error" + ) # Should return JSON string with error structure import json From 50d57d1e4601463f773f292b3b8aad2c9a873dc7 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sun, 15 Jun 2025 02:47:59 +0000 Subject: [PATCH 16/33] Refactors request ID generation Moves request ID prefix logic to subclasses for better organization and extensibility. Introduces an abstract `_request_id_prefix` method in the base class and implements it in the handlers for chat, completions, and embeddings. This change simplifies the base class and allows each handler to define its own prefix. Signed-off-by: Xinyuan Tong --- .../srt/entrypoints/openai/serving_base.py | 22 +++++++------------ .../srt/entrypoints/openai/serving_chat.py | 3 +++ .../entrypoints/openai/serving_completions.py | 3 +++ .../entrypoints/openai/serving_embedding.py | 3 +++ 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index 8eec23746b4..72cbc28a8a7 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -46,9 +46,6 @@ from fastapi.responses import StreamingResponse from sglang.srt.entrypoints.openai.protocol import ( - ChatCompletionRequest, - CompletionRequest, - EmbeddingRequest, ErrorResponse, OpenAIServingRequest, UsageInfo, @@ -108,7 +105,7 @@ async def handle_request( ctx = RequestContext( raw_request=raw_request, openai_request=request, - request_id=self._generate_request_id(request), + request_id=self._generate_request_id_base(request), ) # Convert to internal format @@ -134,20 +131,17 @@ async def handle_request( status_code=500, ) - def _generate_request_id(self, request: OpenAIServingRequest) -> str: + @abstractmethod + def _request_id_prefix(self) -> str: + """Generate request ID based on request type""" + pass + + def _generate_request_id_base(self, request: OpenAIServingRequest) -> str: """Generate request ID based on request type""" - # Default implementation - can be overridden if rid := getattr(request, "rid", None): return rid - # Determine prefix based on request type - prefix_mapping = { - ChatCompletionRequest: "chatcmpl", - CompletionRequest: "cmpl", - EmbeddingRequest: "embd", - } - prefix = prefix_mapping.get(type(request), "req") - return f"{prefix}-{uuid.uuid4()}" + return f"{self._request_id_prefix()}{uuid.uuid4().hex}" @abstractmethod def _convert_to_internal_request( diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index dc5ae1a1e99..108004426c4 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -94,6 +94,9 @@ def __init__(self, *args, **kwargs): self._cached_chat_template = None self._cached_template_format = None + def _request_id_prefix(self) -> str: + return "chatcmpl-" + def _validate_request(self, request: ChatCompletionRequest) -> Optional[str]: """Validate chat messages format and content""" if not (messages := request.messages): diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 4d3fb322cb4..1fb08bf087e 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -70,6 +70,9 @@ class CompletionHandler(OpenAIServingBase): """Handler for completion requests""" + def _request_id_prefix(self) -> str: + return "cmpl-" + def _validate_request(self, request: CompletionRequest) -> Optional[str]: """Validate completion prompt format and content""" if not (prompt := request.prompt): diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py index c9ddc4f7370..ad529e51374 100644 --- a/python/sglang/srt/entrypoints/openai/serving_embedding.py +++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py @@ -63,6 +63,9 @@ class EmbeddingHandler(OpenAIServingBase): """Handler for embedding requests""" + def _request_id_prefix(self) -> str: + return "embd-" + def _validate_request(self, request: EmbeddingRequest) -> Optional[str]: """Validate that the input is not empty or whitespace only.""" if not (input := request.input): From 960f9176e9c9ecffa64d3473d36095c5395bf0ef Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sun, 15 Jun 2025 03:03:06 +0000 Subject: [PATCH 17/33] Removes RequestContext Removes the `RequestContext` class and its associated logic. This change simplifies the code by removing the explicit request context object and instead passing the raw request directly to the relevant functions. The raw request is primarily used for detecting the client connection. Signed-off-by: Xinyuan Tong --- .../srt/entrypoints/openai/serving_base.py | 49 +++---------------- .../srt/entrypoints/openai/serving_chat.py | 10 ++-- .../entrypoints/openai/serving_completions.py | 11 +++-- .../entrypoints/openai/serving_embedding.py | 8 +-- test/srt/openai/test_serving_embedding.py | 49 ------------------- 5 files changed, 22 insertions(+), 105 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index 72cbc28a8a7..656443e4d48 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -19,7 +19,6 @@ architecture for request processing, validation, and response generation. Key Components: -- RequestContext: Tracks request state and metadata throughout processing - OpenAIServingBase: Abstract base class for all endpoint handlers - Common request handling patterns with proper error handling - Validation integration for request parameters @@ -37,7 +36,6 @@ import json import logging -import time import uuid from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Union @@ -56,34 +54,6 @@ logger = logging.getLogger(__name__) -class RequestContext: - """Context object for tracking request state throughout the pipeline""" - - def __init__( - self, - raw_request: Request, - openai_request: OpenAIServingRequest, - request_id: str, - ): - self.raw_request = raw_request - self.openai_request = openai_request - self.request_id = request_id - self.start_time = time.time() - self.metadata: Dict[str, Any] = {} - - def elapsed_time(self) -> float: - """Get elapsed time since request started""" - return time.time() - self.start_time - - def add_metadata(self, key: str, value: Any) -> None: - """Add metadata to the request context""" - self.metadata[key] = value - - def get_metadata(self, key: str, default: Any = None) -> Any: - """Get metadata from the request context""" - return self.metadata.get(key, default) - - # Base class for specific endpoint handlers class OpenAIServingBase(ABC): """Abstract base class for OpenAI endpoint handlers""" @@ -101,26 +71,19 @@ async def handle_request( if error_msg: return self.create_error_response(error_msg) - # Create request context - ctx = RequestContext( - raw_request=raw_request, - openai_request=request, - request_id=self._generate_request_id_base(request), - ) - # Convert to internal format adapted_request, processed_request = self._convert_to_internal_request( - [request], [ctx.request_id] + [request], [self._generate_request_id_base(request)] ) - # Check if this handler supports streaming + # Note(Xinyuan): raw_request below is only used for detecting the connection of the client if hasattr(request, "stream") and request.stream: return await self._handle_streaming_request( - adapted_request, processed_request, ctx + adapted_request, processed_request, raw_request ) else: return await self._handle_non_streaming_request( - adapted_request, processed_request, ctx + adapted_request, processed_request, raw_request ) except Exception as e: @@ -158,7 +121,7 @@ async def _handle_streaming_request( self, adapted_request: GenerateReqInput, request: OpenAIServingRequest, - ctx: RequestContext, + raw_request: Request, ) -> StreamingResponse: """Handle streaming request @@ -174,7 +137,7 @@ async def _handle_non_streaming_request( self, adapted_request: GenerateReqInput, request: OpenAIServingRequest, - ctx: RequestContext, + raw_request: Request, ) -> Union[Any, ErrorResponse]: """Handle non-streaming request diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 108004426c4..96d33ee1893 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -68,7 +68,7 @@ TopLogprob, UsageInfo, ) -from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase, RequestContext +from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.entrypoints.openai.utils import ( _get_enable_thinking_from_request, aggregate_token_usage, @@ -447,7 +447,7 @@ async def _handle_streaming_request( self, adapted_request: GenerateReqInput, request: ChatCompletionRequest, - ctx: RequestContext, + raw_request: Request, ) -> StreamingResponse: """Handle streaming chat completion request""" @@ -464,7 +464,7 @@ async def generate_stream_resp(): try: async for content in self.tokenizer_manager.generate_request( - adapted_request, ctx.raw_request + adapted_request, raw_request ): index = content.get("index", 0) text = content["text"] @@ -633,12 +633,12 @@ async def _handle_non_streaming_request( self, adapted_request: GenerateReqInput, request: ChatCompletionRequest, - ctx: RequestContext, + raw_request: Request, ) -> Union[ChatCompletionResponse, ErrorResponse]: """Handle non-streaming chat completion request""" try: ret = await self.tokenizer_manager.generate_request( - adapted_request, ctx.raw_request + adapted_request, raw_request ).__anext__() except ValueError as e: return self.create_error_response(str(e)) diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 1fb08bf087e..d388820becf 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -43,6 +43,7 @@ import time from typing import Any, Dict, List, Optional, Union +from fastapi import Request from fastapi.responses import StreamingResponse from sglang.srt.code_completion_parser import ( @@ -58,7 +59,7 @@ CompletionStreamResponse, ErrorResponse, ) -from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase, RequestContext +from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.entrypoints.openai.utils import ( aggregate_token_usage, build_base_sampling_params, @@ -218,7 +219,7 @@ async def _handle_streaming_request( self, adapted_request: GenerateReqInput, request: CompletionRequest, - ctx: RequestContext, + raw_request: Request, ) -> StreamingResponse: """Handle streaming completion request""" created = int(time.time()) @@ -232,7 +233,7 @@ async def generate_stream_resp(): try: async for content in self.tokenizer_manager.generate_request( - adapted_request, ctx.raw_request + adapted_request, raw_request ): index = content.get("index", 0) @@ -341,12 +342,12 @@ async def _handle_non_streaming_request( self, adapted_request: GenerateReqInput, request: CompletionRequest, - ctx: RequestContext, + raw_request: Request, ) -> Union[CompletionResponse, ErrorResponse]: """Handle non-streaming completion request""" try: generator = self.tokenizer_manager.generate_request( - adapted_request, ctx.raw_request + adapted_request, raw_request ) ret = await generator.__anext__() except ValueError as e: diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py index ad529e51374..aa5e1c5d4c6 100644 --- a/python/sglang/srt/entrypoints/openai/serving_embedding.py +++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py @@ -47,6 +47,8 @@ from typing import Any, Dict, List, Optional, Union +from fastapi import Request + from sglang.srt.conversation import generate_embedding_convs from sglang.srt.entrypoints.openai.protocol import ( EmbeddingObject, @@ -56,7 +58,7 @@ MultimodalEmbeddingInput, UsageInfo, ) -from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase, RequestContext +from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.managers.io_struct import EmbeddingReqInput @@ -226,12 +228,12 @@ async def _handle_non_streaming_request( self, adapted_request: EmbeddingReqInput, request: EmbeddingRequest, - ctx: RequestContext, + raw_request: Request, ) -> Union[EmbeddingResponse, ErrorResponse]: """Handle the embedding request""" try: ret = await self.tokenizer_manager.generate_request( - adapted_request, ctx.raw_request + adapted_request, raw_request ).__anext__() except ValueError as e: return self.create_error_response(str(e)) diff --git a/test/srt/openai/test_serving_embedding.py b/test/srt/openai/test_serving_embedding.py index b03bdac8f00..00bc95c6a02 100644 --- a/test/srt/openai/test_serving_embedding.py +++ b/test/srt/openai/test_serving_embedding.py @@ -400,55 +400,6 @@ def test_response_format_matches_adapter(self, embedding_handler): class TestEdgeCases: """Test edge cases and error conditions.""" - def test_request_id_generation( - self, embedding_handler, basic_embedding_request, mock_request - ): - """Test that request IDs are properly generated when not provided.""" - # Request without rid - request_without_id = EmbeddingRequest(model="test-model", input="Test") - assert request_without_id.rid is None - - # Should generate ID during handling - with patch.object( - embedding_handler, "_handle_non_streaming_request" - ) as mock_handle: - mock_handle.return_value = EmbeddingResponse( - data=[], - model="test-model", - usage=UsageInfo(prompt_tokens=0, total_tokens=0), - ) - - asyncio.run( - embedding_handler.handle_request(request_without_id, mock_request) - ) - - # Check that context was created with generated ID - args, kwargs = mock_handle.call_args - ctx = args[2] # Third argument is context - assert ctx.request_id.startswith("embd-") - - def test_request_id_preservation(self, embedding_handler, mock_request): - """Test that provided request IDs are preserved.""" - request_with_id = EmbeddingRequest( - model="test-model", input="Test", rid="custom-id" - ) - - with patch.object( - embedding_handler, "_handle_non_streaming_request" - ) as mock_handle: - mock_handle.return_value = EmbeddingResponse( - data=[], - model="test-model", - usage=UsageInfo(prompt_tokens=0, total_tokens=0), - ) - - asyncio.run(embedding_handler.handle_request(request_with_id, mock_request)) - - # Check that custom ID was preserved - args, kwargs = mock_handle.call_args - ctx = args[2] # Third argument is context - assert ctx.request_id == "custom-id" - def test_multimodal_batch_not_implemented(self, embedding_handler): """Test that multimodal batch requests raise NotImplementedError.""" request1 = EmbeddingRequest( From 30663a5904ff5074128e862ac226a22e9fbe176a Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sun, 15 Jun 2025 03:12:46 +0000 Subject: [PATCH 18/33] Simplifies enable_thinking handling and remove unused functions Removes the dedicated function for extracting the `enable_thinking` flag from the request object. It now directly accesses the `enable_thinking` value from the `chat_template_kwargs` attribute of the request, defaulting to `True` if not present. This streamlines the code and reduces redundancy. Signed-off-by: Xinyuan Tong --- .../srt/entrypoints/openai/serving_chat.py | 9 ++++-- python/sglang/srt/entrypoints/openai/utils.py | 32 ------------------- 2 files changed, 6 insertions(+), 35 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 96d33ee1893..84644486c4d 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -70,7 +70,6 @@ ) from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.entrypoints.openai.utils import ( - _get_enable_thinking_from_request, aggregate_token_usage, build_base_sampling_params, detect_template_content_format, @@ -520,7 +519,9 @@ async def generate_stream_resp(): new_stream_buffer = stream_buffer + delta # Handle reasoning content - enable_thinking = _get_enable_thinking_from_request(request) + enable_thinking = getattr(request, "chat_template_kwargs", {}).get( + "enable_thinking", True + ) if ( self.tokenizer_manager.server_args.reasoning_parser and request.separate_reasoning @@ -680,7 +681,9 @@ def _build_chat_response( # Handle reasoning content reasoning_text = None - enable_thinking = _get_enable_thinking_from_request(request) + enable_thinking = getattr(request, "chat_template_kwargs", {}).get( + "enable_thinking", True + ) if reasoning_parser and request.separate_reasoning and enable_thinking: try: parser = ReasoningParser( diff --git a/python/sglang/srt/entrypoints/openai/utils.py b/python/sglang/srt/entrypoints/openai/utils.py index 6526ce4be1b..88c4ceed8ab 100644 --- a/python/sglang/srt/entrypoints/openai/utils.py +++ b/python/sglang/srt/entrypoints/openai/utils.py @@ -33,7 +33,6 @@ """ import logging -import re from typing import Any, Dict, List, Optional import jinja2.nodes @@ -307,25 +306,6 @@ def build_base_sampling_params(request: OpenAIServingRequest) -> Dict[str, Any]: return params -def sanitize_model_name(model: str) -> str: - """Sanitize model name for safe usage - - Args: - model: Model name to sanitize - - Returns: - Sanitized model name - """ - # Remove potentially dangerous characters - sanitized = re.sub(r'[<>:"|?*]', "", model) - - # Limit length - if len(sanitized) > 256: - sanitized = sanitized[:256] - - return sanitized.strip() - - def to_openai_style_logprobs( input_token_logprobs=None, output_token_logprobs=None, @@ -361,15 +341,3 @@ def append_top_logprobs(top_logprobs): append_top_logprobs(output_top_logprobs) return ret_logprobs - - -def _get_enable_thinking_from_request(request_obj): - """Extracts the 'enable_thinking' flag from request chat_template_kwargs. - - Args: - request_obj: The request object (or an item from a list of requests). - - Returns: - The boolean value of 'enable_thinking' if found and not True, otherwise True. - """ - return getattr(request_obj, "chat_template_kwargs", {}).get("enable_thinking", True) From eb6784dabaab77f21c7c1731c5d9710a13d56b1f Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sun, 15 Jun 2025 03:48:35 +0000 Subject: [PATCH 19/33] Refactors sampling parameter building Removes the `build_base_sampling_params` function and inlines the parameter construction logic directly into the `_build_sampling_params` methods of `ChatCompletionHandler` and `CompletionHandler`. This change simplifies the code and allows for more direct control and clarity over how sampling parameters are created for each request type. Signed-off-by: Xinyuan Tong --- .../srt/entrypoints/openai/serving_chat.py | 57 +++++-- .../entrypoints/openai/serving_completions.py | 23 ++- python/sglang/srt/entrypoints/openai/utils.py | 40 ----- test/srt/openai/test_serving_chat.py | 125 +-------------- test/srt/openai/test_serving_completions.py | 145 ------------------ 5 files changed, 67 insertions(+), 323 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 84644486c4d..4ffa0c364a3 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -71,7 +71,6 @@ from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.entrypoints.openai.utils import ( aggregate_token_usage, - build_base_sampling_params, detect_template_content_format, process_content_for_template_format, to_openai_style_logprobs, @@ -226,7 +225,7 @@ def _process_messages( str, Union[str, List[int]], Optional[Any], Optional[Any], List[str], List[str] ]: """Process chat messages and apply chat template""" - tool_call_constraint = None + tool_call_constraint = None # TODO: how to pass this to the sampling params? prompt = "" prompt_ids = [] @@ -410,13 +409,28 @@ def _build_sampling_params( self, request: ChatCompletionRequest, stop: List[str] ) -> Dict[str, Any]: """Build sampling parameters for the request""" - # Start with common parameters - sampling_params = build_base_sampling_params(request) - # Override stop with processed stop sequences - sampling_params["stop"] = stop + sampling_params = { + "temperature": request.temperature, + "max_new_tokens": request.max_tokens or request.max_completion_tokens, + "min_new_tokens": request.min_tokens, + "stop": stop, + "stop_token_ids": request.stop_token_ids, + "top_p": request.top_p, + "top_k": request.top_k, + "min_p": request.min_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "repetition_penalty": request.repetition_penalty, + "regex": request.regex, + "ebnf": request.ebnf, + "n": request.n, + "no_stop_trim": request.no_stop_trim, + "ignore_eos": request.ignore_eos, + "skip_special_tokens": request.skip_special_tokens, + "logit_bias": request.logit_bias, + } - # Handle response format if request.response_format and request.response_format.type == "json_schema": sampling_params["json_schema"] = convert_json_schema_to_str( request.response_format.json_schema.schema_ @@ -430,16 +444,25 @@ def _build_sampling_params( request.response_format.model_dump(by_alias=True) ) - # Handle tool call constraints - if hasattr(self, "_tool_call_constraint") and self._tool_call_constraint: - constraint_type, constraint_value = self._tool_call_constraint - if constraint_type == "structural_tag": - sampling_params[constraint_type] = convert_json_schema_to_str( - constraint_value.model_dump(by_alias=True) - ) - else: - sampling_params[constraint_type] = constraint_value - + # TODO: how to handle tool call constraint? + # Check if there are already existing output constraints + # has_existing_constraints = ( + # sampling_params.get("regex") + # or sampling_params.get("ebnf") + # or sampling_params.get("structural_tag") + # or sampling_params.get("json_schema") + # ) + + # if tool_call_constraint and has_existing_constraints: + # logger.warning("Constrained decoding is not compatible with tool calls.") + # elif tool_call_constraint: + # constraint_type, constraint_value = tool_call_constraint + # if constraint_type == "structural_tag": + # sampling_params[constraint_type] = convert_json_schema_to_str( + # constraint_value.model_dump(by_alias=True) + # ) + # else: + # sampling_params[constraint_type] = constraint_value return sampling_params async def _handle_streaming_request( diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index d388820becf..544cd238f6b 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -62,7 +62,6 @@ from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.entrypoints.openai.utils import ( aggregate_token_usage, - build_base_sampling_params, to_openai_style_logprobs, ) from sglang.srt.managers.io_struct import GenerateReqInput @@ -208,7 +207,27 @@ def _convert_to_internal_request( def _build_sampling_params(self, request: CompletionRequest) -> Dict[str, Any]: """Build sampling parameters for the request""" # Start with common parameters - sampling_params = build_base_sampling_params(request) + sampling_params = { + "temperature": request.temperature, + "max_new_tokens": request.max_tokens, + "min_new_tokens": request.min_tokens, + "stop": request.stop, + "stop_token_ids": request.stop_token_ids, + "top_p": request.top_p, + "top_k": request.top_k, + "min_p": request.min_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "repetition_penalty": request.repetition_penalty, + "regex": request.regex, + "json_schema": request.json_schema, + "ebnf": request.ebnf, + "n": request.n, + "no_stop_trim": request.no_stop_trim, + "ignore_eos": request.ignore_eos, + "skip_special_tokens": request.skip_special_tokens, + "logit_bias": request.logit_bias, + } # No additional completion-specific parameters needed currently # (json_schema is already handled in base method) diff --git a/python/sglang/srt/entrypoints/openai/utils.py b/python/sglang/srt/entrypoints/openai/utils.py index 88c4ceed8ab..5522e0a68d4 100644 --- a/python/sglang/srt/entrypoints/openai/utils.py +++ b/python/sglang/srt/entrypoints/openai/utils.py @@ -266,46 +266,6 @@ def aggregate_token_usage( ) -def build_base_sampling_params(request: OpenAIServingRequest) -> Dict[str, Any]: - """Build common sampling parameters shared by both chat and completion requests""" - params = {} - - # Define parameter mappings (request_attr -> param_name) - direct_mappings = { - "temperature": "temperature", - "max_tokens": "max_new_tokens", - "min_tokens": "min_new_tokens", - "stop": "stop", - "stop_token_ids": "stop_token_ids", - "top_p": "top_p", - "top_k": "top_k", - "min_p": "min_p", - "presence_penalty": "presence_penalty", - "frequency_penalty": "frequency_penalty", - "repetition_penalty": "repetition_penalty", - "regex": "regex", - "ebnf": "ebnf", - "n": "n", - "no_stop_trim": "no_stop_trim", - "ignore_eos": "ignore_eos", - "logit_bias": "logit_bias", - "skip_special_tokens": "skip_special_tokens", - "json_schema": "json_schema", - } - - # Apply direct mappings - for request_attr, param_name in direct_mappings.items(): - if hasattr(request, request_attr): - params[param_name] = getattr(request, request_attr) - - # Handle special cases - # max_completion_tokens overrides max_tokens for chat requests - if isinstance(request, ChatCompletionRequest) and request.max_completion_tokens: - params["max_new_tokens"] = request.max_completion_tokens - - return params - - def to_openai_style_logprobs( input_token_logprobs=None, output_token_logprobs=None, diff --git a/test/srt/openai/test_serving_chat.py b/test/srt/openai/test_serving_chat.py index 8982046918b..f20d42c4676 100644 --- a/test/srt/openai/test_serving_chat.py +++ b/test/srt/openai/test_serving_chat.py @@ -5,32 +5,14 @@ with the original adapter.py functionality. """ -import asyncio -import json -import time import uuid -from typing import Any, Dict, List -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import Mock, patch import pytest from fastapi import Request -from fastapi.responses import StreamingResponse -from pydantic_core import ValidationError - -from sglang.srt.entrypoints.openai.protocol import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatCompletionStreamResponse, - ChatMessage, - DeltaMessage, - ErrorResponse, - FunctionResponse, - ToolCall, - UsageInfo, -) + +from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse from sglang.srt.entrypoints.openai.serving_chat import ChatCompletionHandler -from sglang.srt.entrypoints.openai.utils import build_base_sampling_params from sglang.srt.managers.io_struct import GenerateReqInput @@ -510,7 +492,7 @@ def test_all_sampling_parameters(self, chat_handler): messages=[{"role": "user", "content": "Hello"}], temperature=0.8, max_tokens=150, - max_completion_tokens=200, # Should override max_tokens + max_completion_tokens=200, min_tokens=5, top_p=0.9, top_k=50, @@ -543,9 +525,7 @@ def test_all_sampling_parameters(self, chat_handler): # Verify all parameters assert sampling_params["temperature"] == 0.8 - assert ( - sampling_params["max_new_tokens"] == 200 - ) # max_completion_tokens overrides + assert sampling_params["max_new_tokens"] == 150 assert sampling_params["min_new_tokens"] == 5 assert sampling_params["top_p"] == 0.9 assert sampling_params["top_k"] == 50 @@ -553,9 +533,7 @@ def test_all_sampling_parameters(self, chat_handler): assert sampling_params["presence_penalty"] == 0.1 assert sampling_params["frequency_penalty"] == 0.2 assert sampling_params["repetition_penalty"] == 1.1 - assert sampling_params["stop"] == [ - "" - ] # Should be overridden with processed stop + assert sampling_params["stop"] == [""] assert sampling_params["logit_bias"] == {"1": 0.5, "2": -0.3} def test_response_format_json_schema(self, chat_handler): @@ -616,45 +594,6 @@ def test_response_format_json_object(self, chat_handler): class TestUtilityFunctions: """Test utility functions that were moved from OpenAIServingBase.""" - def test_build_base_sampling_params_functionality(self): - """Test that build_base_sampling_params works correctly.""" - request = ChatCompletionRequest( - model="test-model", - messages=[{"role": "user", "content": "Hello"}], - temperature=0.8, - max_tokens=150, - top_p=0.9, - top_k=50, - presence_penalty=0.1, - frequency_penalty=0.2, - stop=["<|endoftext|>"], - ) - - sampling_params = build_base_sampling_params(request) - - # Test that parameters are correctly mapped - assert sampling_params["temperature"] == request.temperature - assert sampling_params["max_new_tokens"] == request.max_tokens - assert sampling_params["top_p"] == request.top_p - assert sampling_params["top_k"] == request.top_k - assert sampling_params["presence_penalty"] == request.presence_penalty - assert sampling_params["frequency_penalty"] == request.frequency_penalty - assert sampling_params["stop"] == request.stop - - def test_build_base_sampling_params_max_completion_tokens_override(self): - """Test that max_completion_tokens overrides max_tokens.""" - request = ChatCompletionRequest( - model="test-model", - messages=[{"role": "user", "content": "Hello"}], - max_tokens=100, - max_completion_tokens=200, - ) - - sampling_params = build_base_sampling_params(request) - - # max_completion_tokens should override max_tokens - assert sampling_params["max_new_tokens"] == 200 - def test_create_error_response_functionality(self, chat_handler): """Test that create_error_response works correctly.""" error = chat_handler.create_error_response("Test error message") @@ -667,47 +606,6 @@ def test_create_error_response_functionality(self, chat_handler): class TestChatCompletionHandlerCompatibility: """Test compatibility with adapter.py functionality.""" - def test_compatibility_sampling_params(self): - """Test that sampling parameters are built the same way as adapter.py.""" - request = ChatCompletionRequest( - model="test-model", - messages=[{"role": "user", "content": "Hello"}], - temperature=0.8, - max_tokens=150, - top_p=0.9, - top_k=50, - presence_penalty=0.1, - frequency_penalty=0.2, - stop=["<|endoftext|>"], - ) - - # Test the utility function directly - sampling_params = build_base_sampling_params(request) - - # These should match the structure used in adapter.py's v1_chat_generate_request - expected_keys = [ - "temperature", - "max_new_tokens", - "top_p", - "top_k", - "min_p", - "presence_penalty", - "frequency_penalty", - "repetition_penalty", - "stop", - "regex", - "ebnf", - "n", - ] - - for key in expected_keys: - assert key in sampling_params - - assert sampling_params["temperature"] == request.temperature - assert sampling_params["max_new_tokens"] == request.max_tokens - assert sampling_params["top_p"] == request.top_p - assert sampling_params["top_k"] == request.top_k - def test_compatibility_request_structure(self): """Test that the request structure matches what adapter.py expects.""" # Test with all the parameters that adapter.py supports @@ -782,17 +680,6 @@ def test_compatibility_bootstrap_params(self, chat_handler): assert adapted_request.bootstrap_port == 8998 assert adapted_request.bootstrap_room == 12345 - def test_compatibility_logit_bias(self): - """Test that logit_bias parameter is properly handled.""" - request = ChatCompletionRequest( - model="test-model", - messages=[{"role": "user", "content": "Hello"}], - logit_bias={"1": 0.5, "2": -0.3}, - ) - - sampling_params = build_base_sampling_params(request) - assert sampling_params["logit_bias"] == {"1": 0.5, "2": -0.3} - if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/srt/openai/test_serving_completions.py b/test/srt/openai/test_serving_completions.py index 064b6efe953..970ae3b7486 100644 --- a/test/srt/openai/test_serving_completions.py +++ b/test/srt/openai/test_serving_completions.py @@ -14,7 +14,6 @@ ErrorResponse, ) from sglang.srt.entrypoints.openai.serving_completions import CompletionHandler -from sglang.srt.entrypoints.openai.utils import build_base_sampling_params from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -54,102 +53,6 @@ def completion_handler(mock_tokenizer_manager): class TestUtilityFunctions: """Test utility functions that were moved from OpenAIServingBase.""" - def test_build_base_sampling_params_functionality(self): - """Test that build_base_sampling_params works correctly.""" - request = CompletionRequest( - model="test-model", - prompt="Hello world", - temperature=0.8, - max_tokens=150, - top_p=0.9, - top_k=50, - presence_penalty=0.1, - frequency_penalty=0.2, - stop=["<|endoftext|>"], - ) - - sampling_params = build_base_sampling_params(request) - - # Test that parameters are correctly mapped - assert sampling_params["temperature"] == request.temperature - assert sampling_params["max_new_tokens"] == request.max_tokens - assert sampling_params["top_p"] == request.top_p - assert sampling_params["top_k"] == request.top_k - assert sampling_params["presence_penalty"] == request.presence_penalty - assert sampling_params["frequency_penalty"] == request.frequency_penalty - assert sampling_params["stop"] == request.stop - - def test_build_base_sampling_params_logit_bias(self): - """Test that logit_bias parameter is properly handled.""" - request = CompletionRequest( - model="test-model", - prompt="Hello world", - logit_bias={"1": 0.5, "2": -0.3}, - ) - - sampling_params = build_base_sampling_params(request) - assert sampling_params["logit_bias"] == {"1": 0.5, "2": -0.3} - - def test_build_base_sampling_params_all_parameters(self): - """Test that all sampling parameters from adapter.py are handled.""" - request = CompletionRequest( - model="test-model", - prompt="Hello world", - temperature=0.8, - max_tokens=150, - min_tokens=5, - top_p=0.9, - top_k=50, - min_p=0.1, - presence_penalty=0.1, - frequency_penalty=0.2, - repetition_penalty=1.1, - stop=["<|endoftext|>"], - stop_token_ids=[13, 14], - regex=r"\d+", - json_schema='{"type": "object"}', - ebnf=" ::= ", - n=2, - no_stop_trim=True, - ignore_eos=True, - skip_special_tokens=False, - logit_bias={"1": 0.5}, - ) - - sampling_params = build_base_sampling_params(request) - - # Verify all parameters are present - expected_keys = { - "temperature", - "max_new_tokens", - "min_new_tokens", - "stop", - "stop_token_ids", - "top_p", - "top_k", - "min_p", - "presence_penalty", - "frequency_penalty", - "repetition_penalty", - "regex", - "json_schema", - "ebnf", - "n", - "no_stop_trim", - "ignore_eos", - "skip_special_tokens", - "logit_bias", - } - - for key in expected_keys: - assert key in sampling_params, f"Missing parameter: {key}" - - # Verify values - assert sampling_params["temperature"] == 0.8 - assert sampling_params["max_new_tokens"] == 150 - assert sampling_params["min_new_tokens"] == 5 - assert sampling_params["logit_bias"] == {"1": 0.5} - def test_create_error_response_functionality(self, completion_handler): """Test that create_error_response works correctly.""" error = completion_handler.create_error_response("Test error message") @@ -412,54 +315,6 @@ def test_convert_logprob_start_len_without_echo(self, completion_handler): class TestCompatibilityWithAdapter: """Test compatibility with adapter.py functionality""" - def test_sampling_params_structure_matches_adapter(self): - """Test that sampling params structure matches adapter.py exactly""" - request = CompletionRequest( - model="test-model", - prompt="Hello world", - max_tokens=100, - temperature=0.7, - top_p=0.9, - top_k=50, - min_p=0.1, - presence_penalty=0.5, - frequency_penalty=0.3, - repetition_penalty=1.1, - stop=["STOP"], - stop_token_ids=[13], - n=2, - ignore_eos=True, - skip_special_tokens=False, - ) - - # Test the utility function directly - sampling_params = build_base_sampling_params(request) - - # Check all parameters from adapter.py v1_generate_request - expected_params = { - "temperature", - "max_new_tokens", - "min_new_tokens", - "stop", - "stop_token_ids", - "top_p", - "top_k", - "min_p", - "presence_penalty", - "frequency_penalty", - "repetition_penalty", - "regex", - "json_schema", - "ebnf", - "n", - "no_stop_trim", - "ignore_eos", - "skip_special_tokens", - } - - actual_params = set(sampling_params.keys()) - assert expected_params.issubset(actual_params) - def test_bootstrap_parameters_support(self, completion_handler): """Test that bootstrap parameters are supported""" request = CompletionRequest( From 47da102fadd802d819ac72f67fefd6b4029ef968 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sun, 15 Jun 2025 03:54:50 +0000 Subject: [PATCH 20/33] Renames OpenAI serving handler classes Updates the naming convention for OpenAI serving handler classes. This change renames `ChatCompletionHandler`, `CompletionHandler`, and `EmbeddingHandler` to `OpenAIServingChat`, `OpenAIServingCompletion`, and `OpenAIServingEmbedding` respectively, aligning the class names with the base class and improving clarity. Signed-off-by: Xinyuan Tong --- python/sglang/srt/entrypoints/openai/__init__.py | 6 +++--- .../sglang/srt/entrypoints/openai/serving_chat.py | 2 +- .../srt/entrypoints/openai/serving_completions.py | 2 +- .../srt/entrypoints/openai/serving_embedding.py | 2 +- test/srt/openai/test_serving_chat.py | 12 ++++++------ test/srt/openai/test_serving_completions.py | 4 ++-- test/srt/openai/test_serving_embedding.py | 14 +++++++------- 7 files changed, 21 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/__init__.py b/python/sglang/srt/entrypoints/openai/__init__.py index 2df1cce2270..437df552b6a 100644 --- a/python/sglang/srt/entrypoints/openai/__init__.py +++ b/python/sglang/srt/entrypoints/openai/__init__.py @@ -40,9 +40,9 @@ Architecture: - OpenAIServingBase: Abstract base class for all endpoint handlers -- ChatCompletionHandler: Handles chat completion requests -- CompletionHandler: Handles text completion requests -- EmbeddingHandler: Handles embedding requests +- OpenAIServingChat: Handles chat completion requests +- OpenAIServingCompletion: Handles text completion requests +- OpenAIServingEmbedding: Handles embedding requests - Protocol classes: Pydantic models for request/response validation - Utility functions: Shared helpers for formatting and validation """ diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 4ffa0c364a3..98834e86b9c 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -83,7 +83,7 @@ logger = logging.getLogger(__name__) -class ChatCompletionHandler(OpenAIServingBase): +class OpenAIServingChat(OpenAIServingBase): """Handler for chat completion requests""" def __init__(self, *args, **kwargs): diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 544cd238f6b..5416f65b48c 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -67,7 +67,7 @@ from sglang.srt.managers.io_struct import GenerateReqInput -class CompletionHandler(OpenAIServingBase): +class OpenAIServingCompletion(OpenAIServingBase): """Handler for completion requests""" def _request_id_prefix(self) -> str: diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py index aa5e1c5d4c6..854a28fb2c5 100644 --- a/python/sglang/srt/entrypoints/openai/serving_embedding.py +++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py @@ -62,7 +62,7 @@ from sglang.srt.managers.io_struct import EmbeddingReqInput -class EmbeddingHandler(OpenAIServingBase): +class OpenAIServingEmbedding(OpenAIServingBase): """Handler for embedding requests""" def _request_id_prefix(self) -> str: diff --git a/test/srt/openai/test_serving_chat.py b/test/srt/openai/test_serving_chat.py index f20d42c4676..7bfa50d2532 100644 --- a/test/srt/openai/test_serving_chat.py +++ b/test/srt/openai/test_serving_chat.py @@ -1,5 +1,5 @@ """ -Unit tests for the ChatCompletionHandler class from serving_chat.py. +Unit tests for the OpenAIServingChat class from serving_chat.py. These tests ensure that the refactored implementation maintains compatibility with the original adapter.py functionality. @@ -12,7 +12,7 @@ from fastapi import Request from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse -from sglang.srt.entrypoints.openai.serving_chat import ChatCompletionHandler +from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat from sglang.srt.managers.io_struct import GenerateReqInput @@ -62,8 +62,8 @@ def mock_tokenizer_manager(): @pytest.fixture def chat_handler(mock_tokenizer_manager): - """Create a ChatCompletionHandler instance for testing.""" - return ChatCompletionHandler(mock_tokenizer_manager) + """Create a OpenAIServingChat instance for testing.""" + return OpenAIServingChat(mock_tokenizer_manager) @pytest.fixture @@ -98,7 +98,7 @@ def streaming_chat_request(): ) -class TestChatCompletionHandlerConversion: +class TestOpenAIServingChatConversion: """Test request conversion methods.""" def test_convert_to_internal_request_single( @@ -603,7 +603,7 @@ def test_create_error_response_functionality(self, chat_handler): assert error.code == 400 -class TestChatCompletionHandlerCompatibility: +class TestOpenAIServingChatCompatibility: """Test compatibility with adapter.py functionality.""" def test_compatibility_request_structure(self): diff --git a/test/srt/openai/test_serving_completions.py b/test/srt/openai/test_serving_completions.py index 970ae3b7486..5a87332055f 100644 --- a/test/srt/openai/test_serving_completions.py +++ b/test/srt/openai/test_serving_completions.py @@ -13,7 +13,7 @@ CompletionStreamResponse, ErrorResponse, ) -from sglang.srt.entrypoints.openai.serving_completions import CompletionHandler +from sglang.srt.entrypoints.openai.serving_completions import OpenAIServingCompletion from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -47,7 +47,7 @@ def mock_tokenizer_manager(): @pytest.fixture def completion_handler(mock_tokenizer_manager): """Create a completion handler instance""" - return CompletionHandler(mock_tokenizer_manager) + return OpenAIServingCompletion(mock_tokenizer_manager) class TestUtilityFunctions: diff --git a/test/srt/openai/test_serving_embedding.py b/test/srt/openai/test_serving_embedding.py index 00bc95c6a02..dce273963be 100644 --- a/test/srt/openai/test_serving_embedding.py +++ b/test/srt/openai/test_serving_embedding.py @@ -1,5 +1,5 @@ """ -Unit tests for the EmbeddingHandler class from serving_embedding.py. +Unit tests for the OpenAIServingEmbedding class from serving_embedding.py. These tests ensure that the embedding serving implementation maintains compatibility with the original adapter.py functionality and follows OpenAI API specifications. @@ -24,7 +24,7 @@ MultimodalEmbeddingInput, UsageInfo, ) -from sglang.srt.entrypoints.openai.serving_embedding import EmbeddingHandler +from sglang.srt.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from sglang.srt.managers.io_struct import EmbeddingReqInput @@ -65,8 +65,8 @@ def mock_tokenizer_manager(): @pytest.fixture def embedding_handler(mock_tokenizer_manager): - """Create an EmbeddingHandler instance for testing.""" - return EmbeddingHandler(mock_tokenizer_manager) + """Create an OpenAIServingEmbedding instance for testing.""" + return OpenAIServingEmbedding(mock_tokenizer_manager) @pytest.fixture @@ -120,7 +120,7 @@ def token_ids_embedding_request(): ) -class TestEmbeddingHandlerConversion: +class TestOpenAIServingEmbeddingConversion: """Test request conversion methods.""" def test_convert_single_string_request( @@ -264,8 +264,8 @@ def test_build_multiple_embedding_response(self, embedding_handler): @pytest.mark.asyncio -class TestEmbeddingHandlerAsyncMethods: - """Test async methods of EmbeddingHandler.""" +class TestOpenAIServingEmbeddingAsyncMethods: + """Test async methods of OpenAIServingEmbedding.""" async def test_handle_request_success( self, embedding_handler, basic_embedding_request, mock_request From c5a60e0bc85ac70f2d54fb3311be8f746e85bb12 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sun, 15 Jun 2025 03:59:18 +0000 Subject: [PATCH 21/33] cleanup docs and imports Signed-off-by: Xinyuan Tong --- .../sglang/srt/entrypoints/openai/__init__.py | 48 ------------------- .../srt/entrypoints/openai/serving_base.py | 36 -------------- .../srt/entrypoints/openai/serving_chat.py | 43 ----------------- .../entrypoints/openai/serving_completions.py | 42 ---------------- .../entrypoints/openai/serving_embedding.py | 47 ------------------ python/sglang/srt/entrypoints/openai/utils.py | 41 +--------------- 6 files changed, 1 insertion(+), 256 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/__init__.py b/python/sglang/srt/entrypoints/openai/__init__.py index 437df552b6a..e69de29bb2d 100644 --- a/python/sglang/srt/entrypoints/openai/__init__.py +++ b/python/sglang/srt/entrypoints/openai/__init__.py @@ -1,48 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -OpenAI-compatible API server module for SGLang. - -This module provides OpenAI-compatible API endpoints that allow existing OpenAI client -applications to seamlessly work with SGLang models. The implementation includes: - -Key Features: -- Full OpenAI API compatibility for chat completions, text completions, and embeddings -- Streaming support for real-time response generation -- Batch processing capabilities for multiple requests -- Function calling and tool use support -- Multimodal input support (text, images, audio) -- Advanced reasoning capabilities with separate reasoning content -- Custom sampling parameters and constraints (regex, JSON schema, EBNF) -- LoRA adapter support for fine-tuned models -- Cache reporting and token usage tracking - -Supported Endpoints: -- /v1/chat/completions - Chat-based completions with conversation history -- /v1/completions - Text completions for single prompts -- /v1/embeddings - Text/multimodal embeddings generation -- /v1/models - Model listing and information - -The module is structured with separate handlers for each endpoint type, all inheriting -from a common base class that provides shared functionality like request validation, -error handling, and response formatting. - -Architecture: -- OpenAIServingBase: Abstract base class for all endpoint handlers -- OpenAIServingChat: Handles chat completion requests -- OpenAIServingCompletion: Handles text completion requests -- OpenAIServingEmbedding: Handles embedding requests -- Protocol classes: Pydantic models for request/response validation -- Utility functions: Shared helpers for formatting and validation -""" diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index 656443e4d48..fe27806bf96 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -1,39 +1,3 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Base serving engine for OpenAI API endpoints. - -This module provides the foundational classes and request handling patterns -used by all OpenAI API endpoint implementations. It establishes a common -architecture for request processing, validation, and response generation. - -Key Components: -- OpenAIServingBase: Abstract base class for all endpoint handlers -- Common request handling patterns with proper error handling -- Validation integration for request parameters -- Streaming and non-streaming response support - -Architecture Pattern: -All endpoint handlers inherit from OpenAIServingBase and implement: -1. _convert_to_internal_request: Transform OpenAI request to SGLang format -2. _handle_streaming_request: Process streaming requests -3. _handle_non_streaming_request: Process non-streaming requests - -This ensures consistent behavior across all endpoints while allowing -endpoint-specific customization. -""" - import json import logging import uuid diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 98834e86b9c..5a36a1349d1 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -1,45 +1,3 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Chat completions serving logic for OpenAI API. - -This module implements the /v1/chat/completions endpoint, providing full OpenAI -compatibility for chat-based interactions with conversation history support. - -Key Features: -- Full OpenAI chat completions API compatibility -- Streaming and non-streaming response modes -- Multimodal support (text, images, audio) -- Function calling and tool use -- Advanced reasoning with separate reasoning content -- Chat template processing for different model types -- Custom sampling parameters and output constraints -- LoRA adapter support - -Processing Pipeline: -1. Request validation and preprocessing -2. Message processing and chat template application -3. Tool/function call setup if applicable -4. Internal request generation with SGLang extensions -5. Model inference with streaming or batch processing -6. Response formatting and postprocessing -7. Tool call parsing and reasoning extraction - -The implementation handles both string-based and structured content formats, -automatically detecting the appropriate format based on the model's chat template. -""" - import base64 import json import logging @@ -66,7 +24,6 @@ LogProbs, ToolCall, TopLogprob, - UsageInfo, ) from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase from sglang.srt.entrypoints.openai.utils import ( diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 5416f65b48c..142ed91956f 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -1,45 +1,3 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Text completion serving logic for OpenAI API. - -This module implements the /v1/completions endpoint, providing OpenAI-compatible -text completion functionality for single prompts without conversation context. - -Key Features: -- Full OpenAI text completions API compatibility -- Streaming and non-streaming response modes -- Echo support to include the prompt in the response -- Custom completion templates for specialized use cases -- Advanced sampling parameters and output constraints -- Batch processing support for multiple prompts -- LoRA adapter support for fine-tuned models -- Token-level logprobs with configurable detail levels - -Processing Pipeline: -1. Request validation and prompt preprocessing -2. Completion template application if configured -3. Sampling parameter configuration -4. Internal request generation with SGLang extensions -5. Model inference with optional streaming -6. Response formatting with echo handling -7. Logprobs processing and token usage calculation - -The implementation supports various prompt formats including strings, token IDs, -and batched inputs, with automatic type detection and validation. -""" - import time from typing import Any, Dict, List, Optional, Union diff --git a/python/sglang/srt/entrypoints/openai/serving_embedding.py b/python/sglang/srt/entrypoints/openai/serving_embedding.py index 854a28fb2c5..33d1e591878 100644 --- a/python/sglang/srt/entrypoints/openai/serving_embedding.py +++ b/python/sglang/srt/entrypoints/openai/serving_embedding.py @@ -1,50 +1,3 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Embedding serving logic for OpenAI API. - -This module implements the /v1/embeddings endpoint, providing OpenAI-compatible -text and multimodal embedding generation capabilities. - -Key Features: -- Full OpenAI embeddings API compatibility -- Text embedding generation for single and batch inputs -- Multimodal embedding support (text + image combinations) -- Chat template integration for embedding-specific formatting -- Batch processing for multiple inputs -- Comprehensive input validation and error handling -- Token usage tracking and reporting - -Supported Input Types: -- Single string: Direct text input -- List of strings: Batch text embedding -- List of MultimodalEmbeddingInput: Text+image combinations -- Token IDs: Pre-tokenized input sequences - -Processing Pipeline: -1. Input validation and type detection -2. Multimodal content processing (if applicable) -3. Chat template application for embedding context -4. Internal request generation -5. Model inference for embedding generation -6. Response formatting with usage statistics - -The implementation handles various input formats gracefully and provides -detailed error messages for invalid inputs. Multimodal embeddings use -padding for missing text content when needed. -""" - from typing import Any, Dict, List, Optional, Union from fastapi import Request diff --git a/python/sglang/srt/entrypoints/openai/utils.py b/python/sglang/srt/entrypoints/openai/utils.py index 5522e0a68d4..53c67831cdb 100644 --- a/python/sglang/srt/entrypoints/openai/utils.py +++ b/python/sglang/srt/entrypoints/openai/utils.py @@ -1,49 +1,10 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -Utility functions for OpenAI API server. - -This module provides shared utility functions used across the OpenAI API implementation, -including template processing, validation, error handling, and response formatting. - -Key Components: -- Template Format Detection: Analyzes Jinja templates to determine content format -- Content Processing: Handles multimodal content based on template requirements -- Token Usage Calculation: Aggregates token usage across requests and responses -- Error Response Generation: Creates standardized error responses -- Logprobs Formatting: Converts internal logprobs to OpenAI format -- Validation Helpers: Common validation functions for requests -- Streaming Utilities: Helpers for streaming response formatting - -Template Format Detection: -The module includes sophisticated logic to detect whether a chat template expects -'string' or 'openai' content format by analyzing the Jinja template AST. This enables -proper content processing for different model types (e.g., DeepSeek vs Llama). -""" - import logging from typing import Any, Dict, List, Optional import jinja2.nodes import transformers.utils.chat_template_utils as hf_chat_utils -from sglang.srt.entrypoints.openai.protocol import ( - ChatCompletionRequest, - LogProbs, - OpenAIServingRequest, - UsageInfo, -) +from sglang.srt.entrypoints.openai.protocol import LogProbs, UsageInfo logger = logging.getLogger(__name__) From d433e43b91b5668cd659ab51a8e5a6fa0565f04a Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Sun, 15 Jun 2025 04:06:45 +0000 Subject: [PATCH 22/33] Fixes usage calculation in streaming mode Corrects the usage calculation for streaming responses by passing the correct argument to the base function. It ensures accurate token counting when `n` > 1 is requested, preventing potential discrepancies in billing or rate limiting. Signed-off-by: Xinyuan Tong --- python/sglang/srt/entrypoints/openai/serving_chat.py | 2 +- python/sglang/srt/entrypoints/openai/serving_completions.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 5a36a1349d1..76626f4f267 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -578,7 +578,7 @@ async def generate_stream_resp(): # Final chunk with usage if request.stream_options and request.stream_options.include_usage: usage = self._calculate_streaming_usage_base( - prompt_tokens, completion_tokens, cached_tokens, request + prompt_tokens, completion_tokens, cached_tokens, request.n ) else: usage = None diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 142ed91956f..2fd26f7cc5e 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -289,7 +289,7 @@ async def generate_stream_resp(): # Handle final usage chunk if request.stream_options and request.stream_options.include_usage: usage = self._calculate_streaming_usage_base( - prompt_tokens, completion_tokens, cached_tokens, request + prompt_tokens, completion_tokens, cached_tokens, request.n ) final_usage_chunk = CompletionStreamResponse( id=content["meta_info"]["id"], From ba42ea1513f7834a047082dbdae84143ce1439bf Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Jun 2025 08:28:21 +0000 Subject: [PATCH 23/33] Refactors error response handling in OpenAIServingBase Updates the error response creation methods to return ORJSONResponse instead of ErrorResponse. This change enhances the response formatting for error handling, ensuring consistency and improved performance in API responses. Signed-off-by: Xinyuan Tong --- .../sglang/srt/entrypoints/openai/serving_base.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_base.py b/python/sglang/srt/entrypoints/openai/serving_base.py index fe27806bf96..718378f5e0f 100644 --- a/python/sglang/srt/entrypoints/openai/serving_base.py +++ b/python/sglang/srt/entrypoints/openai/serving_base.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Union from fastapi import Request -from fastapi.responses import StreamingResponse +from fastapi.responses import ORJSONResponse, StreamingResponse from sglang.srt.entrypoints.openai.protocol import ( ErrorResponse, @@ -150,15 +150,16 @@ def create_error_response( err_type: str = "BadRequestError", status_code: int = 400, param: Optional[str] = None, - ) -> ErrorResponse: + ) -> ORJSONResponse: """Create an error response""" - return ErrorResponse( + error = ErrorResponse( object="error", message=message, type=err_type, param=param, code=status_code, ) + return ORJSONResponse(content=error.model_dump(), status_code=status_code) def create_streaming_error_response( self, @@ -167,5 +168,11 @@ def create_streaming_error_response( status_code: int = 400, ) -> str: """Create a streaming error response""" - error = self.create_error_response(message, err_type, status_code) + error = ErrorResponse( + object="error", + message=message, + type=err_type, + param=None, + code=status_code, + ) return json.dumps({"error": error.model_dump()}) From 48586bf89fefe5f3fdd3778d85dd9a01c34bd015 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Date: Mon, 16 Jun 2025 16:42:34 +0800 Subject: [PATCH 24/33] Apply suggestions from code review Co-authored-by: Chang Su --- python/sglang/srt/entrypoints/openai/serving_chat.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 76626f4f267..d4cf5080364 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -326,16 +326,18 @@ def _apply_conversation_template( """Apply conversation template""" conv = generate_chat_conv(request, self.tokenizer_manager.chat_template_name) - # Handle continue_final_message + # If we should continue the final assistant message, adjust the conversation. if ( request.continue_final_message and request.messages and request.messages[-1].role == "assistant" ): + # Remove the auto-added blank assistant turn, if present. if conv.messages and conv.messages[-1][1] is None: conv.messages.pop() + # Rebuild the prompt from the conversation. prompt = conv.get_prompt() - # Strip trailing stop tokens + # Strip trailing stop tokens or separators that indicate end-of-assistant. if isinstance(conv.stop_str, list): for stop_token in conv.stop_str: if prompt.endswith(stop_token): From 3e03b742f29e29eb79b6ddaa08f66499b1c2d287 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Jun 2025 08:43:38 +0000 Subject: [PATCH 25/33] Refactors test fixtures for clarity and remove some tests Updates test files to rename the handler fixtures to `serving_*` for better clarity and consistency across the tests. This change improves the readability of the tests and makes it easier to understand which object is being tested in each test case. Signed-off-by: Xinyuan Tong --- test/srt/openai/test_serving_chat.py | 188 ++---- test/srt/openai/test_serving_completions.py | 608 +------------------- test/srt/openai/test_serving_embedding.py | 172 +----- 3 files changed, 95 insertions(+), 873 deletions(-) diff --git a/test/srt/openai/test_serving_chat.py b/test/srt/openai/test_serving_chat.py index 7bfa50d2532..6f2655e5865 100644 --- a/test/srt/openai/test_serving_chat.py +++ b/test/srt/openai/test_serving_chat.py @@ -61,7 +61,7 @@ def mock_tokenizer_manager(): @pytest.fixture -def chat_handler(mock_tokenizer_manager): +def serving_chat(mock_tokenizer_manager): """Create a OpenAIServingChat instance for testing.""" return OpenAIServingChat(mock_tokenizer_manager) @@ -102,7 +102,7 @@ class TestOpenAIServingChatConversion: """Test request conversion methods.""" def test_convert_to_internal_request_single( - self, chat_handler, basic_chat_request, mock_tokenizer_manager + self, serving_chat, basic_chat_request, mock_tokenizer_manager ): """Test converting single request to internal format.""" with patch( @@ -117,7 +117,7 @@ def test_convert_to_internal_request_single( mock_conv.return_value = mock_conv_instance # Mock the _process_messages method to return expected values - with patch.object(chat_handler, "_process_messages") as mock_process: + with patch.object(serving_chat, "_process_messages") as mock_process: mock_process.return_value = ( "Test prompt", [1, 2, 3], @@ -128,7 +128,7 @@ def test_convert_to_internal_request_single( ) adapted_request, processed_request = ( - chat_handler._convert_to_internal_request( + serving_chat._convert_to_internal_request( [basic_chat_request], ["test-id"] ) ) @@ -141,7 +141,7 @@ def test_convert_to_internal_request_single( class TestToolCalls: """Test tool call functionality from adapter.py""" - def test_tool_call_request_conversion(self, chat_handler): + def test_tool_call_request_conversion(self, serving_chat): """Test request with tool calls""" request = ChatCompletionRequest( model="test-model", @@ -162,7 +162,7 @@ def test_tool_call_request_conversion(self, chat_handler): tool_choice="auto", ) - with patch.object(chat_handler, "_process_messages") as mock_process: + with patch.object(serving_chat, "_process_messages") as mock_process: mock_process.return_value = ( "Test prompt", [1, 2, 3], @@ -172,7 +172,7 @@ def test_tool_call_request_conversion(self, chat_handler): [""], ) - adapted_request, _ = chat_handler._convert_to_internal_request( + adapted_request, _ = serving_chat._convert_to_internal_request( [request], ["test-id"] ) @@ -180,7 +180,7 @@ def test_tool_call_request_conversion(self, chat_handler): # Tool call constraint should be processed assert request.tools is not None - def test_tool_choice_none(self, chat_handler): + def test_tool_choice_none(self, serving_chat): """Test tool_choice=none disables tool calls""" request = ChatCompletionRequest( model="test-model", @@ -189,7 +189,7 @@ def test_tool_choice_none(self, chat_handler): tool_choice="none", ) - with patch.object(chat_handler, "_process_messages") as mock_process: + with patch.object(serving_chat, "_process_messages") as mock_process: mock_process.return_value = ( "Test prompt", [1, 2, 3], @@ -199,14 +199,14 @@ def test_tool_choice_none(self, chat_handler): [""], ) - adapted_request, _ = chat_handler._convert_to_internal_request( + adapted_request, _ = serving_chat._convert_to_internal_request( [request], ["test-id"] ) # Tools should not be processed when tool_choice is "none" assert adapted_request.rid == "test-id" - def test_tool_call_response_processing(self, chat_handler): + def test_tool_call_response_processing(self, serving_chat): """Test processing tool calls in response""" mock_ret_item = { "text": '{"name": "get_weather", "parameters": {"location": "Paris"}}', @@ -246,7 +246,7 @@ def test_tool_call_response_processing(self, chat_handler): mock_parser.parse_non_stream.return_value = ("", [mock_tool_call]) mock_parser_class.return_value = mock_parser - tool_calls, text, updated_finish_reason = chat_handler._process_tool_calls( + tool_calls, text, updated_finish_reason = serving_chat._process_tool_calls( mock_ret_item["text"], tools, "hermes", finish_reason ) @@ -258,7 +258,7 @@ def test_tool_call_response_processing(self, chat_handler): class TestMultimodalContent: """Test multimodal content handling from adapter.py""" - def test_multimodal_request_with_images(self, chat_handler): + def test_multimodal_request_with_images(self, serving_chat): """Test request with image content""" request = ChatCompletionRequest( model="test-model", @@ -277,9 +277,9 @@ def test_multimodal_request_with_images(self, chat_handler): ) # Set multimodal mode - chat_handler.tokenizer_manager.model_config.is_multimodal = True + serving_chat.tokenizer_manager.model_config.is_multimodal = True - with patch.object(chat_handler, "_apply_jinja_template") as mock_apply: + with patch.object(serving_chat, "_apply_jinja_template") as mock_apply: mock_apply.return_value = ( "prompt", [1, 2, 3], @@ -290,18 +290,18 @@ def test_multimodal_request_with_images(self, chat_handler): ) with patch.object( - chat_handler, "_apply_conversation_template" + serving_chat, "_apply_conversation_template" ) as mock_conv: mock_conv.return_value = ("prompt", ["image_data"], None, [], []) prompt, prompt_ids, image_data, audio_data, modalities, stop = ( - chat_handler._process_messages(request, True) + serving_chat._process_messages(request, True) ) assert image_data == ["image_data"] assert prompt == "prompt" - def test_multimodal_request_with_audio(self, chat_handler): + def test_multimodal_request_with_audio(self, serving_chat): """Test request with audio content""" request = ChatCompletionRequest( model="test-model", @@ -319,9 +319,9 @@ def test_multimodal_request_with_audio(self, chat_handler): ], ) - chat_handler.tokenizer_manager.model_config.is_multimodal = True + serving_chat.tokenizer_manager.model_config.is_multimodal = True - with patch.object(chat_handler, "_apply_jinja_template") as mock_apply: + with patch.object(serving_chat, "_apply_jinja_template") as mock_apply: mock_apply.return_value = ( "prompt", [1, 2, 3], @@ -332,12 +332,12 @@ def test_multimodal_request_with_audio(self, chat_handler): ) with patch.object( - chat_handler, "_apply_conversation_template" + serving_chat, "_apply_conversation_template" ) as mock_conv: mock_conv.return_value = ("prompt", None, ["audio_data"], ["audio"], []) prompt, prompt_ids, image_data, audio_data, modalities, stop = ( - chat_handler._process_messages(request, True) + serving_chat._process_messages(request, True) ) assert audio_data == ["audio_data"] @@ -347,17 +347,17 @@ def test_multimodal_request_with_audio(self, chat_handler): class TestTemplateHandling: """Test chat template handling from adapter.py""" - def test_jinja_template_processing(self, chat_handler): + def test_jinja_template_processing(self, serving_chat): """Test Jinja template processing""" request = ChatCompletionRequest( model="test-model", messages=[{"role": "user", "content": "Hello"}] ) # Mock the template attribute directly - chat_handler.tokenizer_manager.chat_template_name = None - chat_handler.tokenizer_manager.tokenizer.chat_template = "" + serving_chat.tokenizer_manager.chat_template_name = None + serving_chat.tokenizer_manager.tokenizer.chat_template = "" - with patch.object(chat_handler, "_apply_jinja_template") as mock_apply: + with patch.object(serving_chat, "_apply_jinja_template") as mock_apply: mock_apply.return_value = ( "processed_prompt", [1, 2, 3], @@ -372,31 +372,31 @@ def test_jinja_template_processing(self, chat_handler): mock_hasattr.return_value = True prompt, prompt_ids, image_data, audio_data, modalities, stop = ( - chat_handler._process_messages(request, False) + serving_chat._process_messages(request, False) ) assert prompt == "processed_prompt" assert prompt_ids == [1, 2, 3] - def test_conversation_template_processing(self, chat_handler): + def test_conversation_template_processing(self, serving_chat): """Test conversation template processing""" request = ChatCompletionRequest( model="test-model", messages=[{"role": "user", "content": "Hello"}] ) - chat_handler.tokenizer_manager.chat_template_name = "llama-3" + serving_chat.tokenizer_manager.chat_template_name = "llama-3" - with patch.object(chat_handler, "_apply_conversation_template") as mock_apply: + with patch.object(serving_chat, "_apply_conversation_template") as mock_apply: mock_apply.return_value = ("conv_prompt", None, None, [], [""]) prompt, prompt_ids, image_data, audio_data, modalities, stop = ( - chat_handler._process_messages(request, False) + serving_chat._process_messages(request, False) ) assert prompt == "conv_prompt" assert stop == [""] - def test_continue_final_message(self, chat_handler): + def test_continue_final_message(self, serving_chat): """Test continue_final_message functionality""" request = ChatCompletionRequest( model="test-model", @@ -407,11 +407,11 @@ def test_continue_final_message(self, chat_handler): continue_final_message=True, ) - with patch.object(chat_handler, "_apply_conversation_template") as mock_apply: + with patch.object(serving_chat, "_apply_conversation_template") as mock_apply: mock_apply.return_value = ("Hi there", None, None, [], [""]) prompt, prompt_ids, image_data, audio_data, modalities, stop = ( - chat_handler._process_messages(request, False) + serving_chat._process_messages(request, False) ) # Should handle continue_final_message properly @@ -421,7 +421,7 @@ def test_continue_final_message(self, chat_handler): class TestReasoningContent: """Test reasoning content separation from adapter.py""" - def test_reasoning_content_request(self, chat_handler): + def test_reasoning_content_request(self, serving_chat): """Test request with reasoning content separation""" request = ChatCompletionRequest( model="test-model", @@ -430,7 +430,7 @@ def test_reasoning_content_request(self, chat_handler): stream_reasoning=False, ) - with patch.object(chat_handler, "_process_messages") as mock_process: + with patch.object(serving_chat, "_process_messages") as mock_process: mock_process.return_value = ( "Test prompt", [1, 2, 3], @@ -440,14 +440,14 @@ def test_reasoning_content_request(self, chat_handler): [""], ) - adapted_request, _ = chat_handler._convert_to_internal_request( + adapted_request, _ = serving_chat._convert_to_internal_request( [request], ["test-id"] ) assert adapted_request.rid == "test-id" assert request.separate_reasoning == True - def test_reasoning_content_response(self, chat_handler): + def test_reasoning_content_response(self, serving_chat): """Test reasoning content in response""" mock_ret_item = { "text": "This is reasoningAnswer: 42", @@ -485,7 +485,7 @@ def test_reasoning_content_response(self, chat_handler): class TestSamplingParams: """Test sampling parameter handling from adapter.py""" - def test_all_sampling_parameters(self, chat_handler): + def test_all_sampling_parameters(self, serving_chat): """Test all sampling parameters are properly handled""" request = ChatCompletionRequest( model="test-model", @@ -511,7 +511,7 @@ def test_all_sampling_parameters(self, chat_handler): logit_bias={"1": 0.5, "2": -0.3}, ) - with patch.object(chat_handler, "_process_messages") as mock_process: + with patch.object(serving_chat, "_process_messages") as mock_process: mock_process.return_value = ( "Test prompt", [1, 2, 3], @@ -521,7 +521,7 @@ def test_all_sampling_parameters(self, chat_handler): [""], ) - sampling_params = chat_handler._build_sampling_params(request, [""]) + sampling_params = serving_chat._build_sampling_params(request, [""]) # Verify all parameters assert sampling_params["temperature"] == 0.8 @@ -536,7 +536,7 @@ def test_all_sampling_parameters(self, chat_handler): assert sampling_params["stop"] == [""] assert sampling_params["logit_bias"] == {"1": 0.5, "2": -0.3} - def test_response_format_json_schema(self, chat_handler): + def test_response_format_json_schema(self, serving_chat): """Test response format with JSON schema""" request = ChatCompletionRequest( model="test-model", @@ -553,7 +553,7 @@ def test_response_format_json_schema(self, chat_handler): }, ) - with patch.object(chat_handler, "_process_messages") as mock_process: + with patch.object(serving_chat, "_process_messages") as mock_process: mock_process.return_value = ( "Test prompt", [1, 2, 3], @@ -563,12 +563,12 @@ def test_response_format_json_schema(self, chat_handler): [""], ) - sampling_params = chat_handler._build_sampling_params(request, [""]) + sampling_params = serving_chat._build_sampling_params(request, [""]) assert "json_schema" in sampling_params assert '"type": "object"' in sampling_params["json_schema"] - def test_response_format_json_object(self, chat_handler): + def test_response_format_json_object(self, serving_chat): """Test response format with JSON object""" request = ChatCompletionRequest( model="test-model", @@ -576,7 +576,7 @@ def test_response_format_json_object(self, chat_handler): response_format={"type": "json_object"}, ) - with patch.object(chat_handler, "_process_messages") as mock_process: + with patch.object(serving_chat, "_process_messages") as mock_process: mock_process.return_value = ( "Test prompt", [1, 2, 3], @@ -586,100 +586,6 @@ def test_response_format_json_object(self, chat_handler): [""], ) - sampling_params = chat_handler._build_sampling_params(request, [""]) + sampling_params = serving_chat._build_sampling_params(request, [""]) assert sampling_params["json_schema"] == '{"type": "object"}' - - -class TestUtilityFunctions: - """Test utility functions that were moved from OpenAIServingBase.""" - - def test_create_error_response_functionality(self, chat_handler): - """Test that create_error_response works correctly.""" - error = chat_handler.create_error_response("Test error message") - assert isinstance(error, ErrorResponse) - assert error.message == "Test error message" - assert error.type == "BadRequestError" - assert error.code == 400 - - -class TestOpenAIServingChatCompatibility: - """Test compatibility with adapter.py functionality.""" - - def test_compatibility_request_structure(self): - """Test that the request structure matches what adapter.py expects.""" - # Test with all the parameters that adapter.py supports - request = ChatCompletionRequest( - model="test-model", - messages=[{"role": "user", "content": "Hello"}], - temperature=0.8, - max_tokens=150, - top_p=0.9, - top_k=50, - presence_penalty=0.1, - frequency_penalty=0.2, - repetition_penalty=1.1, - stop=["<|endoftext|>"], - stream=False, - logprobs=True, - top_logprobs=5, - n=1, - continue_final_message=False, - separate_reasoning=True, - stream_reasoning=False, - ) - - # Verify that the request can be created without errors - assert request.model == "test-model" - assert request.temperature == 0.8 - assert request.max_tokens == 150 - assert request.top_p == 0.9 - assert request.top_k == 50 - assert request.presence_penalty == 0.1 - assert request.frequency_penalty == 0.2 - assert request.repetition_penalty == 1.1 - assert request.stop == ["<|endoftext|>"] - assert request.stream == False - assert request.logprobs == True - assert request.top_logprobs == 5 - assert request.n == 1 - assert request.continue_final_message == False - assert request.separate_reasoning == True - assert request.stream_reasoning == False - - def test_compatibility_bootstrap_params(self, chat_handler): - """Test that bootstrap parameters are properly supported.""" - request = ChatCompletionRequest( - model="test-model", - messages=[{"role": "user", "content": "Hello"}], - bootstrap_host="localhost", - bootstrap_port=8998, - bootstrap_room=12345, - ) - - assert request.bootstrap_host == "localhost" - assert request.bootstrap_port == 8998 - assert request.bootstrap_room == 12345 - - # Mock the _process_messages method to return expected values - with patch.object(chat_handler, "_process_messages") as mock_process: - mock_process.return_value = ( - "Test prompt", - [1, 2, 3], - None, - None, - [], - [""], - ) - - adapted_request, _ = chat_handler._convert_to_internal_request( - [request], ["test-id"] - ) - - assert adapted_request.bootstrap_host == "localhost" - assert adapted_request.bootstrap_port == 8998 - assert adapted_request.bootstrap_room == 12345 - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/test/srt/openai/test_serving_completions.py b/test/srt/openai/test_serving_completions.py index 5a87332055f..2cb8a69eaa8 100644 --- a/test/srt/openai/test_serving_completions.py +++ b/test/srt/openai/test_serving_completions.py @@ -45,102 +45,39 @@ def mock_tokenizer_manager(): @pytest.fixture -def completion_handler(mock_tokenizer_manager): - """Create a completion handler instance""" +def serving_completion(mock_tokenizer_manager): + """Create a OpenAIServingCompletion instance""" return OpenAIServingCompletion(mock_tokenizer_manager) -class TestUtilityFunctions: - """Test utility functions that were moved from OpenAIServingBase.""" - - def test_create_error_response_functionality(self, completion_handler): - """Test that create_error_response works correctly.""" - error = completion_handler.create_error_response("Test error message") - assert isinstance(error, ErrorResponse) - assert error.message == "Test error message" - assert error.type == "BadRequestError" - assert error.code == 400 - - def test_create_streaming_error_response_functionality(self, completion_handler): - """Test that create_streaming_error_response works correctly.""" - error_json = completion_handler.create_streaming_error_response( - "Test streaming error" - ) - # Should return JSON string with error structure - import json - - error_data = json.loads(error_json) - assert "error" in error_data - assert error_data["error"]["message"] == "Test streaming error" - - class TestPromptHandling: """Test different prompt types and formats from adapter.py""" - def test_single_string_prompt(self, completion_handler): + def test_single_string_prompt(self, serving_completion): """Test handling single string prompt""" request = CompletionRequest( model="test-model", prompt="Hello world", max_tokens=100 ) - adapted_request, _ = completion_handler._convert_to_internal_request( + adapted_request, _ = serving_completion._convert_to_internal_request( [request], ["test-id"] ) assert adapted_request.text == "Hello world" - def test_single_token_ids_prompt(self, completion_handler): + def test_single_token_ids_prompt(self, serving_completion): """Test handling single token IDs prompt""" request = CompletionRequest( model="test-model", prompt=[1, 2, 3, 4], max_tokens=100 ) - adapted_request, _ = completion_handler._convert_to_internal_request( + adapted_request, _ = serving_completion._convert_to_internal_request( [request], ["test-id"] ) assert adapted_request.input_ids == [1, 2, 3, 4] - def test_multiple_string_prompts(self, completion_handler): - """Test handling multiple string prompts""" - requests = [ - CompletionRequest(model="test-model", prompt="Hello", max_tokens=50), - CompletionRequest(model="test-model", prompt="World", max_tokens=50), - ] - - adapted_request, _ = completion_handler._convert_to_internal_request( - requests, ["id1", "id2"] - ) - - assert adapted_request.text == ["Hello", "World"] - assert adapted_request.rid == ["id1", "id2"] - - def test_multiple_token_ids_prompts(self, completion_handler): - """Test handling multiple token IDs prompts""" - requests = [ - CompletionRequest(model="test-model", prompt=[1, 2], max_tokens=50), - CompletionRequest(model="test-model", prompt=[3, 4], max_tokens=50), - ] - - adapted_request, _ = completion_handler._convert_to_internal_request( - requests, ["id1", "id2"] - ) - - assert adapted_request.input_ids == [[1, 2], [3, 4]] - - def test_list_of_strings_prompt(self, completion_handler): - """Test handling list of strings as prompt""" - request = CompletionRequest( - model="test-model", prompt=["Hello", "world"], max_tokens=100 - ) - - adapted_request, _ = completion_handler._convert_to_internal_request( - [request], ["test-id"] - ) - - assert adapted_request.text == ["Hello", "world"] - - def test_completion_template_handling(self, completion_handler): + def test_completion_template_handling(self, serving_completion): """Test completion template processing""" request = CompletionRequest( model="test-model", @@ -157,7 +94,7 @@ def test_completion_template_handling(self, completion_handler): "sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt", return_value="processed_prompt", ): - adapted_request, _ = completion_handler._convert_to_internal_request( + adapted_request, _ = serving_completion._convert_to_internal_request( [request], ["test-id"] ) @@ -167,17 +104,17 @@ def test_completion_template_handling(self, completion_handler): class TestEchoHandling: """Test echo functionality from adapter.py""" - def test_echo_with_string_prompt_streaming(self, completion_handler): + def test_echo_with_string_prompt_streaming(self, serving_completion): """Test echo handling with string prompt in streaming""" request = CompletionRequest( model="test-model", prompt="Hello", max_tokens=100, echo=True ) # Test _get_echo_text method - echo_text = completion_handler._get_echo_text(request, 0) + echo_text = serving_completion._get_echo_text(request, 0) assert echo_text == "Hello" - def test_echo_with_list_of_strings_streaming(self, completion_handler): + def test_echo_with_list_of_strings_streaming(self, serving_completion): """Test echo handling with list of strings in streaming""" request = CompletionRequest( model="test-model", @@ -187,40 +124,40 @@ def test_echo_with_list_of_strings_streaming(self, completion_handler): n=1, ) - echo_text = completion_handler._get_echo_text(request, 0) + echo_text = serving_completion._get_echo_text(request, 0) assert echo_text == "Hello" - echo_text = completion_handler._get_echo_text(request, 1) + echo_text = serving_completion._get_echo_text(request, 1) assert echo_text == "World" - def test_echo_with_token_ids_streaming(self, completion_handler): + def test_echo_with_token_ids_streaming(self, serving_completion): """Test echo handling with token IDs in streaming""" request = CompletionRequest( model="test-model", prompt=[1, 2, 3], max_tokens=100, echo=True ) - completion_handler.tokenizer_manager.tokenizer.decode.return_value = ( + serving_completion.tokenizer_manager.tokenizer.decode.return_value = ( "decoded_prompt" ) - echo_text = completion_handler._get_echo_text(request, 0) + echo_text = serving_completion._get_echo_text(request, 0) assert echo_text == "decoded_prompt" - def test_echo_with_multiple_token_ids_streaming(self, completion_handler): + def test_echo_with_multiple_token_ids_streaming(self, serving_completion): """Test echo handling with multiple token ID prompts in streaming""" request = CompletionRequest( model="test-model", prompt=[[1, 2], [3, 4]], max_tokens=100, echo=True, n=1 ) - completion_handler.tokenizer_manager.tokenizer.decode.return_value = "decoded" - echo_text = completion_handler._get_echo_text(request, 0) + serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded" + echo_text = serving_completion._get_echo_text(request, 0) assert echo_text == "decoded" - def test_prepare_echo_prompts_non_streaming(self, completion_handler): + def test_prepare_echo_prompts_non_streaming(self, serving_completion): """Test prepare echo prompts for non-streaming response""" # Test with single string request = CompletionRequest(model="test-model", prompt="Hello", echo=True) - echo_prompts = completion_handler._prepare_echo_prompts(request) + echo_prompts = serving_completion._prepare_echo_prompts(request) assert echo_prompts == ["Hello"] # Test with list of strings @@ -228,509 +165,12 @@ def test_prepare_echo_prompts_non_streaming(self, completion_handler): model="test-model", prompt=["Hello", "World"], echo=True ) - echo_prompts = completion_handler._prepare_echo_prompts(request) + echo_prompts = serving_completion._prepare_echo_prompts(request) assert echo_prompts == ["Hello", "World"] # Test with token IDs request = CompletionRequest(model="test-model", prompt=[1, 2, 3], echo=True) - completion_handler.tokenizer_manager.tokenizer.decode.return_value = "decoded" - echo_prompts = completion_handler._prepare_echo_prompts(request) + serving_completion.tokenizer_manager.tokenizer.decode.return_value = "decoded" + echo_prompts = serving_completion._prepare_echo_prompts(request) assert echo_prompts == ["decoded"] - - -class TestCompletionRequestConversion: - """Test request conversion to internal format""" - - def test_convert_simple_string_prompt(self, completion_handler): - """Test conversion of simple string prompt""" - request = CompletionRequest( - model="test-model", prompt="Hello world", max_tokens=100, temperature=0.7 - ) - - adapted_request, processed_request = ( - completion_handler._convert_to_internal_request([request], ["test-id"]) - ) - - assert isinstance(adapted_request, GenerateReqInput) - assert adapted_request.text == "Hello world" - assert adapted_request.sampling_params["temperature"] == 0.7 - assert adapted_request.sampling_params["max_new_tokens"] == 100 - assert adapted_request.rid == "test-id" - assert processed_request == request - - def test_convert_token_ids_prompt(self, completion_handler): - """Test conversion of token IDs prompt""" - request = CompletionRequest( - model="test-model", prompt=[1, 2, 3, 4], max_tokens=100 - ) - - adapted_request, processed_request = ( - completion_handler._convert_to_internal_request([request], ["test-id"]) - ) - - assert isinstance(adapted_request, GenerateReqInput) - assert adapted_request.input_ids == [1, 2, 3, 4] - assert adapted_request.sampling_params["max_new_tokens"] == 100 - - def test_convert_logprob_start_len_with_echo_and_logprobs(self, completion_handler): - """Test logprob_start_len setting with echo and logprobs""" - request = CompletionRequest( - model="test-model", - prompt="Hello world", - max_tokens=100, - echo=True, - logprobs=5, - ) - - adapted_request, _ = completion_handler._convert_to_internal_request( - [request], ["test-id"] - ) - - # When echo=True and logprobs is set, should be 0 - assert adapted_request.logprob_start_len == 0 - assert adapted_request.return_logprob == True - assert adapted_request.top_logprobs_num == 5 - - def test_convert_logprob_start_len_without_echo(self, completion_handler): - """Test logprob_start_len setting without echo""" - request = CompletionRequest( - model="test-model", - prompt="Hello world", - max_tokens=100, - echo=False, - logprobs=3, - ) - - adapted_request, _ = completion_handler._convert_to_internal_request( - [request], ["test-id"] - ) - - # When echo=False, should be -1 - assert adapted_request.logprob_start_len == -1 - assert adapted_request.return_logprob == True - assert adapted_request.top_logprobs_num == 3 - - -class TestCompatibilityWithAdapter: - """Test compatibility with adapter.py functionality""" - - def test_bootstrap_parameters_support(self, completion_handler): - """Test that bootstrap parameters are supported""" - request = CompletionRequest( - model="test-model", - prompt="Hello world", - max_tokens=100, - bootstrap_host="localhost", - bootstrap_port=8080, - bootstrap_room=123, - ) - - adapted_request, _ = completion_handler._convert_to_internal_request( - [request], ["test-id"] - ) - - assert adapted_request.bootstrap_host == "localhost" - assert adapted_request.bootstrap_port == 8080 - assert adapted_request.bootstrap_room == 123 - - def test_lora_path_support(self, completion_handler): - """Test that LoRA path is supported""" - request = CompletionRequest( - model="test-model", - prompt="Hello world", - max_tokens=100, - lora_path="/path/to/lora", - ) - - adapted_request, _ = completion_handler._convert_to_internal_request( - [request], ["test-id"] - ) - - assert adapted_request.lora_path == "/path/to/lora" - - def test_echo_and_logprobs_compatibility(self, completion_handler): - """Test echo and logprobs handling matches adapter behavior""" - request = CompletionRequest( - model="test-model", - prompt="Hello world", - max_tokens=100, - echo=True, - logprobs=5, - ) - - adapted_request, _ = completion_handler._convert_to_internal_request( - [request], ["test-id"] - ) - - # When echo=True and logprobs is set, logprob_start_len should be 0 - assert adapted_request.logprob_start_len == 0 - assert adapted_request.return_logprob == True - assert adapted_request.top_logprobs_num == 5 - - def test_no_echo_logprobs_compatibility(self, completion_handler): - """Test no echo but logprobs handling""" - request = CompletionRequest( - model="test-model", - prompt="Hello world", - max_tokens=100, - echo=False, - logprobs=3, - ) - - adapted_request, _ = completion_handler._convert_to_internal_request( - [request], ["test-id"] - ) - - # When echo=False, logprob_start_len should be -1 - assert adapted_request.logprob_start_len == -1 - assert adapted_request.return_logprob == True - assert adapted_request.top_logprobs_num == 3 - - def test_return_text_in_logprobs_setting(self, completion_handler): - """Test that return_text_in_logprobs is properly set""" - request = CompletionRequest( - model="test-model", prompt="Hello world", max_tokens=100 - ) - - adapted_request, _ = completion_handler._convert_to_internal_request( - [request], ["test-id"] - ) - - assert adapted_request.return_text_in_logprobs == True - - def test_multiple_requests_batch_handling(self, completion_handler): - """Test handling of multiple requests in batch mode""" - requests = [ - CompletionRequest( - model="test-model", prompt="Hello", max_tokens=50, lora_path="/path1" - ), - CompletionRequest( - model="test-model", prompt="World", max_tokens=50, lora_path="/path2" - ), - ] - - adapted_request, processed_requests = ( - completion_handler._convert_to_internal_request(requests, ["id1", "id2"]) - ) - - assert adapted_request.text == ["Hello", "World"] - assert adapted_request.lora_path == ["/path1", "/path2"] - assert adapted_request.rid == ["id1", "id2"] - assert ( - processed_requests == requests - ) # Should return list for multiple requests - - -class TestResponseBuilding: - """Test response building functionality""" - - def test_build_simple_response(self, completion_handler): - """Test building simple completion response""" - request = CompletionRequest(model="test-model", prompt="Hello", max_tokens=100) - - mock_ret = [ - { - "text": " world!", - "meta_info": { - "id": "test-id", - "prompt_tokens": 5, - "completion_tokens": 10, - "finish_reason": {"type": "stop"}, - }, - } - ] - - response = completion_handler._build_completion_response( - request, mock_ret, 1234567890 - ) - - assert isinstance(response, CompletionResponse) - assert response.id == "test-id" - assert response.model == "test-model" - assert response.created == 1234567890 - assert len(response.choices) == 1 - assert response.choices[0].text == " world!" - assert response.choices[0].finish_reason == "stop" - assert response.usage.prompt_tokens == 5 - assert response.usage.completion_tokens == 10 - assert response.usage.total_tokens == 15 - - def test_build_response_with_echo(self, completion_handler): - """Test building response with echo enabled""" - request = CompletionRequest( - model="test-model", prompt="Hello", max_tokens=100, echo=True - ) - - # Mock echo prompts preparation - completion_handler._prepare_echo_prompts = Mock(return_value=["Hello"]) - - mock_ret = [ - { - "text": " world!", - "meta_info": { - "id": "test-id", - "prompt_tokens": 5, - "completion_tokens": 10, - "finish_reason": {"type": "stop"}, - }, - } - ] - - response = completion_handler._build_completion_response( - request, mock_ret, 1234567890 - ) - - # With echo=True, text should include the prompt - assert response.choices[0].text == "Hello world!" - - def test_build_response_with_logprobs(self, completion_handler): - """Test building response with logprobs""" - request = CompletionRequest( - model="test-model", prompt="Hello", max_tokens=100, logprobs=3 - ) - - mock_ret = [ - { - "text": " world!", - "meta_info": { - "id": "test-id", - "prompt_tokens": 5, - "completion_tokens": 10, - "finish_reason": {"type": "stop"}, - "output_token_logprobs": [(-0.1, 1, " world"), (-0.2, 2, "!")], - "output_top_logprobs": [ - [(-0.1, 1, " world"), (-0.3, 3, " earth")], - [(-0.2, 2, "!"), (-0.4, 4, ".")], - ], - }, - } - ] - - response = completion_handler._build_completion_response( - request, mock_ret, 1234567890 - ) - - assert response.choices[0].logprobs is not None - assert len(response.choices[0].logprobs.tokens) == 2 - assert response.choices[0].logprobs.tokens[0] == " world" - assert response.choices[0].logprobs.tokens[1] == "!" - - def test_build_response_with_echo_and_logprobs(self, completion_handler): - """Test building response with both echo and logprobs""" - request = CompletionRequest( - model="test-model", prompt="Hello", max_tokens=100, echo=True, logprobs=2 - ) - - completion_handler._prepare_echo_prompts = Mock(return_value=["Hello"]) - - mock_ret = [ - { - "text": " world!", - "meta_info": { - "id": "test-id", - "prompt_tokens": 5, - "completion_tokens": 10, - "finish_reason": {"type": "stop"}, - "input_token_logprobs": [(-0.05, 0, "Hello")], - "input_top_logprobs": [[(-0.05, 0, "Hello"), (-0.1, 1, "Hi")]], - "output_token_logprobs": [(-0.1, 1, " world"), (-0.2, 2, "!")], - "output_top_logprobs": [ - [(-0.1, 1, " world"), (-0.3, 3, " earth")], - [(-0.2, 2, "!"), (-0.4, 4, ".")], - ], - }, - } - ] - - response = completion_handler._build_completion_response( - request, mock_ret, 1234567890 - ) - - assert response.choices[0].text == "Hello world!" - assert response.choices[0].logprobs is not None - # Should include both input and output logprobs - assert len(response.choices[0].logprobs.tokens) == 3 # Hello + world + ! - - def test_build_response_with_matched_stop(self, completion_handler): - """Test building response with matched stop token""" - request = CompletionRequest(model="test-model", prompt="Hello", max_tokens=100) - - mock_ret = [ - { - "text": " world!", - "meta_info": { - "id": "test-id", - "prompt_tokens": 5, - "completion_tokens": 10, - "finish_reason": {"type": "stop", "matched": ""}, - }, - } - ] - - response = completion_handler._build_completion_response( - request, mock_ret, 1234567890 - ) - - assert response.choices[0].finish_reason == "stop" - assert response.choices[0].matched_stop == "" - - def test_build_response_with_cache_report(self, completion_handler): - """Test building response with cache reporting enabled""" - request = CompletionRequest(model="test-model", prompt="Hello", max_tokens=100) - - mock_ret = [ - { - "text": " world!", - "meta_info": { - "id": "test-id", - "prompt_tokens": 5, - "completion_tokens": 10, - "cached_tokens": 3, - "finish_reason": {"type": "stop"}, - }, - } - ] - - response = completion_handler._build_completion_response( - request, mock_ret, 1234567890, cache_report=True - ) - - assert response.usage.prompt_tokens_details is not None - assert response.usage.prompt_tokens_details["cached_tokens"] == 3 - - def test_build_response_multiple_choices(self, completion_handler): - """Test building response with multiple choices (n > 1)""" - request = CompletionRequest( - model="test-model", prompt="Hello", max_tokens=100, n=2 - ) - - completion_handler._prepare_echo_prompts = Mock(return_value=["Hello"]) - - mock_ret = [ - { - "text": " world!", - "meta_info": { - "id": "test-id", - "prompt_tokens": 5, - "completion_tokens": 10, - "finish_reason": {"type": "stop"}, - }, - }, - { - "text": " there!", - "meta_info": { - "id": "test-id", - "prompt_tokens": 5, - "completion_tokens": 8, - "finish_reason": {"type": "stop"}, - }, - }, - ] - - response = completion_handler._build_completion_response( - request, mock_ret, 1234567890 - ) - - assert len(response.choices) == 2 - assert response.choices[0].text == " world!" - assert response.choices[1].text == " there!" - assert response.choices[0].index == 0 - assert response.choices[1].index == 1 - # Total tokens should be: prompt_tokens + both completion_tokens - assert response.usage.total_tokens == 5 + 10 + 8 - - -@pytest.mark.asyncio -class TestAsyncMethods: - """Test async handler methods""" - - async def test_handle_request_non_streaming(self, completion_handler): - """Test handling non-streaming request - simplified test for async flow""" - mock_request = Mock() - request = CompletionRequest( - model="test-model", prompt="Hello world", max_tokens=100, stream=False - ) - - # For now, just test that we can call the method and get some response - # The detailed functionality is tested in the sync tests above - response = await completion_handler.handle_request(request, mock_request) - - # Should return some response (either error or success, depending on mock setup) - assert response is not None - assert hasattr(response, "model_dump") - - async def test_handle_request_streaming(self, completion_handler): - """Test handling streaming request""" - mock_request = Mock() - request = CompletionRequest( - model="test-model", prompt="Hello world", max_tokens=100, stream=True - ) - - response = await completion_handler.handle_request(request, mock_request) - - # Should return StreamingResponse - from fastapi.responses import StreamingResponse - - assert isinstance(response, StreamingResponse) - - async def test_handle_streaming_with_usage(self, completion_handler): - """Test streaming with usage reporting""" - mock_request = Mock() - request = CompletionRequest( - model="test-model", - prompt="Hello world", - max_tokens=100, - stream=True, - stream_options={"include_usage": True}, - ) - - response = await completion_handler.handle_request(request, mock_request) - - from fastapi.responses import StreamingResponse - - assert isinstance(response, StreamingResponse) - - -class TestEdgeCases: - """Test edge cases and error conditions""" - - def test_multiple_requests_different_prompt_types_error(self, completion_handler): - """Test error when multiple requests have different prompt types""" - requests = [ - CompletionRequest(model="test-model", prompt="Hello", max_tokens=50), - CompletionRequest(model="test-model", prompt=[1, 2, 3], max_tokens=50), - ] - - with pytest.raises(AssertionError): - completion_handler._convert_to_internal_request(requests, ["id1", "id2"]) - - def test_multiple_requests_with_n_greater_than_1_error(self, completion_handler): - """Test error when multiple requests have n > 1""" - requests = [ - CompletionRequest(model="test-model", prompt="Hello", max_tokens=50, n=2), - CompletionRequest(model="test-model", prompt="World", max_tokens=50, n=1), - ] - - with pytest.raises(ValueError, match="Parallel sampling is not supported"): - completion_handler._convert_to_internal_request(requests, ["id1", "id2"]) - - def test_suffix_without_completion_template(self, completion_handler): - """Test that suffix is ignored when completion template is not defined""" - request = CompletionRequest( - model="test-model", - prompt="def hello():", - suffix="return 'world'", - max_tokens=100, - ) - - with patch( - "sglang.srt.entrypoints.openai.serving_completions.is_completion_template_defined", - return_value=False, - ): - adapted_request, _ = completion_handler._convert_to_internal_request( - [request], ["test-id"] - ) - - # Should use original prompt, not processed with suffix - assert adapted_request.text == "def hello():" diff --git a/test/srt/openai/test_serving_embedding.py b/test/srt/openai/test_serving_embedding.py index dce273963be..94f6c5d2d62 100644 --- a/test/srt/openai/test_serving_embedding.py +++ b/test/srt/openai/test_serving_embedding.py @@ -64,7 +64,7 @@ def mock_tokenizer_manager(): @pytest.fixture -def embedding_handler(mock_tokenizer_manager): +def serving_embedding(mock_tokenizer_manager): """Create an OpenAIServingEmbedding instance for testing.""" return OpenAIServingEmbedding(mock_tokenizer_manager) @@ -124,11 +124,11 @@ class TestOpenAIServingEmbeddingConversion: """Test request conversion methods.""" def test_convert_single_string_request( - self, embedding_handler, basic_embedding_request + self, serving_embedding, basic_embedding_request ): """Test converting single string request to internal format.""" adapted_request, processed_request = ( - embedding_handler._convert_to_internal_request( + serving_embedding._convert_to_internal_request( [basic_embedding_request], ["test-id"] ) ) @@ -139,11 +139,11 @@ def test_convert_single_string_request( assert processed_request == basic_embedding_request def test_convert_list_string_request( - self, embedding_handler, list_embedding_request + self, serving_embedding, list_embedding_request ): """Test converting list of strings request to internal format.""" adapted_request, processed_request = ( - embedding_handler._convert_to_internal_request( + serving_embedding._convert_to_internal_request( [list_embedding_request], ["test-id"] ) ) @@ -154,11 +154,11 @@ def test_convert_list_string_request( assert processed_request == list_embedding_request def test_convert_token_ids_request( - self, embedding_handler, token_ids_embedding_request + self, serving_embedding, token_ids_embedding_request ): """Test converting token IDs request to internal format.""" adapted_request, processed_request = ( - embedding_handler._convert_to_internal_request( + serving_embedding._convert_to_internal_request( [token_ids_embedding_request], ["test-id"] ) ) @@ -169,11 +169,11 @@ def test_convert_token_ids_request( assert processed_request == token_ids_embedding_request def test_convert_multimodal_request( - self, embedding_handler, multimodal_embedding_request + self, serving_embedding, multimodal_embedding_request ): """Test converting multimodal request to internal format.""" adapted_request, processed_request = ( - embedding_handler._convert_to_internal_request( + serving_embedding._convert_to_internal_request( [multimodal_embedding_request], ["test-id"] ) ) @@ -187,37 +187,11 @@ def test_convert_multimodal_request( assert adapted_request.image_data[1] is None assert adapted_request.rid == "test-id" - def test_convert_batch_requests(self, embedding_handler): - """Test converting multiple requests (batch) to internal format.""" - request1 = EmbeddingRequest(model="test-model", input="First text") - request2 = EmbeddingRequest(model="test-model", input="Second text") - - adapted_request, processed_requests = ( - embedding_handler._convert_to_internal_request( - [request1, request2], ["id1", "id2"] - ) - ) - - assert isinstance(adapted_request, EmbeddingReqInput) - assert adapted_request.text == ["First text", "Second text"] - assert adapted_request.rid == ["id1", "id2"] - assert processed_requests == [request1, request2] - - def test_convert_batch_requests_type_mismatch_error(self, embedding_handler): - """Test that batch requests with different input types raise error.""" - request1 = EmbeddingRequest(model="test-model", input="String input") - request2 = EmbeddingRequest(model="test-model", input=[1, 2, 3]) # Token IDs - - with pytest.raises(AssertionError, match="same type"): - embedding_handler._convert_to_internal_request( - [request1, request2], ["id1", "id2"] - ) - class TestEmbeddingResponseBuilding: """Test response building methods.""" - def test_build_single_embedding_response(self, embedding_handler): + def test_build_single_embedding_response(self, serving_embedding): """Test building response for single embedding.""" ret_data = [ { @@ -226,7 +200,7 @@ def test_build_single_embedding_response(self, embedding_handler): } ] - response = embedding_handler._build_embedding_response(ret_data, "test-model") + response = serving_embedding._build_embedding_response(ret_data, "test-model") assert isinstance(response, EmbeddingResponse) assert response.model == "test-model" @@ -238,7 +212,7 @@ def test_build_single_embedding_response(self, embedding_handler): assert response.usage.total_tokens == 5 assert response.usage.completion_tokens == 0 - def test_build_multiple_embedding_response(self, embedding_handler): + def test_build_multiple_embedding_response(self, serving_embedding): """Test building response for multiple embeddings.""" ret_data = [ { @@ -251,7 +225,7 @@ def test_build_multiple_embedding_response(self, embedding_handler): }, ] - response = embedding_handler._build_embedding_response(ret_data, "test-model") + response = serving_embedding._build_embedding_response(ret_data, "test-model") assert isinstance(response, EmbeddingResponse) assert len(response.data) == 2 @@ -268,7 +242,7 @@ class TestOpenAIServingEmbeddingAsyncMethods: """Test async methods of OpenAIServingEmbedding.""" async def test_handle_request_success( - self, embedding_handler, basic_embedding_request, mock_request + self, serving_embedding, basic_embedding_request, mock_request ): """Test successful embedding request handling.""" @@ -279,11 +253,11 @@ async def mock_generate(): "meta_info": {"prompt_tokens": 5}, } - embedding_handler.tokenizer_manager.generate_request = Mock( + serving_embedding.tokenizer_manager.generate_request = Mock( return_value=mock_generate() ) - response = await embedding_handler.handle_request( + response = await serving_embedding.handle_request( basic_embedding_request, mock_request ) @@ -292,18 +266,18 @@ async def mock_generate(): assert response.data[0].embedding == [0.1, 0.2, 0.3, 0.4, 0.5] async def test_handle_request_validation_error( - self, embedding_handler, mock_request + self, serving_embedding, mock_request ): """Test handling request with validation error.""" invalid_request = EmbeddingRequest(model="test-model", input="") - response = await embedding_handler.handle_request(invalid_request, mock_request) + response = await serving_embedding.handle_request(invalid_request, mock_request) assert isinstance(response, ErrorResponse) assert "empty" in response.message.lower() async def test_handle_request_generation_error( - self, embedding_handler, basic_embedding_request, mock_request + self, serving_embedding, basic_embedding_request, mock_request ): """Test handling request with generation error.""" @@ -312,11 +286,11 @@ async def mock_generate_error(): raise ValueError("Generation failed") yield # This won't be reached but needed for async generator - embedding_handler.tokenizer_manager.generate_request = Mock( + serving_embedding.tokenizer_manager.generate_request = Mock( return_value=mock_generate_error() ) - response = await embedding_handler.handle_request( + response = await serving_embedding.handle_request( basic_embedding_request, mock_request ) @@ -324,117 +298,19 @@ async def mock_generate_error(): assert "Generation failed" in response.message async def test_handle_request_internal_error( - self, embedding_handler, basic_embedding_request, mock_request + self, serving_embedding, basic_embedding_request, mock_request ): """Test handling request with internal server error.""" # Mock _convert_to_internal_request to raise an exception with patch.object( - embedding_handler, + serving_embedding, "_convert_to_internal_request", side_effect=Exception("Internal error"), ): - response = await embedding_handler.handle_request( + response = await serving_embedding.handle_request( basic_embedding_request, mock_request ) assert isinstance(response, ErrorResponse) assert "Internal server error" in response.message assert response.code == 500 - - -class TestCompatibilityWithAdapter: - """Test compatibility with original adapter.py implementation.""" - - def test_embedding_request_structure_matches_adapter(self, embedding_handler): - """Test that EmbeddingReqInput structure matches adapter expectations.""" - request = EmbeddingRequest(model="test-model", input="Test text") - - adapted_request, _ = embedding_handler._convert_to_internal_request( - [request], ["test-id"] - ) - - # Check that adapted_request has expected fields from adapter.py - assert hasattr(adapted_request, "rid") - assert hasattr(adapted_request, "text") or hasattr(adapted_request, "input_ids") - assert adapted_request.rid == "test-id" - - def test_multimodal_embedding_processing_compatibility(self, embedding_handler): - """Test multimodal processing matches adapter patterns.""" - multimodal_input = [ - MultimodalEmbeddingInput(text="Hello", image="image_data"), - MultimodalEmbeddingInput(text="World", image=None), - ] - request = EmbeddingRequest(model="test-model", input=multimodal_input) - - adapted_request, _ = embedding_handler._convert_to_internal_request( - [request], ["test-id"] - ) - - # Should have text and image_data fields like adapter - assert hasattr(adapted_request, "text") - assert hasattr(adapted_request, "image_data") - assert len(adapted_request.text) == 2 - assert len(adapted_request.image_data) == 2 - - def test_response_format_matches_adapter(self, embedding_handler): - """Test response format matches adapter.py output.""" - ret_data = [ - { - "embedding": [0.1, 0.2, 0.3], - "meta_info": {"prompt_tokens": 3}, - } - ] - - response = embedding_handler._build_embedding_response(ret_data, "test-model") - - # Check response structure matches adapter output - assert response.object == "list" - assert isinstance(response.data, list) - assert len(response.data) == 1 - assert response.data[0].object == "embedding" - assert isinstance(response.data[0].embedding, list) - assert isinstance(response.data[0].index, int) - assert isinstance(response.usage, UsageInfo) - - -class TestEdgeCases: - """Test edge cases and error conditions.""" - - def test_multimodal_batch_not_implemented(self, embedding_handler): - """Test that multimodal batch requests raise NotImplementedError.""" - request1 = EmbeddingRequest( - model="test-model", - input=[MultimodalEmbeddingInput(text="Hello", image="img1")], - ) - request2 = EmbeddingRequest( - model="test-model", - input=[MultimodalEmbeddingInput(text="World", image="img2")], - ) - - with pytest.raises(NotImplementedError, match="multimodal.*not supported"): - embedding_handler._convert_to_internal_request( - [request1, request2], ["id1", "id2"] - ) - - def test_empty_return_data_handling(self, embedding_handler): - """Test handling of empty return data from generation.""" - # Test with empty list - response = embedding_handler._build_embedding_response([], "test-model") - assert len(response.data) == 0 - assert response.usage.prompt_tokens == 0 - assert response.usage.total_tokens == 0 - - def test_missing_meta_info_handling(self, embedding_handler): - """Test handling of missing meta_info in return data.""" - ret_data = [ - { - "embedding": [0.1, 0.2, 0.3], - "meta_info": {}, # Missing prompt_tokens - } - ] - - # Should handle missing prompt_tokens gracefully - response = embedding_handler._build_embedding_response(ret_data, "test-model") - assert len(response.data) == 1 - # Should default to 0 for missing prompt_tokens - assert response.usage.prompt_tokens == 0 From ac908e19897bc0859246fa3b5e39642c5eb8a72f Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Jun 2025 08:45:51 +0000 Subject: [PATCH 26/33] Enables tool call constraint in sampling params Passes the tool call constraint to the sampling parameters and incorporates tool call constraint handling in sampling parameter building. This allows the model to respect constraints specified for tool calls during the sampling process. Signed-off-by: Xinyuan Tong --- .../srt/entrypoints/openai/serving_chat.py | 47 +++++++++---------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index d4cf5080364..861c0bbc088 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -107,7 +107,7 @@ def _convert_to_internal_request( for request in all_requests: # Process messages and apply chat template - prompt, prompt_ids, image_data, audio_data, modalities, stop = ( + prompt, prompt_ids, image_data, audio_data, modalities, stop, tool_call_constraint = ( self._process_messages(request, is_multimodal) ) @@ -119,7 +119,7 @@ def _convert_to_internal_request( lora_paths.append(request.lora_path) # Build sampling parameters - sampling_params = self._build_sampling_params(request, stop) + sampling_params = self._build_sampling_params(request, stop, tool_call_constraint) sampling_params_list.append(sampling_params) image_data_list.append(image_data) @@ -179,10 +179,10 @@ def _convert_to_internal_request( def _process_messages( self, request: ChatCompletionRequest, is_multimodal: bool ) -> tuple[ - str, Union[str, List[int]], Optional[Any], Optional[Any], List[str], List[str] + str, Union[str, List[int]], Optional[Any], Optional[Any], List[str], List[str], Optional[Any] ]: """Process chat messages and apply chat template""" - tool_call_constraint = None # TODO: how to pass this to the sampling params? + tool_call_constraint = None prompt = "" prompt_ids = [] @@ -229,7 +229,7 @@ def _process_messages( modalities = [] prompt = request.messages - return prompt, prompt_ids, image_data, audio_data, modalities, stop + return prompt, prompt_ids, image_data, audio_data, modalities, stop, tool_call_constraint def _apply_jinja_template( self, @@ -365,7 +365,7 @@ def _apply_conversation_template( return prompt, image_data, audio_data, modalities, stop def _build_sampling_params( - self, request: ChatCompletionRequest, stop: List[str] + self, request: ChatCompletionRequest, stop: List[str], tool_call_constraint: Optional[Any] ) -> Dict[str, Any]: """Build sampling parameters for the request""" @@ -403,25 +403,24 @@ def _build_sampling_params( request.response_format.model_dump(by_alias=True) ) - # TODO: how to handle tool call constraint? # Check if there are already existing output constraints - # has_existing_constraints = ( - # sampling_params.get("regex") - # or sampling_params.get("ebnf") - # or sampling_params.get("structural_tag") - # or sampling_params.get("json_schema") - # ) - - # if tool_call_constraint and has_existing_constraints: - # logger.warning("Constrained decoding is not compatible with tool calls.") - # elif tool_call_constraint: - # constraint_type, constraint_value = tool_call_constraint - # if constraint_type == "structural_tag": - # sampling_params[constraint_type] = convert_json_schema_to_str( - # constraint_value.model_dump(by_alias=True) - # ) - # else: - # sampling_params[constraint_type] = constraint_value + has_existing_constraints = ( + sampling_params.get("regex") + or sampling_params.get("ebnf") + or sampling_params.get("structural_tag") + or sampling_params.get("json_schema") + ) + + if tool_call_constraint and has_existing_constraints: + logger.warning("Constrained decoding is not compatible with tool calls.") + elif tool_call_constraint: + constraint_type, constraint_value = tool_call_constraint + if constraint_type == "structural_tag": + sampling_params[constraint_type] = convert_json_schema_to_str( + constraint_value.model_dump(by_alias=True) + ) + else: + sampling_params[constraint_type] = constraint_value return sampling_params async def _handle_streaming_request( From 69e41f78a28cf6592aa273fefbab93917c8228fd Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Jun 2025 08:48:05 +0000 Subject: [PATCH 27/33] move the `text = content["text"]` in serving_chat for Better readability Signed-off-by: Xinyuan Tong --- python/sglang/srt/entrypoints/openai/serving_chat.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 861c0bbc088..71949e3d3a9 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -447,7 +447,6 @@ async def generate_stream_resp(): adapted_request, raw_request ): index = content.get("index", 0) - text = content["text"] is_first = is_firsts.get(index, True) stream_buffer = stream_buffers.get(index, "") @@ -496,7 +495,7 @@ async def generate_stream_resp(): yield f"data: {chunk.model_dump_json()}\n\n" # Process content delta - delta = text[len(stream_buffer) :] + delta = content["text"][len(stream_buffer) :] new_stream_buffer = stream_buffer + delta # Handle reasoning content From 590db9a2964ba7b68d772150d606bddb7c20f3e8 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Jun 2025 08:50:31 +0000 Subject: [PATCH 28/33] lint Signed-off-by: Xinyuan Tong --- .../srt/entrypoints/openai/serving_chat.py | 39 +++++++++++++++---- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 71949e3d3a9..af9c9aa954b 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -107,9 +107,15 @@ def _convert_to_internal_request( for request in all_requests: # Process messages and apply chat template - prompt, prompt_ids, image_data, audio_data, modalities, stop, tool_call_constraint = ( - self._process_messages(request, is_multimodal) - ) + ( + prompt, + prompt_ids, + image_data, + audio_data, + modalities, + stop, + tool_call_constraint, + ) = self._process_messages(request, is_multimodal) input_ids.append(prompt_ids) prompts.append(prompt) @@ -119,7 +125,9 @@ def _convert_to_internal_request( lora_paths.append(request.lora_path) # Build sampling parameters - sampling_params = self._build_sampling_params(request, stop, tool_call_constraint) + sampling_params = self._build_sampling_params( + request, stop, tool_call_constraint + ) sampling_params_list.append(sampling_params) image_data_list.append(image_data) @@ -179,7 +187,13 @@ def _convert_to_internal_request( def _process_messages( self, request: ChatCompletionRequest, is_multimodal: bool ) -> tuple[ - str, Union[str, List[int]], Optional[Any], Optional[Any], List[str], List[str], Optional[Any] + str, + Union[str, List[int]], + Optional[Any], + Optional[Any], + List[str], + List[str], + Optional[Any], ]: """Process chat messages and apply chat template""" tool_call_constraint = None @@ -229,7 +243,15 @@ def _process_messages( modalities = [] prompt = request.messages - return prompt, prompt_ids, image_data, audio_data, modalities, stop, tool_call_constraint + return ( + prompt, + prompt_ids, + image_data, + audio_data, + modalities, + stop, + tool_call_constraint, + ) def _apply_jinja_template( self, @@ -365,7 +387,10 @@ def _apply_conversation_template( return prompt, image_data, audio_data, modalities, stop def _build_sampling_params( - self, request: ChatCompletionRequest, stop: List[str], tool_call_constraint: Optional[Any] + self, + request: ChatCompletionRequest, + stop: List[str], + tool_call_constraint: Optional[Any], ) -> Dict[str, Any]: """Build sampling parameters for the request""" From 4c140c80cc340b8d0434000b5fabd3c4ef0d58b1 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Jun 2025 08:51:49 +0000 Subject: [PATCH 29/33] remove redundant logic Signed-off-by: Xinyuan Tong --- python/sglang/srt/entrypoints/openai/serving_chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index af9c9aa954b..edce6406cd4 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -549,7 +549,7 @@ async def generate_stream_resp(): ) yield f"data: {chunk.model_dump_json()}\n\n" - if not delta or len(delta) == 0: + if not delta: stream_buffers[index] = new_stream_buffer is_firsts[index] = is_first n_prev_tokens[index] = n_prev_token From 7190e6fe8cfd50b0747f968bbdacfcc29ed98a4c Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Jun 2025 08:54:23 +0000 Subject: [PATCH 30/33] logic for generate_completion_prompt Signed-off-by: Xinyuan Tong --- .../sglang/srt/entrypoints/openai/serving_completions.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index 2fd26f7cc5e..af501727562 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -5,8 +5,7 @@ from fastapi.responses import StreamingResponse from sglang.srt.code_completion_parser import ( - completion_template_name, - generate_completion_prompt, + generate_completion_prompt_from_request, is_completion_template_defined, ) from sglang.srt.entrypoints.openai.protocol import ( @@ -101,10 +100,8 @@ def _convert_to_internal_request( # Process prompt prompt = request.prompt if is_completion_template_defined(): - if request.suffix: - prompt = generate_completion_prompt( - str(request.prompt), request.suffix, completion_template_name - ) + prompt = generate_completion_prompt_from_request(request) + prompts.append(prompt) lora_paths.append(request.lora_path) From 40e97fc829a5a6ff7bb7c93e376ad4e45a95be05 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Jun 2025 08:55:58 +0000 Subject: [PATCH 31/33] Add comments back Signed-off-by: Xinyuan Tong --- python/sglang/srt/entrypoints/openai/serving_chat.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index edce6406cd4..54b490131ab 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -314,7 +314,9 @@ def _apply_jinja_template( ), ) except Exception: - # Handle different tools input format (e.g., Mistral) + # This except branch will be triggered when the chosen model + # has a different tools input format that is not compatible + # with openAI's apply_chat_template tool_call format, like Mistral. tools = ( [t if "function" in t else {"function": t} for t in tools] if tools From b95a2884634518a10aa90f8ae75fe626b3ae018a Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Jun 2025 09:21:23 +0000 Subject: [PATCH 32/33] fix tests Signed-off-by: Xinyuan Tong --- test/srt/openai/test_serving_chat.py | 23 ++++++++++++++------- test/srt/openai/test_serving_completions.py | 2 +- test/srt/openai/test_serving_embedding.py | 14 ++++++------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/test/srt/openai/test_serving_chat.py b/test/srt/openai/test_serving_chat.py index 6f2655e5865..e5852b234a1 100644 --- a/test/srt/openai/test_serving_chat.py +++ b/test/srt/openai/test_serving_chat.py @@ -125,6 +125,7 @@ def test_convert_to_internal_request_single( None, [], [""], + None, # tool_call_constraint ) adapted_request, processed_request = ( @@ -170,6 +171,7 @@ def test_tool_call_request_conversion(self, serving_chat): None, [], [""], + None, # tool_call_constraint ) adapted_request, _ = serving_chat._convert_to_internal_request( @@ -197,6 +199,7 @@ def test_tool_choice_none(self, serving_chat): None, [], [""], + None, # tool_call_constraint ) adapted_request, _ = serving_chat._convert_to_internal_request( @@ -294,7 +297,7 @@ def test_multimodal_request_with_images(self, serving_chat): ) as mock_conv: mock_conv.return_value = ("prompt", ["image_data"], None, [], []) - prompt, prompt_ids, image_data, audio_data, modalities, stop = ( + prompt, prompt_ids, image_data, audio_data, modalities, stop, tool_call_constraint = ( serving_chat._process_messages(request, True) ) @@ -336,7 +339,7 @@ def test_multimodal_request_with_audio(self, serving_chat): ) as mock_conv: mock_conv.return_value = ("prompt", None, ["audio_data"], ["audio"], []) - prompt, prompt_ids, image_data, audio_data, modalities, stop = ( + prompt, prompt_ids, image_data, audio_data, modalities, stop, tool_call_constraint = ( serving_chat._process_messages(request, True) ) @@ -371,7 +374,7 @@ def test_jinja_template_processing(self, serving_chat): with patch("builtins.hasattr") as mock_hasattr: mock_hasattr.return_value = True - prompt, prompt_ids, image_data, audio_data, modalities, stop = ( + prompt, prompt_ids, image_data, audio_data, modalities, stop, tool_call_constraint = ( serving_chat._process_messages(request, False) ) @@ -389,7 +392,7 @@ def test_conversation_template_processing(self, serving_chat): with patch.object(serving_chat, "_apply_conversation_template") as mock_apply: mock_apply.return_value = ("conv_prompt", None, None, [], [""]) - prompt, prompt_ids, image_data, audio_data, modalities, stop = ( + prompt, prompt_ids, image_data, audio_data, modalities, stop, tool_call_constraint = ( serving_chat._process_messages(request, False) ) @@ -410,7 +413,7 @@ def test_continue_final_message(self, serving_chat): with patch.object(serving_chat, "_apply_conversation_template") as mock_apply: mock_apply.return_value = ("Hi there", None, None, [], [""]) - prompt, prompt_ids, image_data, audio_data, modalities, stop = ( + prompt, prompt_ids, image_data, audio_data, modalities, stop, tool_call_constraint = ( serving_chat._process_messages(request, False) ) @@ -438,6 +441,7 @@ def test_reasoning_content_request(self, serving_chat): None, [], [""], + None, # tool_call_constraint ) adapted_request, _ = serving_chat._convert_to_internal_request( @@ -519,9 +523,10 @@ def test_all_sampling_parameters(self, serving_chat): None, [], [""], + None, # tool_call_constraint ) - sampling_params = serving_chat._build_sampling_params(request, [""]) + sampling_params = serving_chat._build_sampling_params(request, [""], None) # Verify all parameters assert sampling_params["temperature"] == 0.8 @@ -561,9 +566,10 @@ def test_response_format_json_schema(self, serving_chat): None, [], [""], + None, # tool_call_constraint ) - sampling_params = serving_chat._build_sampling_params(request, [""]) + sampling_params = serving_chat._build_sampling_params(request, [""], None) assert "json_schema" in sampling_params assert '"type": "object"' in sampling_params["json_schema"] @@ -584,8 +590,9 @@ def test_response_format_json_object(self, serving_chat): None, [], [""], + None, # tool_call_constraint ) - sampling_params = serving_chat._build_sampling_params(request, [""]) + sampling_params = serving_chat._build_sampling_params(request, [""], None) assert sampling_params["json_schema"] == '{"type": "object"}' diff --git a/test/srt/openai/test_serving_completions.py b/test/srt/openai/test_serving_completions.py index 2cb8a69eaa8..3e8fc42c8c6 100644 --- a/test/srt/openai/test_serving_completions.py +++ b/test/srt/openai/test_serving_completions.py @@ -91,7 +91,7 @@ def test_completion_template_handling(self, serving_completion): return_value=True, ): with patch( - "sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt", + "sglang.srt.entrypoints.openai.serving_completions.generate_completion_prompt_from_request", return_value="processed_prompt", ): adapted_request, _ = serving_completion._convert_to_internal_request( diff --git a/test/srt/openai/test_serving_embedding.py b/test/srt/openai/test_serving_embedding.py index 94f6c5d2d62..fa9fab75fba 100644 --- a/test/srt/openai/test_serving_embedding.py +++ b/test/srt/openai/test_serving_embedding.py @@ -12,6 +12,7 @@ from typing import Any, Dict, List from unittest.mock import AsyncMock, Mock, patch +from fastapi.responses import ORJSONResponse import pytest from fastapi import Request from pydantic_core import ValidationError @@ -273,8 +274,8 @@ async def test_handle_request_validation_error( response = await serving_embedding.handle_request(invalid_request, mock_request) - assert isinstance(response, ErrorResponse) - assert "empty" in response.message.lower() + assert isinstance(response, ORJSONResponse) + assert response.status_code == 400 async def test_handle_request_generation_error( self, serving_embedding, basic_embedding_request, mock_request @@ -294,8 +295,8 @@ async def mock_generate_error(): basic_embedding_request, mock_request ) - assert isinstance(response, ErrorResponse) - assert "Generation failed" in response.message + assert isinstance(response, ORJSONResponse) + assert response.status_code == 400 async def test_handle_request_internal_error( self, serving_embedding, basic_embedding_request, mock_request @@ -311,6 +312,5 @@ async def test_handle_request_internal_error( basic_embedding_request, mock_request ) - assert isinstance(response, ErrorResponse) - assert "Internal server error" in response.message - assert response.code == 500 + assert isinstance(response, ORJSONResponse) + assert response.status_code == 500 From cc28f37a457f5a4f9a62cbc02f03bd2f082a1f83 Mon Sep 17 00:00:00 2001 From: Xinyuan Tong Date: Mon, 16 Jun 2025 17:08:54 +0000 Subject: [PATCH 33/33] fix lint Signed-off-by: Xinyuan Tong --- test/srt/openai/test_serving_chat.py | 72 +++++++++++++++++------ test/srt/openai/test_serving_embedding.py | 2 +- 2 files changed, 55 insertions(+), 19 deletions(-) diff --git a/test/srt/openai/test_serving_chat.py b/test/srt/openai/test_serving_chat.py index e5852b234a1..b2015866b63 100644 --- a/test/srt/openai/test_serving_chat.py +++ b/test/srt/openai/test_serving_chat.py @@ -297,9 +297,15 @@ def test_multimodal_request_with_images(self, serving_chat): ) as mock_conv: mock_conv.return_value = ("prompt", ["image_data"], None, [], []) - prompt, prompt_ids, image_data, audio_data, modalities, stop, tool_call_constraint = ( - serving_chat._process_messages(request, True) - ) + ( + prompt, + prompt_ids, + image_data, + audio_data, + modalities, + stop, + tool_call_constraint, + ) = serving_chat._process_messages(request, True) assert image_data == ["image_data"] assert prompt == "prompt" @@ -339,9 +345,15 @@ def test_multimodal_request_with_audio(self, serving_chat): ) as mock_conv: mock_conv.return_value = ("prompt", None, ["audio_data"], ["audio"], []) - prompt, prompt_ids, image_data, audio_data, modalities, stop, tool_call_constraint = ( - serving_chat._process_messages(request, True) - ) + ( + prompt, + prompt_ids, + image_data, + audio_data, + modalities, + stop, + tool_call_constraint, + ) = serving_chat._process_messages(request, True) assert audio_data == ["audio_data"] assert modalities == ["audio"] @@ -374,9 +386,15 @@ def test_jinja_template_processing(self, serving_chat): with patch("builtins.hasattr") as mock_hasattr: mock_hasattr.return_value = True - prompt, prompt_ids, image_data, audio_data, modalities, stop, tool_call_constraint = ( - serving_chat._process_messages(request, False) - ) + ( + prompt, + prompt_ids, + image_data, + audio_data, + modalities, + stop, + tool_call_constraint, + ) = serving_chat._process_messages(request, False) assert prompt == "processed_prompt" assert prompt_ids == [1, 2, 3] @@ -392,9 +410,15 @@ def test_conversation_template_processing(self, serving_chat): with patch.object(serving_chat, "_apply_conversation_template") as mock_apply: mock_apply.return_value = ("conv_prompt", None, None, [], [""]) - prompt, prompt_ids, image_data, audio_data, modalities, stop, tool_call_constraint = ( - serving_chat._process_messages(request, False) - ) + ( + prompt, + prompt_ids, + image_data, + audio_data, + modalities, + stop, + tool_call_constraint, + ) = serving_chat._process_messages(request, False) assert prompt == "conv_prompt" assert stop == [""] @@ -413,9 +437,15 @@ def test_continue_final_message(self, serving_chat): with patch.object(serving_chat, "_apply_conversation_template") as mock_apply: mock_apply.return_value = ("Hi there", None, None, [], [""]) - prompt, prompt_ids, image_data, audio_data, modalities, stop, tool_call_constraint = ( - serving_chat._process_messages(request, False) - ) + ( + prompt, + prompt_ids, + image_data, + audio_data, + modalities, + stop, + tool_call_constraint, + ) = serving_chat._process_messages(request, False) # Should handle continue_final_message properly assert prompt == "Hi there" @@ -526,7 +556,9 @@ def test_all_sampling_parameters(self, serving_chat): None, # tool_call_constraint ) - sampling_params = serving_chat._build_sampling_params(request, [""], None) + sampling_params = serving_chat._build_sampling_params( + request, [""], None + ) # Verify all parameters assert sampling_params["temperature"] == 0.8 @@ -569,7 +601,9 @@ def test_response_format_json_schema(self, serving_chat): None, # tool_call_constraint ) - sampling_params = serving_chat._build_sampling_params(request, [""], None) + sampling_params = serving_chat._build_sampling_params( + request, [""], None + ) assert "json_schema" in sampling_params assert '"type": "object"' in sampling_params["json_schema"] @@ -593,6 +627,8 @@ def test_response_format_json_object(self, serving_chat): None, # tool_call_constraint ) - sampling_params = serving_chat._build_sampling_params(request, [""], None) + sampling_params = serving_chat._build_sampling_params( + request, [""], None + ) assert sampling_params["json_schema"] == '{"type": "object"}' diff --git a/test/srt/openai/test_serving_embedding.py b/test/srt/openai/test_serving_embedding.py index fa9fab75fba..58438bba884 100644 --- a/test/srt/openai/test_serving_embedding.py +++ b/test/srt/openai/test_serving_embedding.py @@ -12,9 +12,9 @@ from typing import Any, Dict, List from unittest.mock import AsyncMock, Mock, patch -from fastapi.responses import ORJSONResponse import pytest from fastapi import Request +from fastapi.responses import ORJSONResponse from pydantic_core import ValidationError from sglang.srt.entrypoints.openai.protocol import (