Skip to content
6 changes: 6 additions & 0 deletions vllm/entrypoints/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,7 @@ def __init__(self, *args, **kwargs):
self.encoding = get_encoding()
self.last_tok = None
self.first_tok_of_message = True
self.last_content_delta = None

@property
def messages(self) -> list:
Expand All @@ -832,15 +833,20 @@ def messages(self) -> list:
def append_output(self, output: RequestOutput) -> None:
# append_output is called for each output token in streaming case,
# so we only want to add the prompt tokens once for each message.
self.last_content_delta = None
if self.first_tok_of_message:
self._update_prefill_token_usage(output)
# Reset self.first_tok_of_message if needed:
# if the current token is the last one of the current message
# (finished=True), then the next token processed will mark the
# beginning of a new message
self.first_tok_of_message = output.finished
last_delta_text = ""
for tok in output.outputs[0].token_ids:
self.parser.process(tok)
last_delta_text += self.parser.last_content_delta or ""
if last_delta_text:
self.last_content_delta = last_delta_text
self._update_decode_token_usage(output)

# For streaming, update previous turn when message is complete
Expand Down
14 changes: 7 additions & 7 deletions vllm/entrypoints/openai/serving_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,7 +1811,7 @@ def _emit_final_channel_delta_events(
content_index=state.current_content_index,
output_index=state.current_output_index,
item_id=state.current_item_id,
delta=ctx.parser.last_content_delta,
delta=ctx.last_content_delta,
# TODO, use logprobs from ctx.last_request_output
logprobs=[],
)
Expand Down Expand Up @@ -1861,7 +1861,7 @@ def _emit_analysis_channel_delta_events(
item_id=state.current_item_id,
output_index=state.current_output_index,
content_index=state.current_content_index,
delta=ctx.parser.last_content_delta,
delta=ctx.last_content_delta,
sequence_number=-1,
)
)
Expand Down Expand Up @@ -1908,7 +1908,7 @@ def _emit_mcp_tool_delta_events(
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
delta=ctx.parser.last_content_delta,
delta=ctx.last_content_delta,
)
)
return events
Expand Down Expand Up @@ -1952,7 +1952,7 @@ def _emit_code_interpreter_delta_events(
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
delta=ctx.parser.last_content_delta,
delta=ctx.last_content_delta,
)
)
return events
Expand Down Expand Up @@ -1999,7 +1999,7 @@ def _emit_mcp_prefix_delta_events(
sequence_number=-1,
output_index=state.current_output_index,
item_id=state.current_item_id,
delta=ctx.parser.last_content_delta,
delta=ctx.last_content_delta,
)
)
return events
Expand All @@ -2010,7 +2010,7 @@ def _emit_content_delta_events(
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events for content delta streaming based on channel type."""
if not ctx.parser.last_content_delta:
if not ctx.last_content_delta:
return []

if (
Expand Down Expand Up @@ -2364,7 +2364,7 @@ def _emit_function_call_delta_events(
events.append(
ResponseFunctionCallArgumentsDeltaEvent(
item_id=state.current_item_id,
delta=ctx.parser.last_content_delta,
delta=ctx.last_content_delta,
output_index=state.current_output_index,
sequence_number=-1,
type="response.function_call_arguments.delta",
Expand Down
Loading