Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
)
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
from openai.types.responses import ResponseInputImageParam
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent,
)
from openai_harmony import Message as OpenAIHarmonyMessage
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
Expand Down Expand Up @@ -232,6 +235,7 @@ class CustomThinkCompletionContentParam(TypedDict, total=False):
| CustomChatCompletionContentSimpleVideoParam
| str
| CustomThinkCompletionContentParam
| ResponseReasoningTextContent
)


Expand Down Expand Up @@ -1530,6 +1534,12 @@ def _parse_chat_message_content(
role = message["role"]
content = message.get("content")
reasoning = message.get("reasoning") or message.get("reasoning_content")
# TODO: get from reasoning_content?

# HACK
if role == "tool":
content_format = "openai"

if content is None:
content = []
elif isinstance(content, str):
Expand All @@ -1538,7 +1548,9 @@ def _parse_chat_message_content(
role,
content, # type: ignore
mm_tracker,
wrap_dicts=(content_format == "openai"),
wrap_dicts=(
content_format == "openai"
), # kimik2 thinks this is string, breaks on tool
interleave_strings=interleave_strings,
)

Expand Down
167 changes: 167 additions & 0 deletions vllm/entrypoints/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,35 @@
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Union

from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai.types.responses.tool import Mcp
from openai_harmony import Author, Message, Role, StreamState, TextContent

from vllm import envs
from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
CustomChatCompletionMessageParam,
)
from vllm.entrypoints.harmony_utils import (
get_encoding,
get_streamable_parser_for_assistant,
render_for_completion,
)
from vllm.entrypoints.openai.parser.parser import (
get_streamable_parser_for_simple_context,
)
from vllm.entrypoints.openai.protocol import (
FunctionCall,
ResponseInputOutputItem,
ResponsesRequest,
)
from vllm.entrypoints.tool import Tool
from vllm.entrypoints.tool_server import ToolServer
from vllm.outputs import RequestOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.transformers_utils.tokenizer import AnyTokenizer

if TYPE_CHECKING:
from mcp.client import ClientSession
Expand Down Expand Up @@ -180,6 +197,156 @@
raise NotImplementedError("Should not be called.")


class ParsableContext(ConversationContext):
def __init__(
self,
*,
response_messages: list[CustomChatCompletionMessageParam],
tokenizer: AnyTokenizer,
reasoning_parser: ReasoningParser,
request: ResponsesRequest,
available_tools: list[str] | None,
tool_parser_cls,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
tool_dicts: list[dict] | None = None,
):
self.last_output = None
self.num_prompt_tokens = 0
self.num_output_tokens = 0
self.num_cached_tokens = 0
# TODO: num_reasoning_tokens is not implemented yet.
self.num_reasoning_tokens = 0
# not implemented yet for ParsableContext
self.all_turn_metrics = []

Check failure on line 221 in vllm/entrypoints/context.py

View workflow job for this annotation

GitHub Actions / pre-commit

Need type annotation for "all_turn_metrics" (hint: "all_turn_metrics: list[<type>] = ...") [var-annotated]

Check failure on line 221 in vllm/entrypoints/context.py

View workflow job for this annotation

GitHub Actions / pre-commit

Need type annotation for "all_turn_metrics" (hint: "all_turn_metrics: list[<type>] = ...") [var-annotated]

self.parser = get_streamable_parser_for_simple_context(
tokenizer=tokenizer,
reasoning_parser=reasoning_parser,
response_messages=response_messages,
request=request,
tool_parser_cls=tool_parser_cls,
)
self.tool_parser_cls = tool_parser_cls
self.request = request
self.tokenizer = tokenizer
self.reasoning_parser = reasoning_parser

self.available_tools = available_tools or []
self._tool_sessions: dict[str, ClientSession | Tool] = {}
self.called_tools: set[str] = set()

self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
self.tool_dicts = tool_dicts

def append_output(self, output: RequestOutput) -> None:
self.last_output = output

Check failure on line 244 in vllm/entrypoints/context.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "RequestOutput", variable has type "None") [assignment]

Check failure on line 244 in vllm/entrypoints/context.py

View workflow job for this annotation

GitHub Actions / pre-commit

Incompatible types in assignment (expression has type "RequestOutput", variable has type "None") [assignment]
self.num_prompt_tokens = len(output.prompt_token_ids or [])
self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or [])
self.parser.process(output.outputs[0])

def append_tool_output(
self, output: list[CustomChatCompletionMessageParam]
) -> None:
self.parser.response_messages.extend(output)

def need_builtin_tool_call(self) -> bool:
"""Return true if the last message is a MCP tool call"""
last_message = self.parser.response_messages[-1]
# HACK: figure out which tools are MCP tools
if ( # noqa: SIM103
last_message.type == "function_call"
and (
last_message.name == "code_interpreter" or last_message.name == "python"
)
):
return True

return False

async def call_python_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall
) -> list[ResponseInputOutputItem]:
self.called_tools.add("python")
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
args = json.loads(last_msg.arguments)
param = {
"code": args["code"],
}
result = await tool_session.call_tool("python", param)
result_str = result.content[0].text

message = ResponseFunctionToolCallOutputItem(
id="temp",
type="function_call_output",
call_id="temp",
output=result_str,
status="completed",
)

return [message]

async def call_tool(self) -> list[ResponseInputOutputItem]:
if not self.parser.response_messages:
return []
last_msg = self.parser.response_messages[-1]
if last_msg.name == "code_interpreter":
return await self.call_python_tool(self._tool_sessions["python"], last_msg)
return []

