Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,12 @@
"""
Parse a MiniMax-M2 formatted prompt back into structured conversation messages.

This function reverses the operation of apply_hf_chat_template for MiniMax-M2 format.

Check failure on line 297 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/entrypoints/chat_utils.py:297:89: E501 Line too long (89 > 88)

Args:
prompt: The formatted prompt string with MiniMax-M2 special tokens
Format: ']~!b[]~b]role\ncontent[e~[\n]~b]role\ncontent[e~[...'
include_incomplete: Whether to include incomplete messages (e.g., partial assistant responses)

Check failure on line 302 in vllm/entrypoints/chat_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/entrypoints/chat_utils.py:302:89: E501 Line too long (102 > 88)

Returns:
List of conversation messages in the format:
Expand Down Expand Up @@ -1495,6 +1495,12 @@
role = message["role"]
content = message.get("content")
reasoning = message.get("reasoning") or message.get("reasoning_content")
# TODO: get from reasoning_content?

# HACK
if role == "tool":
content_format = "openai"

if content is None:
content = []
elif isinstance(content, str):
Expand All @@ -1503,7 +1509,9 @@
role,
content, # type: ignore
mm_tracker,
wrap_dicts=(content_format == "openai"),
wrap_dicts=(
content_format == "openai"
), # kimik2 thinks this is string, breaks on tool
interleave_strings=interleave_strings,
)

Expand Down
145 changes: 125 additions & 20 deletions vllm/entrypoints/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Union

from openai.types.chat.chat_completion_content_part_text_param import (
ChatCompletionContentPartTextParam,
)
from openai.types.responses.tool import Mcp
from openai_harmony import Author, Message, Role, StreamState, TextContent

from vllm import envs
from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
CustomChatCompletionMessageParam,
)
from vllm.entrypoints.harmony_utils import (
get_encoding,
get_streamable_parser_for_assistant,
Expand All @@ -20,6 +27,7 @@
from vllm.entrypoints.openai.parser.parser import (
get_streamable_parser_for_simple_context,
)
from vllm.entrypoints.openai.protocol import FunctionCall, ResponsesRequest
from vllm.entrypoints.tool import Tool
from vllm.entrypoints.tool_server import ToolServer
from vllm.outputs import RequestOutput
Expand Down Expand Up @@ -189,9 +197,15 @@
def __init__(
self,
*,
sentences: list,
chat_completion_messages: list[CustomChatCompletionMessageParam],
tokenizer: AnyTokenizer,
reasoning_parser: ReasoningParser,
request: ResponsesRequest,
available_tools: list[str] | None,
tool_parser_cls,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
tool_dicts: list[dict] | None = None,
):
self.last_output = None
self.num_prompt_tokens = 0
Expand All @@ -203,49 +217,140 @@
self.all_turn_metrics = []

self.parser = get_streamable_parser_for_simple_context(
tokenizer=tokenizer, reasoning_parser=reasoning_parser, sentences=sentences
tokenizer=tokenizer,
reasoning_parser=reasoning_parser,
chat_completion_messages=chat_completion_messages,
request=request,
tool_parser_cls=tool_parser_cls,
)
self.tool_parser_cls = tool_parser_cls
self.request = request
self.tokenizer = tokenizer
self.reasoning_parser = reasoning_parser

self.num_init_sentences = len(sentences)
self.sentences = sentences
self.num_init_chat_completion_messages = len(chat_completion_messages)
self.chat_completion_messages = chat_completion_messages

def append_output(self, output) -> None:
self.available_tools = available_tools or []
self._tool_sessions: dict[str, ClientSession | Tool] = {}
self.called_tools: set[str] = set()

self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format
self.tool_dicts = tool_dicts

def append_output(self, output: RequestOutput) -> None:
self.last_output = output
if not isinstance(output, RequestOutput):
raise ValueError("SimpleContext only supports RequestOutput.")
self.num_prompt_tokens = len(output.prompt_token_ids or [])
self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or [])

output_token_ids = output.outputs[0].token_ids
for token_id in output_token_ids:
self.parser.process(token_id)
# output_token_ids = output.outputs[0].token_ids
# for token_id in output_token_ids:
# self.parser.process(token_id)
self.parser.process(output.outputs[0])

def append_tool_output(self, output) -> None:
raise NotImplementedError("Should not be called.")
def append_tool_output(
self, output: list[CustomChatCompletionMessageParam]
) -> None:
self.parser.chat_completion_messages.extend(output)

def need_builtin_tool_call(self) -> bool:
"""Return true if the last message is a MCP tool call"""
last_message = self.parser.chat_completion_messages[-1]["content"][-1]
if isinstance(last_message, FunctionCall):
# HACK: figure out which tools are MCP tools
if last_message.name == "code_interpreter" or last_message.name == "python":

Check failure on line 263 in vllm/entrypoints/context.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (SIM102)

vllm/entrypoints/context.py:261:9: SIM102 Use a single `if` statement instead of nested `if` statements
return True

return False

async def call_tool(self) -> list[Message]:
raise NotImplementedError("Should not be called.")
async def call_python_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall
) -> list[CustomChatCompletionMessageParam]:
self.called_tools.add("python")
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
args = json.loads(last_msg.arguments)
param = {
"code": args["code"],
}
result = await tool_session.call_tool("python", param)
result_str = result.content[0].text

def render_for_completion(self) -> list[int]:
raise NotImplementedError("Should not be called.")
content = TextContent(text=result_str)
# author = Author(role=Role.TOOL, name="python")

message = CustomChatCompletionMessageParam(
role="tool",
content=[
ChatCompletionContentPartTextParam(text=content, type="text")
], # TODO: why is this nested?
)

return [message]

async def call_tool(self) -> list[CustomChatCompletionMessageParam]:
if not self.parser.chat_completion_messages:
return []
last_msg = self.parser.chat_completion_messages[-1]
last_tool_request = last_msg["content"][-1]
if last_tool_request.name == "code_interpreter":
return await self.call_python_tool(
self._tool_sessions["python"], last_tool_request
)
# recipient = last_message.name == "code_interpreter"
# if recipient is not None and recipient.startswith("python"):
# return await self.call_python_tool(self._tool_sessions["python"], last_tool_request)

Check failure on line 304 in vllm/entrypoints/context.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/entrypoints/context.py:304:89: E501 Line too long (98 > 88)

def render_for_completion(self):
return [
self.request,
self.tokenizer,
self.parser.chat_completion_messages,
self.tool_dicts,
self.tool_parser_cls,
self.chat_template,
self.chat_template_content_format,
]

async def init_tool_sessions(
self,
tool_server: ToolServer | None,
exit_stack: AsyncExitStack,
request_id: str,
mcp_tools: dict[str, Mcp],
) -> None:
pass
):
if tool_server:
for tool_name in self.available_tools:
if tool_name not in self._tool_sessions:
tool_type = _map_tool_name_to_tool_type(tool_name)
headers = (
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
)
tool_session = await exit_stack.enter_async_context(
tool_server.new_session(tool_name, request_id, headers)
)
self._tool_sessions[tool_name] = tool_session
exit_stack.push_async_exit(self.cleanup_session)

async def cleanup_session(self) -> None:
raise NotImplementedError("Should not be called.")
async def cleanup_session(self, *args, **kwargs) -> None:
"""Can be used as coro to used in __aexit__"""

async def cleanup_tool_session(tool_session):
if not isinstance(tool_session, Tool):
logger.info(
"Cleaning up tool session for %s", tool_session._client_info
)
with contextlib.suppress(Exception):
await tool_session.call_tool("cleanup_session", {})

await asyncio.gather(
*(
cleanup_tool_session(self._tool_sessions[tool])
for tool in self.called_tools
)
)


class HarmonyContext(ConversationContext):
Expand Down
97 changes: 64 additions & 33 deletions vllm/entrypoints/openai/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from openai.types.chat.chat_completion_content_part_text_param import (
ChatCompletionContentPartTextParam,
)
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent,
)

from vllm.entrypoints.chat_utils import CustomChatCompletionMessageParam
from vllm.entrypoints.openai.protocol import FunctionCall, ResponsesRequest
from vllm.outputs import CompletionOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser

logger = logging.getLogger(__name__)
Expand All @@ -18,13 +17,25 @@
class StreamableParser:
"""Incremental parser over completion tokens with reasoning support."""

def __init__(self, *, tokenizer, reasoning_parser: ReasoningParser):
self.chat_completion_messages: list[CustomChatCompletionMessageParam] = []
def __init__(
self,
*,
tokenizer,
reasoning_parser: ReasoningParser,
chat_completion_messages: list[CustomChatCompletionMessageParam],
request: ResponsesRequest,
tool_parser_cls, #: Callable[[AnyTokenizer], ToolParser]
):
self.chat_completion_messages: list[CustomChatCompletionMessageParam] = (
chat_completion_messages
)
self.tokens: list[int] = []
self.tokenizer = tokenizer
self.request = request

# Initialize reasoning parser instance if provided
self.reasoning_parser_instance = reasoning_parser(tokenizer)

Check failure on line 37 in vllm/entrypoints/openai/parser/parser.py

View workflow job for this annotation

GitHub Actions / pre-commit

"ReasoningParser" not callable [operator]
self.tool_parser_instance = tool_parser_cls(tokenizer)

# start like this
self.current_role = "assistant"
Expand All @@ -42,56 +53,76 @@
to help generate the initial prompt?"""
pass

