diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index a9356a8a403d..fe26430a4e91 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -94,6 +94,8 @@ from vllm.entrypoints.openai.responses.streaming_events import ( StreamingState, emit_content_delta_events, + emit_function_call_delta_events, + emit_function_call_done_events, emit_previous_item_done_events, emit_tool_action_events, ) @@ -1259,6 +1261,13 @@ async def _process_simple_streaming_events( reasoning_parser = None if self.parser and self.parser.reasoning_parser_cls: reasoning_parser = self.parser.reasoning_parser_cls(tokenizer) + tool_parser = None + if self.parser and self.parser.tool_parser_cls: + tool_parser = self.parser.tool_parser_cls(tokenizer) + tool_streaming_state = StreamingState() + tools_streamed = False + reasoning_ended = False + tool_call_text_started = False previous_text = "" previous_token_ids: list[int] = [] first_delta_sent = False @@ -1271,21 +1280,79 @@ async def _process_simple_streaming_events( output = ctx.last_output.outputs[0] # finish_reason='error' indicates a retryable error self._raise_if_error(output.finish_reason, request.request_id) - if reasoning_parser: + delta_text = output.text + delta_token_ids = list(output.token_ids) + current_text = previous_text + delta_text + current_token_ids = previous_token_ids + delta_token_ids + + if reasoning_parser and tool_parser: + # Both reasoning and tool calls: reasoning + # first, then tool calls after reasoning ends. + if not reasoning_ended: + delta_message = reasoning_parser.extract_reasoning_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + ) + if reasoning_parser.is_reasoning_end(delta_token_ids): + reasoning_ended = True + # Reset text/token state so the tool + # parser sees a fresh stream after + # reasoning. + current_token_ids = reasoning_parser.extract_content_ids( + delta_token_ids + ) + if delta_message and delta_message.content: + current_text = delta_message.content + delta_message.content = None + else: + current_text = "" + + if reasoning_ended: + if not tool_call_text_started: + tool_call_text_started = True + previous_text = "" + previous_token_ids = [] + delta_text = current_text + delta_token_ids = current_token_ids + + delta_message = tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request, + ) + elif reasoning_parser: delta_message = reasoning_parser.extract_reasoning_streaming( previous_text=previous_text, - current_text=previous_text + output.text, - delta_text=output.text, + current_text=current_text, + delta_text=delta_text, previous_token_ids=previous_token_ids, - current_token_ids=previous_token_ids + output.token_ids, - delta_token_ids=output.token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + ) + elif tool_parser: + delta_message = tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request, ) else: delta_message = DeltaMessage( content=output.text, ) - previous_text += output.text - previous_token_ids += output.token_ids + previous_text = current_text + previous_token_ids = current_token_ids if not delta_message: continue if not first_delta_sent: @@ -1317,7 +1384,10 @@ async def _process_simple_streaming_events( ), ) ) - else: + elif not delta_message.tool_calls: + # Only emit message output item for content, + # not for tool calls (handled by + # emit_function_call_delta_events) yield _increment_sequence_number_and_return( ResponseOutputItemAddedEvent( type="response.output_item.added", @@ -1348,7 +1418,104 @@ async def _process_simple_streaming_events( ) ) first_delta_sent = True - # todo(kebe7jun) tool call support + # Handle tool call deltas from the tool parser + if delta_message.tool_calls: + if not tools_streamed: + tools_streamed = True + # Close the message output item if content was + # emitted before tool calls (e.g. "Sure!\n\n + # ..."). Done events must precede + # the function-call OutputItemAdded event. + if previous_delta_messages: + final_content = "".join( + pm.content + for pm in previous_delta_messages + if pm.content is not None + ) + yield _increment_sequence_number_and_return( + ResponseTextDoneEvent( + type="response.output_text.done", + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=final_content, + logprobs=[], + item_id=current_item_id, + ) + ) + yield _increment_sequence_number_and_return( + ResponseContentPartDoneEvent( + type="response.content_part.done", + sequence_number=-1, + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + part=ResponseOutputText( + text=final_content, + type="output_text", + annotations=[], + ), + ) + ) + yield _increment_sequence_number_and_return( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=ResponseOutputMessage( + type="message", + role="assistant", + content=[ + ResponseOutputText( + text=final_content, + type="output_text", + annotations=[], + ) + ], + status="completed", + id=current_item_id, + ), + ) + ) + previous_delta_messages = [] + current_output_index += 1 + # Sync tool streaming state with current + # output index so function call items get + # the correct index. + tool_streaming_state.current_output_index = current_output_index + for tc in delta_message.tool_calls: + if tc.function: + fn_name = tc.function.name or "" + args_delta = tc.function.arguments or "" + # When a new tool call starts (id is set), + # reset state for a new output item + if tc.id is not None: + if tool_streaming_state.is_first_function_call_delta: + # Previous tool call finished, + # advance to next output item + tool_streaming_state.reset_for_new_item() + for event in emit_function_call_delta_events( + args_delta, + fn_name, + tool_streaming_state, + ): + yield _increment_sequence_number_and_return(event) + # When arguments close with "}", emit done + # event immediately while state IDs are valid + tc_idx = tc.index if tc.index is not None else 0 + assert tool_parser is not None + if args_delta == "}" and tc_idx < len( + tool_parser.prev_tool_call_arr + ): + tc_info = tool_parser.prev_tool_call_arr[tc_idx] + for event in emit_function_call_done_events( + tc_info.get("name", fn_name), + tc_info.get("arguments", "{}"), + tool_streaming_state, + ): + yield _increment_sequence_number_and_return( + event + ) # check delta message and previous delta message are # same as content or reasoning content @@ -1473,7 +1640,7 @@ async def _process_simple_streaming_events( delta=delta_message.reasoning, ) ) - elif delta_message.content is not None: + elif delta_message.content and not tools_streamed: yield _increment_sequence_number_and_return( ResponseTextDeltaEvent( type="response.output_text.delta", @@ -1495,7 +1662,8 @@ async def _process_simple_streaming_events( ) ) - previous_delta_messages.append(delta_message) + if not tools_streamed: + previous_delta_messages.append(delta_message) if previous_delta_messages: if previous_delta_messages[-1].reasoning is not None: reason_content = "".join(