Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions tests/reasoning/test_gemma4_reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,19 +249,87 @@ def test_gemma4_reasoning(
assert is_reasoning_end == param_dict["is_reasoning_end"]


def test_gemma4_adjust_request(generic_tokenizer):
def test_gemma4_adjust_request_thinking_enabled(generic_tokenizer):
"""When thinking is enabled (default) with no tools, skip_special_tokens is
set to False and skip_token_ids remains None."""
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
generic_tokenizer
)

request = ChatCompletionRequest(messages=[], model="test-model")
assert request.skip_special_tokens is True

result = parser.adjust_request(request)
assert result.skip_special_tokens is False
# No tools provided — default tool_choice='none' must not trigger suppression.
assert result.skip_token_ids is None
assert result is request


def test_gemma4_adjust_request_thinking_disabled(generic_tokenizer):
"""When thinking is disabled, skip_special_tokens is left unchanged (True)."""
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
generic_tokenizer
)
request = ChatCompletionRequest(
messages=[],
model="test-model",
chat_template_kwargs={"enable_thinking": False},
)
assert request.skip_special_tokens is True

result = parser.adjust_request(request)
assert result.skip_special_tokens is True
assert result.skip_token_ids is None
assert result is request


def test_gemma4_adjust_request_thinking_enabled_tool_choice_none(generic_tokenizer):
"""When thinking is enabled and tool_choice='none', skip_special_tokens is False
and skip_token_ids contains the tool-call delimiter token IDs."""
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
generic_tokenizer
)
request = ChatCompletionRequest(
messages=[],
model="test-model",
chat_template_kwargs={"enable_thinking": True},
tools=[{"type": "function", "function": {"name": "f", "parameters": {}}}],
tool_choice="none",
)

result = parser.adjust_request(request)
assert result.skip_special_tokens is False

vocab = generic_tokenizer.get_vocab()
expected_ids = {vocab["<|tool_call>"]}
if (end_id := vocab.get("<tool_call|>")) is not None:
expected_ids.add(end_id)
if (esc_id := vocab.get('<|"|>')) is not None:
expected_ids.add(esc_id)

assert result.skip_token_ids is not None
assert set(result.skip_token_ids) == expected_ids


def test_gemma4_adjust_request_thinking_enabled_tool_choice_auto(generic_tokenizer):
"""When thinking is enabled and tool_choice='auto', skip_special_tokens is False
but skip_token_ids is not set (tool parser handles the tokens)."""
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
generic_tokenizer
)
request = ChatCompletionRequest(
messages=[],
model="test-model",
chat_template_kwargs={"enable_thinking": True},
tools=[{"type": "function", "function": {"name": "f", "parameters": {}}}],
tool_choice="auto",
)

result = parser.adjust_request(request)
assert result.skip_special_tokens is False
assert result.skip_token_ids is None