def process(self, token: int) -> "StreamableParser":
# see process_next()
# https://github.com/openai/harmony/blob/main/src/encoding.rs#L1114
self.tokens.append(token)
decoded = self.tokenizer.decode(token)
if self.reasoning_parser_instance.is_reasoning_end([token]):
# TODO: how to capture reasoning?
# new_content = {
# "role": "assistant",
# "reasoning_content": self.current_text
# }

new_content = ResponseReasoningTextContent(
text=self.current_text, type="reasoning_text"
def process(self, output: CompletionOutput) -> "StreamableParser":
reasoning_content, content = self.reasoning_parser_instance.extract_reasoning(
output.text, request=None
)
if reasoning_content:
new_content = ChatCompletionContentPartTextParam(
text=reasoning_content, type="reasoning_text"
)

self.current_chat_completion_message["content"].append(new_content)

Check failure on line 65 in vllm/entrypoints/openai/parser/parser.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "str" of "str | list[Any | ChatCompletionContentPartAudioParam | Any | ChatCompletionContentPartVideoParam | Any | <8 more items>]" has no attribute "append" [union-attr]

self.current_text = ""
self.current_channel = "final"
elif token == self.tokenizer.eos_token_id:
# end of sentence
new_content = ChatCompletionContentPartTextParam(
text=self.current_text, type="text"
function_calls = []
tool_call_info = self.tool_parser_instance.extract_tool_calls(
content if content is not None else "",
request=self.request, # type: ignore
)
if tool_call_info is not None and tool_call_info.tools_called:
# extract_tool_calls() returns a list of tool calls.
function_calls.extend(
# TODO: this should be a TypedDict
FunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
for tool_call in tool_call_info.tool_calls
)
content = tool_call_info.content

if content:
new_content = ChatCompletionContentPartTextParam(text=content, type="text")
self.current_chat_completion_message["content"].append(new_content)
self.chat_completion_messages.append(self.current_chat_completion_message)
if len(function_calls) > 0:
self.current_chat_completion_message["content"].extend(function_calls)

self.current_text = ""
self.current_channel = None
else:
self.current_text += decoded
self.chat_completion_messages.append(self.current_chat_completion_message)

self.current_chat_completion_message = CustomChatCompletionMessageParam(
role=self.current_role, content=[]
)

# TODO: current state of sentences, etc
return self


def get_streamable_parser_for_simple_context(
*, tokenizer, reasoning_parser: ReasoningParser, sentences
*,
tokenizer,
reasoning_parser: ReasoningParser,
chat_completion_messages: list[CustomChatCompletionMessageParam],
request: ResponsesRequest,
tool_parser_cls,
) -> StreamableParser:
"""Factory function to create a StreamableParser with optional reasoning parser.

Args:
tokenizer: The tokenizer to use for decoding tokens
reasoning_parser: Optional reasoning parser class (e.g., MiniMaxM2ReasoningParser)

Check failure on line 111 in vllm/entrypoints/openai/parser/parser.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/entrypoints/openai/parser/parser.py:111:89: E501 Line too long (90 > 88)

Returns:
StreamableParser instance configured with the provided parser
"""
return StreamableParser(tokenizer=tokenizer, reasoning_parser=reasoning_parser)
return StreamableParser(
tokenizer=tokenizer,
reasoning_parser=reasoning_parser,
chat_completion_messages=chat_completion_messages,
request=request,
tool_parser_cls=tool_parser_cls,
)


# def render_parser_for_completion():


"""
Expand Down
Loading