Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 11 additions & 137 deletions tests/entrypoints/openai/chat_completion/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,8 +1935,10 @@ async def result_generator():
finished=True,
)

# Collect tool-call deltas per choice from the SSE stream.
# Collect tool-call deltas and finish_reasons per choice from the SSE
# stream.
tc_deltas_by_choice: dict[int, list[dict]] = {i: [] for i in range(num_choices)}
finish_reasons_by_choice: dict[int, list[str]] = {i: [] for i in range(num_choices)}
async for chunk_str in serving_chat.chat_completion_stream_generator(
request=request,
result_generator=result_generator(),
Expand All @@ -1959,6 +1961,8 @@ async def result_generator():
if delta.get("tool_calls"):
for tc in delta["tool_calls"]:
tc_deltas_by_choice[idx].append(tc)
if choice.get("finish_reason") is not None:
finish_reasons_by_choice[idx].append(choice["finish_reason"])

# Both choices must independently produce the correct tool call.
for choice_idx in range(num_choices):
Expand All @@ -1984,141 +1988,11 @@ async def result_generator():
f"Choice {choice_idx}: expected {{'city': 'Tokyo'}}, got {parsed_args}"
)


class TestCreateRemainingArgsDelta:
"""Tests for _create_remaining_args_delta helper function.

This helper is used when streaming tool calls to preserve id/type/name
fields in the finish chunk, which would otherwise be lost.
"""

def test_preserves_id_type_name(self):
"""Test that id, type, and name are preserved from original delta."""
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
)

original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_abc123",
type="function",
function=DeltaFunctionCall(
name="get_weather",
arguments='{"location": "Paris"}',
),
)
]
reasons = finish_reasons_by_choice[choice_idx]
assert len(reasons) == 1, (
f"Choice {choice_idx}: expected exactly 1 finish_reason, got {reasons}"
)

result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '", "unit": "celsius"}', 0
)

assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 0
assert tc.id == "call_abc123"
assert tc.type == "function"
assert tc.function.name == "get_weather"
assert tc.function.arguments == '", "unit": "celsius"}'

def test_matches_by_index(self):
"""Test that the correct tool call is matched by index."""
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
)

original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_first",
type="function",
function=DeltaFunctionCall(name="func_a", arguments="{}"),
),
DeltaToolCall(
index=1,
id="call_second",
type="function",
function=DeltaFunctionCall(name="func_b", arguments="{}"),
),
]
Comment thread
sfeng33 marked this conversation as resolved.
)

result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '{"extra": true}', 1
assert reasons[0] == "tool_calls", (
f"Choice {choice_idx}: expected finish_reason='tool_calls', "
f"got '{reasons[0]}'"
)

assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 1
assert tc.id == "call_second"
assert tc.function.name == "func_b"

def test_no_matching_tool_call(self):
"""Test graceful handling when no matching tool call is found."""
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
)

original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_zero",
type="function",
function=DeltaFunctionCall(name="func", arguments="{}"),
)
]
)

result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '{"arg": 1}', 5
)

assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 5
assert tc.id is None
assert tc.type is None
assert tc.function.name is None
assert tc.function.arguments == '{"arg": 1}'

def test_function_is_none(self):
"""Test handling when original tool call has no function."""
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import DeltaMessage, DeltaToolCall

original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_nofunc",
type="function",
function=None,
)
]
)

result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '{"data": "value"}', 0
)

assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 0
assert tc.id == "call_nofunc"
assert tc.type == "function"
assert tc.function.name is None
assert tc.function.arguments == '{"data": "value"}'
83 changes: 83 additions & 0 deletions tests/parser/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,86 @@ def test_parse_delta_reasoning_only_thinking_disabled(tokenizer, request_obj):
assert "Hello" in content
assert "assist" in content
assert len(tool_calls) == 0


def test_parse_delta_finished_no_flush_without_tool_call_delta(tokenizer, request_obj):
"""When finished=True but the final parse_delta produces no
tool-call delta, unstreamed args are not flushed."""
parser = make_parser(tokenizer, reasoning=False, tool=True)

results = stream_text(
parser, tokenizer, MODEL_OUTPUT, request_obj, prompt_token_ids=[]
)
_, _, tool_calls = collect_fields(results)
assert len(tool_calls) > 0

streamed = parser._tool_parser.streamed_args_for_tool[0]
assert len(streamed) > 5
parser._tool_parser.streamed_args_for_tool[0] = streamed[:-5]

# Prevent normal extraction from catching the gap — without a
# tool-call delta to merge into, the flush is skipped.
parser._tool_parser.extract_tool_calls_streaming = lambda *a, **kw: None

flush_result = parser.parse_delta("", [], request_obj, finished=True)
assert flush_result is None or flush_result.tool_calls is None


def test_parse_delta_finished_no_extra_args_when_fully_streamed(tokenizer, request_obj):
"""When all args have been streamed, finished=True must not
produce extra or duplicate arguments."""
parser = make_parser(tokenizer, reasoning=False, tool=True)
results = stream_text(
parser, tokenizer, MODEL_OUTPUT, request_obj, prompt_token_ids=[]
)
_, _, tool_calls = collect_fields(results)

assert len(tool_calls) > 0
assert tool_calls[0].function.name == "get_weather"
tool_args = "".join(
tc.function.arguments for tc in tool_calls if tc.function.arguments
)
assert json.loads(tool_args) == {"city": "Dallas"}

flush_result = parser.parse_delta("", [], request_obj, finished=True)
assert flush_result is None or flush_result.tool_calls is None


def test_parse_delta_finished_appends_remaining_args(tokenizer, request_obj):
"""When finished=True and the tool parser has unstreamed args,
parse_delta appends the remaining arguments to the tool-call delta."""
parser = make_parser(tokenizer, reasoning=False, tool=True)
token_ids = tokenizer.encode(MODEL_OUTPUT, add_special_tokens=False)

remainder = ',"unit":"celsius"}'
prompt_ids: list[int] | None = []
results: list[DeltaMessage | None] = []
for i, tid in enumerate(token_ids):
prev = results[-1] if results else None
prev_had_args = (
prev
and prev.tool_calls
and any(tc.function and tc.function.arguments for tc in prev.tool_calls)
)

if prev_had_args:
parser._tool_parser.get_remaining_unstreamed_args = lambda: remainder

result = parser.parse_delta(
tokenizer.decode([tid]),
[tid],
request_obj,
prompt_token_ids=prompt_ids,
finished=prev_had_args,
)
prompt_ids = None
results.append(result)

if prev_had_args:
break

_, _, tool_calls = collect_fields(results)
tool_args = "".join(
tc.function.arguments for tc in tool_calls if tc.function.arguments
)
assert tool_args.endswith(remainder)
Loading
Loading