diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index e817f07ef594..f70e1fc207f8 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -3,6 +3,7 @@ import json import logging from abc import ABC, abstractmethod +from collections.abc import Sequence from typing import TYPE_CHECKING, Union from openai_harmony import Author, Message, Role, StreamState, TextContent @@ -67,15 +68,27 @@ def __init__( self.parser = get_streamable_parser_for_assistant() self.num_init_messages = len(messages) - # TODO(woosuk): Implement the following fields. self.num_prompt_tokens = 0 - self.num_cached_tokens = 0 self.num_output_tokens = 0 + # TODO(woosuk): Implement the following fields. + self.num_cached_tokens = 0 self.num_reasoning_tokens = 0 + def _update_num_prompt_tokens(self, output: RequestOutput): + if output.prompt_token_ids and len(output.prompt_token_ids) > 0: + # NOTE: with built-in tools, there might be multiple rounds in + # the conversation, with the full conversation being resent + # as new prompt each time. Hence the sum. + self.num_prompt_tokens += len(output.prompt_token_ids) + + def _update_num_output_tokens(self, token_ids: Sequence[int]): + self.num_output_tokens += len(token_ids) + def append_output(self, output) -> None: if isinstance(output, RequestOutput): + self._update_num_prompt_tokens(output) output_token_ids = output.outputs[0].token_ids + self._update_num_output_tokens(output_token_ids) self.parser = get_streamable_parser_for_assistant() for token_id in output_token_ids: self.parser.process(token_id) @@ -158,6 +171,7 @@ def __init__(self, *args, **kwargs): self.parser = get_streamable_parser_for_assistant() self.encoding = get_encoding() self.last_tok = None + self.first_tok_of_message = True @property def messages(self) -> list: @@ -165,8 +179,18 @@ def messages(self) -> list: def append_output(self, output) -> None: if isinstance(output, RequestOutput): + # append_output is called for each output token in streaming case, + # so we only want to add the prompt tokens once for each message. + if self.first_tok_of_message: + self._update_num_prompt_tokens(output) + # Reset self.first_tok_of_message if needed: + # if the current token is the last one of the current message + # (finished=True), then the next token processed will mark the + # beginning of a new message + self.first_tok_of_message = output.finished tok = output.outputs[0].token_ids[0] self.parser.process(tok) + self._update_num_output_tokens(output.outputs[0].token_ids) self.last_tok = tok else: # Handle the case of tool output in direct message format