def test_gemma4_previous_turn_reasoning_is_reasoning_end(generic_tokenizer):
output = (
"<|channel>thought\n1st thought<channel|>1st content<turn|>\n"
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/openai/chat_completion/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
ignore_eos: bool = False
min_tokens: int = 0
skip_special_tokens: bool = True
skip_token_ids: list[int] | None = None
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Annotated[int, Field(ge=-1, le=_INT64_MAX)] | None = None
prompt_logprobs: int | None = None
Expand Down Expand Up @@ -508,6 +509,7 @@ def to_sampling_params(
max_tokens=max_tokens,
min_tokens=self.min_tokens,
skip_special_tokens=self.skip_special_tokens,
skip_token_ids=self.skip_token_ids,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
output_kind=RequestOutputKind.DELTA
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/openai/responses/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class ResponsesRequest(OpenAIBaseModel):
truncation: Literal["auto", "disabled"] | None = "disabled"
user: str | None = None
skip_special_tokens: bool = True
skip_token_ids: list[int] | None = None
include_stop_str_in_output: bool = False
presence_penalty: float | None = Field(
default=None,
Expand Down Expand Up @@ -404,6 +405,7 @@ def to_sampling_params(
extra_args=extra_args,
skip_clone=True, # Created fresh per request, safe to skip clone
skip_special_tokens=self.skip_special_tokens,
skip_token_ids=self.skip_token_ids,
include_stop_str_in_output=self.include_stop_str_in_output,
)

Expand Down
26 changes: 25 additions & 1 deletion vllm/reasoning/gemma4_reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,36 @@ def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
self.new_turn_token_id = self.vocab["<|turn>"]
self.tool_call_token_id = self.vocab["<|tool_call>"]
self.tool_response_token_id = self.vocab["<|tool_response>"]
self._tool_call_end_token_id: int | None = self.vocab.get("<tool_call|>")
self._escape_token_id: int | None = self.vocab.get('<|"|>')

def adjust_request(
self, request: "ChatCompletionRequest | ResponsesRequest"
) -> "ChatCompletionRequest | ResponsesRequest":
"""Disable special-token stripping to preserve boundary tokens."""
"""Disable special-token stripping to preserve reasoning boundary tokens.

When ``tool_choice="none"`` we additionally mark the tool-call delimiter
tokens for selective suppression so they are stripped from the decoded
output without affecting the reasoning channel tokens.
"""
chat_template_kwargs = getattr(request, "chat_template_kwargs", None) or {}
if not chat_template_kwargs.get("enable_thinking", True):
return request
request.skip_special_tokens = False
if getattr(request, "tool_choice", None) == "none" and getattr(
request, "tools", None
):
Comment on lines +74 to +76

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The condition and getattr(request, "tools", None) unnecessarily restricts the suppression of tool-call tokens. When tool_choice is "none", the model is explicitly instructed not to use tools, yet it may still emit structural tool-call delimiters due to training bias or prompt artifacts. Since skip_special_tokens is set to False to preserve reasoning markers, these tool-call tokens will leak into the visible content unless explicitly suppressed. Removing the tools check ensures they are hidden whenever the tool parser is not active, providing a cleaner output.

        if getattr(request, "tool_choice", None) == "none":

tool_call_ids = [self.tool_call_token_id]
for optional_id in (
self._tool_call_end_token_id,
self._escape_token_id,
):
if optional_id is not None:
tool_call_ids.append(optional_id)
existing = list(getattr(request, "skip_token_ids", None) or [])
request.skip_token_ids = existing + [
tid for tid in tool_call_ids if tid not in existing
]
return request

@property
Expand Down
7 changes: 7 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ class SamplingParams(
"""Whether to detokenize the output."""
skip_special_tokens: bool = True
"""Whether to skip special tokens in the output."""
skip_token_ids: list[int] | None = None
"""If provided, these specific token IDs are always suppressed from the
detokenized output regardless of ``skip_special_tokens``. Useful for
selectively hiding certain special tokens (e.g. tool-call delimiters)
while preserving others (e.g. reasoning channel markers)."""
spaces_between_special_tokens: bool = True
"""Whether to add spaces between special tokens in the output."""
include_stop_str_in_output: bool = False
Expand Down Expand Up @@ -329,6 +334,7 @@ def from_optional(
prompt_logprobs: int | None = None,
detokenize: bool = True,
skip_special_tokens: bool = True,
skip_token_ids: list[int] | None = None,
spaces_between_special_tokens: bool = True,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
structured_outputs: StructuredOutputsParams | None = None,
Expand Down Expand Up @@ -370,6 +376,7 @@ def from_optional(
prompt_logprobs=prompt_logprobs,
detokenize=detokenize,
skip_special_tokens=skip_special_tokens,
skip_token_ids=skip_token_ids,
spaces_between_special_tokens=spaces_between_special_tokens,
output_kind=output_kind,
structured_outputs=structured_outputs,
Expand Down
18 changes: 18 additions & 0 deletions vllm/v1/engine/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreReques

self.request_id = request.request_id
self.skip_special_tokens = sampling_params.skip_special_tokens
skip_token_ids = sampling_params.skip_token_ids
self.skip_token_ids: frozenset[int] = (
frozenset(skip_token_ids) if skip_token_ids else frozenset()
)

self.tokenizer: Tokenizer = tokenizer._tokenizer

Expand Down Expand Up @@ -205,6 +209,10 @@ def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreReques
self.spaces_between_special_tokens = True

def decode_next(self, next_token_id: int) -> str:
if next_token_id in self.skip_token_ids:
self._protected_step(next_token_id)
return ""

token = self._protected_step(next_token_id)

if not self.spaces_between_special_tokens:
Expand Down Expand Up @@ -272,6 +280,10 @@ def __init__(self, tokenizer: TokenizerLike, request: EngineCoreRequest):
self.token_ids.extend(request.prompt_token_ids or [0] * self.prompt_len)

self.skip_special_tokens = params.skip_special_tokens
skip_token_ids = params.skip_token_ids
self.skip_token_ids: frozenset[int] = (
frozenset(skip_token_ids) if skip_token_ids else frozenset()
)
self.spaces_between_special_tokens = params.spaces_between_special_tokens

@property
Expand All @@ -284,6 +296,12 @@ def num_output_tokens(self) -> int:
return len(self.token_ids) - self.prompt_len

def decode_next(self, next_token_id: int) -> str:
if next_token_id in self.skip_token_ids:
self.tokens.append("")
self.prefix_offset = self.read_offset
self.read_offset = len(self.tokens)
return ""

new_tokens, decoded_text, prefix_offset, read_offset = detokenize_incrementally(
tokenizer=self.tokenizer,
all_input_ids=self.token_ids,
Expand Down
Loading