diff --git a/python/sglang/srt/function_call/glm4_moe_detector.py b/python/sglang/srt/function_call/glm4_moe_detector.py index b0fc78249aca..95423b414d40 100644 --- a/python/sglang/srt/function_call/glm4_moe_detector.py +++ b/python/sglang/srt/function_call/glm4_moe_detector.py @@ -2,16 +2,43 @@ import json import logging import re -from typing import List +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple from sglang.srt.entrypoints.openai.protocol import Tool from sglang.srt.function_call.base_format_detector import BaseFormatDetector -from sglang.srt.function_call.core_types import StreamingParseResult, _GetInfoFunc +from sglang.srt.function_call.core_types import ( + StreamingParseResult, + ToolCallItem, + _GetInfoFunc, +) logger = logging.getLogger(__name__) -def get_argument_type(func_name: str, arg_key: str, defined_tools: list): +class StreamState(str, Enum): + """State machine states for XML to JSON streaming conversion.""" + + INIT = "INIT" + BETWEEN = "BETWEEN" + IN_KEY = "IN_KEY" + WAITING_VALUE = "WAITING_VALUE" + IN_VALUE = "IN_VALUE" + + +def get_argument_type( + func_name: str, arg_key: str, defined_tools: List[Tool] +) -> Optional[str]: + """Get the expected type of a function argument from tool definitions. + + Args: + func_name: Name of the function/tool + arg_key: Name of the argument + defined_tools: List of available tools + + Returns: + The type string (e.g., 'string', 'number', 'object') or None if not found + """ name2tool = {tool.function.name: tool for tool in defined_tools} if func_name not in name2tool: return None @@ -21,32 +48,90 @@ def get_argument_type(func_name: str, arg_key: str, defined_tools: list): return tool.function.parameters["properties"][arg_key].get("type", None) -def parse_arguments(json_value): +def _convert_to_number(value: str) -> Any: + """Convert string to appropriate number type (int or float). + + Args: + value: String value to convert + + Returns: + Converted number or original string if conversion fails + """ + try: + if "." in value or "e" in value.lower(): + return float(value) + else: + return int(value) + except (ValueError, AttributeError): + return value + + +def parse_arguments( + json_value: str, arg_type: Optional[str] = None +) -> Tuple[Any, bool]: + """Parse argument value with multiple fallback strategies. + + Args: + json_value: Raw string value to parse + arg_type: Expected type hint ('string', 'number', 'object', etc.) + + Returns: + Tuple of (parsed_value, is_valid_json) + """ + # Strategy 1: Direct JSON parsing try: parsed_value = json.loads(json_value) + + # Type coercion for number type + if arg_type == "number" and isinstance(parsed_value, str): + parsed_value = _convert_to_number(parsed_value) + return parsed_value, True - except: - # If that fails, try wrapping it to unescape JSON characters - try: - # Wrap the value as a JSON string field - wrapped = json.loads('{"tmp": "' + json_value + '"}') - # parse the unescaped value - parsed_value = json.loads(wrapped["tmp"]) - return parsed_value, True - except: - # Final fallback to ast.literal_eval - try: - parsed_value = ast.literal_eval(json_value) - return parsed_value, True - except: - return json_value, False + except (json.JSONDecodeError, ValueError): + pass + + # Strategy 2: Unescape and parse + try: + wrapped = json.loads('{"tmp": "' + json_value + '"}') + parsed_value = json.loads(wrapped["tmp"]) + + if arg_type == "number" and isinstance(parsed_value, str): + parsed_value = _convert_to_number(parsed_value) + + return parsed_value, True + except (json.JSONDecodeError, ValueError, KeyError): + pass + + # Strategy 3: ast.literal_eval + try: + parsed_value = ast.literal_eval(json_value) + return parsed_value, True + except (ValueError, SyntaxError): + pass + + # Strategy 4: Treat as string + try: + quoted_value = json.dumps(str(json_value)) + return json.loads(quoted_value), True + except (json.JSONDecodeError, ValueError): + return json_value, False class Glm4MoeDetector(BaseFormatDetector): """ Detector for GLM-4.5 and GLM-4.6 models. - Assumes function call format: - get_weather\ncity\n北京\ndate\n2024-06-27\n\nget_weather\ncity\n上海\ndate\n2024-06-27\n + Assumes function call format (with actual newlines): + get_weather + city + 北京 + date + 2024-06-27 + + + Or with literal \n characters (escaped as \\n in the output): + get_weather\ncity\n北京\n + + Uses a streaming state machine to convert XML to JSON incrementally for maximum speed. """ def __init__(self): @@ -61,6 +146,23 @@ def __init__(self): r"(.*?)(?:\\n|\s)*(.*?)", re.DOTALL, ) + self._last_arguments = "" + self.current_tool_id = -1 + self.current_tool_name_sent = False + self._streamed_raw_length = 0 + self._reset_streaming_state() + + def _reset_streaming_state(self) -> None: + """Reset the streaming state machine for a new tool call.""" + self._stream_state = StreamState.INIT + self._current_key = "" + self._current_value = "" + self._xml_tag_buffer = "" + self._is_first_param = True + self._value_started = False + self._cached_value_type: Optional[str] = ( + None # Cache the value type for consistency + ) def has_tool_call(self, text: str) -> bool: """Check if the text contains a glm-4.5 / glm-4.6 format tool call.""" @@ -87,69 +189,399 @@ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult func_name = func_detail.group(1) func_args = func_detail.group(2) pairs = self.func_arg_regex.findall(func_args) - arguments = {} - for arg_key, arg_value in pairs: - arg_key = arg_key.strip() - arg_value = arg_value.strip() - arg_type = get_argument_type(func_name, arg_key, tools) - if arg_type != "string": - arg_value, is_good_json = parse_arguments(arg_value) - arguments[arg_key] = arg_value + + # Parse arguments using shared method + arguments = self._parse_argument_pairs(pairs, func_name, tools) + # construct match_result for parse_base_json match_result = {"name": func_name, "parameters": arguments} calls.extend(self.parse_base_json(match_result, tools)) return StreamingParseResult(normal_text=normal_text, calls=calls) except Exception as e: - logger.error(f"Error in detect_and_parse: {e}") + logger.error(f"Error in detect_and_parse: {e}", exc_info=True) # return the normal text if parsing fails return StreamingParseResult(normal_text=text) + def _get_value_type(self, func_name: str, key: str, tools: List[Tool]) -> str: + """Get parameter type from tool definition, with fallback to auto-detection. + + Args: + func_name: Name of the function + key: Parameter name + tools: List of available tools + + Returns: + Type string: 'string', 'number', or 'object' + """ + arg_type = get_argument_type(func_name, key, tools) + if arg_type: + return arg_type + + # Auto-detect type from value (best effort) + first_chars = self._current_value.strip()[:10] if self._current_value else "" + if first_chars: + first_char = first_chars[0] + if first_char.isdigit() or first_char in ["-", "."]: + return "number" + elif first_char in ["{", "["]: + return "object" + + return "string" + + def _format_value_complete(self, value: str, value_type: str) -> str: + """Format complete value based on type. + + Args: + value: Raw value string + value_type: Expected type ('string', 'number', 'object') + + Returns: + Properly formatted JSON value string + """ + if value_type == "string": + # Ensure proper JSON string formatting with quotes + return json.dumps(value, ensure_ascii=False) + elif value_type == "number": + try: + num = _convert_to_number(value.strip()) + return str(num) + except (ValueError, AttributeError): + # Fallback to string if not a valid number + logger.warning( + f"Failed to parse '{value}' as number, treating as string" + ) + return json.dumps(str(value), ensure_ascii=False) + else: + # For object/array types, return as-is (should already be valid JSON) + return value + + def _process_xml_to_json_streaming( + self, raw_increment: str, func_name: str, tools: List[Tool] + ) -> str: + """Convert XML increment to JSON streaming output using state machine. + + This method processes XML fragments character by character and converts them + to JSON format incrementally. It maintains state across calls to handle + partial XML tags and values. + + Args: + raw_increment: New XML content to process + func_name: Name of the function being called + tools: List of available tools for type inference + + Returns: + JSON string increment to append to the output + """ + json_output = "" + + for char in raw_increment: + self._xml_tag_buffer += char + + if self._stream_state in [StreamState.INIT, StreamState.BETWEEN]: + if self._xml_tag_buffer.endswith(""): + self._stream_state = StreamState.IN_KEY + self._current_key = "" + self._xml_tag_buffer = "" + json_output += "{" if self._is_first_param else ", " + self._is_first_param = False + + elif self._stream_state == StreamState.IN_KEY: + if self._xml_tag_buffer.endswith(""): + self._current_key = self._xml_tag_buffer[:-10].strip() + self._xml_tag_buffer = "" + self._stream_state = StreamState.WAITING_VALUE + json_output += ( + json.dumps(self._current_key, ensure_ascii=False) + ": " + ) + + elif self._stream_state == StreamState.WAITING_VALUE: + if self._xml_tag_buffer.endswith(""): + self._stream_state = StreamState.IN_VALUE + self._current_value = "" + self._xml_tag_buffer = "" + self._value_started = False + # Determine and cache the value type at the start + self._cached_value_type = self._get_value_type( + func_name, self._current_key, tools + ) + + elif self._stream_state == StreamState.IN_VALUE: + if self._xml_tag_buffer.endswith(""): + final_value = self._xml_tag_buffer[:-12] + self._current_value += final_value + + # Use cached value type for consistency + value_type = self._cached_value_type or "string" + + if self._value_started: + # Output any remaining content + if final_value: + if value_type == "string": + json_output += json.dumps( + final_value, ensure_ascii=False + )[1:-1] + else: + json_output += final_value + # Always output closing quote for string type when value was started + if value_type == "string": + json_output += '"' + else: + # Value was never started (empty or complete in one chunk) + json_output += self._format_value_complete( + self._current_value, value_type + ) + + self._xml_tag_buffer = "" + self._stream_state = StreamState.BETWEEN + self._current_value = "" + self._value_started = False + self._cached_value_type = None # Reset cached type + else: + closing_tag = "" + is_potential_closing = len(self._xml_tag_buffer) <= len( + closing_tag + ) and closing_tag.startswith(self._xml_tag_buffer) + + if not is_potential_closing: + content = self._xml_tag_buffer + # Use cached value type for consistency + value_type = self._cached_value_type or "string" + + if value_type == "string": + if not self._value_started: + json_output += '"' + self._value_started = True + if content: + json_output += json.dumps(content, ensure_ascii=False)[ + 1:-1 + ] + self._current_value += content + self._xml_tag_buffer = "" + elif value_type == "number": + if content: + if not self._value_started: + self._value_started = True + json_output += content + self._current_value += content + self._xml_tag_buffer = "" + else: + # For object/array types, output as-is + if content: + if not self._value_started: + self._value_started = True + json_output += content + self._current_value += content + self._xml_tag_buffer = "" + + return json_output + def parse_streaming_increment( self, new_text: str, tools: List[Tool] ) -> StreamingParseResult: """ Streaming incremental parsing tool calls for GLM-4.5 and GLM-4.6 format. + Uses a state machine to convert XML to JSON incrementally for true character-by-character streaming. + Outputs JSON increments immediately as XML data arrives. """ self._buffer += new_text current_text = self._buffer - start = current_text.find(self.bot_token) - if start == -1: - self._buffer = "" - if self.current_tool_id > 0: - current_text = "" - return StreamingParseResult(normal_text=current_text) - # find ensures we find the first self.eot_token so there will be at most one tool_call in current_text[:end+len(self.eot_token) - end = current_text.find(self.eot_token) - if end != -1: - # Initialize state if this is the first tool call - if self.current_tool_id == -1: - self.current_tool_id = 0 - self.prev_tool_call_arr = [] - self.streamed_args_for_tool = [""] - # Ensure we have enough entries in our tracking arrays - while len(self.prev_tool_call_arr) <= self.current_tool_id: - self.prev_tool_call_arr.append({}) - while len(self.streamed_args_for_tool) <= self.current_tool_id: - self.streamed_args_for_tool.append("") - result = self.detect_and_parse( - current_text[: end + len(self.eot_token)], tools=tools + # Check if we have a tool call + has_tool_call = self.bot_token in current_text + + if not has_tool_call: + # Check if buffer could be the start of a tool call + # Keep buffer if it could be a partial match of bot_token + is_potential_start = any( + self.bot_token.startswith(current_text[-i:]) + for i in range(1, min(len(current_text), len(self.bot_token)) + 1) ) - if result.calls: - self.prev_tool_call_arr[self.current_tool_id] = { - "name": result.calls[0].name, - "arguments": json.loads(result.calls[0].parameters), - } - self.streamed_args_for_tool[self.current_tool_id] = result.calls[ - 0 - ].parameters - result.calls[0].tool_index = self.current_tool_id - self.current_tool_id += 1 - self._buffer = current_text[end + len(self.eot_token) :] - return result - normal_text = current_text[:start] - self._buffer = current_text[start:] - return StreamingParseResult(normal_text=normal_text) + + if not is_potential_start: + # Not a potential tool call, return as normal text + # Must return the entire buffer (current_text), not just new_text, + # because buffer may contain previously accumulated characters like '<' + # that turned out not to be part of a tool call + output_text = current_text + self._buffer = "" + if self.eot_token in output_text: + output_text = output_text.replace(self.eot_token, "") + return StreamingParseResult(normal_text=output_text) + else: + # Could be start of tool call, keep buffering + return StreamingParseResult(normal_text="", calls=[]) + + if not hasattr(self, "_tool_indices"): + self._tool_indices = self._get_tool_indices(tools) + + calls: list[ToolCallItem] = [] + try: + # Try to match a partial or complete tool call + partial_match = re.search( + pattern=r"(.*?)(?:\\n|\n)(.*?)(|$)", + string=current_text, + flags=re.DOTALL, + ) + if partial_match: + func_name = partial_match.group(1).strip() + func_args_raw = partial_match.group(2).strip() + is_tool_end = partial_match.group(3) + + # Initialize state if this is the first tool call + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + self._streamed_raw_length = 0 + self.current_tool_name_sent = False + self._reset_streaming_state() + + # Ensure we have enough entries in our tracking arrays + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + # Send tool name first if not sent yet + if not self.current_tool_name_sent: + assert func_name, "func_name should not be empty" + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self._streamed_raw_length = 0 + self._reset_streaming_state() + # Store the tool call info + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + else: + # Process XML to JSON streaming + current_raw_length = len(func_args_raw) + + if current_raw_length > self._streamed_raw_length: + # Get the new raw XML content + raw_increment = func_args_raw[self._streamed_raw_length :] + + # Convert XML increment to JSON increment using state machine + json_increment = self._process_xml_to_json_streaming( + raw_increment, func_name, tools + ) + + if json_increment: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=json_increment, + ) + ) + self._last_arguments += json_increment + self.streamed_args_for_tool[ + self.current_tool_id + ] += json_increment + + # Update the streamed length + self._streamed_raw_length = current_raw_length + + if is_tool_end == self.eot_token: + if self._is_first_param: + empty_object = "{}" + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=empty_object, + ) + ) + self._last_arguments += empty_object + elif not self._last_arguments.endswith("}"): + closing_brace = "}" + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=None, + parameters=closing_brace, + ) + ) + self._last_arguments += closing_brace + self.streamed_args_for_tool[ + self.current_tool_id + ] += closing_brace + + try: + pairs = self.func_arg_regex.findall(func_args_raw) + if pairs: + arguments = self._parse_argument_pairs( + pairs, func_name, tools + ) + self.prev_tool_call_arr[self.current_tool_id][ + "arguments" + ] = arguments + except Exception as e: + logger.debug( + f"Failed to parse arguments: {e}", exc_info=True + ) + + # Remove the completed tool call from buffer + self._buffer = current_text[partial_match.end(3) :] + + result = StreamingParseResult(normal_text="", calls=calls) + self.current_tool_id += 1 + self._last_arguments = "" + self.current_tool_name_sent = False + self._streamed_raw_length = 0 + self._reset_streaming_state() + return result + + return StreamingParseResult(normal_text="", calls=calls) + + except Exception as e: + logger.error(f"Error in parse_streaming_increment: {e}", exc_info=True) + return StreamingParseResult(normal_text=current_text) + + def _parse_argument_pairs( + self, pairs: List[Tuple[str, str]], func_name: str, tools: List[Tool] + ) -> Dict[str, Any]: + """Parse argument key-value pairs with type coercion. + + Args: + pairs: List of (key, value) tuples from regex matching + func_name: Name of the function + tools: List of available tools + + Returns: + Dictionary of parsed arguments + """ + arguments = {} + for arg_key, arg_value in pairs: + arg_key = arg_key.strip() + arg_value = arg_value.strip() + arg_type = get_argument_type(func_name, arg_key, tools) + parsed_value, is_good_json = parse_arguments(arg_value, arg_type) + + if arg_type == "string": + # Only convert to string if explicitly defined as string type + if isinstance(parsed_value, str): + arguments[arg_key] = parsed_value + elif isinstance(parsed_value, (dict, list)): + # If parsed as dict/list but schema says string, convert to JSON string + arguments[arg_key] = json.dumps(parsed_value, ensure_ascii=False) + else: + arguments[arg_key] = str(parsed_value) + elif arg_type is None: + # If type is not defined, keep the parsed value as-is + arguments[arg_key] = parsed_value if is_good_json else arg_value + else: + # For other types (number, object, array, etc.), use parsed value + arguments[arg_key] = parsed_value if is_good_json else arg_value + + return arguments def supports_structural_tag(self) -> bool: return False diff --git a/test/registered/function_call/test_function_call_parser.py b/test/registered/function_call/test_function_call_parser.py index b3c448187320..d5484fdad5b0 100644 --- a/test/registered/function_call/test_function_call_parser.py +++ b/test/registered/function_call/test_function_call_parser.py @@ -2034,12 +2034,12 @@ def test_streaming_tool_call(self): and tool_call_chunk.tool_index is not None ): while len(tool_calls) <= tool_call_chunk.tool_index: - tool_calls.append({"name": "", "parameters": {}}) + tool_calls.append({"name": "", "parameters": ""}) tc = tool_calls[tool_call_chunk.tool_index] if tool_call_chunk.name: tc["name"] = tool_call_chunk.name if tool_call_chunk.parameters: - tc["parameters"] = tool_call_chunk.parameters + tc["parameters"] += tool_call_chunk.parameters self.assertEqual(len(tool_calls), 1) self.assertEqual(tool_calls[0]["name"], "get_weather") self.assertEqual( @@ -2066,12 +2066,12 @@ def test_streaming_multiple_tool_calls(self): and tool_call_chunk.tool_index is not None ): while len(tool_calls) <= tool_call_chunk.tool_index: - tool_calls.append({"name": "", "parameters": {}}) + tool_calls.append({"name": "", "parameters": ""}) tc = tool_calls[tool_call_chunk.tool_index] if tool_call_chunk.name: tc["name"] = tool_call_chunk.name if tool_call_chunk.parameters: - tc["parameters"] = tool_call_chunk.parameters + tc["parameters"] += tool_call_chunk.parameters self.assertEqual(len(tool_calls), 2) self.assertEqual(tool_calls[0]["name"], "get_weather") self.assertEqual( @@ -2102,19 +2102,33 @@ def test_invalid_tool_call(self): def test_partial_tool_call(self): """Test parsing a partial tool call that spans multiple chunks.""" - text1 = "get_weather\ncity\n" - result1 = self.detector.parse_streaming_increment(text1, self.tools) - self.assertEqual(result1.normal_text, "") - self.assertEqual(result1.calls, []) - self.assertEqual(self.detector._buffer, text1) - text2 = "Beijing\ndate\n2024-06-27\n" - result2 = self.detector.parse_streaming_increment(text2, self.tools) - self.assertEqual(len(result2.calls), 1) - self.assertEqual(result2.calls[0].name, "get_weather") + chunks = [ + "get_weather\n", + "city\nBeijing\n", + "date\n2024-06-27\n", + ] + + tool_calls = [] + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for tool_call_chunk in result.calls: + if ( + hasattr(tool_call_chunk, "tool_index") + and tool_call_chunk.tool_index is not None + ): + while len(tool_calls) <= tool_call_chunk.tool_index: + tool_calls.append({"name": "", "parameters": ""}) + tc = tool_calls[tool_call_chunk.tool_index] + if tool_call_chunk.name: + tc["name"] = tool_call_chunk.name + if tool_call_chunk.parameters: + tc["parameters"] += tool_call_chunk.parameters + + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0]["name"], "get_weather") self.assertEqual( - result2.calls[0].parameters, '{"city": "Beijing", "date": "2024-06-27"}' + tool_calls[0]["parameters"], '{"city": "Beijing", "date": "2024-06-27"}' ) - self.assertEqual(self.detector._buffer, "") def test_array_argument_with_escaped_json(self): """Test that array arguments with escaped JSON are properly handled without double-escaping."""