Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ async def lifespan(fast_api_app: FastAPI):
_global_state.tokenizer_manager, _global_state.template_manager
)
fast_api_app.state.openai_serving_tokenize = OpenAIServingTokenize(
_global_state.tokenizer_manager
_global_state.tokenizer_manager, _global_state.template_manager
)
fast_api_app.state.openai_serving_detokenize = OpenAIServingDetokenize(
_global_state.tokenizer_manager
Expand Down
29 changes: 28 additions & 1 deletion python/sglang/srt/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from openai.types.responses.tool import Tool
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_serializer,
Expand Down Expand Up @@ -1118,13 +1119,39 @@ def _serialize(self, handler):
class TokenizeRequest(BaseModel):
"""Request schema for the /tokenize endpoint."""

model_config = ConfigDict(extra="allow")

model: str = DEFAULT_MODEL_NAME
prompt: Union[str, List[str]]
prompt: Optional[Union[str, List[str]]] = None
messages: Optional[List[ChatCompletionMessageParam]] = None
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
tool_choice: Optional[Union[ToolChoice, Literal["auto", "required", "none"]]] = (
Field(default=None, examples=["auto"])
)
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None
continue_final_message: bool = False
chat_template_kwargs: Optional[Dict] = None
add_special_tokens: bool = Field(
default=True,
description="whether to add model-specific special tokens (e.g. BOS/EOS) during encoding.",
)

@model_validator(mode="after")
def validate_tokenize_input(self) -> "TokenizeRequest":
if (self.prompt is None) == (self.messages is None):
raise ValueError("Exactly one of 'prompt' or 'messages' must be provided.")
return self

def to_chat_completion_request(self) -> ChatCompletionRequest:
data = self.model_dump(
exclude={"prompt", "add_special_tokens"},
exclude_none=True,
)
extra = getattr(self, "__pydantic_extra__", None)
if extra:
data.update(extra)
return ChatCompletionRequest.model_validate(data)


class TokenizeResponse(BaseModel):
"""Response schema for the /tokenize endpoint."""
Expand Down
49 changes: 47 additions & 2 deletions python/sglang/srt/entrypoints/openai/serving_tokenize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from http import HTTPStatus
from typing import List, Union
from typing import List, Optional, Union

from fastapi import Request

Expand All @@ -12,13 +12,22 @@
TokenizeResponse,
)
from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat

logger = logging.getLogger(__name__)


class OpenAIServingTokenize(OpenAIServingBase):
"""Handler for /v1/tokenize requests"""

def __init__(self, tokenizer_manager, template_manager=None):
super().__init__(tokenizer_manager)
self.chat_serving: Optional[OpenAIServingChat] = (
OpenAIServingChat(tokenizer_manager, template_manager)
if template_manager is not None
else None
)

def _request_id_prefix(self) -> str:
return "tok-"

Expand All @@ -37,7 +46,11 @@ async def _handle_non_streaming_request(
tokenizer = self.tokenizer_manager.tokenizer
max_model_len = getattr(tokenizer, "model_max_length", -1)

if isinstance(request.prompt, str):
if request.messages is not None:
token_ids = self._tokenize_chat_request(request)
tokens = token_ids
count = len(token_ids)
elif isinstance(request.prompt, str):
token_ids = tokenizer.encode(
request.prompt,
add_special_tokens=request.add_special_tokens,
Expand All @@ -61,6 +74,8 @@ async def _handle_non_streaming_request(
return TokenizeResponse(
tokens=tokens, count=count, max_model_len=max_model_len
)
except ValueError as e:
return self.create_error_response(str(e))
except Exception as e:
logger.error("Error during tokenization", exc_info=True)
return self.create_error_response(
Expand All @@ -69,6 +84,36 @@ async def _handle_non_streaming_request(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)

def _tokenize_chat_request(self, request: TokenizeRequest) -> List[int]:
if self.chat_serving is None:
raise ValueError("Chat template tokenization requires a template manager.")

chat_request = request.to_chat_completion_request()
validation_error = self.chat_serving._validate_request(chat_request)
if validation_error:
raise ValueError(validation_error)

is_multimodal = self.tokenizer_manager.model_config.is_multimodal
processed_messages = self.chat_serving._process_messages(
chat_request, is_multimodal
)

prompt_ids = processed_messages.prompt_ids
if isinstance(prompt_ids, list) and (
prompt_ids or not processed_messages.prompt
):
return prompt_ids
if isinstance(prompt_ids, str):
return self.tokenizer_manager.tokenizer.encode(
prompt_ids, add_special_tokens=False
)
if processed_messages.prompt:
return self.tokenizer_manager.tokenizer.encode(
processed_messages.prompt, add_special_tokens=False
)

raise ValueError("Failed to render chat messages into token ids.")


class OpenAIServingDetokenize(OpenAIServingBase):
"""Handler for /v1/detokenize requests"""
Expand Down
Loading