Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 67 additions & 25 deletions tests/tool_parsers/test_glm47_moe_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def sample_tools():
ChatCompletionToolsParam(
function=FunctionDefinition(name="get_current_date", parameters={}),
),
ChatCompletionToolsParam(
function=FunctionDefinition(
name="mcp__gmail-cleanup__list_filters",
parameters={},
),
),
ChatCompletionToolsParam(
function=FunctionDefinition(
name="get_weather",
Expand Down Expand Up @@ -125,23 +131,70 @@ def _reset(parser):
parser._sent_content_idx = 0


def _stream_chunks(parser, request, chunks):
current_text = ""
deltas = []
for chunk in chunks:
current_text += chunk
delta = parser.extract_tool_calls_streaming(
previous_text="",
current_text=current_text,
delta_text=chunk,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=request,
)
if delta and delta.tool_calls:
deltas.extend(delta.tool_calls)
return deltas


def _tool_call_args_from_deltas(deltas, index):
return "".join(
tool_call.function.arguments or ""
for tool_call in deltas
if tool_call.index == index and tool_call.function
)


class TestGlm47Streaming:
def test_no_args(self, glm47_tool_parser, mock_request):
_reset(glm47_tool_parser)
chunks = ["<tool_call>", "get_current_date", "</tool_call>"]
current_text = ""
for chunk in chunks:
current_text += chunk
glm47_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=current_text,
delta_text=chunk,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=mock_request,
)
assert len(glm47_tool_parser.prev_tool_call_arr) >= 1
deltas = _stream_chunks(glm47_tool_parser, mock_request, chunks)

assert glm47_tool_parser.prev_tool_call_arr[0]["name"] == "get_current_date"
assert json.loads(glm47_tool_parser.prev_tool_call_arr[0]["arguments"]) == {}
assert any(
tool_call.function
and tool_call.function.name == "get_current_date"
and tool_call.index == 0
for tool_call in deltas
)
assert json.loads(_tool_call_args_from_deltas(deltas, 0)) == {}

def test_no_args_mcp_tool_name(self, glm47_tool_parser, mock_request):
_reset(glm47_tool_parser)
chunks = [
"<tool_call>",
"mcp__gmail-cleanup__list_filters",
"</tool_call>",
]
deltas = _stream_chunks(glm47_tool_parser, mock_request, chunks)

assert (
glm47_tool_parser.prev_tool_call_arr[0]["name"]
== "mcp__gmail-cleanup__list_filters"
)
assert json.loads(glm47_tool_parser.prev_tool_call_arr[0]["arguments"]) == {}
assert any(
tool_call.function
and tool_call.function.name == "mcp__gmail-cleanup__list_filters"
and tool_call.index == 0
for tool_call in deltas
)
assert json.loads(_tool_call_args_from_deltas(deltas, 0)) == {}

def test_with_args(self, glm47_tool_parser, mock_request):
_reset(glm47_tool_parser)
Expand All @@ -154,17 +207,6 @@ def test_with_args(self, glm47_tool_parser, mock_request):
"</arg_value>",
"</tool_call>",
]
current_text = ""
for chunk in chunks:
current_text += chunk
glm47_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=current_text,
delta_text=chunk,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=mock_request,
)
_stream_chunks(glm47_tool_parser, mock_request, chunks)
args = json.loads(glm47_tool_parser.prev_tool_call_arr[0]["arguments"])
assert args["city"] == "Beijing"
12 changes: 9 additions & 3 deletions vllm/tool_parsers/glm4_moe_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,16 +313,22 @@ def _extract_tool_call_regions(self, text: str) -> list[tuple[str, bool]]:
break
return results

def _extract_tool_name_from_region(self, inner_text: str) -> str | None:
def _extract_tool_name_from_region(
self, inner_text: str, is_complete: bool = False
) -> str | None:
"""Extract the tool name from the beginning of a tool-call region.

The name is everything before the first ``\\n`` or ``<arg_key>``.
The name is everything before the first ``\\n`` or ``<arg_key>``. For
complete zero-argument calls, the whole region is the tool name.
Returns ``None`` if the name hasn't fully arrived yet.
"""
nl = inner_text.find("\n")
ak = inner_text.find(self.arg_key_start)
candidates = [i for i in [nl, ak] if i != -1]
if not candidates:
if is_complete:
name = inner_text.strip()
return name if name and "<" not in name else None
return None
cut = min(candidates)
name = inner_text[:cut].strip()
Expand Down Expand Up @@ -458,7 +464,7 @@ def extract_tool_calls_streaming(
self._ensure_tool_state_for(i)

# Extract tool name
tool_name = self._extract_tool_name_from_region(inner_text)
tool_name = self._extract_tool_name_from_region(inner_text, is_complete)
if not tool_name:
break

Expand Down
Loading