Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions python/sglang/srt/function_call/base_format_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,13 @@ def parse_streaming_increment(
# parallel tool calls because the bot_token (e.g., '[') can also
# appear inside array parameters of the current tool, and we must not
# mistakenly identify that as the start of a new tool.
used_separator_branch = False
if self.current_tool_id > 0 and current_text.startswith(
self.tool_call_separator
):
start_idx = len(self.tool_call_separator)
used_separator_branch = True
else:
# Only search for bot_token if not processing subsequent tool
tool_call_pos = current_text.find(self.bot_token)
if tool_call_pos != -1:
start_idx = tool_call_pos + len(self.bot_token)
Expand All @@ -186,7 +187,23 @@ def parse_streaming_increment(
if start_idx >= len(current_text):
return StreamingParseResult()

obj, end_idx = _partial_json_loads(current_text[start_idx:], flags)
try:
obj, end_idx = _partial_json_loads(current_text[start_idx:], flags)
except (MalformedJSON, json.JSONDecodeError):
# Separator landed on non-JSON markup; fall back to
# bot_token which skips past all inter-object markup.
# e.g. Qwen25: separator "," matches between eot/bot tags.
if used_separator_branch and self.bot_token in current_text:
start_idx = current_text.find(self.bot_token) + len(
self.bot_token
)
if start_idx >= len(current_text):
return StreamingParseResult()
obj, end_idx = _partial_json_loads(
current_text[start_idx:], flags
)
else:
raise

is_current_complete = _is_complete_json(
current_text[start_idx : start_idx + end_idx]
Expand All @@ -212,7 +229,7 @@ def parse_streaming_increment(

current_tool_call = obj

except MalformedJSON:
except (MalformedJSON, json.JSONDecodeError):
return StreamingParseResult()

if not current_tool_call:
Expand Down
155 changes: 155 additions & 0 deletions test/registered/unit/function_call/test_function_call_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3853,5 +3853,160 @@ def test_streaming_function_call_marker_json_split_at_quotes(self):
self.assertEqual(params["city"], "Rome")


class TestQwen25Detector(unittest.TestCase):
"""Test Qwen25Detector streaming and non-streaming multi-tool-call parsing."""

def setUp(self):
from sglang.srt.function_call.qwen25_detector import Qwen25Detector

self.detector = Qwen25Detector()
self.tools = [
Tool(
type="function",
function=Function(
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name",
},
"state": {
"type": "string",
"description": "Two-letter state abbreviation",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "unit"],
},
),
),
]

# -- Non-streaming tests --

def test_detect_and_parse_single_tool_call(self):
text = '<tool_call>\n{"name": "get_current_weather", "arguments": {"city": "NYC", "state": "NY", "unit": "fahrenheit"}}\n</tool_call>'
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(len(result.calls), 1)
self.assertEqual(result.calls[0].name, "get_current_weather")
params = json.loads(result.calls[0].parameters)
self.assertEqual(params["city"], "NYC")

def test_detect_and_parse_multiple_tool_calls(self):
text = (
'<tool_call>\n{"name": "get_current_weather", "arguments": {"city": "NYC", "state": "NY", "unit": "fahrenheit"}}\n</tool_call>\n'
'<tool_call>\n{"name": "get_current_weather", "arguments": {"city": "Baltimore", "state": "MD", "unit": "fahrenheit"}}\n</tool_call>\n'
'<tool_call>\n{"name": "get_current_weather", "arguments": {"city": "Minneapolis", "state": "MN", "unit": "fahrenheit"}}\n</tool_call>\n'
'<tool_call>\n{"name": "get_current_weather", "arguments": {"city": "Los Angeles", "state": "CA", "unit": "fahrenheit"}}\n</tool_call>'
)
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(len(result.calls), 4)
cities = [json.loads(c.parameters)["city"] for c in result.calls]
self.assertEqual(cities, ["NYC", "Baltimore", "Minneapolis", "Los Angeles"])

def test_detect_and_parse_with_normal_text_prefix(self):
text = (
"Sure, let me check the weather.\n"
'<tool_call>\n{"name": "get_current_weather", "arguments": {"city": "NYC", "state": "NY", "unit": "celsius"}}\n</tool_call>'
)
result = self.detector.detect_and_parse(text, self.tools)
self.assertEqual(len(result.calls), 1)
self.assertIn("let me check", result.normal_text)

# -- Streaming tests --

def _collect_streaming_tool_calls(self, chunks):
"""Helper: feed chunks through streaming parser and collect tool calls by index."""
tool_calls_by_index = {}
for chunk in chunks:
result = self.detector.parse_streaming_increment(chunk, self.tools)
for call in result.calls:
if call.tool_index is not None:
if call.tool_index not in tool_calls_by_index:
tool_calls_by_index[call.tool_index] = {
"name": "",
"parameters": "",
}
if call.name:
tool_calls_by_index[call.tool_index]["name"] = call.name
if call.parameters:
tool_calls_by_index[call.tool_index][
"parameters"
] += call.parameters
return tool_calls_by_index

def test_streaming_single_tool_call(self):
chunks = [
"<tool_call>\n",
'{"name": "get_current_weather",',
' "arguments": {"city": "NYC",',
' "state": "NY",',
' "unit": "fahrenheit"}}',
"\n</tool_call>",
]
result = self._collect_streaming_tool_calls(chunks)
self.assertEqual(len(result), 1)
self.assertEqual(result[0]["name"], "get_current_weather")
params = json.loads(result[0]["parameters"])
self.assertEqual(params["city"], "NYC")

def test_streaming_multiple_tool_calls(self):
"""Core regression test: multiple tool calls must all be parsed in streaming mode."""
chunks = [
"<tool_call>\n",
'{"name": "get_current_weather",',
' "arguments": {"city": "NYC", "state": "NY", "unit": "fahrenheit"}}',
"\n</tool_call>\n",
"<tool_call>\n",
'{"name": "get_current_weather",',
' "arguments": {"city": "Baltimore", "state": "MD", "unit": "fahrenheit"}}',
"\n</tool_call>\n",
"<tool_call>\n",
'{"name": "get_current_weather",',
' "arguments": {"city": "LA", "state": "CA", "unit": "fahrenheit"}}',
"\n</tool_call>",
]
result = self._collect_streaming_tool_calls(chunks)
self.assertEqual(len(result), 3, f"Expected 3 tool calls, got {len(result)}")
cities = [json.loads(result[i]["parameters"])["city"] for i in sorted(result)]
self.assertEqual(cities, ["NYC", "Baltimore", "LA"])

def test_streaming_multiple_tool_calls_fused_chunks(self):
"""Test when separator and next bot_token arrive in a single chunk."""
chunks = [
'<tool_call>\n{"name": "get_current_weather", "arguments": {"city": "NYC", "state": "NY", "unit": "fahrenheit"}}',
'\n</tool_call>\n<tool_call>\n{"name": "get_current_weather",',
' "arguments": {"city": "LA", "state": "CA", "unit": "fahrenheit"}}',
"\n</tool_call>",
]
result = self._collect_streaming_tool_calls(chunks)
self.assertEqual(len(result), 2, f"Expected 2 tool calls, got {len(result)}")
cities = [json.loads(result[i]["parameters"])["city"] for i in sorted(result)]
self.assertEqual(cities, ["NYC", "LA"])

def test_streaming_multiple_tool_calls_char_by_char_separator(self):
"""Test when the separator between tool calls arrives character by character."""
call1 = '{"name": "get_current_weather", "arguments": {"city": "NYC", "state": "NY", "unit": "fahrenheit"}}'
call2 = '{"name": "get_current_weather", "arguments": {"city": "LA", "state": "CA", "unit": "celsius"}}'
separator = "\n</tool_call>\n<tool_call>\n"

chunks = ["<tool_call>\n", call1]
for ch in separator:
chunks.append(ch)
chunks.append(call2)
chunks.append("\n</tool_call>")

result = self._collect_streaming_tool_calls(chunks)
self.assertEqual(len(result), 2, f"Expected 2 tool calls, got {len(result)}")
cities = [json.loads(result[i]["parameters"])["city"] for i in sorted(result)]
self.assertEqual(cities, ["NYC", "LA"])


if __name__ == "__main__":
unittest.main()
Loading