Skip to content
30 changes: 29 additions & 1 deletion python/sglang/srt/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
"""Pydantic models for OpenAI API protocol"""

import time
from typing import Dict, List, Optional, Union
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

from pydantic import (
BaseModel,
Expand Down Expand Up @@ -587,3 +588,30 @@ class RerankResponse(BaseModel):
ScoringRequest,
V1RerankReqInput,
]


@dataclass
class MessageProcessingResult:
"""Result of processing chat messages and applying templates.

This dataclass encapsulates all the outputs from message processing including
prompt generation, multimodal data extraction, and constraint preparation.
Used internally by OpenAIServingChat to pass processed data between methods.

Args:
prompt: The final text prompt after applying chat template
prompt_ids: Either the text prompt (str) or tokenized IDs (List[int])
image_data: Extracted image data from messages, if any
audio_data: Extracted audio data from messages, if any
modalities: List of modality types present in the messages
stop: Combined stop strings from template and request
tool_call_constraint: Optional constraint for structured tool calls
"""

prompt: str
prompt_ids: Union[str, List[int]]
image_data: Optional[Any]
audio_data: Optional[Any]
modalities: List[str]
stop: List[str]
tool_call_constraint: Optional[Any] = None
127 changes: 51 additions & 76 deletions python/sglang/srt/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ErrorResponse,
FunctionResponse,
LogProbs,
MessageProcessingResult,
ToolCall,
TopLogprob,
)
Expand Down Expand Up @@ -62,41 +63,33 @@ def _convert_to_internal_request(
is_multimodal = self.tokenizer_manager.model_config.is_multimodal

# Process messages and apply chat template
(
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
) = self._process_messages(request, is_multimodal)
processed_messages = self._process_messages(request, is_multimodal)

# Build sampling parameters
sampling_params = self._build_sampling_params(
request, stop, tool_call_constraint
request, processed_messages.stop, processed_messages.tool_call_constraint
)

# Handle single vs multiple requests
if is_multimodal:
prompt_kwargs = {"text": prompt}
prompt_kwargs = {"text": processed_messages.prompt}
else:
if isinstance(prompt_ids, str):
prompt_kwargs = {"text": prompt_ids}
if isinstance(processed_messages.prompt_ids, str):
prompt_kwargs = {"text": processed_messages.prompt_ids}
else:
prompt_kwargs = {"input_ids": prompt_ids}
prompt_kwargs = {"input_ids": processed_messages.prompt_ids}

adapted_request = GenerateReqInput(
**prompt_kwargs,
image_data=image_data,
audio_data=audio_data,
image_data=processed_messages.image_data,
audio_data=processed_messages.audio_data,
sampling_params=sampling_params,
return_logprob=request.logprobs,
logprob_start_len=-1,
top_logprobs_num=request.top_logprobs or 0,
stream=request.stream,
return_text_in_logprobs=True,
modalities=modalities,
modalities=processed_messages.modalities,
lora_path=request.lora_path,
bootstrap_host=request.bootstrap_host,
bootstrap_port=request.bootstrap_port,
Expand All @@ -108,74 +101,42 @@ def _convert_to_internal_request(

def _process_messages(
self, request: ChatCompletionRequest, is_multimodal: bool
) -> tuple[
str,
Union[str, List[int]],
Optional[Any],
Optional[Any],
List[str],
List[str],
Optional[Any],
]:
) -> MessageProcessingResult:
"""Process chat messages and apply chat template"""
tool_call_constraint = None
prompt = ""
prompt_ids = []

if not isinstance(request.messages, str):
# Apply chat template and its stop strings
tools = None
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
if not isinstance(request.tool_choice, str):
tools = [
item.function.model_dump()
for item in request.tools
if item.function.name == request.tool_choice.function.name
]
else:
tools = [item.function.model_dump() for item in request.tools]
# Apply chat template and its stop strings
tools = None
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
if not isinstance(request.tool_choice, str):
tools = [
item.function.model_dump()
for item in request.tools
if item.function.name == request.tool_choice.function.name
]
else:
tools = [item.function.model_dump() for item in request.tools]

tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
parser = FunctionCallParser(request.tools, tool_call_parser)
tool_call_constraint = parser.get_structure_constraint(
request.tool_choice
)
tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser
parser = FunctionCallParser(request.tools, tool_call_parser)
tool_call_constraint = parser.get_structure_constraint(request.tool_choice)

# Use chat template
if self.template_manager.chat_template_name is None:
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
self._apply_jinja_template(request, tools, is_multimodal)
)
else:
prompt, prompt_ids, image_data, audio_data, modalities, stop = (
self._apply_conversation_template(request, is_multimodal)
)
# Use chat template
if self.template_manager.chat_template_name is None:
result = self._apply_jinja_template(request, tools, is_multimodal)
else:
# Use raw prompt
prompt_ids = request.messages
stop = request.stop or []
image_data = None
audio_data = None
modalities = []
prompt = request.messages

return (
prompt,
prompt_ids,
image_data,
audio_data,
modalities,
stop,
tool_call_constraint,
)
result = self._apply_conversation_template(request, is_multimodal)

result.tool_call_constraint = tool_call_constraint
return result

def _apply_jinja_template(
self,
request: ChatCompletionRequest,
tools: Optional[List[Dict]],
is_multimodal: bool,
) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]:
) -> MessageProcessingResult:
"""Apply Jinja chat template"""
prompt = ""
prompt_ids = []
Expand Down Expand Up @@ -253,13 +214,20 @@ def _apply_jinja_template(
image_data = image_data if image_data else None
audio_data = audio_data if audio_data else None
modalities = modalities if modalities else []
return prompt, prompt_ids, image_data, audio_data, modalities, stop
return MessageProcessingResult(
prompt=prompt,
prompt_ids=prompt_ids,
image_data=image_data,
audio_data=audio_data,
modalities=modalities,
stop=stop,
)

def _apply_conversation_template(
self,
request: ChatCompletionRequest,
is_multimodal: bool,
) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str], List[str]]:
) -> MessageProcessingResult:
"""Apply conversation template"""
prompt = ""
prompt_ids = []
Expand Down Expand Up @@ -304,7 +272,14 @@ def _apply_conversation_template(
if not is_multimodal:
prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt)

