Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 13 additions & 147 deletions vllm/tool_parsers/llama4_pythonic_tool_parser.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import ast
import json
from collections.abc import Sequence
from typing import Any

import regex as re
from transformers import PreTrainedTokenizerBase
Expand All @@ -13,25 +12,23 @@
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.tool_parsers.utils import (
UnexpectedAstError,
compute_tool_delta,
handle_single_tool,
make_valid_python,
)

logger = init_logger(__name__)


class _UnexpectedAstError(Exception):
pass


class Llama4PythonicToolParser(ToolParser):
"""
Toolcall parser for Llama4 that produce tool calls in a pythonic style
Expand Down Expand Up @@ -103,15 +100,13 @@ def extract_tool_calls(
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=[
_handle_single_tool(e) # type: ignore
handle_single_tool(e) # type: ignore
for e in parsed.elts
],
content=None,
)
else:
raise _UnexpectedAstError(
"Tool output must be a list of function calls"
)
raise UnexpectedAstError("Tool output must be a list of function calls")
except Exception:
logger.exception("Error in extracting tool call from response.")
# Treat as regular text
Expand Down Expand Up @@ -140,7 +135,7 @@ def extract_tool_calls_streaming(
current_text = current_text[len("<|python_start|>") :]
if current_text.endswith("<|python_end|>"):
current_text = current_text[: current_text.rfind("<|python_end|>")]
valid_and_added_text = _make_valid_python(current_text)
valid_and_added_text = make_valid_python(current_text)
if valid_and_added_text is None:
return None
valid_text, added_text = valid_and_added_text
Expand All @@ -150,11 +145,9 @@ def extract_tool_calls_streaming(
if not isinstance(parsed, ast.List) or not all(
isinstance(e, ast.Call) for e in parsed.elts
):
raise _UnexpectedAstError(
"Tool output must be a list of function calls"
)
raise UnexpectedAstError("Tool output must be a list of function calls")
tool_calls = [
_handle_single_tool(e) # type: ignore
handle_single_tool(e) # type: ignore
for e in parsed.elts
]

Expand All @@ -180,7 +173,7 @@ def extract_tool_calls_streaming(
# Strings get single quotes in the model-produced string.
# JSON requires double quotes.
withheld_suffix = withheld_suffix.replace("'", '"')
delta = _compute_tool_delta(
delta = compute_tool_delta(
self.streamed_args_for_tool[index], new_call, index, withheld_suffix
)

Expand Down Expand Up @@ -214,130 +207,3 @@ def extract_tool_calls_streaming(
"Skipping chunk as a result of tool streaming extraction error"
)
return None


def _get_parameter_value(val: ast.expr) -> Any:
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
return {
k.value: _get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [_get_parameter_value(v) for v in val.elts]
else:
raise _UnexpectedAstError("Tool call arguments must be literals")


def _handle_single_tool(call: ast.Call) -> ToolCall:
if not isinstance(call.func, ast.Name):
raise _UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(
type="function",
function=FunctionCall(name=function_name, arguments=json.dumps(arguments)),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

json.dumps uses ensure_ascii: bool = True as default, which we might change the behavior.
I think it is fine as well but please take a look and check

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be safe too. Both ensure_ascii=True and ensure_ascii=False produce valid JSON. The only difference is whether non-ASCII characters are escaped vs preserved. The pythonic and olmo3 parsers already used ensure_ascii=False, so this makes the behavior consistent across all three parsers.

)


def _make_valid_python(text: str) -> tuple[str, str] | None:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise _UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise _UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise _UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
# Treat an escaped quote as a regular character
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
# Double quote within a single quote string or vice versa.
pass
else:
bracket_stack.append(char)

text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[: text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None # Incomplete property name within parameter value
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[: text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None # Incomplete parameter name
if text.endswith(","):
text = text[:-1]
if (
bracket_stack
and bracket_stack[-1] == "["
and not text.endswith("[")
and not text.endswith(")")
):
return None # Incomplete function name

added_text = ""
for char in reversed(bracket_stack):
if char == "[":
added_text += "]"
elif char == "(":
added_text += ")"
elif char == "{":
added_text += "}"
elif char == "'":
added_text += "'"
elif char == '"':
added_text += '"'

return text + added_text, added_text


def _compute_tool_delta(
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
) -> DeltaToolCall | None:
new_call_args = new_call.function.arguments
if withheld_suffix:
assert new_call_args.endswith(withheld_suffix)
new_call_args = new_call_args[: -len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(
id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
),
)

arg_diff = new_call_args[len(previously_sent_args) :]
return (
DeltaToolCall(
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
)
if arg_diff
else None
)
Loading