Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 76 additions & 6 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,12 @@ def serving_chat(self, mock_engine) -> OpenAIServingChat:
return chat

def mock_request_output_from_req_and_token_ids(
self, req: ChatCompletionRequest, token_ids: list[int], finished: bool = False
self,
req: ChatCompletionRequest,
token_ids: list[int],
finished: bool = False,
finish_reason: str | None = None,
stop_reason: int | str | None = None,
) -> RequestOutput:
# Our tests don't use most fields, so just get the token ids correct
completion_output = CompletionOutput(
Expand All @@ -1086,6 +1091,8 @@ def mock_request_output_from_req_and_token_ids(
token_ids=token_ids,
cumulative_logprob=0.0,
logprobs=None,
finish_reason=finish_reason,
stop_reason=stop_reason,
)
return RequestOutput(
request_id=req.request_id,
Expand Down Expand Up @@ -1130,18 +1137,27 @@ async def generate_response_from_harmony_str(
req: ChatCompletionRequest,
harmony_str: str,
stream: bool = False,
terminal_stream_chunk: bool = False,
) -> ChatCompletionResponse:
harmony_token_ids = get_encoding().encode(harmony_str, allowed_special="all")

async def result_generator():
if stream:
for token_id in harmony_token_ids:
if terminal_stream_chunk:
yield self.mock_request_output_from_req_and_token_ids(
req, [token_id]
req,
harmony_token_ids,
finished=True,
finish_reason="stop",
)
else:
for token_id in harmony_token_ids:
yield self.mock_request_output_from_req_and_token_ids(
req, [token_id]
)
yield self.mock_request_output_from_req_and_token_ids(
req, [], finished=True
)
yield self.mock_request_output_from_req_and_token_ids(
req, [], finished=True
)
else:
yield self.mock_request_output_from_req_and_token_ids(
req, harmony_token_ids, finished=True
Expand Down Expand Up @@ -1377,6 +1393,60 @@ async def test_tools_and_reasoning(
],
)

@pytest.mark.asyncio
@pytest.mark.skip_global_cleanup
async def test_streaming_terminal_chunk_recovers_analysis_tool_call(
self, serving_chat, weather_tools, weather_messages_start
):
tool_args_str = '{"location": "Paris"}'
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=list(weather_messages_start),
tools=weather_tools,
include_reasoning=False,
)
serving_chat.tool_parser = None
response_str = (
"<|start|>assistant to=functions.get_weather<|channel|>analysis"
f"<|constrain|>json<|message|>{tool_args_str}<|call|>"
)

response = await self.generate_response_from_harmony_str(
serving_chat,
req,
response_str,
stream=True,
terminal_stream_chunk=True,
)

verify_chat_response(response, tool_calls=[("get_weather", tool_args_str)])
assert response.choices[0].finish_reason == "tool_calls"

@pytest.mark.asyncio
@pytest.mark.skip_global_cleanup
async def test_streaming_terminal_chunk_does_not_promote_reasoning_to_content(
self, serving_chat, weather_tools, weather_messages_start
):
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=list(weather_messages_start),
tools=weather_tools,
include_reasoning=False,
)
serving_chat.tool_parser = None
response_str = "<|channel|>analysis<|message|>I'll think about it.<|end|>"

response = await self.generate_response_from_harmony_str(
serving_chat,
req,
response_str,
stream=True,
terminal_stream_chunk=True,
)

verify_chat_response(response)
assert response.choices[0].finish_reason == "stop"

@pytest.mark.asyncio
async def test_multi_turn_tools_and_reasoning(
self, serving_chat, stream, weather_tools, weather_messages_start
Expand Down
Loading