diff --git a/tests/reasoning/test_gemma4_reasoning_parser.py b/tests/reasoning/test_gemma4_reasoning_parser.py index 699fc509d828..b59d1f302f0a 100644 --- a/tests/reasoning/test_gemma4_reasoning_parser.py +++ b/tests/reasoning/test_gemma4_reasoning_parser.py @@ -249,19 +249,87 @@ def test_gemma4_reasoning( assert is_reasoning_end == param_dict["is_reasoning_end"] -def test_gemma4_adjust_request(generic_tokenizer): +def test_gemma4_adjust_request_thinking_enabled(generic_tokenizer): + """When thinking is enabled (default) with no tools, skip_special_tokens is + set to False and skip_token_ids remains None.""" parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( generic_tokenizer ) - request = ChatCompletionRequest(messages=[], model="test-model") assert request.skip_special_tokens is True result = parser.adjust_request(request) assert result.skip_special_tokens is False + # No tools provided — default tool_choice='none' must not trigger suppression. + assert result.skip_token_ids is None + assert result is request + + +def test_gemma4_adjust_request_thinking_disabled(generic_tokenizer): + """When thinking is disabled, skip_special_tokens is left unchanged (True).""" + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + generic_tokenizer + ) + request = ChatCompletionRequest( + messages=[], + model="test-model", + chat_template_kwargs={"enable_thinking": False}, + ) + assert request.skip_special_tokens is True + + result = parser.adjust_request(request) + assert result.skip_special_tokens is True + assert result.skip_token_ids is None assert result is request +def test_gemma4_adjust_request_thinking_enabled_tool_choice_none(generic_tokenizer): + """When thinking is enabled and tool_choice='none', skip_special_tokens is False + and skip_token_ids contains the tool-call delimiter token IDs.""" + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + generic_tokenizer + ) + request = ChatCompletionRequest( + messages=[], + model="test-model", + chat_template_kwargs={"enable_thinking": True}, + tools=[{"type": "function", "function": {"name": "f", "parameters": {}}}], + tool_choice="none", + ) + + result = parser.adjust_request(request) + assert result.skip_special_tokens is False + + vocab = generic_tokenizer.get_vocab() + expected_ids = {vocab["<|tool_call>"]} + if (end_id := vocab.get("")) is not None: + expected_ids.add(end_id) + if (esc_id := vocab.get('<|"|>')) is not None: + expected_ids.add(esc_id) + + assert result.skip_token_ids is not None + assert set(result.skip_token_ids) == expected_ids + + +def test_gemma4_adjust_request_thinking_enabled_tool_choice_auto(generic_tokenizer): + """When thinking is enabled and tool_choice='auto', skip_special_tokens is False + but skip_token_ids is not set (tool parser handles the tokens).""" + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + generic_tokenizer + ) + request = ChatCompletionRequest( + messages=[], + model="test-model", + chat_template_kwargs={"enable_thinking": True}, + tools=[{"type": "function", "function": {"name": "f", "parameters": {}}}], + tool_choice="auto", + ) + + result = parser.adjust_request(request) + assert result.skip_special_tokens is False + assert result.skip_token_ids is None + + def test_gemma4_previous_turn_reasoning_is_reasoning_end(generic_tokenizer): output = ( "<|channel>thought\n1st thought1st content\n" diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index 533959df6094..4d60f7303819 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -198,6 +198,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ignore_eos: bool = False min_tokens: int = 0 skip_special_tokens: bool = True + skip_token_ids: list[int] | None = None spaces_between_special_tokens: bool = True truncate_prompt_tokens: Annotated[int, Field(ge=-1, le=_INT64_MAX)] | None = None prompt_logprobs: int | None = None @@ -508,6 +509,7 @@ def to_sampling_params( max_tokens=max_tokens, min_tokens=self.min_tokens, skip_special_tokens=self.skip_special_tokens, + skip_token_ids=self.skip_token_ids, spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, output_kind=RequestOutputKind.DELTA diff --git a/vllm/entrypoints/openai/responses/protocol.py b/vllm/entrypoints/openai/responses/protocol.py index 79f5894fb91d..e1586363486b 100644 --- a/vllm/entrypoints/openai/responses/protocol.py +++ b/vllm/entrypoints/openai/responses/protocol.py @@ -172,6 +172,7 @@ class ResponsesRequest(OpenAIBaseModel): truncation: Literal["auto", "disabled"] | None = "disabled" user: str | None = None skip_special_tokens: bool = True + skip_token_ids: list[int] | None = None include_stop_str_in_output: bool = False presence_penalty: float | None = Field( default=None, @@ -404,6 +405,7 @@ def to_sampling_params( extra_args=extra_args, skip_clone=True, # Created fresh per request, safe to skip clone skip_special_tokens=self.skip_special_tokens, + skip_token_ids=self.skip_token_ids, include_stop_str_in_output=self.include_stop_str_in_output, ) diff --git a/vllm/reasoning/gemma4_reasoning_parser.py b/vllm/reasoning/gemma4_reasoning_parser.py index 6f2241603f9a..0f7a00fd8bd0 100644 --- a/vllm/reasoning/gemma4_reasoning_parser.py +++ b/vllm/reasoning/gemma4_reasoning_parser.py @@ -55,12 +55,36 @@ def __init__(self, tokenizer: TokenizerLike, *args, **kwargs): self.new_turn_token_id = self.vocab["<|turn>"] self.tool_call_token_id = self.vocab["<|tool_call>"] self.tool_response_token_id = self.vocab["<|tool_response>"] + self._tool_call_end_token_id: int | None = self.vocab.get("") + self._escape_token_id: int | None = self.vocab.get('<|"|>') def adjust_request( self, request: "ChatCompletionRequest | ResponsesRequest" ) -> "ChatCompletionRequest | ResponsesRequest": - """Disable special-token stripping to preserve boundary tokens.""" + """Disable special-token stripping to preserve reasoning boundary tokens. + + When ``tool_choice="none"`` we additionally mark the tool-call delimiter + tokens for selective suppression so they are stripped from the decoded + output without affecting the reasoning channel tokens. + """ + chat_template_kwargs = getattr(request, "chat_template_kwargs", None) or {} + if not chat_template_kwargs.get("enable_thinking", True): + return request request.skip_special_tokens = False + if getattr(request, "tool_choice", None) == "none" and getattr( + request, "tools", None + ): + tool_call_ids = [self.tool_call_token_id] + for optional_id in ( + self._tool_call_end_token_id, + self._escape_token_id, + ): + if optional_id is not None: + tool_call_ids.append(optional_id) + existing = list(getattr(request, "skip_token_ids", None) or []) + request.skip_token_ids = existing + [ + tid for tid in tool_call_ids if tid not in existing + ] return request @property diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 9bcc669591eb..9f2a92f2a7dd 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -255,6 +255,11 @@ class SamplingParams( """Whether to detokenize the output.""" skip_special_tokens: bool = True """Whether to skip special tokens in the output.""" + skip_token_ids: list[int] | None = None + """If provided, these specific token IDs are always suppressed from the + detokenized output regardless of ``skip_special_tokens``. Useful for + selectively hiding certain special tokens (e.g. tool-call delimiters) + while preserving others (e.g. reasoning channel markers).""" spaces_between_special_tokens: bool = True """Whether to add spaces between special tokens in the output.""" include_stop_str_in_output: bool = False @@ -329,6 +334,7 @@ def from_optional( prompt_logprobs: int | None = None, detokenize: bool = True, skip_special_tokens: bool = True, + skip_token_ids: list[int] | None = None, spaces_between_special_tokens: bool = True, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, structured_outputs: StructuredOutputsParams | None = None, @@ -370,6 +376,7 @@ def from_optional( prompt_logprobs=prompt_logprobs, detokenize=detokenize, skip_special_tokens=skip_special_tokens, + skip_token_ids=skip_token_ids, spaces_between_special_tokens=spaces_between_special_tokens, output_kind=output_kind, structured_outputs=structured_outputs, diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 2f81ba4f6c78..ab2e794d2581 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -173,6 +173,10 @@ def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreReques self.request_id = request.request_id self.skip_special_tokens = sampling_params.skip_special_tokens + skip_token_ids = sampling_params.skip_token_ids + self.skip_token_ids: frozenset[int] = ( + frozenset(skip_token_ids) if skip_token_ids else frozenset() + ) self.tokenizer: Tokenizer = tokenizer._tokenizer @@ -205,6 +209,10 @@ def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreReques self.spaces_between_special_tokens = True def decode_next(self, next_token_id: int) -> str: + if next_token_id in self.skip_token_ids: + self._protected_step(next_token_id) + return "" + token = self._protected_step(next_token_id) if not self.spaces_between_special_tokens: @@ -272,6 +280,10 @@ def __init__(self, tokenizer: TokenizerLike, request: EngineCoreRequest): self.token_ids.extend(request.prompt_token_ids or [0] * self.prompt_len) self.skip_special_tokens = params.skip_special_tokens + skip_token_ids = params.skip_token_ids + self.skip_token_ids: frozenset[int] = ( + frozenset(skip_token_ids) if skip_token_ids else frozenset() + ) self.spaces_between_special_tokens = params.spaces_between_special_tokens @property @@ -284,6 +296,12 @@ def num_output_tokens(self) -> int: return len(self.token_ids) - self.prompt_len def decode_next(self, next_token_id: int) -> str: + if next_token_id in self.skip_token_ids: + self.tokens.append("") + self.prefix_offset = self.read_offset + self.read_offset = len(self.tokens) + return "" + new_tokens, decoded_text, prefix_offset, read_offset = detokenize_incrementally( tokenizer=self.tokenizer, all_input_ids=self.token_ids,