diff --git a/python/sglang/srt/function_call/glm4_moe_detector.py b/python/sglang/srt/function_call/glm4_moe_detector.py index 301d0e0dedc8..0ead201878f3 100644 --- a/python/sglang/srt/function_call/glm4_moe_detector.py +++ b/python/sglang/srt/function_call/glm4_moe_detector.py @@ -6,7 +6,11 @@ from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector -from sglang.srt.function_call.core_types import StreamingParseResult, _GetInfoFunc +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + ToolCallItem, + _GetInfoFunc, +) from sglang.srt.function_call.ebnf_composer import EBNFComposer logger = logging.getLogger(__name__) @@ -99,6 +103,7 @@ def parse_streaming_increment( ) -> StreamingParseResult: """ Streaming incremental parsing tool calls for GLM-4.5 and GLM-4.6 format. + Now supports streaming tool names and arguments incrementally. """ self._buffer += new_text current_text = self._buffer @@ -109,38 +114,194 @@ def parse_streaming_increment( if self.current_tool_id > 0: current_text = "" return StreamingParseResult(normal_text=current_text) - # find ensures we find the first self.eot_token so there will be at most one tool_call in current_text[:end+len(self.eot_token) - end = current_text.find(self.eot_token) - if end != -1: + + # Extract normal text before tool calls + normal_text = current_text[:start] + + # Try to parse partial tool call for streaming + partial_result = self._parse_partial_tool_call(current_text[start:], tools) + if partial_result: + func_name, partial_args_str, is_complete = partial_result + # Initialize state if this is the first tool call if self.current_tool_id == -1: self.current_tool_id = 0 self.prev_tool_call_arr = [] self.streamed_args_for_tool = [""] + self.current_tool_name_sent = False + # Ensure we have enough entries in our tracking arrays while len(self.prev_tool_call_arr) <= self.current_tool_id: self.prev_tool_call_arr.append({}) while len(self.streamed_args_for_tool) <= self.current_tool_id: self.streamed_args_for_tool.append("") - result = self.detect_and_parse( - current_text[: end + len(self.eot_token)], tools=tools - ) - if result.calls: - self.prev_tool_call_arr[self.current_tool_id] = { - "name": result.calls[0].name, - "arguments": json.loads(result.calls[0].parameters), - } - self.streamed_args_for_tool[self.current_tool_id] = result.calls[ - 0 - ].parameters - result.calls[0].tool_index = self.current_tool_id + + tool_id = self.current_tool_id + calls = [] + + # Case 1: Send tool name if not sent yet + if not self.current_tool_name_sent: + self.current_tool_name_sent = True + calls.append( + ToolCallItem(tool_index=tool_id, name=func_name, parameters="") + ) + # Case 2: Stream arguments incrementally + else: + # Calculate diff between current and previously streamed arguments + prev_args_str = self.streamed_args_for_tool[tool_id] + + # Always check if there's new content to stream + if partial_args_str != prev_args_str: + # Try to parse both as JSON to compare properly + try: + prev_args = json.loads(prev_args_str) if prev_args_str else {} + current_args = json.loads(partial_args_str) + + # Find new keys or changed values + new_content = {} + for key, value in current_args.items(): + if key not in prev_args or prev_args[key] != value: + new_content[key] = value + + if new_content: + argument_diff = json.dumps(new_content) + else: + argument_diff = "" + except: + # Fallback to string comparison + if partial_args_str.startswith(prev_args_str): + argument_diff = partial_args_str[len(prev_args_str) :] + else: + # If strings don't match, try to find common prefix + common_prefix = self._find_common_prefix( + prev_args_str, partial_args_str + ) + if len(prev_args_str) < len(common_prefix): + argument_diff = partial_args_str[ + len(prev_args_str) : len(common_prefix) + ] + else: + argument_diff = "" + else: + argument_diff = "" + + if argument_diff: + # Update streamed arguments + self.streamed_args_for_tool[tool_id] += argument_diff + + calls.append( + ToolCallItem( + tool_index=tool_id, name=None, parameters=argument_diff + ) + ) + + # Update prev_tool_call_arr with current state + try: + parsed_args = json.loads(partial_args_str) + except: + parsed_args = {} + + self.prev_tool_call_arr[tool_id] = { + "name": func_name, + "arguments": parsed_args, + } + + # If complete, advance to next tool + if is_complete: + # Remove processed portion from buffer + end = current_text.find(self.eot_token) + if end != -1: + self._buffer = current_text[end + len(self.eot_token) :] + self.current_tool_name_sent = False self.current_tool_id += 1 - self._buffer = current_text[end + len(self.eot_token) :] - return result - normal_text = current_text[:start] + else: + # Keep the buffer for partial tool call + self._buffer = current_text[start:] + + return StreamingParseResult(normal_text=normal_text, calls=calls) + + # No tool call found yet, return normal text before start token self._buffer = current_text[start:] return StreamingParseResult(normal_text=normal_text) + def _parse_partial_tool_call( + self, text: str, tools: List[Tool] + ) -> tuple[str, str, bool] | None: + """ + Parse partial tool call from buffer (for streaming) + Returns (tool_name, partial_arguments_json, is_complete) + """ + if not text.startswith(self.bot_token): + return None + + after_start = text[len(self.bot_token) :] + + # Extract function name (until first newline) + name_end = after_start.find("\n") + if name_end == -1: + name_end = len(after_start) + func_name = after_start[:name_end].strip() + + if not func_name: + return None + + # Check if we have complete tool call + if self.eot_token in text: + # Complete tool call + end_pos = text.find(self.eot_token) + args_text = after_start[name_end + 1 : end_pos - len(self.bot_token)] + + # Parse arguments using existing logic + pairs = re.findall( + r"(.*?)\s*(.*?)", + args_text, + re.DOTALL, + ) + arguments = {} + for arg_key, arg_value in pairs: + arg_key = arg_key.strip() + arg_value = arg_value.strip() + arg_type = get_argument_type(func_name, arg_key, tools) + if arg_type != "string": + arg_value, is_good_json = parse_arguments(arg_value) + arguments[arg_key] = arg_value + + arguments_str = json.dumps(arguments) + return (func_name, arguments_str, True) + else: + # Partial tool call - try to parse partial arguments + args_text = after_start[name_end + 1 :] + partial_args = {} + + # Try to parse any complete key-value pairs + pairs = re.findall( + r"(.*?)\s*(.*?)", + args_text, + re.DOTALL, + ) + for arg_key, arg_value in pairs: + arg_key = arg_key.strip() + arg_value = arg_value.strip() + + if arg_key and arg_value: + arg_type = get_argument_type(func_name, arg_key, tools) + if arg_type != "string": + arg_value, is_good_json = parse_arguments(arg_value) + partial_args[arg_key] = arg_value + + partial_args_str = json.dumps(partial_args) + return (func_name, partial_args_str, False) + + def _find_common_prefix(self, s1: str, s2: str) -> str: + """Find the common prefix of two strings""" + result = [] + for c1, c2 in zip(s1, s2): + if c1 == c2: + result.append(c1) + else: + break + return "".join(result) + def supports_structural_tag(self) -> bool: return False diff --git a/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs b/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs index 8b9dc502470a..da1e6ed19c66 100644 --- a/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs +++ b/sgl-router/src/tool_parser/parsers/glm4_moe_parser.rs @@ -41,6 +41,9 @@ pub struct Glm4MoeParser { /// Tracks raw JSON string content streamed to client for each tool's arguments streamed_args_for_tool: Vec, + /// Whether the current tool's name has been sent (for streaming) + current_tool_name_sent: bool, + /// Token configuration bot_token: &'static str, eot_token: &'static str, @@ -67,6 +70,7 @@ impl Glm4MoeParser { prev_tool_call_arr: Vec::new(), current_tool_id: -1, streamed_args_for_tool: Vec::new(), + current_tool_name_sent: false, bot_token: "", eot_token: "", } @@ -154,6 +158,79 @@ impl Glm4MoeParser { Ok(tools) } + + /// Parse partial tool call from buffer (for streaming) + /// Returns (tool_name, partial_arguments_json, is_complete) + fn parse_partial_tool_call(&self, text: &str) -> ParserResult> { + // Check if we have a tool call start + if let Some(start_pos) = text.find(self.bot_token) { + let after_start = &text[start_pos + self.bot_token.len()..]; + + // Extract function name (until first newline) + let name_end = after_start.find('\n').unwrap_or(after_start.len()); + let func_name = after_start[..name_end].trim().to_string(); + + if func_name.is_empty() { + return Ok(None); + } + + // Check if we have complete tool call + if let Some(end_pos) = text.find(self.eot_token) { + // Complete tool call + let args_text = &after_start[name_end + 1..end_pos - start_pos - self.bot_token.len()]; + let arguments = self.parse_arguments(args_text)?; + let arguments_str = serde_json::to_string(&arguments) + .map_err(|e| ParserError::ParsingFailed(e.to_string()))?; + + return Ok(Some((func_name, arguments_str, true))); + } else { + // Partial tool call - try to parse partial arguments + let args_text = &after_start[name_end + 1..]; + let mut partial_args = serde_json::Map::new(); + + // Try to parse any complete key-value pairs + for capture in self.arg_extractor.captures_iter(args_text) { + let key = capture.get(1).map_or("", |m| m.as_str()).trim(); + let value_str = capture.get(2).map_or("", |m| m.as_str()).trim(); + + if !key.is_empty() && !value_str.is_empty() { + // Try to parse the value as JSON first, fallback to string + let value = if let Ok(json_val) = serde_json::from_str::(value_str) { + json_val + } else { + // Try parsing as Python literal (similar to Python's ast.literal_eval) + if value_str == "true" || value_str == "True" { + Value::Bool(true) + } else if value_str == "false" || value_str == "False" { + Value::Bool(false) + } else if value_str == "null" || value_str == "None" { + Value::Null + } else if let Ok(num) = value_str.parse::() { + Value::Number(num.into()) + } else if let Ok(num) = value_str.parse::() { + if let Some(n) = serde_json::Number::from_f64(num) { + Value::Number(n) + } else { + Value::String(value_str.to_string()) + } + } else { + Value::String(value_str.to_string()) + } + }; + + partial_args.insert(key.to_string(), value); + } + } + + let partial_args_str = serde_json::to_string(&partial_args) + .map_err(|e| ParserError::ParsingFailed(e.to_string()))?; + + return Ok(Some((func_name, partial_args_str, false))); + } + } + + Ok(None) + } } impl Default for Glm4MoeParser { @@ -190,123 +267,157 @@ impl ToolParser for Glm4MoeParser { chunk: &str, tools: &[Tool], ) -> ParserResult { - // Python logic: Wait for complete tool call, then parse it all at once + // Append new text to buffer self.buffer.push_str(chunk); let current_text = &self.buffer.clone(); - // Check if we have bot_token - let start = current_text.find(self.bot_token); - if start.is_none() { - self.buffer.clear(); - // If we're in the middle of streaming (current_tool_id > 0), don't return text - let normal_text = if self.current_tool_id > 0 { - String::new() - } else { - current_text.clone() - }; - return Ok(StreamingParseResult { - normal_text, - calls: vec![], - }); - } + // Check if we have tool markers + let has_tool_start = self.has_tool_markers(current_text); - // Check if we have eot_token (end of tool call) - let end = current_text.find(self.eot_token); - if let Some(end_pos) = end { - // We have a complete tool call! + if !has_tool_start { + // No tool markers found, clear buffer and return normal text + if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() { + let normal_text = self.buffer.clone(); + self.buffer.clear(); - // Initialize state if this is the first tool call - if self.current_tool_id == -1 { - self.current_tool_id = 0; - self.prev_tool_call_arr = Vec::new(); - self.streamed_args_for_tool = vec![String::new()]; - } + // If we're in the middle of streaming (current_tool_id > 0), don't return text + let normal_text = if self.current_tool_id > 0 { + String::new() + } else { + normal_text + }; - // Ensure we have enough entries in our tracking arrays - helpers::ensure_capacity( - self.current_tool_id, - &mut self.prev_tool_call_arr, - &mut self.streamed_args_for_tool, - ); - - // Parse the complete block using shared helper - let block_end = end_pos + self.eot_token.len(); - let parsed_tools = self.parse_tool_calls_from_text(¤t_text[..block_end])?; - - // Extract normal text before tool calls - let idx = current_text.find(self.bot_token); - let normal_text = if let Some(pos) = idx { - current_text[..pos].trim().to_string() + return Ok(StreamingParseResult { + normal_text, + calls: vec![], + }); } else { - String::new() - }; - - // Build tool indices for validation - let tool_indices = helpers::get_tool_indices(tools); + // Might be partial bot_token, keep buffering + return Ok(StreamingParseResult::default()); + } + } - let mut calls = Vec::new(); + // Build tool indices for validation + let tool_indices = helpers::get_tool_indices(tools); - if !parsed_tools.is_empty() { - // Take the first tool and convert to ToolCallItem - let tool_call = &parsed_tools[0]; - let tool_id = self.current_tool_id as usize; + // Extract normal text before tool calls + let start_pos = current_text.find(self.bot_token).unwrap(); + let normal_text = current_text[..start_pos].to_string(); + // Try to parse partial tool call + match self.parse_partial_tool_call(current_text)? { + Some((func_name, partial_args_str, is_complete)) => { // Validate tool name - if !tool_indices.contains_key(&tool_call.function.name) { + if !tool_indices.contains_key(&func_name) { // Invalid tool name - skip this tool, preserve indexing for next tool - tracing::warn!("Invalid tool name '{}' - skipping", tool_call.function.name); + tracing::warn!("Invalid tool name '{}' - skipping", func_name); helpers::reset_current_tool_state( &mut self.buffer, - &mut false, // glm4_moe doesn't track name_sent per tool + &mut self.current_tool_name_sent, &mut self.streamed_args_for_tool, &self.prev_tool_call_arr, ); return Ok(StreamingParseResult::default()); } - calls.push(ToolCallItem { - tool_index: tool_id, - name: Some(tool_call.function.name.clone()), - parameters: tool_call.function.arguments.clone(), - }); + // Initialize state if this is the first tool call + if self.current_tool_id == -1 { + self.current_tool_id = 0; + self.prev_tool_call_arr = Vec::new(); + self.streamed_args_for_tool = vec![String::new()]; + } + + // Ensure we have enough entries in our tracking arrays + helpers::ensure_capacity( + self.current_tool_id, + &mut self.prev_tool_call_arr, + &mut self.streamed_args_for_tool, + ); - // Store in tracking arrays - if self.prev_tool_call_arr.len() <= tool_id { - self.prev_tool_call_arr - .resize_with(tool_id + 1, || Value::Null); + let tool_id = self.current_tool_id as usize; + let mut calls = Vec::new(); + + // Case 1: Send tool name if not sent yet + if !self.current_tool_name_sent { + self.current_tool_name_sent = true; + calls.push(ToolCallItem { + tool_index: tool_id, + name: Some(func_name.clone()), + parameters: String::new(), + }); + } + // Case 2: Stream arguments incrementally + else { + // Calculate diff between current and previously streamed arguments + let prev_args_str = self.streamed_args_for_tool + .get(tool_id) + .map(|s| s.as_str()) + .unwrap_or(""); + + // Always check if there's new content to stream + let argument_diff = if partial_args_str != prev_args_str { + if partial_args_str.starts_with(prev_args_str) { + &partial_args_str[prev_args_str.len()..] + } else { + // If strings don't match, try to find common prefix + let common_prefix = helpers::find_common_prefix(prev_args_str, &partial_args_str); + if prev_args_str.len() < common_prefix.len() { + &partial_args_str[prev_args_str.len()..common_prefix.len()] + } else { + "" + } + } + } else { + "" + }; + + if !argument_diff.is_empty() { + // Update streamed arguments + if tool_id < self.streamed_args_for_tool.len() { + self.streamed_args_for_tool[tool_id].push_str(argument_diff); + } + + calls.push(ToolCallItem { + tool_index: tool_id, + name: None, + parameters: argument_diff.to_string(), + }); + } } - // Parse parameters as JSON and store - if let Ok(args) = serde_json::from_str::(&tool_call.function.arguments) { + // Update prev_tool_call_arr with current state + if tool_id < self.prev_tool_call_arr.len() { self.prev_tool_call_arr[tool_id] = serde_json::json!({ - "name": tool_call.function.name, - "arguments": args, + "name": func_name, + "arguments": serde_json::from_str::(&partial_args_str).unwrap_or(Value::Object(serde_json::Map::new())), }); } - if self.streamed_args_for_tool.len() <= tool_id { - self.streamed_args_for_tool - .resize_with(tool_id + 1, String::new); + // If complete, advance to next tool + if is_complete { + // Remove processed portion from buffer + if let Some(end_pos) = current_text.find(self.eot_token) { + let block_end = end_pos + self.eot_token.len(); + self.buffer = current_text[block_end..].to_string(); + } + self.current_tool_name_sent = false; + self.current_tool_id += 1; + } else { + // Keep the buffer for partial tool call + self.buffer = current_text[start_pos..].to_string(); } - self.streamed_args_for_tool[tool_id] = tool_call.function.arguments.clone(); - self.current_tool_id += 1; + Ok(StreamingParseResult { normal_text, calls }) + } + None => { + // No tool call found yet, return normal text before start token + self.buffer = current_text[start_pos..].to_string(); + Ok(StreamingParseResult { + normal_text, + calls: vec![], + }) } - - // Remove processed portion from buffer - self.buffer = current_text[block_end..].to_string(); - return Ok(StreamingParseResult { normal_text, calls }); } - - // No complete tool call yet - return normal text before start token - let start_pos = start.unwrap(); - let normal_text = current_text[..start_pos].to_string(); - self.buffer = current_text[start_pos..].to_string(); - - Ok(StreamingParseResult { - normal_text, - calls: vec![], - }) } fn has_tool_markers(&self, text: &str) -> bool { @@ -321,6 +432,7 @@ impl ToolParser for Glm4MoeParser { self.buffer.clear(); self.prev_tool_call_arr.clear(); self.current_tool_id = -1; + self.current_tool_name_sent = false; self.streamed_args_for_tool.clear(); } } diff --git a/sgl-router/tests/tool_parser_glm4_moe.rs b/sgl-router/tests/tool_parser_glm4_moe.rs index 86d161c9ef73..11f29eb7651a 100644 --- a/sgl-router/tests/tool_parser_glm4_moe.rs +++ b/sgl-router/tests/tool_parser_glm4_moe.rs @@ -167,3 +167,75 @@ async fn test_glm4_nested_json_in_arg_values() { assert!(args["data"].is_object()); assert!(args["list"].is_array()); } + +#[tokio::test] +async fn test_glm4_streaming_tool_call_arguments() { + let mut parser = Glm4MoeParser::new(); + let tools = create_test_tools(); + + // Test streaming tool call arguments incrementally + let chunks = vec![ + "Let me help you with that.\nget_weather\n", + "city\nBeijing\n", + "date\n2024-12-25\n", + "" + ]; + + let mut all_calls = Vec::new(); + let mut normal_text = String::new(); + + for chunk in chunks { + let result = parser.parse_incremental(chunk, &tools).await.unwrap(); + normal_text.push_str(&result.normal_text); + all_calls.extend(result.calls); + } + + // Should have received tool name first, then arguments incrementally + assert_eq!(all_calls.len(), 3); // name + 2 argument chunks + + // First call should be tool name + assert_eq!(all_calls[0].name, Some("get_weather".to_string())); + assert_eq!(all_calls[0].parameters, ""); + + // Second call should be first argument + assert_eq!(all_calls[1].name, None); + assert!(all_calls[1].parameters.contains("Beijing")); + + // Third call should be second argument + assert_eq!(all_calls[2].name, None); + assert!(all_calls[2].parameters.contains("2024-12-25")); + + assert_eq!(normal_text, "Let me help you with that.\n"); +} + +#[tokio::test] +async fn test_glm4_streaming_partial_arguments() { + let mut parser = Glm4MoeParser::new(); + let tools = create_test_tools(); + + // Test streaming with partial arguments + let chunks = vec![ + "get_weather\n", + "city\nBeijing\n", + "date\n2024-12-25\n", + "" + ]; + + let mut all_calls = Vec::new(); + + for chunk in chunks { + let result = parser.parse_incremental(chunk, &tools).await.unwrap(); + all_calls.extend(result.calls); + } + + // Should have received tool name first, then arguments incrementally + assert_eq!(all_calls.len(), 3); // name + 2 argument chunks + + // First call should be tool name + assert_eq!(all_calls[0].name, Some("get_weather".to_string())); + assert_eq!(all_calls[0].parameters, ""); + + // Subsequent calls should be argument chunks + assert_eq!(all_calls[1].name, None); + assert_eq!(all_calls[2].name, None); +}