diff --git a/tests/tool_parsers/test_gemma4_tool_parser.py b/tests/tool_parsers/test_gemma4_tool_parser.py index 6f3709e19a45..0bb87dff114f 100644 --- a/tests/tool_parsers/test_gemma4_tool_parser.py +++ b/tests/tool_parsers/test_gemma4_tool_parser.py @@ -728,3 +728,143 @@ def test_streaming_trailing_bare_bool_not_duplicated(self, parser, mock_request) } assert args_text.count("replace_all") == 1 + + +# --------------------------------------------------------------------------- +# Regression tests for issue #44522: tokenizer decodes special tokens to +# strings that differ from their vocabulary keys. +# --------------------------------------------------------------------------- + + +# Alternative decoded forms that some Gemma4 tokenizer builds produce. +_ALT_TOOL_CALL_START = "<|tool_call|>" # extra |> suffix (13 chars vs 12) +_ALT_TOOL_CALL_END = "" # same as canonical — end token OK +_ALT_STRING_DELIM = '<|">' # 4-char form vs 5-char <|"|> + + +@pytest.fixture +def mock_tokenizer_alt_decode(): + """Tokenizer whose vocab keys are canonical but decode() returns alt forms. + + This simulates the real-world mismatch between get_vocab() keys and the + AddedToken.content values used by the fast-tokenizer's decode path. + """ + tokenizer = MagicMock() + tokenizer.encode.return_value = [1, 2, 3] + tokenizer.get_vocab.return_value = { + TOOL_CALL_START: 48, + TOOL_CALL_END: 49, + '<|"|>': 52, + } + + def _decode(token_ids, skip_special_tokens=True): + mapping = { + 48: _ALT_TOOL_CALL_START, + 49: _ALT_TOOL_CALL_END, + 52: _ALT_STRING_DELIM, + } + return "".join(mapping.get(t, "") for t in token_ids) + + tokenizer.decode.side_effect = _decode + return tokenizer + + +@pytest.fixture +def parser_alt(mock_tokenizer_alt_decode): + return Gemma4ToolParser(mock_tokenizer_alt_decode) + + +class TestAltTokenStrings: + """Regression tests for issue #44522. + + Some Gemma4 tokenizer builds decode the tool-call start token to + ``<|tool_call|>`` (with an extra ``|>`` suffix) instead of the + canonical ``<|tool_call>``. Before the fix, the streaming guard + ``self.tool_call_start_token not in current_text`` would always fire + because the parser looked for the vocabulary key, not the decoded form, + leaking the entire tool-call block as plain content. + """ + + def _stream(self, parser, request, chunks): + """Simulate token-by-token streaming using the parser's actual tokens.""" + results = [] + prev_text = "" + prev_ids = [] + start_tok = parser.tool_call_start_token + end_tok = parser.tool_call_end_token + + for chunk in chunks: + curr_text = prev_text + chunk + if start_tok in chunk: + delta_ids = [48] + elif end_tok in chunk: + delta_ids = [49] + else: + delta_ids = [0] + curr_ids = prev_ids + delta_ids + delta = parser.extract_tool_calls_streaming( + previous_text=prev_text, + current_text=curr_text, + delta_text=chunk, + previous_token_ids=tuple(prev_ids), + current_token_ids=tuple(curr_ids), + delta_token_ids=tuple(delta_ids), + request=request, + ) + results.append(delta) + prev_text = curr_text + prev_ids = list(curr_ids) + return results + + def test_parser_uses_decoded_start_token(self, parser_alt): + """Parser must use the decoded token string, not the vocab key.""" + assert parser_alt.tool_call_start_token == _ALT_TOOL_CALL_START + assert parser_alt.string_delim == _ALT_STRING_DELIM + + def test_streaming_no_content_leak_with_alt_start_token( + self, parser_alt, mock_request + ): + """Tool calls must be detected even when the start token decodes to + an alternative string (e.g. <|tool_call|> instead of <|tool_call>). + + Regression for issue #44522 where raw tool-call text leaked as + content because the guard check used the vocab key instead of the + decoded form. + """ + start = parser_alt.tool_call_start_token # "<|tool_call|>" + end = parser_alt.tool_call_end_token # "" + sd = parser_alt.string_delim # '<|">' + + chunks = [ + start, + f"call:preview_url{{explanation:{sd}Hello{sd},url:{sd}/foo{sd}}}", + end, + ] + results = self._stream(parser_alt, mock_request, chunks) + + # No chunk should produce content — all must be tool_calls or None + for delta in results: + assert delta is None or delta.content is None, ( + f"Raw tool-call text leaked as content: {delta.content!r}" + ) + + # At least one delta should carry tool call info + tool_call_deltas = [d for d in results if d is not None and d.tool_calls] + assert tool_call_deltas, "No tool call deltas were emitted" + + def test_non_streaming_with_alt_token_strings(self, parser_alt, mock_request): + """Non-streaming extraction must also use the detected token strings.""" + start = parser_alt.tool_call_start_token + end = parser_alt.tool_call_end_token + sd = parser_alt.string_delim + + model_output = f"{start}call:bash{{command:{sd}ls -la{sd}}}{end}" + result = parser_alt.extract_tool_calls(model_output, mock_request) + + assert result.tools_called is True, ( + "extract_tool_calls failed to detect tool call with alt token strings" + ) + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "bash" + args = json.loads(result.tool_calls[0].function.arguments) + assert args == {"command": "ls -la"} diff --git a/vllm/tool_parsers/gemma4_tool_parser.py b/vllm/tool_parsers/gemma4_tool_parser.py index 9925284273f9..7f4eb059a1c3 100644 --- a/vllm/tool_parsers/gemma4_tool_parser.py +++ b/vllm/tool_parsers/gemma4_tool_parser.py @@ -43,12 +43,35 @@ logger = init_logger(__name__) -# Gemma4 special tokens for tool calls +# Gemma4 special tokens for tool calls — canonical vocabulary keys. +# NOTE: The tokenizer may DECODE these tokens to different strings than +# the vocab keys (e.g., <|tool_call> may decode as <|tool_call|>). All +# runtime text-matching uses the detected decoded form (see __init__). TOOL_CALL_START = "<|tool_call>" TOOL_CALL_END = "" STRING_DELIM = '<|"|>' +def _decoded_token_str(tokenizer: "TokenizerLike", token_id: int, fallback: str) -> str: + """Return the decoded string for *token_id* using *tokenizer*. + + Some tokenizer versions store a different `content` string in their + ``added_tokens_decoder`` than the vocabulary key returned by + ``get_vocab()``. For example, the Gemma4 start token may be keyed as + ``<|tool_call>`` in the vocabulary but decoded as ``<|tool_call|>`` + by the fast tokenizer (issue #44522). Using the vocabulary key for + all text comparisons then causes the streaming guard check to miss + the token in ``current_text``, leaking raw tool-call text as content. + + Falls back to *fallback* if decoding fails or returns an empty string. + """ + try: + decoded = tokenizer.decode([token_id], skip_special_tokens=False) + return decoded if isinstance(decoded, str) and decoded else fallback + except Exception: + return fallback + + # --------------------------------------------------------------------------- # Gemma4 argument parser (used by both streaming and non-streaming paths) # --------------------------------------------------------------------------- @@ -82,7 +105,12 @@ def _parse_gemma4_value(value_str: str) -> object: return value_str -def _parse_gemma4_args(args_str: str, *, partial: bool = False) -> dict: +def _parse_gemma4_args( + args_str: str, + *, + partial: bool = False, + string_delim: str = STRING_DELIM, +) -> dict: """Parse Gemma4's custom key:value format into a Python dict. Format examples:: @@ -98,6 +126,9 @@ def _parse_gemma4_args(args_str: str, *, partial: bool = False) -> dict: partial: When True (streaming), bare values at end of string are omitted because they may be incomplete and type-unstable (e.g. partial boolean parsed as bare string). + string_delim: The string delimiter token as decoded by the + tokenizer (defaults to the canonical ``STRING_DELIM`` constant + but may differ for some tokenizer versions). Returns a dict ready for ``json.dumps()``. """ @@ -138,17 +169,17 @@ def _parse_gemma4_args(args_str: str, *, partial: bool = False) -> dict: result[key] = "" break - # String value: <|"|>...<|"|> - if args_str[i:].startswith(STRING_DELIM): - i += len(STRING_DELIM) + # String value: <|"|>...<|"|> (actual delimiter may differ per tokenizer) + if args_str[i:].startswith(string_delim): + i += len(string_delim) val_start = i - end_pos = args_str.find(STRING_DELIM, i) + end_pos = args_str.find(string_delim, i) if end_pos == -1: # Unterminated string — take rest result[key] = args_str[val_start:] break result[key] = args_str[val_start:end_pos] - i = end_pos + len(STRING_DELIM) + i = end_pos + len(string_delim) # Nested object: {...} elif args_str[i] == "{": @@ -156,11 +187,11 @@ def _parse_gemma4_args(args_str: str, *, partial: bool = False) -> dict: obj_start = i + 1 i += 1 while i < n and depth > 0: - if args_str[i:].startswith(STRING_DELIM): + if args_str[i:].startswith(string_delim): # Skip over string contents to avoid counting { inside strings - i += len(STRING_DELIM) - next_delim = args_str.find(STRING_DELIM, i) - i = n if next_delim == -1 else next_delim + len(STRING_DELIM) + i += len(string_delim) + next_delim = args_str.find(string_delim, i) + i = n if next_delim == -1 else next_delim + len(string_delim) continue if args_str[i] == "{": depth += 1 @@ -170,9 +201,13 @@ def _parse_gemma4_args(args_str: str, *, partial: bool = False) -> dict: if depth > 0: # Incomplete nested object — use i (not i-1) to avoid # dropping the last char, and recurse as partial. - result[key] = _parse_gemma4_args(args_str[obj_start:i], partial=True) + result[key] = _parse_gemma4_args( + args_str[obj_start:i], partial=True, string_delim=string_delim + ) else: - result[key] = _parse_gemma4_args(args_str[obj_start : i - 1]) + result[key] = _parse_gemma4_args( + args_str[obj_start : i - 1], string_delim=string_delim + ) # Array: [...] elif args_str[i] == "[": @@ -180,10 +215,10 @@ def _parse_gemma4_args(args_str: str, *, partial: bool = False) -> dict: arr_start = i + 1 i += 1 while i < n and depth > 0: - if args_str[i:].startswith(STRING_DELIM): - i += len(STRING_DELIM) - next_delim = args_str.find(STRING_DELIM, i) - i = n if next_delim == -1 else next_delim + len(STRING_DELIM) + if args_str[i:].startswith(string_delim): + i += len(string_delim) + next_delim = args_str.find(string_delim, i) + i = n if next_delim == -1 else next_delim + len(string_delim) continue if args_str[i] == "[": depth += 1 @@ -191,9 +226,13 @@ def _parse_gemma4_args(args_str: str, *, partial: bool = False) -> dict: depth -= 1 i += 1 if depth > 0: - result[key] = _parse_gemma4_array(args_str[arr_start:i], partial=True) + result[key] = _parse_gemma4_array( + args_str[arr_start:i], partial=True, string_delim=string_delim + ) else: - result[key] = _parse_gemma4_array(args_str[arr_start : i - 1]) + result[key] = _parse_gemma4_array( + args_str[arr_start : i - 1], string_delim=string_delim + ) # Bare value (number, boolean, etc.) else: @@ -224,7 +263,12 @@ def _parse_gemma4_args(args_str: str, *, partial: bool = False) -> dict: return result -def _parse_gemma4_array(arr_str: str, *, partial: bool = False) -> list: +def _parse_gemma4_array( + arr_str: str, + *, + partial: bool = False, + string_delim: str = STRING_DELIM, +) -> list: """Parse a Gemma4 array content string into a Python list.""" items: list = [] i = 0 @@ -237,14 +281,14 @@ def _parse_gemma4_array(arr_str: str, *, partial: bool = False) -> list: break # String element - if arr_str[i:].startswith(STRING_DELIM): - i += len(STRING_DELIM) - end_pos = arr_str.find(STRING_DELIM, i) + if arr_str[i:].startswith(string_delim): + i += len(string_delim) + end_pos = arr_str.find(string_delim, i) if end_pos == -1: items.append(arr_str[i:]) break items.append(arr_str[i:end_pos]) - i = end_pos + len(STRING_DELIM) + i = end_pos + len(string_delim) # Nested object elif arr_str[i] == "{": @@ -252,10 +296,10 @@ def _parse_gemma4_array(arr_str: str, *, partial: bool = False) -> list: obj_start = i + 1 i += 1 while i < n and depth > 0: - if arr_str[i:].startswith(STRING_DELIM): - i += len(STRING_DELIM) - nd = arr_str.find(STRING_DELIM, i) - i = nd + len(STRING_DELIM) if nd != -1 else n + if arr_str[i:].startswith(string_delim): + i += len(string_delim) + nd = arr_str.find(string_delim, i) + i = nd + len(string_delim) if nd != -1 else n continue if arr_str[i] == "{": depth += 1 @@ -263,9 +307,17 @@ def _parse_gemma4_array(arr_str: str, *, partial: bool = False) -> list: depth -= 1 i += 1 if depth > 0: - items.append(_parse_gemma4_args(arr_str[obj_start:i], partial=True)) + items.append( + _parse_gemma4_args( + arr_str[obj_start:i], partial=True, string_delim=string_delim + ) + ) else: - items.append(_parse_gemma4_args(arr_str[obj_start : i - 1])) + items.append( + _parse_gemma4_args( + arr_str[obj_start : i - 1], string_delim=string_delim + ) + ) # Nested array elif arr_str[i] == "[": @@ -273,10 +325,10 @@ def _parse_gemma4_array(arr_str: str, *, partial: bool = False) -> list: sub_start = i + 1 i += 1 while i < n and depth > 0: - if arr_str[i:].startswith(STRING_DELIM): - i += len(STRING_DELIM) - nd = arr_str.find(STRING_DELIM, i) - i = nd + len(STRING_DELIM) if nd != -1 else n + if arr_str[i:].startswith(string_delim): + i += len(string_delim) + nd = arr_str.find(string_delim, i) + i = nd + len(string_delim) if nd != -1 else n continue if arr_str[i] == "[": depth += 1 @@ -284,9 +336,17 @@ def _parse_gemma4_array(arr_str: str, *, partial: bool = False) -> list: depth -= 1 i += 1 if depth > 0: - items.append(_parse_gemma4_array(arr_str[sub_start:i], partial=True)) + items.append( + _parse_gemma4_array( + arr_str[sub_start:i], partial=True, string_delim=string_delim + ) + ) else: - items.append(_parse_gemma4_array(arr_str[sub_start : i - 1])) + items.append( + _parse_gemma4_array( + arr_str[sub_start : i - 1], string_delim=string_delim + ) + ) # Bare value else: @@ -352,13 +412,10 @@ def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): "constructor during construction." ) - # Token strings - self.tool_call_start_token = TOOL_CALL_START - self.tool_call_end_token = TOOL_CALL_END - - # Token IDs + # Token IDs (looked up via vocabulary keys) self.tool_call_start_token_id = self.vocab.get(TOOL_CALL_START) self.tool_call_end_token_id = self.vocab.get(TOOL_CALL_END) + _string_delim_id = self.vocab.get(STRING_DELIM) if self.tool_call_start_token_id is None: raise RuntimeError( @@ -366,11 +423,40 @@ def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): f"token '{TOOL_CALL_START}' in the tokenizer!" ) + # Detect the actual decoded token strings. + # + # Some tokenizer builds (see issue #44522) store the canonical + # vocabulary key (e.g. ``<|tool_call>``) in ``get_vocab()`` while + # decoding the same token ID to a *different* string + # (e.g. ``<|tool_call|>``). The fast-tokenizer's + # ``added_tokens_decoder[id].content`` is what ends up in + # ``output.text``, so all text-matching in this parser must use + # that form, not the vocabulary key. + self.tool_call_start_token = _decoded_token_str( + self.model_tokenizer, self.tool_call_start_token_id, TOOL_CALL_START + ) + self.tool_call_end_token = ( + _decoded_token_str( + self.model_tokenizer, self.tool_call_end_token_id, TOOL_CALL_END + ) + if self.tool_call_end_token_id is not None + else TOOL_CALL_END + ) + self.string_delim = ( + _decoded_token_str( + self.model_tokenizer, _string_delim_id, STRING_DELIM + ) + if _string_delim_id is not None + else STRING_DELIM + ) + # Regex for non-streaming: extract complete tool calls. # Supports function names with letters, digits, underscores, # hyphens, and dots (e.g. "get-weather", "module.func"). self.tool_call_regex = re.compile( - r"<\|tool_call>call:([\w\-\.]+)\{(.*?)\}", + re.escape(self.tool_call_start_token) + + r"call:([\w\-\.]+)\{(.*?)\}" + + re.escape(self.tool_call_end_token), re.DOTALL, ) @@ -418,12 +504,14 @@ def _buffer_delta_text(self, delta_text: str) -> str: combined = self.buffered_delta_text + delta_text # Check if combined ends with a complete special token - if combined.endswith(TOOL_CALL_START) or combined.endswith(TOOL_CALL_END): + if combined.endswith(self.tool_call_start_token) or combined.endswith( + self.tool_call_end_token + ): self.buffered_delta_text = "" return combined # Check if combined ends with a partial prefix of a special token - for tag in [TOOL_CALL_START, TOOL_CALL_END]: + for tag in [self.tool_call_start_token, self.tool_call_end_token]: for i in range(1, len(tag)): if combined.endswith(tag[:i]): self.buffered_delta_text = combined[-i:] @@ -456,7 +544,7 @@ def extract_tool_calls( tool_calls: list[ToolCall] = [] for func_name, args_str in matches: - arguments = _parse_gemma4_args(args_str) + arguments = _parse_gemma4_args(args_str, string_delim=self.string_delim) tool_calls.append( ToolCall( type="function", @@ -671,7 +759,7 @@ def _handle_tool_call_end(self, current_text: str) -> DeltaMessage | None: all_matches = self.tool_call_regex.findall(current_text) if self.current_tool_id < len(all_matches): _, args_str = all_matches[self.current_tool_id] - final_args = _parse_gemma4_args(args_str) + final_args = _parse_gemma4_args(args_str, string_delim=self.string_delim) final_args_json = json.dumps(final_args, ensure_ascii=False) prev_streamed = self.streamed_args_for_tool[self.current_tool_id] @@ -726,7 +814,9 @@ def _emit_argument_diff(self, raw_args_str: str) -> DeltaMessage | None: DeltaMessage with the argument diff, or None if no new content. """ try: - current_args = _parse_gemma4_args(raw_args_str, partial=True) + current_args = _parse_gemma4_args( + raw_args_str, partial=True, string_delim=self.string_delim + ) except Exception: logger.debug( "Could not parse partial Gemma4 args yet: %s",