diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index e2f005337fa2..c7630b5f995a 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -333,7 +333,7 @@ async def lifespan(fast_api_app: FastAPI): _global_state.tokenizer_manager, _global_state.template_manager ) fast_api_app.state.openai_serving_tokenize = OpenAIServingTokenize( - _global_state.tokenizer_manager + _global_state.tokenizer_manager, _global_state.template_manager ) fast_api_app.state.openai_serving_detokenize = OpenAIServingDetokenize( _global_state.tokenizer_manager diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 5d0ea1561577..9776cc68292f 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -31,6 +31,7 @@ from openai.types.responses.tool import Tool from pydantic import ( BaseModel, + ConfigDict, Field, field_validator, model_serializer, @@ -1118,13 +1119,39 @@ def _serialize(self, handler): class TokenizeRequest(BaseModel): """Request schema for the /tokenize endpoint.""" + model_config = ConfigDict(extra="allow") + model: str = DEFAULT_MODEL_NAME - prompt: Union[str, List[str]] + prompt: Optional[Union[str, List[str]]] = None + messages: Optional[List[ChatCompletionMessageParam]] = None + tools: Optional[List[Tool]] = Field(default=None, examples=[None]) + tool_choice: Optional[Union[ToolChoice, Literal["auto", "required", "none"]]] = ( + Field(default=None, examples=["auto"]) + ) + reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None + continue_final_message: bool = False + chat_template_kwargs: Optional[Dict] = None add_special_tokens: bool = Field( default=True, description="whether to add model-specific special tokens (e.g. BOS/EOS) during encoding.", ) + @model_validator(mode="after") + def validate_tokenize_input(self) -> "TokenizeRequest": + if (self.prompt is None) == (self.messages is None): + raise ValueError("Exactly one of 'prompt' or 'messages' must be provided.") + return self + + def to_chat_completion_request(self) -> ChatCompletionRequest: + data = self.model_dump( + exclude={"prompt", "add_special_tokens"}, + exclude_none=True, + ) + extra = getattr(self, "__pydantic_extra__", None) + if extra: + data.update(extra) + return ChatCompletionRequest.model_validate(data) + class TokenizeResponse(BaseModel): """Response schema for the /tokenize endpoint.""" diff --git a/python/sglang/srt/entrypoints/openai/serving_tokenize.py b/python/sglang/srt/entrypoints/openai/serving_tokenize.py index 1bf6de97acd7..6eaddc58a654 100644 --- a/python/sglang/srt/entrypoints/openai/serving_tokenize.py +++ b/python/sglang/srt/entrypoints/openai/serving_tokenize.py @@ -1,6 +1,6 @@ import logging from http import HTTPStatus -from typing import List, Union +from typing import List, Optional, Union from fastapi import Request @@ -12,6 +12,7 @@ TokenizeResponse, ) from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase +from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat logger = logging.getLogger(__name__) @@ -19,6 +20,14 @@ class OpenAIServingTokenize(OpenAIServingBase): """Handler for /v1/tokenize requests""" + def __init__(self, tokenizer_manager, template_manager=None): + super().__init__(tokenizer_manager) + self.chat_serving: Optional[OpenAIServingChat] = ( + OpenAIServingChat(tokenizer_manager, template_manager) + if template_manager is not None + else None + ) + def _request_id_prefix(self) -> str: return "tok-" @@ -37,7 +46,11 @@ async def _handle_non_streaming_request( tokenizer = self.tokenizer_manager.tokenizer max_model_len = getattr(tokenizer, "model_max_length", -1) - if isinstance(request.prompt, str): + if request.messages is not None: + token_ids = self._tokenize_chat_request(request) + tokens = token_ids + count = len(token_ids) + elif isinstance(request.prompt, str): token_ids = tokenizer.encode( request.prompt, add_special_tokens=request.add_special_tokens, @@ -61,6 +74,8 @@ async def _handle_non_streaming_request( return TokenizeResponse( tokens=tokens, count=count, max_model_len=max_model_len ) + except ValueError as e: + return self.create_error_response(str(e)) except Exception as e: logger.error("Error during tokenization", exc_info=True) return self.create_error_response( @@ -69,6 +84,36 @@ async def _handle_non_streaming_request( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, ) + def _tokenize_chat_request(self, request: TokenizeRequest) -> List[int]: + if self.chat_serving is None: + raise ValueError("Chat template tokenization requires a template manager.") + + chat_request = request.to_chat_completion_request() + validation_error = self.chat_serving._validate_request(chat_request) + if validation_error: + raise ValueError(validation_error) + + is_multimodal = self.tokenizer_manager.model_config.is_multimodal + processed_messages = self.chat_serving._process_messages( + chat_request, is_multimodal + ) + + prompt_ids = processed_messages.prompt_ids + if isinstance(prompt_ids, list) and ( + prompt_ids or not processed_messages.prompt + ): + return prompt_ids + if isinstance(prompt_ids, str): + return self.tokenizer_manager.tokenizer.encode( + prompt_ids, add_special_tokens=False + ) + if processed_messages.prompt: + return self.tokenizer_manager.tokenizer.encode( + processed_messages.prompt, add_special_tokens=False + ) + + raise ValueError("Failed to render chat messages into token ids.") + class OpenAIServingDetokenize(OpenAIServingBase): """Handler for /v1/detokenize requests""" diff --git a/test/registered/core/test_srt_endpoint.py b/test/registered/core/test_srt_endpoint.py index 081c93cb7e8d..061bded0eb83 100644 --- a/test/registered/core/test_srt_endpoint.py +++ b/test/registered/core/test_srt_endpoint.py @@ -17,6 +17,7 @@ from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.utils import kill_process_tree +from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, @@ -642,7 +643,7 @@ def s(): # ------------------------------------------------------------------------- -# /tokenize & /detokenize Test Class: TestTokenizeDetokenize +# /tokenize, /v1/tokenize & /detokenize Test Class: TestTokenizeDetokenize # ------------------------------------------------------------------------- @@ -652,6 +653,7 @@ def setUpClass(cls): cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.tokenize_url = f"{cls.base_url}/tokenize" + cls.openai_tokenize_url = f"{cls.base_url}/v1/tokenize" cls.detokenize_url = f"{cls.base_url}/detokenize" cls.session = requests.Session() cls.process = popen_launch_server( @@ -659,6 +661,7 @@ def setUpClass(cls): cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, ) + cls.tokenizer = get_tokenizer(cls.model) @classmethod def tearDownClass(cls): @@ -705,6 +708,58 @@ def test_tokenize_invalid_type(self): ) self.assertEqual(r.status_code, 400) + def test_openai_tokenize_chat_messages(self): + messages = [{"role": "user", "content": "What is the weather in Paris?"}] + resp = self._post_json( + self.openai_tokenize_url, + {"model": self.model, "messages": messages}, + ) + expected_tokens = self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + ) + if not isinstance(expected_tokens, list): + expected_tokens = expected_tokens["input_ids"] + if hasattr(expected_tokens, "tolist"): + expected_tokens = expected_tokens.tolist() + self.assertEqual(resp["tokens"], expected_tokens) + self.assertEqual(resp["count"], len(expected_tokens)) + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a city.", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + } + ] + tools_resp = self._post_json( + self.openai_tokenize_url, + {"model": self.model, "messages": messages, "tools": tools}, + ) + self.assertIsInstance(tools_resp["tokens"], list) + self.assertEqual(tools_resp["count"], len(tools_resp["tokens"])) + self.assertNotEqual(tools_resp["tokens"], resp["tokens"]) + + no_tools_resp = self._post_json( + self.openai_tokenize_url, + { + "model": self.model, + "messages": messages, + "tools": tools, + "tool_choice": "none", + }, + ) + self.assertEqual(no_tools_resp["tokens"], resp["tokens"]) + self.assertEqual(no_tools_resp["count"], resp["count"]) + def test_detokenize_roundtrip(self): text = "Verify detokenization round trip. यह डिटोकेनाइजेशन है" t0 = self._post_json(