Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 15 additions & 22 deletions tests/tool_use/test_tool_choice_required.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from copy import deepcopy
from unittest.mock import MagicMock

import pytest
import regex as re
Expand All @@ -11,7 +10,7 @@
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.tool_parsers.streaming import extract_required_tool_call_streaming
from vllm.tool_parsers.utils import get_json_schema_from_tools

pytestmark = pytest.mark.cpu_test
Expand Down Expand Up @@ -281,8 +280,6 @@ def test_structured_outputs_json_without_parameters(
@pytest.mark.parametrize("empty_params", [False, True])
@pytest.mark.parametrize("delta_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
def test_streaming_output_valid(output, empty_params, delta_len):
self = MagicMock()

output = deepcopy(output)
if empty_params:
output = [{"name": o["name"], "parameters": {}} for o in output]
Expand All @@ -295,14 +292,13 @@ def test_streaming_output_valid(output, empty_params, delta_len):
delta_text = output_json[i : i + delta_len]
current_text = previous_text + delta_text

delta_message, function_name_returned = (
OpenAIServingChat.extract_tool_call_required_streaming(
self,
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned,
)
delta_message, function_name_returned = extract_required_tool_call_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned,
tool_call_idx=None,
tool_call_id_type="random",
)

if delta_message:
Expand Down Expand Up @@ -332,8 +328,6 @@ def test_streaming_output_valid(output, empty_params, delta_len):


def test_streaming_output_valid_with_trailing_extra_data():
self = MagicMock()

output = [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}]
output_json = json.dumps(output) + "\nDONE"

Expand All @@ -345,14 +339,13 @@ def test_streaming_output_valid_with_trailing_extra_data():
delta_text = output_json[i : i + delta_len]
current_text = previous_text + delta_text

delta_message, function_name_returned = (
OpenAIServingChat.extract_tool_call_required_streaming(
self,
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned,
)
delta_message, function_name_returned = extract_required_tool_call_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned,
tool_call_idx=None,
tool_call_id_type="random",
)

if delta_message:
Expand Down
179 changes: 10 additions & 169 deletions vllm/entrypoints/openai/chat_completion/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@
from vllm.renderers import ChatParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.streaming import (
extract_named_tool_call_streaming,
extract_required_tool_call_streaming,
)
from vllm.utils.collection_utils import as_list
from vllm.utils.mistral import is_mistral_tokenizer, is_mistral_tool_parser

Expand Down Expand Up @@ -389,23 +385,6 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
return self.response_role
return request.messages[-1]["role"]

def extract_tool_call_required_streaming(
self,
previous_text: str,
current_text: str | None,
delta_text: str,
function_name_returned: bool,
tool_call_idx: int | None = None,
) -> tuple[DeltaMessage | None, bool]:
return extract_required_tool_call_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned,
tool_call_idx=tool_call_idx,
tool_call_id_type=self.tool_call_id_type,
)

