Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 180 additions & 19 deletions python/sglang/srt/function_call/glm4_moe_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from sglang.srt.entrypoints.openai.protocol import Tool
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import StreamingParseResult, _GetInfoFunc
from sglang.srt.function_call.core_types import (
StreamingParseResult,
ToolCallItem,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -99,6 +103,7 @@ def parse_streaming_increment(
) -> StreamingParseResult:
"""
Streaming incremental parsing tool calls for GLM-4.5 and GLM-4.6 format.
Now supports streaming tool names and arguments incrementally.
"""
self._buffer += new_text
current_text = self._buffer
Expand All @@ -109,38 +114,194 @@ def parse_streaming_increment(
if self.current_tool_id > 0:
current_text = ""
return StreamingParseResult(normal_text=current_text)
# find ensures we find the first self.eot_token so there will be at most one tool_call in current_text[:end+len(self.eot_token)
end = current_text.find(self.eot_token)
if end != -1:

# Extract normal text before tool calls
normal_text = current_text[:start]

# Try to parse partial tool call for streaming
partial_result = self._parse_partial_tool_call(current_text[start:], tools)
if partial_result:
func_name, partial_args_str, is_complete = partial_result

# Initialize state if this is the first tool call
if self.current_tool_id == -1:
self.current_tool_id = 0
self.prev_tool_call_arr = []
self.streamed_args_for_tool = [""]
self.current_tool_name_sent = False

# Ensure we have enough entries in our tracking arrays
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")
result = self.detect_and_parse(
current_text[: end + len(self.eot_token)], tools=tools
)
if result.calls:
self.prev_tool_call_arr[self.current_tool_id] = {
"name": result.calls[0].name,
"arguments": json.loads(result.calls[0].parameters),
}
self.streamed_args_for_tool[self.current_tool_id] = result.calls[
0
].parameters
result.calls[0].tool_index = self.current_tool_id

tool_id = self.current_tool_id
calls = []

# Case 1: Send tool name if not sent yet
if not self.current_tool_name_sent:
self.current_tool_name_sent = True
calls.append(
ToolCallItem(tool_index=tool_id, name=func_name, parameters="")
)
# Case 2: Stream arguments incrementally
else:
# Calculate diff between current and previously streamed arguments
prev_args_str = self.streamed_args_for_tool[tool_id]

# Always check if there's new content to stream
if partial_args_str != prev_args_str:
# Try to parse both as JSON to compare properly
try:
prev_args = json.loads(prev_args_str) if prev_args_str else {}
current_args = json.loads(partial_args_str)

# Find new keys or changed values
new_content = {}
for key, value in current_args.items():
if key not in prev_args or prev_args[key] != value:
new_content[key] = value

if new_content:
argument_diff = json.dumps(new_content)
else:
argument_diff = ""
except:
# Fallback to string comparison
if partial_args_str.startswith(prev_args_str):
argument_diff = partial_args_str[len(prev_args_str) :]
else:
# If strings don't match, try to find common prefix
common_prefix = self._find_common_prefix(
prev_args_str, partial_args_str
)
if len(prev_args_str) < len(common_prefix):
argument_diff = partial_args_str[
len(prev_args_str) : len(common_prefix)
]
else:
argument_diff = ""
else:
argument_diff = ""

if argument_diff:
# Update streamed arguments
self.streamed_args_for_tool[tool_id] += argument_diff

calls.append(
ToolCallItem(
tool_index=tool_id, name=None, parameters=argument_diff
)
)

# Update prev_tool_call_arr with current state
try:
parsed_args = json.loads(partial_args_str)
except:
parsed_args = {}

self.prev_tool_call_arr[tool_id] = {
"name": func_name,
"arguments": parsed_args,
}

# If complete, advance to next tool
if is_complete:
# Remove processed portion from buffer
end = current_text.find(self.eot_token)
if end != -1:
self._buffer = current_text[end + len(self.eot_token) :]
self.current_tool_name_sent = False
self.current_tool_id += 1
self._buffer = current_text[end + len(self.eot_token) :]
return result
normal_text = current_text[:start]
else:
# Keep the buffer for partial tool call
self._buffer = current_text[start:]

return StreamingParseResult(normal_text=normal_text, calls=calls)

# No tool call found yet, return normal text before start token
self._buffer = current_text[start:]
return StreamingParseResult(normal_text=normal_text)

def _parse_partial_tool_call(
self, text: str, tools: List[Tool]
) -> tuple[str, str, bool] | None:
"""
Parse partial tool call from buffer (for streaming)
Returns (tool_name, partial_arguments_json, is_complete)
"""
if not text.startswith(self.bot_token):
return None

after_start = text[len(self.bot_token) :]

# Extract function name (until first newline)
name_end = after_start.find("\n")
if name_end == -1:
name_end = len(after_start)
func_name = after_start[:name_end].strip()

if not func_name:
return None

# Check if we have complete tool call
if self.eot_token in text:
# Complete tool call
end_pos = text.find(self.eot_token)
args_text = after_start[name_end + 1 : end_pos - len(self.bot_token)]

# Parse arguments using existing logic
pairs = re.findall(
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>",
args_text,
re.DOTALL,
)
arguments = {}
for arg_key, arg_value in pairs:
arg_key = arg_key.strip()
arg_value = arg_value.strip()
arg_type = get_argument_type(func_name, arg_key, tools)
if arg_type != "string":
arg_value, is_good_json = parse_arguments(arg_value)
arguments[arg_key] = arg_value

arguments_str = json.dumps(arguments)
return (func_name, arguments_str, True)
else:
# Partial tool call - try to parse partial arguments
args_text = after_start[name_end + 1 :]
partial_args = {}

# Try to parse any complete key-value pairs
pairs = re.findall(
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>",
args_text,
re.DOTALL,
)
for arg_key, arg_value in pairs:
arg_key = arg_key.strip()
arg_value = arg_value.strip()

if arg_key and arg_value:
arg_type = get_argument_type(func_name, arg_key, tools)
if arg_type != "string":
arg_value, is_good_json = parse_arguments(arg_value)
partial_args[arg_key] = arg_value

partial_args_str = json.dumps(partial_args)
return (func_name, partial_args_str, False)

def _find_common_prefix(self, s1: str, s2: str) -> str:
"""Find the common prefix of two strings"""
result = []
for c1, c2 in zip(s1, s2):
if c1 == c2:
result.append(c1)
else:
break
return "".join(result)

def supports_structural_tag(self) -> bool:
return False

Expand Down
Loading
Loading