diff --git a/python/sglang/srt/function_call/qwen3_coder_detector.py b/python/sglang/srt/function_call/qwen3_coder_detector.py index 597d8600fc46..9dd77903d429 100644 --- a/python/sglang/srt/function_call/qwen3_coder_detector.py +++ b/python/sglang/srt/function_call/qwen3_coder_detector.py @@ -1,12 +1,10 @@ import ast -import html import json import logging import re -from typing import Any, Dict, List, Tuple +from typing import Any, List, Optional from sglang.srt.entrypoints.openai.protocol import Tool -from sglang.srt.environ import envs from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import ( StreamingParseResult, @@ -17,334 +15,457 @@ logger = logging.getLogger(__name__) -def _safe_val(raw: str) -> Any: - raw = html.unescape(raw.strip()) - try: - return json.loads(raw) - except Exception: - try: - return ast.literal_eval(raw) - except Exception: - return raw - - class Qwen3CoderDetector(BaseFormatDetector): - """ - Detector for Qwen 3 models. - Assumes function call format: - - - - pwd && ls - - - - """ - def __init__(self): super().__init__() + + # Sentinel tokens self.tool_call_start_token: str = "" self.tool_call_end_token: str = "" self.tool_call_prefix: str = "(.*?)|(.*?)$", re.DOTALL - ) + self.function_end_token: str = "" + self.parameter_prefix: str = "(.*?)", re.DOTALL) self.tool_call_function_regex = re.compile( r"|||(?=)|$)", + re.DOTALL, ) - self._buf: str = "" - # Streaming state variables - self._current_function_name: str = "" - self._current_parameters: Dict[str, Any] = {} - self._streamed_parameters: Dict[str, str] = ( - {} - ) # Track what parameter content we've streamed - self._in_tool_call: bool = False - self._function_name_sent: bool = False + # Streaming State + # Base class already initializes _buffer, we just use it directly + # No need to check with hasattr - we control the lifecycle through inheritance + + # Index pointing to the next character to be processed in buffer + self.parsed_pos: int = 0 + # Parameter count inside the current tool being processed, used to determine whether to add comma + self.current_tool_param_count: int = 0 + # Flag indicating whether current tool has already sent '{' + self.json_started: bool = False + + # [FIX] New state flag: mark whether inside tool_call structure block + self.is_inside_tool_call: bool = False + + # Initialize attributes that were missing in the original PR + self.current_func_name: Optional[str] = None def has_tool_call(self, text: str) -> bool: return self.tool_call_start_token in text + def _get_arguments_config( + self, func_name: str, tools: Optional[list[Tool]] + ) -> dict: + """Extract argument configuration for a function.""" + if tools is None: + return {} + for config in tools: + try: + config_type = config.type + config_function = config.function + config_function_name = config_function.name + except AttributeError: + continue + + if config_type == "function" and config_function_name == func_name: + try: + params = config_function.parameters + except AttributeError: + return {} + + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + else: + return {} + logger.warning(f"Tool '{func_name}' is not defined in the tools list.") + return {} + + def _convert_param_value( + self, param_value: str, param_name: str, param_config: dict, func_name: str + ) -> Any: + """Convert parameter value based on its type in the schema.""" + # Handle null value for any type + if param_value.lower() == "null": + return None + + if param_name not in param_config: + if param_config != {}: + logger.warning( + f"Parsed parameter '{param_name}' is not defined in the tool " + f"parameters for tool '{func_name}', directly returning the string value." + ) + return param_value + + if ( + isinstance(param_config[param_name], dict) + and "type" in param_config[param_name] + ): + param_type = str(param_config[param_name]["type"]).strip().lower() + else: + param_type = "string" + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + return param_value + elif ( + param_type.startswith("int") + or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned") + ): + try: + param_value = int(param_value) + except Exception: + logger.warning( + f"Parsed value '{param_value}' of parameter '{param_name}' is not an integer in tool " + f"'{func_name}', degenerating to string." + ) + return param_value + elif param_type.startswith("num") or param_type.startswith("float"): + try: + maybe_convert = ( + False if "." in param_value or "e" in param_value.lower() else True + ) + param_value: float = float(param_value) + if maybe_convert and param_value.is_integer(): + param_value = int(param_value) + except Exception: + logger.warning( + f"Parsed value '{param_value}' of parameter '{param_name}' is not a float in tool " + f"'{func_name}', degenerating to string." + ) + return param_value + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + if param_value not in ["true", "false"]: + logger.warning( + f"Parsed value '{param_value}' of parameter '{param_name}' is not a boolean (`true` of `false`) in tool '{func_name}', degenerating to false." + ) + return param_value == "true" + else: + if ( + param_type in ["object", "array", "arr"] + or param_type.startswith("dict") + or param_type.startswith("list") + ): + try: + param_value = json.loads(param_value) + return param_value + except Exception: + logger.warning( + f"Parsed value '{param_value}' of parameter '{param_name}' cannot be parsed with json.loads in tool " + f"'{func_name}', will try other methods to parse it." + ) + try: + param_value = ast.literal_eval(param_value) # safer + except Exception: + logger.warning( + f"Parsed value '{param_value}' of parameter '{param_name}' cannot be converted via Python `ast.literal_eval()` in tool '{func_name}', degenerating to string." + ) + return param_value + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: - normal, calls = self._extract(text, tools) - return StreamingParseResult(normal_text=normal, calls=calls) + """One-shot parsing for non-streaming scenarios.""" + if self.tool_call_start_token not in text: + return StreamingParseResult(normal_text=text) + + calls = [] + try: + # Simple cleanup of the text to find tool calls + # Note: This is a simplified regex approach consistent with vLLM + raw_tool_calls = self.tool_call_regex.findall(text) + if not raw_tool_calls: + # Fallback: maybe the whole text is inside the tag or tags are stripped + if self.tool_call_prefix in text: + raw_tool_calls = [text] + + tool_idx = 0 + for tool_content in raw_tool_calls: + # Find function calls + funcs = self.tool_call_function_regex.findall(tool_content) + for func_match in funcs: + func_body = func_match[0] or func_match[1] + if ">" not in func_body: + continue + + name_end = func_body.index(">") + func_name = func_body[:name_end] + params_str = func_body[name_end + 1 :] + + param_config = self._get_arguments_config(func_name, tools) + parsed_params = {} + + for p_match in self.tool_call_parameter_regex.findall(params_str): + if ">" not in p_match: + continue + p_idx = p_match.index(">") + p_name = p_match[:p_idx] + p_val = p_match[p_idx + 1 :] + # Remove prefixing and trailing \n + if p_val.startswith("\n"): + p_val = p_val[1:] + if p_val.endswith("\n"): + p_val = p_val[:-1] + + parsed_params[p_name] = self._convert_param_value( + p_val, p_name, param_config, func_name + ) + + calls.append( + ToolCallItem( + tool_index=tool_idx, + name=func_name, + parameters=json.dumps(parsed_params, ensure_ascii=False), + ) + ) + tool_idx += 1 + + # Determine normal text (text before the first tool call) + start_idx = text.find(self.tool_call_start_token) + if start_idx == -1: + start_idx = text.find(self.tool_call_prefix) + normal_text = text[:start_idx] if start_idx > 0 else "" + + return StreamingParseResult(normal_text=normal_text, calls=calls) + + except Exception as e: + logger.error(f"Error in detect_and_parse: {e}") + return StreamingParseResult(normal_text=text) def parse_streaming_increment( self, new_text: str, tools: List[Tool] ) -> StreamingParseResult: - self._buf += new_text - normal = "" - calls: List[ToolCallItem] = [] + """ + Robust cursor-based streaming parser. + """ + self._buffer += new_text - # Build tool indices for validation - if not hasattr(self, "_tool_indices"): - self._tool_indices = self._get_tool_indices(tools) + # Guard against empty buffer + if not self._buffer: + return StreamingParseResult() + + calls = [] + normal_text_chunks = [] while True: - # If we're not in a tool call and don't see a start token, return normal text - if not self._in_tool_call and self.tool_call_start_token not in self._buf: - normal += self._buf - self._buf = "" - break + # Working text slice + current_slice = self._buffer[self.parsed_pos :] - # Look for tool call start - if not self._in_tool_call: - s = self._buf.find(self.tool_call_start_token) - if s == -1: - normal += self._buf - self._buf = "" - break + # Optimization: If almost empty, wait for more + if not current_slice: + break - normal += self._buf[:s] - self._buf = self._buf[s:] + # ------------------------------------------------------- + # 1. Priority detection: check if it's the start of Tool Call + # ------------------------------------------------------- + if current_slice.startswith(self.tool_call_start_token): + self.parsed_pos += len(self.tool_call_start_token) + self.is_inside_tool_call = True + continue - self._in_tool_call = True - self._function_name_sent = False - self._current_function_name = "" - self._current_parameters = {} - self._streamed_parameters = {} + # ------------------------------------------------------- + # 2. Function Name: + # ------------------------------------------------------- + if current_slice.startswith(self.tool_call_prefix): + end_angle = current_slice.find(">") + if end_angle != -1: + func_name = current_slice[len(self.tool_call_prefix) : end_angle] - # Remove the start token - self._buf = self._buf[len(self.tool_call_start_token) :] - continue + self.current_tool_id += 1 + self.current_tool_name_sent = True + self.current_tool_param_count = 0 + self.json_started = False + self.current_func_name = func_name - # We're in a tool call, try to parse function name if not sent yet - if not self._function_name_sent: - # Look for function name pattern: - function_match = re.search(r"]+)>", self._buf) - if function_match: - function_name = function_match.group(1).strip() - - # Validate function name - is_valid = function_name in self._tool_indices - if not is_valid: - logger.warning(f"Invalid function name: {function_name}") - if not envs.SGLANG_FORWARD_UNKNOWN_TOOLS.get(): - # Reset state and skip (default legacy behavior) - self._reset_streaming_state() - normal += self._buf - self._buf = "" - break - - # Process tool call (valid or unknown with env=TRUE) - self._current_function_name = function_name - self._function_name_sent = True - - # Initialize tool call tracking - if self.current_tool_id == -1: - self.current_tool_id = 0 - - # Ensure tracking arrays are large enough - while len(self.prev_tool_call_arr) <= self.current_tool_id: - self.prev_tool_call_arr.append({}) - while len(self.streamed_args_for_tool) <= self.current_tool_id: - self.streamed_args_for_tool.append("") - - # Store tool call info - self.prev_tool_call_arr[self.current_tool_id] = { - "name": function_name, - "arguments": {}, - } - - # Send tool name with empty parameters calls.append( ToolCallItem( tool_index=self.current_tool_id, - name=function_name, + name=func_name, parameters="", ) ) - # Remove the processed function declaration - self._buf = self._buf[function_match.end() :] + self.parsed_pos += end_angle + 1 continue else: - # Function name not complete yet, wait for more text + # Incomplete tag break - # Parse parameters incrementally - if self._function_name_sent: - # Process parameters and get any calls to emit - parameter_calls = self._parse_and_stream_parameters(self._buf) - calls.extend(parameter_calls) - - # Check if tool call is complete - if self.tool_call_end_token in self._buf: - end_pos = self._buf.find(self.tool_call_end_token) - - # Add closing brace to complete the JSON object - current_streamed = self.streamed_args_for_tool[self.current_tool_id] - if current_streamed: - # Count opening and closing braces to check if JSON is complete - open_braces = current_streamed.count("{") - close_braces = current_streamed.count("}") - if open_braces > close_braces: + # ------------------------------------------------------- + # 3. Parameter: value... + # ------------------------------------------------------- + if current_slice.startswith(self.parameter_prefix): + name_end = current_slice.find(">") + if name_end != -1: + value_start_idx = name_end + 1 + rest_of_slice = current_slice[value_start_idx:] + + # A parameter can end in multiple ways: + # 1. [Normal] Encounter + # 2. [Abnormal] Encounter next + # So we need to find the smallest one as the parameter end position. + cand_end_param = rest_of_slice.find(self.parameter_end_token) + cand_next_param = rest_of_slice.find(self.parameter_prefix) + cand_end_func = rest_of_slice.find(self.function_end_token) + + candidates = [] + if cand_end_param != -1: + candidates.append( + (cand_end_param, len(self.parameter_end_token)) + ) + if cand_next_param != -1: + candidates.append((cand_next_param, 0)) + if cand_end_func != -1: + candidates.append((cand_end_func, 0)) + + if candidates: + best_cand = min(candidates, key=lambda x: x[0]) + end_pos = best_cand[0] + end_token_len = best_cand[1] + + param_name = current_slice[ + len(self.parameter_prefix) : name_end + ] + raw_value = rest_of_slice[:end_pos] + + # Cleanup value + if raw_value.startswith("\n"): + raw_value = raw_value[1:] + if raw_value.endswith("\n"): + raw_value = raw_value[:-1] + + # JSON Construction + if not self.json_started: calls.append( ToolCallItem( - tool_index=self.current_tool_id, - name=None, - parameters="}", + tool_index=self.current_tool_id, parameters="{" ) ) - self.streamed_args_for_tool[self.current_tool_id] = ( - current_streamed + "}" - ) - - # Complete the tool call - self._buf = self._buf[end_pos + len(self.tool_call_end_token) :] - self._reset_streaming_state() - self.current_tool_id += 1 - continue - else: - # Tool call not complete yet, wait for more text - break + self.json_started = True - return StreamingParseResult(normal_text=normal, calls=calls) + param_config = self._get_arguments_config( + self.current_func_name, tools + ) + converted_val = self._convert_param_value( + raw_value, param_name, param_config, self.current_func_name + ) - def _parse_and_stream_parameters(self, text_to_parse: str) -> List[ToolCallItem]: - """ - Parse complete parameter blocks from text and return any tool call items to emit. + # Construct JSON fragment: "key": value + # Note: We must be careful with json.dumps to ensure valid JSON streaming + json_key_val = f"{json.dumps(param_name)}: {json.dumps(converted_val, ensure_ascii=False)}" - This method: - 1. Finds all complete blocks - 2. Parses them into a dictionary - 3. Compares with current parameters and generates diff if needed - 4. Updates internal state + if self.current_tool_param_count > 0: + fragment = f", {json_key_val}" + else: + fragment = json_key_val - Args: - text_to_parse: The text to search for parameter blocks + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, parameters=fragment + ) + ) + self.current_tool_param_count += 1 - Returns: - List of ToolCallItem objects to emit (may be empty) - """ - calls: List[ToolCallItem] = [] + # Advance cursor + total_len = (name_end + 1) + end_pos + end_token_len + self.parsed_pos += total_len + continue - # Find all complete parameter patterns - param_matches = list( - re.finditer( - r"]+)>(.*?)", text_to_parse, re.DOTALL - ) - ) + # Incomplete parameter tag or value + break - # Build new parameters dictionary - new_params = {} - for match in param_matches: - param_name = match.group(1).strip() - param_value = match.group(2) - new_params[param_name] = _safe_val(param_value) - - # Calculate parameter diff to stream with proper incremental JSON building - if new_params != self._current_parameters: - previous_args_json = self.streamed_args_for_tool[self.current_tool_id] - - # Build incremental JSON properly - if not self._current_parameters: - # First parameter(s) - start JSON object but don't close it yet - items = [] - for key, value in new_params.items(): - items.append( - f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}" + # ------------------------------------------------------- + # 4. Function End: + # ------------------------------------------------------- + if current_slice.startswith(self.function_end_token): + if not self.json_started: + calls.append( + ToolCallItem(tool_index=self.current_tool_id, parameters="{") ) - json_fragment = "{" + ", ".join(items) + self.json_started = True calls.append( - ToolCallItem( - tool_index=self.current_tool_id, - name=None, - parameters=json_fragment, - ) + ToolCallItem(tool_index=self.current_tool_id, parameters="}") ) - self.streamed_args_for_tool[self.current_tool_id] = json_fragment + self.parsed_pos += len(self.function_end_token) + self.current_func_name = None + continue - else: - # Additional parameters - add them incrementally - new_keys = set(new_params.keys()) - set(self._current_parameters.keys()) - if new_keys: - # Build the continuation part (no closing brace yet) - continuation_parts = [] - for key in new_keys: - value = new_params[key] - continuation_parts.append( - f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}" - ) + # ------------------------------------------------------- + # 5. Tool Call End: + # ------------------------------------------------------- + if current_slice.startswith(self.tool_call_end_token): + self.parsed_pos += len(self.tool_call_end_token) + self.is_inside_tool_call = False # [FIX] Exit tool call region + continue - json_fragment = ", " + ", ".join(continuation_parts) + # ------------------------------------------------------- + # 6. Handling content / whitespace / normal text + # ------------------------------------------------------- + # If current position is not the start of a tag (i.e., doesn't start with <), it might be plain text, + # or a newline between two tags. + # But we need to be careful not to output truncated tags like " Tuple[str, List[ToolCallItem]]: - normal_parts: List[str] = [] - calls: List[ToolCallItem] = [] - cursor = 0 - while True: - s = text.find(self.tool_call_start_token, cursor) - if s == -1: - normal_parts.append(text[cursor:]) - break - normal_parts.append(text[cursor:s]) - e = text.find(self.tool_call_end_token, s) - if e == -1: - normal_parts.append(text[s:]) - break - block = text[s : e + len(self.tool_call_end_token)] - cursor = e + len(self.tool_call_end_token) - calls.extend(self._parse_block(block, tools)) - return "".join(normal_parts), calls - - def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]: - res: List[ToolCallItem] = [] - for m in self.tool_call_function_regex.findall(block): - txt = m[0] if m[0] else m[1] - if ">" not in txt: + else: + # '<' is in the middle + text_segment = current_slice[:next_open_angle] + if not self.is_inside_tool_call: + normal_text_chunks.append(text_segment) + # [FIX] If inside tool call, discard whitespace/text before Tag + self.parsed_pos += next_open_angle continue - idx = txt.index(">") - fname = txt[:idx].strip() - body = txt[idx + 1 :] - params: Dict[str, Any] = {} - for pm in self.tool_call_parameter_regex.findall(body): - ptxt = pm[0] if pm[0] else pm[1] - if ">" not in ptxt: - continue - pidx = ptxt.index(">") - pname = ptxt[:pidx].strip() - pval = ptxt[pidx + 1 :].lstrip("\n").rstrip("\n") - params[pname] = _safe_val(pval) - raw = {"name": fname, "arguments": params} - try: - # TODO: fix idx in function call, the index for a function - # call will always be -1 in parse_base_json - res.extend(self.parse_base_json(raw, tools)) - except Exception: - logger.warning("invalid tool call for %s dropped", fname) - return res + + # Memory Cleanup: Slice the buffer + # Keep unparsed part, discard parsed part + if self.parsed_pos > 0: + self._buffer = self._buffer[self.parsed_pos :] + self.parsed_pos = 0 + + normal_text = "".join(normal_text_chunks) if normal_text_chunks else "" + return StreamingParseResult(calls=calls, normal_text=normal_text) def supports_structural_tag(self) -> bool: return False diff --git a/test/registered/function_call/test_function_call_parser.py b/test/registered/function_call/test_function_call_parser.py index e2bb53f8b33a..0e855e7636a2 100644 --- a/test/registered/function_call/test_function_call_parser.py +++ b/test/registered/function_call/test_function_call_parser.py @@ -1613,507 +1613,356 @@ def test_streaming_no_parameters_with_whitespace(self): class TestQwen3CoderDetector(unittest.TestCase): + """Test suite for Qwen3CoderDetector.""" + def setUp(self): - # Create sample tools for testing + """Initialize test fixtures before each test method.""" self.tools = [ Tool( type="function", function=Function( name="get_current_weather", - description="Get the current weather", parameters={ + "type": "object", "properties": { - "city": {"type": "string", "description": "The city name"}, - "state": { - "type": "string", - "description": "The state code", - }, + "location": {"type": "string"}, "unit": { "type": "string", - "enum": ["fahrenheit", "celsius"], + "enum": ["celsius", "fahrenheit"], }, + "days": {"type": "integer"}, }, - "required": ["city", "state"], + "required": ["location"], }, ), ), Tool( type="function", function=Function( - name="calculate_area", - description="Calculate area of a shape", + name="sql_interpreter", parameters={ + "type": "object", "properties": { - "shape": {"type": "string"}, - "dimensions": {"type": "object"}, - "precision": {"type": "integer"}, - } + "query": {"type": "string"}, + "dry_run": {"type": "boolean"}, + }, + }, + ), + ), + Tool( + type="function", + function=Function( + name="TodoWrite", + parameters={ + "type": "object", + "properties": { + "todos": { + "type": "array", + "items": { + "type": "object", + "properties": { + "content": {"type": "string"}, + "status": {"type": "string"}, + }, + "required": ["content", "status"], + }, + }, + }, }, ), ), ] self.detector = Qwen3CoderDetector() - def test_has_tool_call(self): - """Test detection of tool call markers.""" - self.assertTrue(self.detector.has_tool_call("test")) - self.assertFalse(self.detector.has_tool_call("No tool call here")) + # ==================== Basic Functionality Tests ==================== - def test_detect_and_parse_no_tools(self): - """Test parsing text without tool calls.""" - model_output = "This is a test response without any tool calls" - result = self.detector.detect_and_parse(model_output, tools=[]) - self.assertEqual(result.normal_text, model_output) - self.assertEqual(result.calls, []) + def test_plain_text_only(self): + """ + Test parsing of plain text without any tool calls. - def test_detect_and_parse_single_tool(self): - """Test parsing a single tool call.""" - model_output = """ + Scenario: Input contains only plain text, no tool call markers. + Purpose: Verify that plain text is correctly identified and no false tool calls are detected. + """ + text = "This is plain text without any tool calls." + result = self.detector.detect_and_parse(text, self.tools) + + self.assertEqual(result.normal_text, text) + self.assertEqual(len(result.calls), 0) + + def test_single_tool_call(self): + """ + Test parsing of a single tool call. + + Scenario: Input contains one complete tool call with parameters. + Purpose: Verify correct extraction of tool name and parameters. + """ + text = """ - -Dallas - - -TX - - -fahrenheit - +Boston +celsius +3 """ + result = self.detector.detect_and_parse(text, self.tools) - result = self.detector.detect_and_parse(model_output, tools=self.tools) - - self.assertEqual(result.normal_text, "") self.assertEqual(len(result.calls), 1) self.assertEqual(result.calls[0].name, "get_current_weather") params = json.loads(result.calls[0].parameters) - self.assertEqual(params["city"], "Dallas") - self.assertEqual(params["state"], "TX") - self.assertEqual(params["unit"], "fahrenheit") - - def test_detect_and_parse_with_content(self): - """Test parsing tool call with surrounding text.""" - model_output = """Sure! Let me check the weather for you. - - -Dallas - - -TX - - -fahrenheit - - -""" + self.assertEqual(params["location"], "Boston") + self.assertEqual(params["unit"], "celsius") + self.assertEqual(params["days"], 3) - result = self.detector.detect_and_parse(model_output, tools=self.tools) + def test_single_tool_call_with_text_prefix(self): + """ + Test parsing of tool call with preceding text. - self.assertEqual(result.normal_text, "Sure! Let me check the weather for you.") - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_current_weather") + Scenario: Input has plain text followed by a tool call. + Purpose: Verify correct separation of text and tool call. + """ + text = """Let me check the weather for you. - def test_detect_and_parse_multiline_param(self): - """Test parsing tool call with multiline parameter values.""" - model_output = """ - - -rectangle - - -{"width": 10, - "height": 20} - - -2 - + + +New York """ + result = self.detector.detect_and_parse(text, self.tools) - result = self.detector.detect_and_parse(model_output, tools=self.tools) - + self.assertTrue(result.normal_text.startswith("Let me check")) self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "calculate_area") + self.assertEqual(result.calls[0].name, "get_current_weather") - params = json.loads(result.calls[0].parameters) - self.assertEqual(params["shape"], "rectangle") - self.assertEqual(params["dimensions"], {"width": 10, "height": 20}) - self.assertEqual(params["precision"], 2) + def test_multiple_tool_calls(self): + """ + Test parsing of multiple consecutive tool calls. - def test_detect_and_parse_parallel_tools(self): - """Test parsing multiple tool calls.""" - model_output = """ + Scenario: Input contains two tool calls one after another. + Purpose: Verify that multiple tool calls are correctly identified and parsed. + """ + text = """ - -Dallas - - -TX - - -fahrenheit - +New York - - -Orlando - - -FL - - -fahrenheit - + +SELECT * FROM users +True """ + result = self.detector.detect_and_parse(text, self.tools) - result = self.detector.detect_and_parse(model_output, tools=self.tools) - - self.assertEqual(result.normal_text, "\n") self.assertEqual(len(result.calls), 2) - - # First call self.assertEqual(result.calls[0].name, "get_current_weather") + self.assertEqual(result.calls[1].name, "sql_interpreter") + params1 = json.loads(result.calls[0].parameters) - self.assertEqual(params1["city"], "Dallas") - self.assertEqual(params1["state"], "TX") + self.assertEqual(params1["location"], "New York") - # Second call - self.assertEqual(result.calls[1].name, "get_current_weather") params2 = json.loads(result.calls[1].parameters) - self.assertEqual(params2["city"], "Orlando") - self.assertEqual(params2["state"], "FL") + self.assertEqual(params2["query"], "SELECT * FROM users") + self.assertEqual(params2["dry_run"], True) + + # ==================== Streaming Tests ==================== + + def test_streaming_single_tool_call(self): + """ + Test streaming parsing of a single tool call. - def test_parse_streaming_simple(self): - """Test basic streaming parsing.""" + Scenario: Tool call is fed incrementally in chunks. + Purpose: Verify streaming parser correctly assembles tool call from chunks. + """ chunks = [ - "Sure! ", - "Let me check ", - "the weather.", "", - "\n", - "\n", - "\nDallas", - "\n", - "\n", - "\nTX", - "\n", - "\n", - "\n", + "", + "", + "Boston", + "", + "celsius", + "", + "", ] - accumulated_text = "" - accumulated_calls = [] - tool_calls_by_index = {} + detector = Qwen3CoderDetector() + all_calls = [] + collected_params = "" for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, tools=self.tools) - accumulated_text += result.normal_text - - # Track calls by tool_index to handle streaming properly + result = detector.parse_streaming_increment(chunk, self.tools) + all_calls.extend(result.calls) for call in result.calls: - if call.tool_index is not None: - if call.tool_index not in tool_calls_by_index: - tool_calls_by_index[call.tool_index] = { - "name": "", - "parameters": "", - } - - if call.name: - tool_calls_by_index[call.tool_index]["name"] = call.name - if call.parameters: - tool_calls_by_index[call.tool_index][ - "parameters" - ] += call.parameters + if call.parameters: + collected_params += call.parameters - self.assertEqual(accumulated_text, "Sure! Let me check the weather.") - self.assertEqual(len(tool_calls_by_index), 1) + # Verify we got the tool call + self.assertGreater(len(all_calls), 0) - # Get the complete tool call - tool_call = tool_calls_by_index[0] - self.assertEqual(tool_call["name"], "get_current_weather") + # Verify parameters were collected + if collected_params: + params = json.loads(collected_params) + self.assertEqual(params["location"], "Boston") + self.assertEqual(params["unit"], "celsius") - # Parse the accumulated parameters - params = json.loads(tool_call["parameters"]) - self.assertEqual(params["city"], "Dallas") - self.assertEqual(params["state"], "TX") + def test_streaming_with_text_and_tool(self): + """ + Test streaming parsing with mixed text and tool call. - def test_parse_streaming_incomplete(self): - """Test streaming with incomplete tool call.""" - # Send incomplete tool call + Scenario: Stream contains plain text followed by a tool call. + Purpose: Verify correct separation in streaming mode. + """ chunks = [ + "Let me ", + "help you.\n\n", "", - "\n", - "\n", - "\nDallas", - "\n", - "\n", - "\nTX", - # Missing , , + "", + "Paris", + "", + "", ] - tool_calls_by_index = {} - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, tools=self.tools) + detector = Qwen3CoderDetector() + full_text = "" + all_calls = [] - # Track calls by tool_index to handle streaming properly - for call in result.calls: - if call.tool_index is not None: - if call.tool_index not in tool_calls_by_index: - tool_calls_by_index[call.tool_index] = { - "name": "", - "parameters": "", - } + for chunk in chunks: + result = detector.parse_streaming_increment(chunk, self.tools) + if result.normal_text: + full_text += result.normal_text + all_calls.extend(result.calls) - if call.name: - tool_calls_by_index[call.tool_index]["name"] = call.name - if call.parameters: - tool_calls_by_index[call.tool_index][ - "parameters" - ] += call.parameters + self.assertTrue(full_text.startswith("Let me")) + self.assertGreater(len(all_calls), 0) - # Should have partial tool call with name but incomplete parameters - self.assertGreater(len(tool_calls_by_index), 0) - self.assertEqual(tool_calls_by_index[0]["name"], "get_current_weather") + # ==================== Parameter Type Tests ==================== - # Parameters should be incomplete (no closing brace) - params_str = tool_calls_by_index[0]["parameters"] - self.assertTrue(params_str.startswith('{"city": "Dallas"')) - self.assertFalse(params_str.endswith("}")) + def test_integer_parameter_conversion(self): + """ + Test correct type conversion for integer parameters. - # Now complete it - result = self.detector.parse_streaming_increment( - "\n\n\n", tools=self.tools - ) + Scenario: Tool call with integer parameter. + Purpose: Verify integer values are correctly parsed and typed. + """ + text = """ + +Tokyo +5 + +""" + result = self.detector.detect_and_parse(text, self.tools) - # Update the accumulated parameters - for call in result.calls: - if call.tool_index is not None and call.parameters: - tool_calls_by_index[call.tool_index]["parameters"] += call.parameters + params = json.loads(result.calls[0].parameters) + self.assertIsInstance(params["days"], int) + self.assertEqual(params["days"], 5) - # Now should have complete parameters - final_params = json.loads(tool_calls_by_index[0]["parameters"]) - self.assertEqual(final_params["city"], "Dallas") - self.assertEqual(final_params["state"], "TX") + def test_boolean_parameter_conversion(self): + """ + Test correct type conversion for boolean parameters. - def test_edge_case_no_parameters(self): - """Test tool call without parameters.""" - model_output = """ - + Scenario: Tool call with boolean parameter. + Purpose: Verify boolean values are correctly parsed. + """ + text = """ + +SELECT 1 +True """ + result = self.detector.detect_and_parse(text, self.tools) - result = self.detector.detect_and_parse(model_output, tools=self.tools) - self.assertEqual(len(result.calls), 1) - self.assertEqual(result.calls[0].name, "get_current_weather") - self.assertEqual(json.loads(result.calls[0].parameters), {}) + params = json.loads(result.calls[0].parameters) + self.assertIsInstance(params["dry_run"], bool) + self.assertEqual(params["dry_run"], True) - def test_edge_case_special_chars_in_value(self): - """Test parameter with special characters in value.""" - model_output = """ - - -Dallas->TX + def test_complex_array_parameter(self): + """ + Test parsing of complex array parameters. + + Scenario: Tool call with array of objects as parameter. + Purpose: Verify complex nested structures are correctly parsed. + """ + text = """ + + +[ + {"content": "Buy groceries", "status": "pending"}, + {"content": "Finish report", "status": "completed"} +] """ - - result = self.detector.detect_and_parse(model_output, tools=self.tools) - self.assertEqual(len(result.calls), 1) + result = self.detector.detect_and_parse(text, self.tools) params = json.loads(result.calls[0].parameters) - self.assertEqual(params["city"], "Dallas->TX") + self.assertIsInstance(params["todos"], list) + self.assertEqual(len(params["todos"]), 2) + self.assertEqual(params["todos"][0]["content"], "Buy groceries") + self.assertEqual(params["todos"][1]["status"], "completed") - def test_extract_tool_calls_fallback_no_tags(self): - """Test fallback parsing when XML tags are missing (just function without tool_call wrapper).""" - model_output = """ - -Dallas - - -TX - -""" - - result = self.detector.detect_and_parse(model_output, tools=self.tools) - - self.assertIsNotNone(result) - - def test_extract_tool_calls_type_conversion(self): - """Test parameter type conversion based on tool schema.""" - test_tool = Tool( - type="function", - function=Function( - name="test_types", - parameters={ - "type": "object", - "properties": { - "int_param": {"type": "integer"}, - "float_param": {"type": "float"}, - "bool_param": {"type": "boolean"}, - "str_param": {"type": "string"}, - "obj_param": {"type": "object"}, - }, - }, - ), - ) + # ==================== Edge Cases ==================== - model_output = """ - - -42 - - -3.14 - - -true - - -hello world - - -{"key": "value"} - + def test_empty_parameter_value(self): + """ + Test handling of empty parameter values. + + Scenario: Tool call with empty parameter value. + Purpose: Verify empty values are handled gracefully. + """ + text = """ + + """ - - result = self.detector.detect_and_parse(model_output, tools=[test_tool]) + result = self.detector.detect_and_parse(text, self.tools) self.assertEqual(len(result.calls), 1) params = json.loads(result.calls[0].parameters) - self.assertEqual(params["int_param"], 42) - self.assertEqual(params["float_param"], 3.14) - self.assertEqual(params["bool_param"], True) - self.assertEqual(params["str_param"], "hello world") - self.assertEqual(params["obj_param"], {"key": "value"}) - - def test_parse_streaming_incremental(self): - """Test that streaming is truly incremental with very small chunks.""" - model_output = """I'll check the weather. - - - Dallas - - - TX - - - """ - - # Simulate more realistic token-based chunks where is a single token - chunks = [ - "I'll check the weather.", - "", - "\n\n", - "\n", - "Dallas\n", - "\n", - "\n", - "TX\n", - "\n", - "\n", - "", - ] + self.assertEqual(params["location"], "") - accumulated_text = "" - tool_calls = [] - chunks_count = 0 - - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - accumulated_text += result.normal_text - chunks_count += 1 - for tool_call_chunk in result.calls: - if ( - hasattr(tool_call_chunk, "tool_index") - and tool_call_chunk.tool_index is not None - ): - while len(tool_calls) <= tool_call_chunk.tool_index: - tool_calls.append({"name": "", "parameters": ""}) - tc = tool_calls[tool_call_chunk.tool_index] - if tool_call_chunk.name: - tc["name"] = tool_call_chunk.name - if tool_call_chunk.parameters: - tc["parameters"] += tool_call_chunk.parameters + def test_parameter_with_special_characters(self): + """ + Test handling of parameters with special characters. - self.assertGreater(chunks_count, 3) + Scenario: Parameter value contains special characters like quotes, newlines. + Purpose: Verify special characters are correctly preserved. + """ + text = """ + +SELECT * FROM users WHERE name = 'John "Doe"' + +""" + result = self.detector.detect_and_parse(text, self.tools) - # Verify the accumulated results - self.assertIn("I'll check the weather.", accumulated_text) - self.assertEqual(len(tool_calls), 1) - self.assertEqual(tool_calls[0]["name"], "get_current_weather") - - params = json.loads(tool_calls[0]["parameters"]) - self.assertEqual(params, {"city": "Dallas", "state": "TX"}) - - def test_parse_streaming_multiple_tools(self): - """Test streaming with multiple tool calls.""" - model_output = """ - - - Dallas - - - TX - - - - Some text in between. - - - - circle - - - {"radius": 5} - - - """ - - # Simulate streaming by chunks - chunk_size = 20 - chunks = [ - model_output[i : i + chunk_size] - for i in range(0, len(model_output), chunk_size) - ] + params = json.loads(result.calls[0].parameters) + self.assertIn("John", params["query"]) + self.assertIn("Doe", params["query"]) - accumulated_text = "" - tool_calls = [] - chunks_count = 0 + def test_incomplete_tool_call(self): + """ + Test handling of incomplete tool call at end of stream. - for chunk in chunks: - result = self.detector.parse_streaming_increment(chunk, self.tools) - accumulated_text += result.normal_text - chunks_count += 1 - for tool_call_chunk in result.calls: - if ( - hasattr(tool_call_chunk, "tool_index") - and tool_call_chunk.tool_index is not None - ): - while len(tool_calls) <= tool_call_chunk.tool_index: - tool_calls.append({"name": "", "parameters": ""}) - tc = tool_calls[tool_call_chunk.tool_index] - if tool_call_chunk.name: - tc["name"] = tool_call_chunk.name - if tool_call_chunk.parameters: - tc["parameters"] += tool_call_chunk.parameters + Scenario: Stream ends with an incomplete tool call (missing closing tag). + Purpose: Verify detector handles incomplete input gracefully without crashing. + """ + text = """ + +London""" - self.assertIn("Some text in between.", accumulated_text) - self.assertEqual(len(tool_calls), 2) - self.assertEqual(tool_calls[0]["name"], "get_current_weather") - self.assertEqual(tool_calls[1]["name"], "calculate_area") + # Should not crash + result = self.detector.detect_and_parse(text, self.tools) + self.assertIsInstance(result, StreamingParseResult) - # Verify parameters - params1 = json.loads(tool_calls[0]["parameters"]) - self.assertEqual(params1, {"city": "Dallas", "state": "TX"}) + def test_has_tool_call_detection(self): + """ + Test the has_tool_call method for detecting tool call markers. - params2 = json.loads(tool_calls[1]["parameters"]) - self.assertEqual(params2, {"shape": "circle", "dimensions": {"radius": 5}}) + Scenario: Various inputs with and without tool call markers. + Purpose: Verify correct detection of tool call presence. + """ + self.assertTrue(self.detector.has_tool_call("")) + self.assertTrue(self.detector.has_tool_call("text more")) + self.assertFalse(self.detector.has_tool_call("plain text only")) + self.assertFalse(self.detector.has_tool_call("")) class TestGlm4MoeDetector(unittest.TestCase):