diff --git a/python/sglang/srt/function_call/llama32_detector.py b/python/sglang/srt/function_call/llama32_detector.py index a2aaba3fefd..065ffd7f632 100644 --- a/python/sglang/srt/function_call/llama32_detector.py +++ b/python/sglang/srt/function_call/llama32_detector.py @@ -42,31 +42,41 @@ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult return StreamingParseResult(normal_text=text, calls=[]) if "<|python_tag|>" in text: - normal_text, action_text = text.split("<|python_tag|>") + normal_text, action_text = text.split("<|python_tag|>", maxsplit=1) else: normal_text, action_text = "", text - # Split by semicolon and process each part - json_parts = [ - part.strip() - for part in action_text.split(self.tool_call_separator) - if part.strip() - ] + decoder = json.JSONDecoder() + idx = 0 + safe_idx = idx # the index of the last valid JSON object all_actions = [] - for part in json_parts: + action_text_len = len(action_text) + while idx < action_text_len: try: - # Parse each individual JSON object - action = json.loads(part) - all_actions.append(action) + obj, end = decoder.raw_decode(action_text[idx:]) + all_actions.append(obj) + idx += end + len(self.tool_call_separator) + safe_idx = idx except json.JSONDecodeError as e: - logger.warning(f"Failed to parse JSON part: {part}") - logger.warning(f"JSON parse error: {str(e)}") + # Find where next `{"name"` appears and try again + logger.warning( + f"Failed to parse JSON part: {action_text[idx:]}, JSON parse error: {str(e)}" + ) + next_obj_start = action_text.find('{"name":', idx + 1) + if next_obj_start == -1: + break + idx = next_obj_start continue - calls = [] + # Only process if we found valid JSON objects - if all_actions: - calls = self.parse_base_json(all_actions, tools) - return StreamingParseResult(normal_text=normal_text, calls=calls) + calls = self.parse_base_json(all_actions, tools) if all_actions else [] + # Use safe_idx to avoid idx containing the last part of an invalid JSON object + trailing_text = ( + action_text[safe_idx:].strip() if safe_idx < action_text_len else "" + ) + return StreamingParseResult( + normal_text=normal_text + trailing_text, calls=calls + ) def structure_info(self) -> _GetInfoFunc: return lambda name: StructureInfo( diff --git a/test/srt/test_function_call_parser.py b/test/srt/test_function_call_parser.py index d418fcd9038..6fbd7b34862 100644 --- a/test/srt/test_function_call_parser.py +++ b/test/srt/test_function_call_parser.py @@ -824,5 +824,101 @@ def test_buffer_reset_on_invalid_tool(self): ) +class TestLlama32Detector(unittest.TestCase): + def setUp(self): + """Set up test tools and detector for Mistral format testing.""" + self.tools = [ + Tool( + type="function", + function=Function( + name="get_weather", + description="Get weather information", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + } + }, + "required": ["city"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="get_tourist_attractions", + description="Get tourist attractions", + parameters={ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + } + }, + "required": ["city"], + }, + ), + ), + ] + self.detector = Llama32Detector() + + def test_single_json(self): + text = '{"name": "get_weather", "parameters": {"city": "Paris"}}' + result = self.detector.detect_and_parse(text, self.tools) + assert len(result.calls) == 1 + assert result.calls[0].name == "get_weather" + assert result.normal_text == "" + + def test_multiple_json_with_separator(self): + text = ( + '<|python_tag|>{"name": "get_weather", "parameters": {"city": "Paris"}};' + '{"name": "get_tourist_attractions", "parameters": {"city": "Paris"}}' + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[1].name, "get_tourist_attractions") + self.assertEqual(result.normal_text, "") + + def test_multiple_json_with_separator_customized(self): + text = ( + '<|python_tag|>{"name": "get_weather", "parameters": {}}' + '<|python_tag|>{"name": "get_tourist_attractions", "parameters": {}}' + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[1].name, "get_tourist_attractions") + self.assertEqual(result.normal_text, "") + + def test_json_with_trailing_text(self): + text = '{"name": "get_weather", "parameters": {}} Some follow-up text' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertIn("follow-up", result.normal_text) + + def test_invalid_then_valid_json(self): + text = ( + '{"name": "get_weather", "parameters": {' # malformed + '{"name": "get_weather", "parameters": {}}' + ) + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "get_weather") + + def test_plain_text_only(self): + text = "This is just plain explanation text." + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(result.calls, []) + self.assertEqual(result.normal_text, text) + + def test_with_python_tag_prefix(self): + text = 'Some intro. <|python_tag|>{"name": "get_weather", "parameters": {}}' + result = self.detector.detect_and_parse(text, self.tools) + self.assertEqual(len(result.calls), 1) + self.assertTrue(result.normal_text.strip().startswith("Some intro.")) + + if __name__ == "__main__": unittest.main()