From 9e4263bdde303d9ae3d9ba0a3e36c5f0ec31859c Mon Sep 17 00:00:00 2001 From: jfrery Date: Fri, 6 Mar 2026 13:48:55 +0100 Subject: [PATCH] [Bugfix] Fix Harmony streaming tool call recovery across chunk and stop boundaries Signed-off-by: jfrery --- tests/entrypoints/openai/test_serving_chat.py | 82 ++++- .../test_serving_chat_stream_harmony.py | 268 ++++++++++----- .../openai/chat_completion/serving.py | 210 ++++++++++-- .../openai/chat_completion/stream_harmony.py | 321 +++++++++++------- 4 files changed, 631 insertions(+), 250 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index e1380d4290f8..16279af7c7f4 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1077,7 +1077,12 @@ def serving_chat(self, mock_engine) -> OpenAIServingChat: return chat def mock_request_output_from_req_and_token_ids( - self, req: ChatCompletionRequest, token_ids: list[int], finished: bool = False + self, + req: ChatCompletionRequest, + token_ids: list[int], + finished: bool = False, + finish_reason: str | None = None, + stop_reason: int | str | None = None, ) -> RequestOutput: # Our tests don't use most fields, so just get the token ids correct completion_output = CompletionOutput( @@ -1086,6 +1091,8 @@ def mock_request_output_from_req_and_token_ids( token_ids=token_ids, cumulative_logprob=0.0, logprobs=None, + finish_reason=finish_reason, + stop_reason=stop_reason, ) return RequestOutput( request_id=req.request_id, @@ -1130,18 +1137,27 @@ async def generate_response_from_harmony_str( req: ChatCompletionRequest, harmony_str: str, stream: bool = False, + terminal_stream_chunk: bool = False, ) -> ChatCompletionResponse: harmony_token_ids = get_encoding().encode(harmony_str, allowed_special="all") async def result_generator(): if stream: - for token_id in harmony_token_ids: + if terminal_stream_chunk: yield self.mock_request_output_from_req_and_token_ids( - req, [token_id] + req, + harmony_token_ids, + finished=True, + finish_reason="stop", + ) + else: + for token_id in harmony_token_ids: + yield self.mock_request_output_from_req_and_token_ids( + req, [token_id] + ) + yield self.mock_request_output_from_req_and_token_ids( + req, [], finished=True ) - yield self.mock_request_output_from_req_and_token_ids( - req, [], finished=True - ) else: yield self.mock_request_output_from_req_and_token_ids( req, harmony_token_ids, finished=True @@ -1377,6 +1393,60 @@ async def test_tools_and_reasoning( ], ) + @pytest.mark.asyncio + @pytest.mark.skip_global_cleanup + async def test_streaming_terminal_chunk_recovers_analysis_tool_call( + self, serving_chat, weather_tools, weather_messages_start + ): + tool_args_str = '{"location": "Paris"}' + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=list(weather_messages_start), + tools=weather_tools, + include_reasoning=False, + ) + serving_chat.tool_parser = None + response_str = ( + "<|start|>assistant to=functions.get_weather<|channel|>analysis" + f"<|constrain|>json<|message|>{tool_args_str}<|call|>" + ) + + response = await self.generate_response_from_harmony_str( + serving_chat, + req, + response_str, + stream=True, + terminal_stream_chunk=True, + ) + + verify_chat_response(response, tool_calls=[("get_weather", tool_args_str)]) + assert response.choices[0].finish_reason == "tool_calls" + + @pytest.mark.asyncio + @pytest.mark.skip_global_cleanup + async def test_streaming_terminal_chunk_does_not_promote_reasoning_to_content( + self, serving_chat, weather_tools, weather_messages_start + ): + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=list(weather_messages_start), + tools=weather_tools, + include_reasoning=False, + ) + serving_chat.tool_parser = None + response_str = "<|channel|>analysis<|message|>I'll think about it.<|end|>" + + response = await self.generate_response_from_harmony_str( + serving_chat, + req, + response_str, + stream=True, + terminal_stream_chunk=True, + ) + + verify_chat_response(response) + assert response.choices[0].finish_reason == "stop" + @pytest.mark.asyncio async def test_multi_turn_tools_and_reasoning( self, serving_chat, stream, weather_tools, weather_messages_start diff --git a/tests/entrypoints/openai/test_serving_chat_stream_harmony.py b/tests/entrypoints/openai/test_serving_chat_stream_harmony.py index 9f8c36f0473d..3448d56b5237 100644 --- a/tests/entrypoints/openai/test_serving_chat_stream_harmony.py +++ b/tests/entrypoints/openai/test_serving_chat_stream_harmony.py @@ -10,7 +10,7 @@ import pytest from vllm.entrypoints.openai.chat_completion.stream_harmony import ( - TokenState, + HarmonyStreamingState, extract_harmony_streaming_delta, ) @@ -21,6 +21,7 @@ class MockMessage: channel: str | None = None recipient: str | None = None + content: str = "" @dataclass @@ -28,6 +29,9 @@ class MockStreamableParser: """Mock StreamableParser for testing without openai_harmony dependency.""" messages: list[MockMessage] = field(default_factory=list) + current_channel: str | None = None + current_recipient: str | None = None + current_content: str | None = None class TestExtractHarmonyStreamingDelta: @@ -43,19 +47,21 @@ class TestExtractHarmonyStreamingDelta: def test_final_channel_returns_content_delta(self, delta_text, expected_content): """Test that final channel returns a DeltaMessage with content.""" parser = MockStreamableParser() - - # Updated to use TokenState list - token_states = [TokenState(channel="final", recipient=None, text=delta_text)] + parser.current_channel = "final" + parser.current_content = delta_text + stream_state = HarmonyStreamingState() delta_message, tools_streamed = extract_harmony_streaming_delta( harmony_parser=parser, - token_states=token_states, - prev_recipient=None, + stream_state=stream_state, include_reasoning=False, ) - assert delta_message is not None - assert delta_message.content == expected_content + if expected_content: + assert delta_message is not None + assert delta_message.content == expected_content + else: + assert delta_message is None assert tools_streamed is False @pytest.mark.parametrize( @@ -69,12 +75,13 @@ def test_analysis_channel_reasoning(self, include_reasoning, expected_has_messag """Test analysis channel respects include_reasoning flag.""" parser = MockStreamableParser() text = "Let me think..." - token_states = [TokenState(channel="analysis", recipient=None, text=text)] + parser.current_channel = "analysis" + parser.current_content = text + stream_state = HarmonyStreamingState() delta_message, tools_streamed = extract_harmony_streaming_delta( harmony_parser=parser, - token_states=token_states, - prev_recipient=None, + stream_state=stream_state, include_reasoning=include_reasoning, ) @@ -91,15 +98,14 @@ def test_new_tool_call(self, mock_make_tool_call_id, channel): """Test new tool call creation when recipient changes.""" mock_make_tool_call_id.return_value = "call_test123" parser = MockStreamableParser() - - token_states = [ - TokenState(channel=channel, recipient="functions.get_weather", text="") - ] + parser.current_channel = channel + parser.current_recipient = "functions.get_weather" + parser.current_content = "" + stream_state = HarmonyStreamingState() delta_message, tools_streamed = extract_harmony_streaming_delta( harmony_parser=parser, - token_states=token_states, - prev_recipient=None, + stream_state=stream_state, include_reasoning=False, ) @@ -118,17 +124,19 @@ def test_tool_call_argument_streaming(self, channel): """Test streaming tool call arguments (same recipient).""" parser = MockStreamableParser() args_text = '{"location": "Paris"}' - - token_states = [ - TokenState( - channel=channel, recipient="functions.get_weather", text=args_text - ) - ] + parser.current_channel = channel + parser.current_recipient = "functions.get_weather" + parser.current_content = args_text + stream_state = HarmonyStreamingState( + prev_current_signature=(channel, "functions.get_weather"), + prev_current_emitted_len=0, + prev_current_tool_index=0, + prev_current_tool_header_emitted=True, + ) delta_message, tools_streamed = extract_harmony_streaming_delta( harmony_parser=parser, - token_states=token_states, - prev_recipient="functions.get_weather", + stream_state=stream_state, include_reasoning=False, ) @@ -143,15 +151,19 @@ def test_tool_call_argument_streaming(self, channel): def test_tool_call_empty_arguments_returns_none(self, channel): """Test empty delta_text with same recipient returns None.""" parser = MockStreamableParser() - - token_states = [ - TokenState(channel=channel, recipient="functions.get_weather", text="") - ] + parser.current_channel = channel + parser.current_recipient = "functions.get_weather" + parser.current_content = "" + stream_state = HarmonyStreamingState( + prev_current_signature=(channel, "functions.get_weather"), + prev_current_emitted_len=0, + prev_current_tool_index=0, + prev_current_tool_header_emitted=True, + ) delta_message, tools_streamed = extract_harmony_streaming_delta( harmony_parser=parser, - token_states=token_states, - prev_recipient="functions.get_weather", + stream_state=stream_state, include_reasoning=False, ) @@ -166,15 +178,17 @@ def test_tool_call_index_from_previous_messages(self): MockMessage(channel="final", recipient=None), # Not counted ] parser = MockStreamableParser(messages=messages) - - token_states = [ - TokenState(channel="commentary", recipient="functions.tool2", text="args") - ] + parser.current_channel = "commentary" + parser.current_recipient = "functions.tool2" + parser.current_content = "args" + stream_state = HarmonyStreamingState( + emitted_message_count=3, + next_tool_call_index=1, + ) delta_message, _ = extract_harmony_streaming_delta( harmony_parser=parser, - token_states=token_states, - prev_recipient="functions.tool2", + stream_state=stream_state, include_reasoning=False, ) @@ -184,15 +198,13 @@ def test_returns_preambles_as_content(self): """Test that commentary with no recipient (preamble) is user content.""" parser = MockStreamableParser() delta_text = "some text" - - token_states = [ - TokenState(channel="commentary", recipient=None, text=delta_text) - ] + parser.current_channel = "commentary" + parser.current_content = delta_text + stream_state = HarmonyStreamingState() delta_message, tools_streamed = extract_harmony_streaming_delta( harmony_parser=parser, - token_states=token_states, - prev_recipient=None, + stream_state=stream_state, include_reasoning=True, ) @@ -210,19 +222,22 @@ def test_returns_preambles_as_content(self): def test_returns_none_for_invalid_inputs(self, channel, recipient): """Test that invalid channel/recipient combinations return None.""" parser = MockStreamableParser() - - token_states = [ - TokenState(channel=channel, recipient=recipient, text="some text") - ] + parser.current_channel = channel + parser.current_recipient = recipient + parser.current_content = "some text" + stream_state = HarmonyStreamingState() delta_message, tools_streamed = extract_harmony_streaming_delta( harmony_parser=parser, - token_states=token_states, - prev_recipient=None, + stream_state=stream_state, include_reasoning=True, ) - assert delta_message is None + if channel == "commentary" and recipient is not None: + assert delta_message is not None + assert delta_message.content == "some text" + else: + assert delta_message is None assert tools_streamed is False def test_consecutive_token_grouping(self): @@ -231,18 +246,13 @@ def test_consecutive_token_grouping(self): are merged into a single processing group. """ parser = MockStreamableParser() - token_states = [ - TokenState("final", None, "H"), - TokenState("final", None, "el"), - TokenState("final", None, "lo"), - TokenState("final", None, ","), - TokenState("final", None, " World"), - ] + parser.current_channel = "final" + parser.current_content = "Hello, World" + stream_state = HarmonyStreamingState() delta_message, _ = extract_harmony_streaming_delta( harmony_parser=parser, - token_states=token_states, - prev_recipient=None, + stream_state=stream_state, include_reasoning=False, ) @@ -258,21 +268,21 @@ def test_complex_batch_permutation(self, mock_make_id): """ mock_make_id.return_value = "call_batch_test" parser = MockStreamableParser() - - token_states = [ - # 1. Reasoning - TokenState("analysis", None, "Reasoning about query..."), - # 2. Tool Calling - TokenState("commentary", "functions.search", '{"query":'), - TokenState("commentary", "functions.search", ' "vllm"}'), - # 3. Final Content - TokenState("final", None, "."), + parser.messages = [ + MockMessage(channel="analysis", content="Reasoning about query..."), + MockMessage( + channel="commentary", + recipient="functions.search", + content='{"query": "vllm"}', + ), ] + parser.current_channel = "final" + parser.current_content = "." + stream_state = HarmonyStreamingState() delta_message, tools_streamed = extract_harmony_streaming_delta( harmony_parser=parser, - token_states=token_states, - prev_recipient=None, + stream_state=stream_state, include_reasoning=True, ) @@ -304,25 +314,34 @@ def test_tool_call_index_consistency_with_ongoing_call(self, mock_make_id): Test that an ongoing tool call continuation and subsequent new calls maintain correct indexing when interleaved with content. """ - mock_make_id.side_effect = ["id_b", "id_c"] + mock_make_id.side_effect = ["id_prev", "id_a", "id_b", "id_c"] messages = [ - MockMessage(channel="commentary", recipient="functions.previous_tool") + MockMessage(channel="commentary", recipient="functions.previous_tool"), + MockMessage( + channel="commentary", + recipient="functions.tool_a", + content='{"key_a": "val_a"}', + ), + MockMessage(channel="final", content="Thinking..."), + MockMessage( + channel="commentary", + recipient="functions.tool_b", + content='{"key_b": "val_b"}', + ), + MockMessage(channel="final", content=" Thinking again..."), + MockMessage( + channel="commentary", + recipient="functions.tool_c", + content='{"key_c": "val_c"}', + ), ] parser = MockStreamableParser(messages=messages) - - token_states = [ - TokenState("commentary", "functions.tool_a", '{"key_a": "val_a"}'), - TokenState("final", None, "Thinking..."), - TokenState("commentary", "functions.tool_b", '{"key_b": "val_b"}'), - TokenState("final", None, " Thinking again..."), - TokenState("commentary", "functions.tool_c", '{"key_c": "val_c"}'), - ] + stream_state = HarmonyStreamingState() delta_message, _ = extract_harmony_streaming_delta( harmony_parser=parser, - token_states=token_states, - prev_recipient="functions.tool_a", + stream_state=stream_state, include_reasoning=False, ) @@ -330,8 +349,8 @@ def test_tool_call_index_consistency_with_ongoing_call(self, mock_make_id): tool_a_deltas = [t for t in delta_message.tool_calls if t.index == 1] assert len(tool_a_deltas) > 0 - assert tool_a_deltas[0].id is None - assert tool_a_deltas[0].function.arguments == '{"key_a": "val_a"}' + tool_a_args = next(t for t in tool_a_deltas if t.id is None) + assert tool_a_args.function.arguments == '{"key_a": "val_a"}' tool_b_header = next(t for t in delta_message.tool_calls if t.id == "id_b") assert tool_b_header.index == 2 @@ -348,3 +367,84 @@ def test_tool_call_index_consistency_with_ongoing_call(self, mock_make_id): assert tool_c_args.function.arguments == '{"key_c": "val_c"}' assert delta_message.content == "Thinking... Thinking again..." + + def test_carryover_avoids_reemitting_streamed_content(self): + """ + Simulate two consecutive calls. Call 1 streams partial content + from the in-progress message. Call 2 sees that message completed + and should only emit the new portion. + """ + parser = MockStreamableParser() + parser.current_channel = "final" + parser.current_content = "Hello" + stream_state = HarmonyStreamingState() + + delta_message, _ = extract_harmony_streaming_delta( + harmony_parser=parser, + stream_state=stream_state, + include_reasoning=False, + ) + + assert delta_message is not None + assert delta_message.content == "Hello" + + parser.messages = [ + MockMessage(channel="final", content="Hello, World"), + ] + parser.current_channel = None + parser.current_content = None + + delta_message, _ = extract_harmony_streaming_delta( + harmony_parser=parser, + stream_state=stream_state, + include_reasoning=False, + ) + + assert delta_message is not None + assert delta_message.content == ", World" + + @patch("vllm.entrypoints.openai.chat_completion.stream_harmony.make_tool_call_id") + def test_carryover_tool_call_avoids_reemitting_header(self, mock_make_id): + """ + Simulate two consecutive calls for a tool call. Call 1 streams + the header + partial args. Call 2 sees the message completed and + should only emit the remaining args without a duplicate header. + """ + mock_make_id.return_value = "call_carry" + parser = MockStreamableParser() + parser.current_channel = "commentary" + parser.current_recipient = "functions.search" + parser.current_content = '{"q":' + stream_state = HarmonyStreamingState() + + delta_message, tools_streamed = extract_harmony_streaming_delta( + harmony_parser=parser, + stream_state=stream_state, + include_reasoning=False, + ) + + assert tools_streamed is True + assert len(delta_message.tool_calls) == 2 + assert delta_message.tool_calls[0].id == "call_carry" + assert delta_message.tool_calls[1].function.arguments == '{"q":' + + parser.messages = [ + MockMessage( + channel="commentary", + recipient="functions.search", + content='{"q": "vllm"}', + ), + ] + parser.current_channel = None + parser.current_content = None + + delta_message, tools_streamed = extract_harmony_streaming_delta( + harmony_parser=parser, + stream_state=stream_state, + include_reasoning=False, + ) + + assert tools_streamed is True + assert len(delta_message.tool_calls) == 1 + assert delta_message.tool_calls[0].id is None + assert delta_message.tool_calls[0].function.arguments == ' "vllm"}' diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 08c783f87d83..ca4cdc6080e6 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -35,7 +35,7 @@ ChatMessage, ) from vllm.entrypoints.openai.chat_completion.stream_harmony import ( - TokenState, + HarmonyStreamingState, extract_harmony_streaming_delta, ) from vllm.entrypoints.openai.engine.protocol import ( @@ -62,6 +62,7 @@ get_system_message, parse_chat_inputs_to_harmony_messages, parse_chat_output, + parse_output_into_messages, render_for_completion, ) from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls @@ -298,7 +299,7 @@ async def render_chat_request( ) else: # For GPT-OSS. - should_include_tools = tool_dicts is not None + should_include_tools = bool(tool_dicts) conversation, engine_prompts = self._make_request_with_harmony( request, should_include_tools ) @@ -621,13 +622,20 @@ async def chat_completion_stream_generator( harmony_parsers = [ get_streamable_parser_for_assistant() for _ in range(num_choices) ] + harmony_stream_states = [ + HarmonyStreamingState() for _ in range(num_choices) + ] harmony_tools_streamed = [False] * num_choices + harmony_fallback_to_text = [False] * num_choices + harmony_all_token_ids: list[list[int]] = [[] for _ in range(num_choices)] + harmony_visible_output_streamed = [False] * num_choices tools_streamed = [False] * num_choices if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): tool_choice_function_name = request.tool_choice.function.name else: tool_choice_function_name = None + harmony_allow_tool_calls = bool(request.tools) # Determine whether tools are in use with "auto" tool choice tool_choice_auto = ( @@ -776,6 +784,8 @@ async def chat_completion_stream_generator( for output in res.outputs: i = output.index tool_parser = tool_parsers[i] + if self.use_harmony and output.token_ids: + harmony_all_token_ids[i].extend(as_list(output.token_ids)) if ( reasoning_parser @@ -803,29 +813,27 @@ async def chat_completion_stream_generator( logprobs = None if self.use_harmony: - harmony_parser = harmony_parsers[i] - prev_recipient = harmony_parser.current_recipient - - # Track accumulated content per token with their state - token_states: list[TokenState] = [] - for token_id in output.token_ids: - harmony_parser.process(token_id) - token_delta = harmony_parser.last_content_delta or "" - token_states.append( - TokenState( - harmony_parser.current_channel, - harmony_parser.current_recipient, - token_delta, + if harmony_fallback_to_text[i]: + delta_text = output.text + else: + harmony_parser = harmony_parsers[i] + delta_text = "" + try: + for token_id in output.token_ids: + harmony_parser.process(token_id) + delta_text += ( + harmony_parser.last_content_delta or "" + ) + except Exception as e: + logger.warning( + "Harmony parser failed in streaming; " + "falling back to raw text deltas. " + "choice=%d, error=%s", + i, + e, ) - ) - delta_text = "".join(delta for _, _, delta in token_states) - cur_channel = harmony_parser.current_channel - - # handle the case where several tokens where generated at once - # including the final token, leading to a delta in the text - # but the current channel to be empty (start state) - if not cur_channel and delta_text: - cur_channel = "final" + harmony_fallback_to_text[i] = True + delta_text = output.text else: delta_text = output.text @@ -855,14 +863,35 @@ async def chat_completion_stream_generator( current_token_ids = as_list(output.token_ids) if self.use_harmony: - delta_message, tools_streamed_flag = ( - extract_harmony_streaming_delta( - harmony_parser=harmony_parser, - token_states=token_states, - prev_recipient=prev_recipient, - include_reasoning=request.include_reasoning, + if harmony_fallback_to_text[i]: + delta_message = DeltaMessage(content=delta_text) + tools_streamed_flag = False + else: + delta_message, tools_streamed_flag = ( + extract_harmony_streaming_delta( + harmony_parser=harmony_parser, + stream_state=harmony_stream_states[i], + include_reasoning=request.include_reasoning, + ) ) - ) + if ( + not harmony_allow_tool_calls + and delta_message is not None + and delta_message.tool_calls + ): + if ( + delta_message.role is not None + or delta_message.content is not None + or delta_message.reasoning is not None + ): + delta_message = DeltaMessage( + role=delta_message.role, + content=delta_message.content, + reasoning=delta_message.reasoning, + ) + else: + delta_message = None + tools_streamed_flag = False harmony_tools_streamed[i] |= tools_streamed_flag # handle streaming deltas for tools with named tool_choice elif tool_choice_function_name: @@ -1138,6 +1167,15 @@ async def chat_completion_stream_generator( continue delta_message = DeltaMessage() + if self.use_harmony and ( + ( + delta_message.content is not None + and delta_message.content != "" + ) + or bool(delta_message.tool_calls) + ): + harmony_visible_output_streamed[i] = True + # Log streaming delta if output logging is enabled if self.enable_log_outputs and self.request_logger: delta_content_parts = [] @@ -1202,10 +1240,14 @@ async def chat_completion_stream_generator( index = 0 if ( - self._should_check_for_unstreamed_tool_arg_tokens( + not self.use_harmony + and auto_tools_called + and tool_parser is not None + and index < len(tool_parser.prev_tool_call_arr) + and index < len(tool_parser.streamed_args_for_tool) + and self._should_check_for_unstreamed_tool_arg_tokens( delta_message, output ) - and tool_parser ): latest_delta_len = 0 if ( @@ -1252,6 +1294,24 @@ async def chat_completion_stream_generator( delta_message, remaining_call, index ) + if self.use_harmony and not harmony_visible_output_streamed[i]: + recovered_delta, recovered_tools_called = ( + self._recover_harmony_terminal_delta( + request=request, + tokenizer=tokenizer, + token_ids=harmony_all_token_ids[i], + ) + ) + if recovered_delta is not None: + delta_message = recovered_delta + if ( + recovered_delta.content is not None + and recovered_delta.content != "" + ) or recovered_delta.tool_calls: + harmony_visible_output_streamed[i] = True + if recovered_tools_called: + harmony_tools_streamed[i] = True + # Send the finish response for each request.n only once # In OpenAI's API, when a tool is called, the # finish_reason is: @@ -1421,7 +1481,7 @@ async def chat_completion_full_generator( if not request.include_reasoning: reasoning = None - if self.tool_parser is not None: + if self.tool_parser is not None and request.tools: if tokenizer is None: raise ValueError( "Tokenizer not available when `skip_tokenizer_init=True`" @@ -1915,6 +1975,88 @@ def _create_remaining_args_delta( ] ) + def _recover_harmony_terminal_delta( + self, + *, + request: ChatCompletionRequest, + tokenizer: TokenizerLike | None, + token_ids: list[int], + ) -> tuple[DeltaMessage | None, bool]: + """ + Recover a final Harmony delta from aggregate token_ids when streaming + produced no visible content/tool delta for a choice. + """ + if not token_ids: + return None, False + + recovered_content: str | None = None + if self.tool_parser is not None and request.tools and tokenizer is not None: + try: + tool_parser = self.tool_parser(tokenizer) + tool_call_info = tool_parser.extract_tool_calls( + "", + request=request, + token_ids=token_ids, # type: ignore + ) + except Exception: + logger.exception("Failed to recover Harmony terminal tool calls.") + else: + if tool_call_info.tools_called and tool_call_info.tool_calls: + recovered_tool_calls: list[DeltaToolCall] = [] + for idx, tc in enumerate(tool_call_info.tool_calls): + fn = tc.function + recovered_tool_calls.append( + DeltaToolCall( + index=idx, + id=tc.id, + type=tc.type, + function=DeltaFunctionCall( + name=fn.name if fn else None, + arguments=fn.arguments if fn else None, + ), + ) + ) + return ( + DeltaMessage( + content=tool_call_info.content, + tool_calls=recovered_tool_calls, + ), + True, + ) + recovered_content = tool_call_info.content + + try: + harmony_parser = parse_output_into_messages(token_ids) + except Exception: + logger.exception("Failed to parse Harmony terminal output from token IDs.") + return None, False + + delta_message, tools_called = extract_harmony_streaming_delta( + harmony_parser=harmony_parser, + stream_state=HarmonyStreamingState(), + include_reasoning=request.include_reasoning, + ) + + if delta_message is not None and delta_message.tool_calls and not request.tools: + if ( + delta_message.role is not None + or delta_message.content is not None + or delta_message.reasoning is not None + ): + delta_message = DeltaMessage( + role=delta_message.role, + content=delta_message.content, + reasoning=delta_message.reasoning, + ) + else: + delta_message = None + tools_called = False + + if delta_message is None and recovered_content: + return DeltaMessage(content=recovered_content), False + + return delta_message, tools_called + def _make_request_with_harmony( self, request: ChatCompletionRequest, diff --git a/vllm/entrypoints/openai/chat_completion/stream_harmony.py b/vllm/entrypoints/openai/chat_completion/stream_harmony.py index 87f2f9b92275..07c053627829 100644 --- a/vllm/entrypoints/openai/chat_completion/stream_harmony.py +++ b/vllm/entrypoints/openai/chat_completion/stream_harmony.py @@ -7,6 +7,7 @@ harmony parser state during streaming chat completions. """ +from dataclasses import dataclass from typing import NamedTuple from openai_harmony import StreamableParser @@ -25,147 +26,215 @@ class TokenState(NamedTuple): text: str +@dataclass +class HarmonyStreamingState: + emitted_message_count: int = 0 + next_tool_call_index: int = 0 + prev_current_signature: tuple[str | None, str | None] | None = None + prev_current_emitted_len: int = 0 + prev_current_tool_index: int | None = None + prev_current_tool_header_emitted: bool = False + + +def _is_function_tool_message(channel: str | None, recipient: str | None) -> bool: + return ( + channel in ("commentary", "analysis") + and recipient is not None + and recipient.startswith("functions.") + ) + + +def _extract_message_text(message) -> str: + content = getattr(message, "content", None) + if isinstance(content, str): + return content + if isinstance(content, list): + texts: list[str] = [] + for part in content: + if isinstance(part, str): + texts.append(part) + else: + text = getattr(part, "text", None) + if isinstance(text, str): + texts.append(text) + return "".join(texts) + + text = getattr(message, "text", None) + if isinstance(text, str): + return text + + return "" + + +def _append_tool_deltas( + *, + tool_messages: list[DeltaToolCall], + recipient: str, + tool_index: int, + emit_header: bool, + args_delta: str, +) -> None: + if emit_header: + tool_messages.append( + DeltaToolCall( + id=make_tool_call_id(), + type="function", + function=DeltaFunctionCall( + name=recipient.split("functions.", 1)[1], + arguments="", + ), + index=tool_index, + ) + ) + + if args_delta: + tool_messages.append( + DeltaToolCall( + index=tool_index, + function=DeltaFunctionCall(arguments=args_delta), + ) + ) + + def extract_harmony_streaming_delta( harmony_parser: StreamableParser, - token_states: list[TokenState], - prev_recipient: str | None, + stream_state: HarmonyStreamingState, include_reasoning: bool, ) -> tuple[DeltaMessage | None, bool]: """ Extract a DeltaMessage from harmony parser state during streaming. - Args: - harmony_parser: The StreamableParser instance tracking parse state - token_states: List of TokenState tuples for each token - prev_recipient: Previous recipient for detecting tool call transitions - include_reasoning: Whether to include reasoning content - - Returns: - A tuple of (DeltaMessage or None, tools_streamed_flag) + Unlike the previous token-group heuristic, this function diffs parser-level + completed messages and parser current_content against persistent state. + This makes tool call indexing and argument streaming robust across arbitrary + chunk boundaries, including repeated recipients. """ - if not token_states: - return None, False - - tools_streamed = False - - # Group consecutive tokens with same channel/recipient - groups: list[TokenState] = [] + tool_messages: list[DeltaToolCall] = [] + combined_content = "" + combined_reasoning = "" + content_encountered = False - current_channel = token_states[0].channel - current_recipient = token_states[0].recipient - current_text = token_states[0].text + prev_signature = stream_state.prev_current_signature + prev_emitted_len = stream_state.prev_current_emitted_len + prev_tool_index = stream_state.prev_current_tool_index + prev_tool_header_emitted = stream_state.prev_current_tool_header_emitted - for i in range(1, len(token_states)): - state = token_states[i] - if state.channel == current_channel and state.recipient == current_recipient: - current_text += state.text - else: - groups.append(TokenState(current_channel, current_recipient, current_text)) - current_channel = state.channel - current_recipient = state.recipient - current_text = state.text + carryover_consumed = False - groups.append(TokenState(current_channel, current_recipient, current_text)) + for msg in harmony_parser.messages[stream_state.emitted_message_count :]: + channel = msg.channel + recipient = msg.recipient + signature = (channel, recipient) + msg_text = _extract_message_text(msg) - # Process each group and create delta messages - delta_message = None - combined_content = "" - combined_reasoning = "" - tool_messages = [] - content_encountered = False + already_emitted_len = 0 + tool_index: int | None = None + tool_header_emitted = False - # Calculate base_index once before the loop - # This counts completed tool calls in messages - base_index = 0 - for msg in harmony_parser.messages: if ( - (msg.channel == "commentary" or msg.channel == "analysis") - and msg.recipient - and msg.recipient.startswith("functions.") - ): - base_index += 1 - - # If there's an ongoing tool call from previous chunk, - # the next new tool call starts at base_index + 1 - if prev_recipient and prev_recipient.startswith("functions."): - next_tool_index = base_index + 1 - # Ongoing call is at base_index - ongoing_tool_index = base_index - else: - # No ongoing call, next new call is at base_index - next_tool_index = base_index - ongoing_tool_index = None - - for group in groups: - if group.channel == "final": - combined_content += group.text - content_encountered = True - elif ( - (group.channel == "commentary" or group.channel == "analysis") - and group.recipient - and group.recipient.startswith("functions.") + not carryover_consumed + and prev_signature is not None + and signature == prev_signature ): - opened_new_call = False - if prev_recipient != group.recipient: - # New tool call - emit the opening message - tool_name = group.recipient.split("functions.", 1)[1] - tool_messages.append( - DeltaToolCall( - id=make_tool_call_id(), - type="function", - function=DeltaFunctionCall( - name=tool_name, - arguments="", - ), - index=next_tool_index, - ) - ) - opened_new_call = True - prev_recipient = group.recipient - # Increment for subsequent new tool calls - next_tool_index += 1 - - if group.text: - # Stream arguments for the ongoing tool call - if opened_new_call: - # Just opened in this group - tool_call_index = next_tool_index - 1 - else: - # Continuing from previous chunk - # If ongoing_tool_index is None here, it means - # we're continuing a call but prev_recipient - # wasn't a function. Use base_index. - tool_call_index = ( - ongoing_tool_index - if ongoing_tool_index is not None - else base_index - ) - tool_messages.append( - DeltaToolCall( - index=tool_call_index, - function=DeltaFunctionCall(arguments=group.text), - ) - ) - elif group.channel == "commentary" and group.recipient is None: - # Tool call preambles meant to be shown to the user - combined_content += group.text - content_encountered = True - elif group.channel == "analysis" and include_reasoning: - combined_reasoning += group.text - - # Combine all non-empty fields into a single message - if content_encountered or combined_reasoning or tool_messages: - delta_kwargs: dict[str, str | list[DeltaToolCall]] = {} - if content_encountered: - delta_kwargs["content"] = combined_content - if combined_reasoning: - delta_kwargs["reasoning"] = combined_reasoning - if tool_messages: - delta_kwargs["tool_calls"] = tool_messages - tools_streamed = True - delta_message = DeltaMessage(**delta_kwargs) + already_emitted_len = min(prev_emitted_len, len(msg_text)) + tool_index = prev_tool_index + tool_header_emitted = prev_tool_header_emitted + carryover_consumed = True + + delta_text = msg_text[already_emitted_len:] + + if _is_function_tool_message(channel, recipient): + assert recipient is not None + if tool_index is None: + tool_index = stream_state.next_tool_call_index + stream_state.next_tool_call_index += 1 + _append_tool_deltas( + tool_messages=tool_messages, + recipient=recipient, + tool_index=tool_index, + emit_header=not tool_header_emitted, + args_delta=delta_text, + ) + elif channel in ("final", "commentary"): + if delta_text: + combined_content += delta_text + content_encountered = True + elif channel == "analysis" and include_reasoning and delta_text: + combined_reasoning += delta_text + + stream_state.emitted_message_count = len(harmony_parser.messages) + + current_channel = harmony_parser.current_channel + current_recipient = harmony_parser.current_recipient + current_signature: tuple[str | None, str | None] | None = None + if current_channel is not None: + current_signature = (current_channel, current_recipient) + current_content = harmony_parser.current_content or "" + + if current_signature is None: + stream_state.prev_current_signature = None + stream_state.prev_current_emitted_len = 0 + stream_state.prev_current_tool_index = None + stream_state.prev_current_tool_header_emitted = False else: - delta_message = None + same_current_message = ( + not carryover_consumed + and prev_signature is not None + and current_signature == prev_signature + and len(current_content) >= prev_emitted_len + ) + + if same_current_message: + already_emitted_len = prev_emitted_len + current_tool_index = prev_tool_index + current_header_emitted = prev_tool_header_emitted + else: + already_emitted_len = 0 + current_tool_index = None + current_header_emitted = False + + delta_text = current_content[already_emitted_len:] + + if _is_function_tool_message(current_channel, current_recipient): + assert current_recipient is not None + if current_tool_index is None: + current_tool_index = stream_state.next_tool_call_index + stream_state.next_tool_call_index += 1 + + _append_tool_deltas( + tool_messages=tool_messages, + recipient=current_recipient, + tool_index=current_tool_index, + emit_header=not current_header_emitted, + args_delta=delta_text, + ) + + stream_state.prev_current_tool_index = current_tool_index + stream_state.prev_current_tool_header_emitted = True + else: + if current_channel in ("final", "commentary"): + if delta_text: + combined_content += delta_text + content_encountered = True + elif current_channel == "analysis" and include_reasoning and delta_text: + combined_reasoning += delta_text + + stream_state.prev_current_tool_index = None + stream_state.prev_current_tool_header_emitted = False + + stream_state.prev_current_signature = current_signature + stream_state.prev_current_emitted_len = len(current_content) + + if not (content_encountered or combined_reasoning or tool_messages): + return None, False + + delta_kwargs: dict[str, str | list[DeltaToolCall]] = {} + if content_encountered: + delta_kwargs["content"] = combined_content + if combined_reasoning: + delta_kwargs["reasoning"] = combined_reasoning + if tool_messages: + delta_kwargs["tool_calls"] = tool_messages - return delta_message, tools_streamed + return DeltaMessage(**delta_kwargs), bool(tool_messages)