diff --git a/vllm/tool_parsers/llama4_pythonic_tool_parser.py b/vllm/tool_parsers/llama4_pythonic_tool_parser.py index 707cdd6625c7..93807196dd67 100644 --- a/vllm/tool_parsers/llama4_pythonic_tool_parser.py +++ b/vllm/tool_parsers/llama4_pythonic_tool_parser.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import ast -import json from collections.abc import Sequence -from typing import Any import regex as re from transformers import PreTrainedTokenizerBase @@ -13,25 +12,23 @@ ChatCompletionRequest, ) from vllm.entrypoints.openai.engine.protocol import ( - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, ExtractedToolCallInformation, - FunctionCall, - ToolCall, ) from vllm.logger import init_logger from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) +from vllm.tool_parsers.utils import ( + UnexpectedAstError, + compute_tool_delta, + handle_single_tool, + make_valid_python, +) logger = init_logger(__name__) -class _UnexpectedAstError(Exception): - pass - - class Llama4PythonicToolParser(ToolParser): """ Toolcall parser for Llama4 that produce tool calls in a pythonic style @@ -103,15 +100,13 @@ def extract_tool_calls( return ExtractedToolCallInformation( tools_called=True, tool_calls=[ - _handle_single_tool(e) # type: ignore + handle_single_tool(e) # type: ignore for e in parsed.elts ], content=None, ) else: - raise _UnexpectedAstError( - "Tool output must be a list of function calls" - ) + raise UnexpectedAstError("Tool output must be a list of function calls") except Exception: logger.exception("Error in extracting tool call from response.") # Treat as regular text @@ -140,7 +135,7 @@ def extract_tool_calls_streaming( current_text = current_text[len("<|python_start|>") :] if current_text.endswith("<|python_end|>"): current_text = current_text[: current_text.rfind("<|python_end|>")] - valid_and_added_text = _make_valid_python(current_text) + valid_and_added_text = make_valid_python(current_text) if valid_and_added_text is None: return None valid_text, added_text = valid_and_added_text @@ -150,11 +145,9 @@ def extract_tool_calls_streaming( if not isinstance(parsed, ast.List) or not all( isinstance(e, ast.Call) for e in parsed.elts ): - raise _UnexpectedAstError( - "Tool output must be a list of function calls" - ) + raise UnexpectedAstError("Tool output must be a list of function calls") tool_calls = [ - _handle_single_tool(e) # type: ignore + handle_single_tool(e) # type: ignore for e in parsed.elts ] @@ -180,7 +173,7 @@ def extract_tool_calls_streaming( # Strings get single quotes in the model-produced string. # JSON requires double quotes. withheld_suffix = withheld_suffix.replace("'", '"') - delta = _compute_tool_delta( + delta = compute_tool_delta( self.streamed_args_for_tool[index], new_call, index, withheld_suffix ) @@ -214,130 +207,3 @@ def extract_tool_calls_streaming( "Skipping chunk as a result of tool streaming extraction error" ) return None - - -def _get_parameter_value(val: ast.expr) -> Any: - if isinstance(val, ast.Constant): - return val.value - elif isinstance(val, ast.Dict): - if not all(isinstance(k, ast.Constant) for k in val.keys): - raise _UnexpectedAstError("Dict tool call arguments must have literal keys") - return { - k.value: _get_parameter_value(v) # type: ignore - for k, v in zip(val.keys, val.values) - } - elif isinstance(val, ast.List): - return [_get_parameter_value(v) for v in val.elts] - else: - raise _UnexpectedAstError("Tool call arguments must be literals") - - -def _handle_single_tool(call: ast.Call) -> ToolCall: - if not isinstance(call.func, ast.Name): - raise _UnexpectedAstError("Invalid tool call name") - function_name = call.func.id - arguments = {} - for keyword in call.keywords: - arguments[keyword.arg] = _get_parameter_value(keyword.value) - return ToolCall( - type="function", - function=FunctionCall(name=function_name, arguments=json.dumps(arguments)), - ) - - -def _make_valid_python(text: str) -> tuple[str, str] | None: - bracket_stack = [] - for index, char in enumerate(text): - if char in {"[", "(", "{"}: - bracket_stack.append(char) - elif char == "]": - if not bracket_stack or bracket_stack.pop() != "[": - raise _UnexpectedAstError("Mismatched square brackets") - elif char == ")": - if not bracket_stack or bracket_stack.pop() != "(": - raise _UnexpectedAstError("Mismatched parentheses") - elif char == "}": - if not bracket_stack or bracket_stack.pop() != "{": - raise _UnexpectedAstError("Mismatched curly braces") - elif char in {"'", '"'}: - if bracket_stack and bracket_stack[-1] == char: - if index > 0 and text[index - 1] == "\\": - # Treat an escaped quote as a regular character - pass - else: - bracket_stack.pop() - elif bracket_stack and bracket_stack[-1] in {"'", '"'}: - # Double quote within a single quote string or vice versa. - pass - else: - bracket_stack.append(char) - - text = text.rstrip() - if text.endswith("=") or text.endswith(":"): - # Since we have no type information for this property/parameter value, - # we can't fill in a valid value. - return None - if bracket_stack and bracket_stack[-1] == "{": - trailing_dict_text = text[: text.rfind("{")] - num_keys = trailing_dict_text.count(":") - num_values = trailing_dict_text.count(",") - if num_keys <= num_values: - return None # Incomplete property name within parameter value - if bracket_stack and bracket_stack[-1] == "(": - trailing_params_text = text[: text.rfind("(")] - num_full_param_names = trailing_params_text.count("=") - num_full_param_values = trailing_params_text.count(",") - if num_full_param_names <= num_full_param_values: - return None # Incomplete parameter name - if text.endswith(","): - text = text[:-1] - if ( - bracket_stack - and bracket_stack[-1] == "[" - and not text.endswith("[") - and not text.endswith(")") - ): - return None # Incomplete function name - - added_text = "" - for char in reversed(bracket_stack): - if char == "[": - added_text += "]" - elif char == "(": - added_text += ")" - elif char == "{": - added_text += "}" - elif char == "'": - added_text += "'" - elif char == '"': - added_text += '"' - - return text + added_text, added_text - - -def _compute_tool_delta( - previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str -) -> DeltaToolCall | None: - new_call_args = new_call.function.arguments - if withheld_suffix: - assert new_call_args.endswith(withheld_suffix) - new_call_args = new_call_args[: -len(withheld_suffix)] - if not previously_sent_args: - return DeltaToolCall( - id=new_call.id, - type="function", - index=index, - function=DeltaFunctionCall( - name=new_call.function.name, - arguments=new_call_args, - ), - ) - - arg_diff = new_call_args[len(previously_sent_args) :] - return ( - DeltaToolCall( - id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff) - ) - if arg_diff - else None - ) diff --git a/vllm/tool_parsers/olmo3_tool_parser.py b/vllm/tool_parsers/olmo3_tool_parser.py index 7b0d609d51df..dd63b108635c 100644 --- a/vllm/tool_parsers/olmo3_tool_parser.py +++ b/vllm/tool_parsers/olmo3_tool_parser.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import ast -import json from collections.abc import Sequence -from typing import Any import regex as re from transformers import PreTrainedTokenizerBase @@ -13,25 +12,23 @@ ChatCompletionRequest, ) from vllm.entrypoints.openai.engine.protocol import ( - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, ExtractedToolCallInformation, - FunctionCall, - ToolCall, ) from vllm.logger import init_logger from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) +from vllm.tool_parsers.utils import ( + UnexpectedAstError, + compute_tool_delta, + handle_single_tool, + make_valid_python, +) logger = init_logger(__name__) -class _UnexpectedAstError(Exception): - pass - - class Olmo3PythonicToolParser(ToolParser): """ Tool call parser for Olmo 3 models that produce tool calls as @@ -113,15 +110,13 @@ def extract_tool_calls( return ExtractedToolCallInformation( tools_called=True, tool_calls=[ - _handle_single_tool(e) # type: ignore + handle_single_tool(e) # type: ignore for e in parsed.elts ], content=None, ) else: - raise _UnexpectedAstError( - "Tool output must be a list of function calls" - ) + raise UnexpectedAstError("Tool output must be a list of function calls") except Exception: logger.exception("Error in extracting tool call from response.") # Treat as regular text @@ -151,7 +146,7 @@ def extract_tool_calls_streaming( if current_text.endswith(""): current_text = current_text[: -len("")] - valid_and_added_text = _make_valid_python(current_text) + valid_and_added_text = make_valid_python(current_text) if valid_and_added_text is None: return None valid_text, added_text = valid_and_added_text @@ -166,11 +161,11 @@ def extract_tool_calls_streaming( if not isinstance(parsed, ast.List) or not all( isinstance(e, ast.Call) for e in parsed.elts ): - raise _UnexpectedAstError( + raise UnexpectedAstError( "Tool output must be a sequence of newline-separated calls" ) tool_calls = [ - _handle_single_tool(e) # type: ignore + handle_single_tool(e) # type: ignore for e in parsed.elts ] @@ -194,7 +189,7 @@ def extract_tool_calls_streaming( # Strings get single quotes in the model-produced string. # JSON requires double quotes. withheld_suffix = withheld_suffix.replace("'", '"') - delta = _compute_tool_delta( + delta = compute_tool_delta( self.streamed_args_for_tool[index], new_call, index, withheld_suffix ) @@ -228,141 +223,3 @@ def extract_tool_calls_streaming( "Skipping chunk as a result of tool streaming extraction error" ) return None - - -def _get_parameter_value(val: ast.expr) -> Any: - if isinstance(val, ast.Constant): - return val.value - elif isinstance(val, ast.Dict): - if not all(isinstance(k, ast.Constant) for k in val.keys): - raise _UnexpectedAstError("Dict tool call arguments must have literal keys") - return { - k.value: _get_parameter_value(v) # type: ignore - for k, v in zip(val.keys, val.values) - } - elif isinstance(val, ast.List): - return [_get_parameter_value(v) for v in val.elts] - # The model may return function calls where the values are null/true/false - # because the system prompt has API description in json. - elif isinstance(val, ast.Name) and val.id in ["null", "true", "false"]: - if val.id == "null": - return None - elif val.id == "true": - return True - elif val.id == "false": - return False - else: - raise _UnexpectedAstError("Tool call arguments must be literals") - - -def _handle_single_tool(call: ast.Call) -> ToolCall: - if not isinstance(call.func, ast.Name): - raise _UnexpectedAstError("Invalid tool call name") - function_name = call.func.id - arguments = {} - for keyword in call.keywords: - arguments[keyword.arg] = _get_parameter_value(keyword.value) - return ToolCall( - type="function", - function=FunctionCall( - name=function_name, arguments=json.dumps(arguments, ensure_ascii=False) - ), - ) - - -def _make_valid_python(text: str) -> tuple[str, str] | None: - bracket_stack = [] - for index, char in enumerate(text): - if char in {"[", "(", "{"}: - bracket_stack.append(char) - elif char == "]": - if not bracket_stack or bracket_stack.pop() != "[": - raise _UnexpectedAstError("Mismatched square brackets") - elif char == ")": - if not bracket_stack or bracket_stack.pop() != "(": - raise _UnexpectedAstError("Mismatched parentheses") - elif char == "}": - if not bracket_stack or bracket_stack.pop() != "{": - raise _UnexpectedAstError("Mismatched curly braces") - elif char in {"'", '"'}: - if bracket_stack and bracket_stack[-1] == char: - if index > 0 and text[index - 1] == "\\": - # Treat an escaped quote as a regular character - pass - else: - bracket_stack.pop() - elif bracket_stack and bracket_stack[-1] in {"'", '"'}: - # Double quote within a single quote string or vice versa. - pass - else: - bracket_stack.append(char) - - text = text.rstrip() - if text.endswith("=") or text.endswith(":"): - # Since we have no type information for this property/parameter value, - # we can't fill in a valid value. - return None - if bracket_stack and bracket_stack[-1] == "{": - trailing_dict_text = text[: text.rfind("{")] - num_keys = trailing_dict_text.count(":") - num_values = trailing_dict_text.count(",") - if num_keys <= num_values: - return None # Incomplete property name within parameter value - if bracket_stack and bracket_stack[-1] == "(": - trailing_params_text = text[: text.rfind("(")] - num_full_param_names = trailing_params_text.count("=") - num_full_param_values = trailing_params_text.count(",") - if num_full_param_names <= num_full_param_values: - return None # Incomplete parameter name - if text.endswith(","): - text = text[:-1] - if ( - bracket_stack - and bracket_stack[-1] == "[" - and not text.endswith("[") - and not text.endswith(")") - ): - return None # Incomplete function name - - added_text = "" - for char in reversed(bracket_stack): - if char == "[": - added_text += "]" - elif char == "(": - added_text += ")" - elif char == "{": - added_text += "}" - elif char == "'": - added_text += "'" - elif char == '"': - added_text += '"' - - return text + added_text, added_text - - -def _compute_tool_delta( - previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str -) -> DeltaToolCall | None: - new_call_args = new_call.function.arguments - if withheld_suffix: - assert new_call_args.endswith(withheld_suffix) - new_call_args = new_call_args[: -len(withheld_suffix)] - if not previously_sent_args: - return DeltaToolCall( - id=new_call.id, - type="function", - index=index, - function=DeltaFunctionCall( - name=new_call.function.name, - arguments=new_call_args, - ), - ) - - arg_diff = new_call_args[len(previously_sent_args) :] - return ( - DeltaToolCall( - id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff) - ) - if arg_diff - else None - ) diff --git a/vllm/tool_parsers/pythonic_tool_parser.py b/vllm/tool_parsers/pythonic_tool_parser.py index dc9926608e60..9c9f3e183d34 100644 --- a/vllm/tool_parsers/pythonic_tool_parser.py +++ b/vllm/tool_parsers/pythonic_tool_parser.py @@ -2,9 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast -import json from collections.abc import Sequence -from typing import Any import regex as re from transformers import PreTrainedTokenizerBase @@ -14,25 +12,23 @@ ChatCompletionRequest, ) from vllm.entrypoints.openai.engine.protocol import ( - DeltaFunctionCall, DeltaMessage, - DeltaToolCall, ExtractedToolCallInformation, - FunctionCall, - ToolCall, ) from vllm.logger import init_logger from vllm.tool_parsers.abstract_tool_parser import ( ToolParser, ) +from vllm.tool_parsers.utils import ( + UnexpectedAstError, + compute_tool_delta, + handle_single_tool, + make_valid_python, +) logger = init_logger(__name__) -class _UnexpectedAstError(Exception): - pass - - class PythonicToolParser(ToolParser): """ Tool call parser for models that produce tool calls in a pythonic style, @@ -99,15 +95,13 @@ def extract_tool_calls( return ExtractedToolCallInformation( tools_called=True, tool_calls=[ - _handle_single_tool(e) # type: ignore + handle_single_tool(e) # type: ignore for e in parsed.elts ], content=None, ) else: - raise _UnexpectedAstError( - "Tool output must be a list of function calls" - ) + raise UnexpectedAstError("Tool output must be a list of function calls") except Exception: logger.exception("Error in extracting tool call from response.") # Treat as regular text @@ -129,7 +123,7 @@ def extract_tool_calls_streaming( return DeltaMessage(content=delta_text) try: - valid_and_added_text = _make_valid_python(current_text) + valid_and_added_text = make_valid_python(current_text) if valid_and_added_text is None: return None valid_text, added_text = valid_and_added_text @@ -139,11 +133,9 @@ def extract_tool_calls_streaming( if not isinstance(parsed, ast.List) or not all( isinstance(e, ast.Call) for e in parsed.elts ): - raise _UnexpectedAstError( - "Tool output must be a list of function calls" - ) + raise UnexpectedAstError("Tool output must be a list of function calls") tool_calls = [ - _handle_single_tool(e) # type: ignore + handle_single_tool(e) # type: ignore for e in parsed.elts ] @@ -169,7 +161,7 @@ def extract_tool_calls_streaming( # Strings get single quotes in the model-produced string. # JSON requires double quotes. withheld_suffix = withheld_suffix.replace("'", '"') - delta = _compute_tool_delta( + delta = compute_tool_delta( self.streamed_args_for_tool[index], new_call, index, withheld_suffix ) @@ -203,132 +195,3 @@ def extract_tool_calls_streaming( "Skipping chunk as a result of tool streaming extraction error" ) return None - - -def _get_parameter_value(val: ast.expr) -> Any: - if isinstance(val, ast.Constant): - return val.value - elif isinstance(val, ast.Dict): - if not all(isinstance(k, ast.Constant) for k in val.keys): - raise _UnexpectedAstError("Dict tool call arguments must have literal keys") - return { - k.value: _get_parameter_value(v) # type: ignore - for k, v in zip(val.keys, val.values) - } - elif isinstance(val, ast.List): - return [_get_parameter_value(v) for v in val.elts] - else: - raise _UnexpectedAstError("Tool call arguments must be literals") - - -def _handle_single_tool(call: ast.Call) -> ToolCall: - if not isinstance(call.func, ast.Name): - raise _UnexpectedAstError("Invalid tool call name") - function_name = call.func.id - arguments = {} - for keyword in call.keywords: - arguments[keyword.arg] = _get_parameter_value(keyword.value) - return ToolCall( - type="function", - function=FunctionCall( - name=function_name, arguments=json.dumps(arguments, ensure_ascii=False) - ), - ) - - -def _make_valid_python(text: str) -> tuple[str, str] | None: - bracket_stack = [] - for index, char in enumerate(text): - if char in {"[", "(", "{"}: - bracket_stack.append(char) - elif char == "]": - if not bracket_stack or bracket_stack.pop() != "[": - raise _UnexpectedAstError("Mismatched square brackets") - elif char == ")": - if not bracket_stack or bracket_stack.pop() != "(": - raise _UnexpectedAstError("Mismatched parentheses") - elif char == "}": - if not bracket_stack or bracket_stack.pop() != "{": - raise _UnexpectedAstError("Mismatched curly braces") - elif char in {"'", '"'}: - if bracket_stack and bracket_stack[-1] == char: - if index > 0 and text[index - 1] == "\\": - # Treat an escaped quote as a regular character - pass - else: - bracket_stack.pop() - elif bracket_stack and bracket_stack[-1] in {"'", '"'}: - # Double quote within a single quote string or vice versa. - pass - else: - bracket_stack.append(char) - - text = text.rstrip() - if text.endswith("=") or text.endswith(":"): - # Since we have no type information for this property/parameter value, - # we can't fill in a valid value. - return None - if bracket_stack and bracket_stack[-1] == "{": - trailing_dict_text = text[: text.rfind("{")] - num_keys = trailing_dict_text.count(":") - num_values = trailing_dict_text.count(",") - if num_keys <= num_values: - return None # Incomplete property name within parameter value - if bracket_stack and bracket_stack[-1] == "(": - trailing_params_text = text[: text.rfind("(")] - num_full_param_names = trailing_params_text.count("=") - num_full_param_values = trailing_params_text.count(",") - if num_full_param_names <= num_full_param_values: - return None # Incomplete parameter name - if text.endswith(","): - text = text[:-1] - if ( - bracket_stack - and bracket_stack[-1] == "[" - and not text.endswith("[") - and not text.endswith(")") - ): - return None # Incomplete function name - - added_text = "" - for char in reversed(bracket_stack): - if char == "[": - added_text += "]" - elif char == "(": - added_text += ")" - elif char == "{": - added_text += "}" - elif char == "'": - added_text += "'" - elif char == '"': - added_text += '"' - - return text + added_text, added_text - - -def _compute_tool_delta( - previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str -) -> DeltaToolCall | None: - new_call_args = new_call.function.arguments - if withheld_suffix: - assert new_call_args.endswith(withheld_suffix) - new_call_args = new_call_args[: -len(withheld_suffix)] - if not previously_sent_args: - return DeltaToolCall( - id=new_call.id, - type="function", - index=index, - function=DeltaFunctionCall( - name=new_call.function.name, - arguments=new_call_args, - ), - ) - - arg_diff = new_call_args[len(previously_sent_args) :] - return ( - DeltaToolCall( - id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff) - ) - if arg_diff - else None - ) diff --git a/vllm/tool_parsers/utils.py b/vllm/tool_parsers/utils.py index 49dd023d4788..a279e5b9b59c 100644 --- a/vllm/tool_parsers/utils.py +++ b/vllm/tool_parsers/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast import json from json import JSONDecodeError, JSONDecoder from typing import Any @@ -17,6 +18,15 @@ ChatCompletionNamedToolChoiceParam, ChatCompletionToolsParam, ) +from vllm.entrypoints.openai.engine.protocol import ( + DeltaFunctionCall, + DeltaToolCall, + FunctionCall, + ToolCall, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) def find_common_prefix(s1: str, s2: str) -> str: @@ -212,3 +222,202 @@ def get_json_schema_from_tools( return _get_json_schema_from_tools(tools) # tool_choice: "auto" return None + + +# --------------------------------------------------------------------------- +# Shared utilities for pythonic-style tool call parsers +# (PythonicToolParser, Llama4PythonicToolParser, Olmo3PythonicToolParser) +# --------------------------------------------------------------------------- + + +class UnexpectedAstError(Exception): + """Raised when the AST structure does not match the expected + pythonic tool call format.""" + + pass + + +_JSON_NAME_LITERALS = { + "null": None, + "true": True, + "false": False, +} + + +def get_parameter_value(val: ast.expr) -> Any: + """Extract a Python literal value from an AST expression node. + + Handles constants, dicts, lists, and JSON-style name literals + (null, true, false) that some models produce instead of Python + literals (None, True, False). + + Raises: + UnexpectedAstError: If the AST node is not a supported literal type. + """ + if isinstance(val, ast.Constant): + return val.value + elif isinstance(val, ast.Dict): + if not all(isinstance(k, ast.Constant) for k in val.keys): + logger.warning( + "Dict argument keys are not all literals: %s", + ast.dump(val), + ) + raise UnexpectedAstError("Dict tool call arguments must have literal keys") + return { + k.value: get_parameter_value(v) # type: ignore + for k, v in zip(val.keys, val.values) + } + elif isinstance(val, ast.List): + return [get_parameter_value(v) for v in val.elts] + elif isinstance(val, ast.Name) and val.id in _JSON_NAME_LITERALS: + return _JSON_NAME_LITERALS[val.id] + else: + logger.warning( + "Unsupported AST node type in tool call arguments: %s", + ast.dump(val), + ) + raise UnexpectedAstError("Tool call arguments must be literals") + + +def handle_single_tool(call: ast.Call) -> ToolCall: + """Convert a single AST function call node into a ToolCall object. + + Raises: + UnexpectedAstError: If the call node does not have a simple + function name (e.g. it's an attribute access or subscript). + """ + if not isinstance(call.func, ast.Name): + logger.warning( + "Tool call has non-simple function name: %s", + ast.dump(call.func), + ) + raise UnexpectedAstError("Invalid tool call name") + function_name = call.func.id + arguments = {} + for keyword in call.keywords: + arguments[keyword.arg] = get_parameter_value(keyword.value) + return ToolCall( + type="function", + function=FunctionCall( + name=function_name, + arguments=json.dumps(arguments, ensure_ascii=False), + ), + ) + + +def make_valid_python(text: str) -> tuple[str, str] | None: + """Attempt to close all open brackets/quotes to make partial Python valid. + + Used during streaming to parse incomplete tool call expressions by + appending the necessary closing characters. + + Returns: + A tuple of (completed_text, added_suffix) if the text can be + made valid, or None if the text is too incomplete to complete + meaningfully (e.g. mid-parameter-name or mid-dict-key). + + Raises: + UnexpectedAstError: If mismatched brackets or parentheses + are detected. + """ + bracket_stack: list[str] = [] + for index, char in enumerate(text): + if char in {"[", "(", "{"}: + bracket_stack.append(char) + elif char == "]": + if not bracket_stack or bracket_stack.pop() != "[": + raise UnexpectedAstError("Mismatched square brackets") + elif char == ")": + if not bracket_stack or bracket_stack.pop() != "(": + raise UnexpectedAstError("Mismatched parentheses") + elif char == "}": + if not bracket_stack or bracket_stack.pop() != "{": + raise UnexpectedAstError("Mismatched curly braces") + elif char in {"'", '"'}: + if bracket_stack and bracket_stack[-1] == char: + if index > 0 and text[index - 1] == "\\": + pass + else: + bracket_stack.pop() + elif bracket_stack and bracket_stack[-1] in {"'", '"'}: + pass + else: + bracket_stack.append(char) + + text = text.rstrip() + if text.endswith("=") or text.endswith(":"): + return None + if bracket_stack and bracket_stack[-1] == "{": + trailing_dict_text = text[: text.rfind("{")] + num_keys = trailing_dict_text.count(":") + num_values = trailing_dict_text.count(",") + if num_keys <= num_values: + return None + if bracket_stack and bracket_stack[-1] == "(": + trailing_params_text = text[: text.rfind("(")] + num_full_param_names = trailing_params_text.count("=") + num_full_param_values = trailing_params_text.count(",") + if num_full_param_names <= num_full_param_values: + return None + if text.endswith(","): + text = text[:-1] + if ( + bracket_stack + and bracket_stack[-1] == "[" + and not text.endswith("[") + and not text.endswith(")") + ): + return None + + _CLOSING = {"[": "]", "(": ")", "{": "}", "'": "'", '"': '"'} + added_text = "" + for char in reversed(bracket_stack): + added_text += _CLOSING[char] + + return text + added_text, added_text + + +def compute_tool_delta( + previously_sent_args: str, + new_call: ToolCall, + index: int, + withheld_suffix: str, +) -> DeltaToolCall | None: + """Compute the incremental delta between previously streamed arguments + and the current tool call state. + + Returns: + A DeltaToolCall with only the new argument characters, or None + if there is no difference from what was previously sent. + """ + new_call_args = new_call.function.arguments + if withheld_suffix: + if not new_call_args.endswith(withheld_suffix): + msg = ( + f"Tool call arguments '{new_call_args}' do not end with " + f"expected withheld suffix '{withheld_suffix}'" + ) + logger.error(msg) + raise ValueError(msg) + new_call_args = new_call_args[: -len(withheld_suffix)] + if not previously_sent_args: + return DeltaToolCall( + id=new_call.id, + type="function", + index=index, + function=DeltaFunctionCall( + name=new_call.function.name, + arguments=new_call_args, + ), + ) + + arg_diff = new_call_args[len(previously_sent_args) :] + return ( + DeltaToolCall( + id=None, + index=index, + function=DeltaFunctionCall(arguments=arg_diff), + ) + if arg_diff + else None + )