diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 89b2d3ab625..6bc975c04eb 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -14,7 +14,8 @@ """Pydantic models for OpenAI API protocol""" import time -from typing import Dict, List, Optional, Union +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union from pydantic import ( BaseModel, @@ -587,3 +588,30 @@ class RerankResponse(BaseModel): ScoringRequest, V1RerankReqInput, ] + + +@dataclass +class MessageProcessingResult: + """Result of processing chat messages and applying templates. + + This dataclass encapsulates all the outputs from message processing including + prompt generation, multimodal data extraction, and constraint preparation. + Used internally by OpenAIServingChat to pass processed data between methods. + + Args: + prompt: The final text prompt after applying chat template + prompt_ids: Either the text prompt (str) or tokenized IDs (List[int]) + image_data: Extracted image data from messages, if any + audio_data: Extracted audio data from messages, if any + modalities: List of modality types present in the messages + stop: Combined stop strings from template and request + tool_call_constraint: Optional constraint for structured tool calls + """ + + prompt: str + prompt_ids: Union[str, List[int]] + image_data: Optional[Any] + audio_data: Optional[Any] + modalities: List[str] + stop: List[str] + tool_call_constraint: Optional[Any] = None diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index ea72452e178..9857bcd9e90 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -22,6 +22,7 @@ ErrorResponse, FunctionResponse, LogProbs, + MessageProcessingResult, ToolCall, TopLogprob, ) @@ -62,41 +63,33 @@ def _convert_to_internal_request( is_multimodal = self.tokenizer_manager.model_config.is_multimodal # Process messages and apply chat template - ( - prompt, - prompt_ids, - image_data, - audio_data, - modalities, - stop, - tool_call_constraint, - ) = self._process_messages(request, is_multimodal) + processed_messages = self._process_messages(request, is_multimodal) # Build sampling parameters sampling_params = self._build_sampling_params( - request, stop, tool_call_constraint + request, processed_messages.stop, processed_messages.tool_call_constraint ) # Handle single vs multiple requests if is_multimodal: - prompt_kwargs = {"text": prompt} + prompt_kwargs = {"text": processed_messages.prompt} else: - if isinstance(prompt_ids, str): - prompt_kwargs = {"text": prompt_ids} + if isinstance(processed_messages.prompt_ids, str): + prompt_kwargs = {"text": processed_messages.prompt_ids} else: - prompt_kwargs = {"input_ids": prompt_ids} + prompt_kwargs = {"input_ids": processed_messages.prompt_ids} adapted_request = GenerateReqInput( **prompt_kwargs, - image_data=image_data, - audio_data=audio_data, + image_data=processed_messages.image_data, + audio_data=processed_messages.audio_data, sampling_params=sampling_params, return_logprob=request.logprobs, logprob_start_len=-1, top_logprobs_num=request.top_logprobs or 0, stream=request.stream, return_text_in_logprobs=True, - modalities=modalities, + modalities=processed_messages.modalities, lora_path=request.lora_path, bootstrap_host=request.bootstrap_host, bootstrap_port=request.bootstrap_port, @@ -108,74 +101,42 @@ def _convert_to_internal_request( def _process_messages( self, request: ChatCompletionRequest, is_multimodal: bool - ) -> tuple[ - str, - Union[str, List[int]], - Optional[Any], - Optional[Any], - List[str], - List[str], - Optional[Any], - ]: + ) -> MessageProcessingResult: """Process chat messages and apply chat template""" tool_call_constraint = None - prompt = "" - prompt_ids = [] - if not isinstance(request.messages, str): - # Apply chat template and its stop strings - tools = None - if request.tools and request.tool_choice != "none": - request.skip_special_tokens = False - if not isinstance(request.tool_choice, str): - tools = [ - item.function.model_dump() - for item in request.tools - if item.function.name == request.tool_choice.function.name - ] - else: - tools = [item.function.model_dump() for item in request.tools] + # Apply chat template and its stop strings + tools = None + if request.tools and request.tool_choice != "none": + request.skip_special_tokens = False + if not isinstance(request.tool_choice, str): + tools = [ + item.function.model_dump() + for item in request.tools + if item.function.name == request.tool_choice.function.name + ] + else: + tools = [item.function.model_dump() for item in request.tools] - tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser - parser = FunctionCallParser(request.tools, tool_call_parser) - tool_call_constraint = parser.get_structure_constraint( - request.tool_choice - ) + tool_call_parser = self.tokenizer_manager.server_args.tool_call_parser + parser = FunctionCallParser(request.tools, tool_call_parser) + tool_call_constraint = parser.get_structure_constraint(request.tool_choice) - # Use chat template - if self.template_manager.chat_template_name is None: - prompt, prompt_ids, image_data, audio_data, modalities, stop = ( - self._apply_jinja_template(request, tools, is_multimodal) - ) - else: - prompt, prompt_ids, image_data, audio_data, modalities, stop = ( - self._apply_conversation_template(request, is_multimodal) - ) + # Use chat template + if self.template_manager.chat_template_name is None: + result = self._apply_jinja_template(request, tools, is_multimodal) else: - # Use raw prompt - prompt_ids = request.messages - stop = request.stop or [] - image_data = None - audio_data = None - modalities = [] - prompt = request.messages - - return ( - prompt, - prompt_ids, - image_data, - audio_data, - modalities, - stop, - tool_call_constraint, - ) + result = self._apply_conversation_template(request, is_multimodal) + + result.tool_call_constraint = tool_call_constraint + return result def _apply_jinja_template( self, request: ChatCompletionRequest, tools: Optional[List[Dict]], is_multimodal: bool, - ) -> tuple[str, List[int], Optional[Any], Optional[Any], List[str], List[str]]: + ) -> MessageProcessingResult: """Apply Jinja chat template""" prompt = "" prompt_ids = [] @@ -253,13 +214,20 @@ def _apply_jinja_template( image_data = image_data if image_data else None audio_data = audio_data if audio_data else None modalities = modalities if modalities else [] - return prompt, prompt_ids, image_data, audio_data, modalities, stop + return MessageProcessingResult( + prompt=prompt, + prompt_ids=prompt_ids, + image_data=image_data, + audio_data=audio_data, + modalities=modalities, + stop=stop, + ) def _apply_conversation_template( self, request: ChatCompletionRequest, is_multimodal: bool, - ) -> tuple[str, Optional[Any], Optional[Any], List[str], List[str], List[str]]: + ) -> MessageProcessingResult: """Apply conversation template""" prompt = "" prompt_ids = [] @@ -304,7 +272,14 @@ def _apply_conversation_template( if not is_multimodal: prompt_ids = self.tokenizer_manager.tokenizer.encode(prompt) - return prompt, prompt_ids, image_data, audio_data, modalities, stop + return MessageProcessingResult( + prompt=prompt, + prompt_ids=prompt_ids, + image_data=image_data, + audio_data=audio_data, + modalities=modalities, + stop=stop, + ) def _build_sampling_params( self, diff --git a/test/srt/openai_server/basic/test_serving_chat.py b/test/srt/openai_server/basic/test_serving_chat.py index 701dc2e55f6..7108b405d5d 100644 --- a/test/srt/openai_server/basic/test_serving_chat.py +++ b/test/srt/openai_server/basic/test_serving_chat.py @@ -13,7 +13,10 @@ from fastapi import Request -from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest +from sglang.srt.entrypoints.openai.protocol import ( + ChatCompletionRequest, + MessageProcessingResult, +) from sglang.srt.entrypoints.openai.serving_chat import OpenAIServingChat from sglang.srt.managers.io_struct import GenerateReqInput @@ -104,7 +107,7 @@ def test_convert_to_internal_request_single(self): conv_ins.stop_str = [""] conv_mock.return_value = conv_ins - proc_mock.return_value = ( + proc_mock.return_value = MessageProcessingResult( "Test prompt", [1, 2, 3], None, @@ -119,6 +122,59 @@ def test_convert_to_internal_request_single(self): self.assertFalse(adapted.stream) self.assertEqual(processed, self.basic_req) + def test_stop_str_isolation_between_requests(self): + """Test that stop strings from one request don't affect subsequent requests. + + This tests the fix for the bug where conv.stop_str was being mutated globally, + causing stop strings from one request to persist in subsequent requests. + """ + # Mock conversation template with initial stop_str + initial_stop_str = ["\n"] + + with patch( + "sglang.srt.entrypoints.openai.serving_chat.generate_chat_conv" + ) as conv_mock: + # Create a mock conversation object that will be returned by generate_chat_conv + conv_ins = Mock() + conv_ins.get_prompt.return_value = "Test prompt" + conv_ins.image_data = None + conv_ins.audio_data = None + conv_ins.modalities = [] + conv_ins.stop_str = ( + initial_stop_str.copy() + ) # Template's default stop strings + conv_mock.return_value = conv_ins + + # First request with additional stop string + req1 = ChatCompletionRequest( + model="x", + messages=[{"role": "user", "content": "First request"}], + stop=["CUSTOM_STOP"], + ) + + # Call the actual _apply_conversation_template method (not mocked) + result1 = self.chat._apply_conversation_template(req1, is_multimodal=False) + + # Verify first request has both stop strings + expected_stop1 = initial_stop_str + ["CUSTOM_STOP"] + self.assertEqual(result1.stop, expected_stop1) + + # Verify the original template's stop_str wasn't mutated after first request + self.assertEqual(conv_ins.stop_str, initial_stop_str) + + # Second request without additional stop string + req2 = ChatCompletionRequest( + model="x", + messages=[{"role": "user", "content": "Second request"}], + # No custom stop strings + ) + result2 = self.chat._apply_conversation_template(req2, is_multimodal=False) + + # Verify second request only has original stop strings (no CUSTOM_STOP from req1) + self.assertEqual(result2.stop, initial_stop_str) + self.assertNotIn("CUSTOM_STOP", result2.stop) + self.assertEqual(conv_ins.stop_str, initial_stop_str) + # ------------- sampling-params ------------- def test_sampling_param_build(self): req = ChatCompletionRequest(