async def chat_completion_stream_generator(
self,
request: ChatCompletionRequest,
Expand Down Expand Up @@ -448,22 +427,7 @@ async def chat_completion_stream_generator(
and self._should_stream_with_auto_tool_parsing(request)
)

# Determine whether required/named tool_choice should fall back to
# the auto tool_parser path instead of the standard JSON-based parsing.
# This happens when the parser declares supports_required_and_named=False
# (e.g. GLM models that output XML instead of JSON).
tool_choice_uses_parser = (
self.tool_parser is not None
and not self.tool_parser.supports_required_and_named
and request.tools
and (
request.tool_choice == "required"
or isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
)
)

all_previous_token_ids: list[list[int]] | None
function_name_returned = [False] * num_choices
if self.tool_call_id_type == "kimi_k2":
history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
else:
Expand All @@ -477,10 +441,10 @@ async def chat_completion_stream_generator(
if (
is_mistral_grammar_path
or tool_choice_auto
or tool_choice_uses_parser
or tool_choice_function_name
or request.tool_choice == "required"
or reasoning_parser
):
# These are only required in "auto" tool choice case
all_previous_token_ids = [[] for _ in range(num_choices)]
reasoning_end_arr = [False] * num_choices
prompt_is_reasoning_end_arr: list[bool | None] = [None] * num_choices
Expand All @@ -501,6 +465,10 @@ async def chat_completion_stream_generator(
)
for _ in range(num_choices)
]
for p in parsers:
if p is not None:
p._stream_state.tool_call_id_type = self.tool_call_id_type
p._stream_state.history_tool_call_cnt = history_tool_call_cnt
else:
parsers = [None] * num_choices
except Exception as e:
Expand Down Expand Up @@ -677,7 +645,8 @@ async def chat_completion_stream_generator(
if (
is_mistral_grammar_path
or tool_choice_auto
or tool_choice_uses_parser
or tool_choice_function_name
or request.tool_choice == "required"
or reasoning_parser
):
assert previous_texts is not None
Expand Down Expand Up @@ -731,135 +700,6 @@ async def chat_completion_stream_generator(
current_token_ids = result.current_token_ids
if result.tools_called:
tools_streamed[i] = True
# handle streaming deltas for tools with named tool_choice
# Skip when tool_choice_uses_parser so it falls through
# to the auto tool_parser branches below.
elif tool_choice_function_name and not tool_choice_uses_parser:
# When encountering think end id in prompt_token_ids
# i.e {"enable_thinking": False},
# check BEFORE calling the parser to avoid a spurious
# reasoning delta on the first chunk.
if (
reasoning_parser
and not reasoning_end_arr[i]
and prompt_is_reasoning_end_arr[i]
):
reasoning_end_arr[i] = True

if (
reasoning_parser
and not reasoning_end_arr[i]
and not reasoning_parser.is_reasoning_end(
previous_token_ids
)
):
assert reasoning_parser is not None
delta_message = (
reasoning_parser.extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
output.token_ids,
)
)
# When encountering think end id in delta_token_ids,
# set reasoning status to end.
# Only keep 'content', remove 'reasoning'.
if reasoning_parser.is_reasoning_end(
as_list(output.token_ids)
):
reasoning_end_arr[i] = True
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
else:
# Just to add remaining `content`
if reasoning_parser:
delta_text = previous_text + delta_text
current_text = ""

delta_message, function_name_returned[i] = (
extract_named_tool_call_streaming(
delta_text=delta_text,
function_name=tool_choice_function_name,
function_name_returned=function_name_returned[i],
tool_call_idx=history_tool_call_cnt,
tool_call_id_type=self.tool_call_id_type,
tokenizer=tokenizer,
tool_call_array_index=i,
)
)
if (
delta_message
and delta_message.tool_calls
and delta_message.tool_calls[0].id is not None
):
history_tool_call_cnt += 1
tools_streamed[i] = True

# Skip when tool_choice_uses_parser so it falls through
# to the auto tool_parser branches below.
elif (
request.tool_choice == "required"
and not tool_choice_uses_parser
):
assert previous_texts is not None
previous_text = previous_texts[i]
current_text = previous_text + delta_text
fn_name_returned = function_name_returned[i]
output_token_ids = as_list(output.token_ids)

if (
reasoning_parser is not None
and not reasoning_end_arr[i]
and prompt_is_reasoning_end_arr[i]
):
reasoning_end_arr[i] = True

if reasoning_parser and not reasoning_end_arr[i]:
delta_message = (
reasoning_parser.extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
output_token_ids,
)
)
if reasoning_parser.is_reasoning_end(output_token_ids):
reasoning_end_arr[i] = True
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
# reasoning ended
current_text = ""

else:
# either finished reasoning or no reasoning at all
content = current_text

delta_message, function_name_returned[i] = (
self.extract_tool_call_required_streaming(
previous_text=previous_text,
current_text=content,
delta_text=delta_text,
function_name_returned=fn_name_returned,
tool_call_idx=history_tool_call_cnt,
)
)
if (
delta_message
and delta_message.tool_calls
and delta_message.tool_calls[0].id is not None
):
history_tool_call_cnt += 1
tools_streamed[i] = True

elif parser is not None:
delta_message = parser.parse_delta(
Expand All @@ -878,7 +718,8 @@ async def chat_completion_stream_generator(
if (
is_mistral_grammar_path
or tool_choice_auto
or tool_choice_uses_parser
or tool_choice_function_name
or request.tool_choice == "required"
or reasoning_parser
) and not self.use_harmony:
assert previous_texts is not None
Expand Down
20 changes: 16 additions & 4 deletions vllm/parser/abstract_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,9 +587,15 @@ def _extract_tool_calls_streaming(
tool_call_id_type: str = "random",
function_name_returned: bool = False,
) -> tuple[DeltaMessage | None, bool]:
if request.tool_choice and isinstance(
request.tool_choice,
(ToolChoiceFunction, ChatCompletionNamedToolChoiceParam),
assert self._tool_parser is not None
supports_required_and_named = self._tool_parser.supports_required_and_named
if (
supports_required_and_named
and request.tool_choice
and isinstance(
request.tool_choice,
(ToolChoiceFunction, ChatCompletionNamedToolChoiceParam),
)
):
delta_message, function_name_returned = extract_named_tool_call_streaming(
delta_text=delta_text,
Expand All @@ -601,7 +607,7 @@ def _extract_tool_calls_streaming(
)
return delta_message, function_name_returned

if request.tool_choice == "required":
if supports_required_and_named and request.tool_choice == "required":
delta_message, function_name_returned = (
extract_required_tool_call_streaming(
previous_text=previous_text,
Expand Down Expand Up @@ -706,6 +712,12 @@ def parse_delta(
function_name_returned=state.function_name_returned,
)
)
if (
delta_message
and delta_message.tool_calls
and delta_message.tool_calls[0].id is not None
):
state.history_tool_call_cnt += 1
Comment thread
sfeng33 marked this conversation as resolved.

# No phase active: pass through as content
if (
Expand Down
Loading