def render_for_completion(self):
return [
self.request,
self.tokenizer,
self.parser.response_messages,
self.tool_dicts,
self.tool_parser_cls,
self.chat_template,
self.chat_template_content_format,
]

async def init_tool_sessions(
self,
tool_server: ToolServer | None,
exit_stack: AsyncExitStack,
request_id: str,
mcp_tools: dict[str, Mcp],
):
if tool_server:
for tool_name in self.available_tools:
if tool_name not in self._tool_sessions:
tool_type = _map_tool_name_to_tool_type(tool_name)
headers = (
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
)
tool_session = await exit_stack.enter_async_context(
tool_server.new_session(tool_name, request_id, headers)
)
self._tool_sessions[tool_name] = tool_session
exit_stack.push_async_exit(self.cleanup_session)

async def cleanup_session(self, *args, **kwargs) -> None:
"""Can be used as coro to used in __aexit__"""

async def cleanup_tool_session(tool_session):
if not isinstance(tool_session, Tool):
logger.info(
"Cleaning up tool session for %s", tool_session._client_info
)
with contextlib.suppress(Exception):
await tool_session.call_tool("cleanup_session", {})

await asyncio.gather(
*(
cleanup_tool_session(self._tool_sessions[tool])
for tool in self.called_tools
)
)


class HarmonyContext(ConversationContext):
def __init__(
self,
Expand Down
144 changes: 144 additions & 0 deletions vllm/entrypoints/openai/parser/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging

from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_output_text import ResponseOutputText
from openai.types.responses.response_reasoning_item import (
Content,
ResponseReasoningItem,
)

from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
from vllm.outputs import CompletionOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser

logger = logging.getLogger(__name__)


class ResponseParser:
"""Incremental parser over completion tokens with reasoning support."""

def __init__(
self,
*,
tokenizer,
reasoning_parser: ReasoningParser,
response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest,
tool_parser_cls,
):
self.response_messages: list[ResponseInputOutputItem] = (
# TODO: initial messages may not be properly typed
response_messages
)
self.num_init_messages = len(response_messages)
self.tokens: list[int] = []
self.tokenizer = tokenizer
self.request = request

# Initialize reasoning parser instance if provided
self.reasoning_parser_instance = reasoning_parser(tokenizer)

Check failure on line 42 in vllm/entrypoints/openai/parser/parser.py

View workflow job for this annotation

GitHub Actions / pre-commit

"ReasoningParser" not callable [operator]

Check failure on line 42 in vllm/entrypoints/openai/parser/parser.py

View workflow job for this annotation

GitHub Actions / pre-commit

"ReasoningParser" not callable [operator]

Check failure on line 42 in vllm/entrypoints/openai/parser/parser.py

View workflow job for this annotation

GitHub Actions / pre-commit

"ReasoningParser" not callable [operator]
self.tool_parser_instance = tool_parser_cls(tokenizer)

def process(self, output: CompletionOutput) -> "ResponseParser":
reasoning_content, content = self.reasoning_parser_instance.extract_reasoning(
output.text, request=None
)
if reasoning_content:
# HACK
self.response_messages.append(
ResponseReasoningItem(
type="reasoning",
id="temp",
summary=[],
content=[
Content(
type="reasoning_text",
text=reasoning_content,
)
],
)
)

function_calls: list[ResponseFunctionToolCall] = []
tool_call_info = self.tool_parser_instance.extract_tool_calls(
content if content is not None else "",
request=self.request, # type: ignore
)
if tool_call_info is not None and tool_call_info.tools_called:
# extract_tool_calls() returns a list of tool calls.
function_calls.extend(
ResponseFunctionToolCall(
id="fc_lol",
call_id="call_lol",
type="function_call",
status="completed",
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
for tool_call in tool_call_info.tool_calls
)
content = tool_call_info.content
if content and content.strip() == "":
content = None

if content:
self.response_messages.append(
ResponseOutputMessage(
type="message",
id="lol",
status="completed",
role="assistant",
content=[
ResponseOutputText(
type="output_text", text=content, annotations=[]
)
],
)
)
if len(function_calls) > 0:
self.response_messages.extend(function_calls)

return self


def get_streamable_parser_for_simple_context(
*,
tokenizer,
reasoning_parser: ReasoningParser,
response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest,
tool_parser_cls,
) -> ResponseParser:
"""Factory function to create a ResponseParser with
optional reasoning parser.

Args:
tokenizer: The tokenizer to use for decoding tokens
reasoning_parser: Optional reasoning parser class (e.g., MiniMaxM2ReasoningParser)

Check failure on line 120 in vllm/entrypoints/openai/parser/parser.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/entrypoints/openai/parser/parser.py:120:89: E501 Line too long (90 > 88)

Returns:
ResponseParser instance configured with the provided parser
"""
return ResponseParser(
tokenizer=tokenizer,
reasoning_parser=reasoning_parser,
response_messages=response_messages,
request=request,
tool_parser_cls=tool_parser_cls,
)


# def render_parser_for_completion():


"""
TODO:
how to figure out which tokens are special tokens

system
tool
ai
"""
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2409,6 +2409,7 @@ class ResponseInProgressEvent(OpenAIResponseInProgressEvent):
| ResponseCodeInterpreterCallCompletedEvent
)


BatchRequestInputBody: TypeAlias = (
ChatCompletionRequest | EmbeddingRequest | ScoreRequest | RerankRequest
)
Expand Down
Loading