return prompt, prompt_ids, image_data, audio_data, modalities, stop
return MessageProcessingResult(
prompt=prompt,
prompt_ids=prompt_ids,
image_data=image_data,
audio_data=audio_data,
modalities=modalities,
stop=stop,
)

def _build_sampling_params(
self,
Expand Down
60 changes: 58 additions & 2 deletions test/srt/openai_server/basic/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@

from fastapi import Request

from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
from sglang.srt.entrypoints.openai.protocol import (
ChatCompletionRequest,
MessageProcessingResult,
)
from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat
from sglang.srt.managers.io_struct import GenerateReqInput

Expand Down Expand Up @@ -104,7 +107,7 @@ def test_convert_to_internal_request_single(self):
conv_ins.stop_str = ["</s>"]
conv_mock.return_value = conv_ins

proc_mock.return_value = (
proc_mock.return_value = MessageProcessingResult(
"Test prompt",
[1, 2, 3],
None,
Expand All @@ -119,6 +122,59 @@ def test_convert_to_internal_request_single(self):
self.assertFalse(adapted.stream)
self.assertEqual(processed, self.basic_req)

def test_stop_str_isolation_between_requests(self):
"""Test that stop strings from one request don't affect subsequent requests.

This tests the fix for the bug where conv.stop_str was being mutated globally,
causing stop strings from one request to persist in subsequent requests.
"""
# Mock conversation template with initial stop_str
initial_stop_str = ["\n"]

with patch(
"sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv"
) as conv_mock:
# Create a mock conversation object that will be returned by generate_chat_conv
conv_ins = Mock()
conv_ins.get_prompt.return_value = "Test prompt"
conv_ins.image_data = None
conv_ins.audio_data = None
conv_ins.modalities = []
conv_ins.stop_str = (
initial_stop_str.copy()
) # Template's default stop strings
conv_mock.return_value = conv_ins

# First request with additional stop string
req1 = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "First request"}],
stop=["CUSTOM_STOP"],
)

# Call the actual _apply_conversation_template method (not mocked)
result1 = self.chat._apply_conversation_template(req1, is_multimodal=False)

# Verify first request has both stop strings
expected_stop1 = initial_stop_str + ["CUSTOM_STOP"]
self.assertEqual(result1.stop, expected_stop1)

# Verify the original template's stop_str wasn't mutated after first request
self.assertEqual(conv_ins.stop_str, initial_stop_str)

# Second request without additional stop string
req2 = ChatCompletionRequest(
model="x",
messages=[{"role": "user", "content": "Second request"}],
# No custom stop strings
)
result2 = self.chat._apply_conversation_template(req2, is_multimodal=False)

# Verify second request only has original stop strings (no CUSTOM_STOP from req1)
self.assertEqual(result2.stop, initial_stop_str)
self.assertNotIn("CUSTOM_STOP", result2.stop)
self.assertEqual(conv_ins.stop_str, initial_stop_str)

# ------------- sampling-params -------------
def test_sampling_param_build(self):
req = ChatCompletionRequest(
Expand Down
Loading