Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions tests/tool_parsers/test_gemma4_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,3 +728,143 @@ def test_streaming_trailing_bare_bool_not_duplicated(self, parser, mock_request)
}

assert args_text.count("replace_all") == 1


# ---------------------------------------------------------------------------
# Regression tests for issue #44522: tokenizer decodes special tokens to
# strings that differ from their vocabulary keys.
# ---------------------------------------------------------------------------


# Alternative decoded forms that some Gemma4 tokenizer builds produce.
_ALT_TOOL_CALL_START = "<|tool_call|>" # extra |> suffix (13 chars vs 12)
_ALT_TOOL_CALL_END = "<tool_call|>" # same as canonical — end token OK
_ALT_STRING_DELIM = '<|">' # 4-char form vs 5-char <|"|>


@pytest.fixture
def mock_tokenizer_alt_decode():
"""Tokenizer whose vocab keys are canonical but decode() returns alt forms.

This simulates the real-world mismatch between get_vocab() keys and the
AddedToken.content values used by the fast-tokenizer's decode path.
"""
tokenizer = MagicMock()
tokenizer.encode.return_value = [1, 2, 3]
tokenizer.get_vocab.return_value = {
TOOL_CALL_START: 48,
TOOL_CALL_END: 49,
'<|"|>': 52,
}

def _decode(token_ids, skip_special_tokens=True):
mapping = {
48: _ALT_TOOL_CALL_START,
49: _ALT_TOOL_CALL_END,
52: _ALT_STRING_DELIM,
}
return "".join(mapping.get(t, "") for t in token_ids)

tokenizer.decode.side_effect = _decode
return tokenizer


@pytest.fixture
def parser_alt(mock_tokenizer_alt_decode):
return Gemma4ToolParser(mock_tokenizer_alt_decode)


class TestAltTokenStrings:
"""Regression tests for issue #44522.

Some Gemma4 tokenizer builds decode the tool-call start token to
``<|tool_call|>`` (with an extra ``|>`` suffix) instead of the
canonical ``<|tool_call>``. Before the fix, the streaming guard
``self.tool_call_start_token not in current_text`` would always fire
because the parser looked for the vocabulary key, not the decoded form,
leaking the entire tool-call block as plain content.
"""

def _stream(self, parser, request, chunks):
"""Simulate token-by-token streaming using the parser's actual tokens."""
results = []
prev_text = ""
prev_ids = []
start_tok = parser.tool_call_start_token
end_tok = parser.tool_call_end_token

for chunk in chunks:
curr_text = prev_text + chunk
if start_tok in chunk:
delta_ids = [48]
elif end_tok in chunk:
delta_ids = [49]
else:
delta_ids = [0]
curr_ids = prev_ids + delta_ids
delta = parser.extract_tool_calls_streaming(
previous_text=prev_text,
current_text=curr_text,
delta_text=chunk,
previous_token_ids=tuple(prev_ids),
current_token_ids=tuple(curr_ids),
delta_token_ids=tuple(delta_ids),
request=request,
)
results.append(delta)
prev_text = curr_text
prev_ids = list(curr_ids)
return results

def test_parser_uses_decoded_start_token(self, parser_alt):
"""Parser must use the decoded token string, not the vocab key."""
assert parser_alt.tool_call_start_token == _ALT_TOOL_CALL_START
assert parser_alt.string_delim == _ALT_STRING_DELIM

def test_streaming_no_content_leak_with_alt_start_token(
self, parser_alt, mock_request
):
"""Tool calls must be detected even when the start token decodes to
an alternative string (e.g. <|tool_call|> instead of <|tool_call>).

Regression for issue #44522 where raw tool-call text leaked as
content because the guard check used the vocab key instead of the
decoded form.
"""
start = parser_alt.tool_call_start_token # "<|tool_call|>"
end = parser_alt.tool_call_end_token # "<tool_call|>"
sd = parser_alt.string_delim # '<|">'

chunks = [
start,
f"call:preview_url{{explanation:{sd}Hello{sd},url:{sd}/foo{sd}}}",
end,
]
results = self._stream(parser_alt, mock_request, chunks)

# No chunk should produce content — all must be tool_calls or None
for delta in results:
assert delta is None or delta.content is None, (
f"Raw tool-call text leaked as content: {delta.content!r}"
)

# At least one delta should carry tool call info
tool_call_deltas = [d for d in results if d is not None and d.tool_calls]
assert tool_call_deltas, "No tool call deltas were emitted"

def test_non_streaming_with_alt_token_strings(self, parser_alt, mock_request):
"""Non-streaming extraction must also use the detected token strings."""
start = parser_alt.tool_call_start_token
end = parser_alt.tool_call_end_token
sd = parser_alt.string_delim

model_output = f"{start}call:bash{{command:{sd}ls -la{sd}}}{end}"
result = parser_alt.extract_tool_calls(model_output, mock_request)

assert result.tools_called is True, (
"extract_tool_calls failed to detect tool call with alt token strings"
)
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "bash"
args = json.loads(result.tool_calls[0].function.arguments)
assert args == {"command": "ls -la"}
Loading