Skip to content
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
1d17465
Add refactored OpenAI API server modules implementation
JustinTong0323 Jun 14, 2025
42bb560
Merge branch 'main' into refactor_oai_server_serving
JustinTong0323 Jun 14, 2025
d9ceddd
feat: add serving_embedding
JustinTong0323 Jun 14, 2025
f8d604b
Refactors request handling in OpenAI endpoints
JustinTong0323 Jun 14, 2025
a86bf27
Adds documentation to OpenAI API endpoints
JustinTong0323 Jun 14, 2025
5ddc8fc
Simplifies getting enable_thinking value
JustinTong0323 Jun 14, 2025
2ddbb40
rename serving_engine to serving_base
JustinTong0323 Jun 14, 2025
26771ad
Merge branch 'main' into refactor_oai_server_serving
JustinTong0323 Jun 14, 2025
4596b52
Makes chat template caching instance-specific
JustinTong0323 Jun 14, 2025
47d54dc
Refactors logprobs processing
JustinTong0323 Jun 14, 2025
8ac4349
Update python/sglang/srt/entrypoints/openai/protocol.py
JustinTong0323 Jun 14, 2025
00b202c
Improve test cases for eagle infer (#7173)
merrymercy Jun 14, 2025
fb4ae05
fix CI
JustinTong0323 Jun 14, 2025
81f5e41
Merge branch 'main' into refactor_oai_server_serving
JustinTong0323 Jun 14, 2025
2a10db7
Merge branch 'main' into refactor_oai_server_serving
JustinTong0323 Jun 14, 2025
3b28fdb
Removes unused utility functions
JustinTong0323 Jun 14, 2025
012bcb5
Refactors request validation for OpenAI endpoints
JustinTong0323 Jun 15, 2025
27341ae
Improves OpenAI serving base class logic
JustinTong0323 Jun 15, 2025
286751a
Refactors error handling for OpenAI endpoints
JustinTong0323 Jun 15, 2025
50d57d1
Refactors request ID generation
JustinTong0323 Jun 15, 2025
960f917
Removes RequestContext
JustinTong0323 Jun 15, 2025
30663a5
Simplifies enable_thinking handling and remove unused functions
JustinTong0323 Jun 15, 2025
eb6784d
Refactors sampling parameter building
JustinTong0323 Jun 15, 2025
47da102
Renames OpenAI serving handler classes
JustinTong0323 Jun 15, 2025
177efdc
Merge branch 'main' into refactor_oai_server_serving
JustinTong0323 Jun 15, 2025
c5a60e0
cleanup docs and imports
JustinTong0323 Jun 15, 2025
d433e43
Fixes usage calculation in streaming mode
JustinTong0323 Jun 15, 2025
ba42ea1
Refactors error response handling in OpenAIServingBase
JustinTong0323 Jun 16, 2025
48586bf
Apply suggestions from code review
JustinTong0323 Jun 16, 2025
3e03b74
Refactors test fixtures for clarity and remove some tests
JustinTong0323 Jun 16, 2025
ac908e1
Enables tool call constraint in sampling params
JustinTong0323 Jun 16, 2025
69e41f7
move the `text = content["text"]` in serving_chat for Better readability
JustinTong0323 Jun 16, 2025
590db9a
lint
JustinTong0323 Jun 16, 2025
4c140c8
remove redundant logic
JustinTong0323 Jun 16, 2025
7190e6f
logic for generate_completion_prompt
JustinTong0323 Jun 16, 2025
40e97fc
Add comments back
JustinTong0323 Jun 16, 2025
84f6037
Merge branch 'main' into refactor_oai_server_serving
JustinTong0323 Jun 16, 2025
b95a288
fix tests
JustinTong0323 Jun 16, 2025
cc28f37
fix lint
JustinTong0323 Jun 16, 2025
ea30a8c
Merge branch 'main' into refactor_oai_server_serving
zhyncs Jun 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
539 changes: 539 additions & 0 deletions python/sglang/srt/entrypoints/openai/protocol.py

Large diffs are not rendered by default.

178 changes: 178 additions & 0 deletions python/sglang/srt/entrypoints/openai/serving_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import json
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union

from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse

from sglang.srt.entrypoints.openai.protocol import (
ErrorResponse,
OpenAIServingRequest,
UsageInfo,
)
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.managers.tokenizer_manager import TokenizerManager

logger = logging.getLogger(__name__)


# Base class for specific endpoint handlers
class OpenAIServingBase(ABC):
"""Abstract base class for OpenAI endpoint handlers"""

def __init__(self, tokenizer_manager: TokenizerManager):
self.tokenizer_manager = tokenizer_manager

async def handle_request(
self, request: OpenAIServingRequest, raw_request: Request
) -> Union[Any, StreamingResponse, ErrorResponse]:
"""Handle the specific request type with common pattern"""
try:
# Validate request
error_msg = self._validate_request(request)
if error_msg:
return self.create_error_response(error_msg)

# Convert to internal format
adapted_request, processed_request = self._convert_to_internal_request(
[request], [self._generate_request_id_base(request)]
)

# Note(Xinyuan): raw_request below is only used for detecting the connection of the client
if hasattr(request, "stream") and request.stream:
return await self._handle_streaming_request(
adapted_request, processed_request, raw_request
)
else:
return await self._handle_non_streaming_request(
adapted_request, processed_request, raw_request
)

except Exception as e:
logger.error(f"Error in request: {e}")
return self.create_error_response(
message=f"Internal server error: {str(e)}",
err_type="InternalServerError",
status_code=500,
)

@abstractmethod
def _request_id_prefix(self) -> str:
"""Generate request ID based on request type"""
pass

def _generate_request_id_base(self, request: OpenAIServingRequest) -> str:
"""Generate request ID based on request type"""
if rid := getattr(request, "rid", None):
return rid

return f"{self._request_id_prefix()}{uuid.uuid4().hex}"

@abstractmethod
def _convert_to_internal_request(
self,
all_requests: List[OpenAIServingRequest],
request_ids: List[str],
) -> tuple[
GenerateReqInput, Union[OpenAIServingRequest, List[OpenAIServingRequest]]
]:
"""Convert OpenAI request to internal format"""
pass

async def _handle_streaming_request(
self,
adapted_request: GenerateReqInput,
request: OpenAIServingRequest,
raw_request: Request,
) -> StreamingResponse:
"""Handle streaming request

Override this method in child classes that support streaming requests.
"""
return self.create_error_response(
message=f"{self.__class__.__name__} does not support streaming requests",
err_type="NotImplementedError",
status_code=501,
)

async def _handle_non_streaming_request(
self,
adapted_request: GenerateReqInput,
request: OpenAIServingRequest,
raw_request: Request,
) -> Union[Any, ErrorResponse]:
"""Handle non-streaming request

Override this method in child classes that support non-streaming requests.
"""
return self.create_error_response(
message=f"{self.__class__.__name__} does not support non-streaming requests",
err_type="NotImplementedError",
status_code=501,
)

def _validate_request(self, request: OpenAIServingRequest) -> Optional[str]:
"""Validate request"""
pass

def _calculate_streaming_usage_base(
self,
prompt_tokens: Dict[int, int],
completion_tokens: Dict[int, int],
cached_tokens: Dict[int, int],
n_choices: int,
) -> UsageInfo:
"""Calculate usage information for streaming responses (common logic)"""
total_prompt_tokens = sum(
tokens for i, tokens in prompt_tokens.items() if i % n_choices == 0
)
total_completion_tokens = sum(tokens for tokens in completion_tokens.values())

cache_report = self.tokenizer_manager.server_args.enable_cache_report
prompt_tokens_details = None
if cache_report:
cached_tokens_sum = sum(tokens for tokens in cached_tokens.values())
if cached_tokens_sum > 0:
prompt_tokens_details = {"cached_tokens": cached_tokens_sum}

return UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
prompt_tokens_details=prompt_tokens_details,
)

def create_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: int = 400,
param: Optional[str] = None,
) -> ORJSONResponse:
"""Create an error response"""
error = ErrorResponse(
object="error",
message=message,
type=err_type,
param=param,
code=status_code,
)
return ORJSONResponse(content=error.model_dump(), status_code=status_code)

def create_streaming_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: int = 400,
) -> str:
"""Create a streaming error response"""
error = ErrorResponse(
object="error",
message=message,
type=err_type,
param=None,
code=status_code,
)
return json.dumps({"error": error.model_dump()})
Loading
Loading