|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +import json |
3 | 4 | import logging |
4 | 5 | from abc import ABC, abstractmethod |
| 6 | +from typing import TYPE_CHECKING, Union |
5 | 7 |
|
6 | | -from openai_harmony import Message, Role, StreamState |
| 8 | +from openai_harmony import Author, Message, Role, StreamState, TextContent |
7 | 9 |
|
8 | 10 | from vllm.entrypoints.harmony_utils import ( |
9 | 11 | get_encoding, get_streamable_parser_for_assistant, render_for_completion) |
10 | 12 | from vllm.entrypoints.tool import Tool |
11 | 13 | from vllm.outputs import RequestOutput |
12 | 14 |
|
| 15 | +if TYPE_CHECKING: |
| 16 | + from mcp.client import ClientSession |
| 17 | + |
13 | 18 | logger = logging.getLogger(__name__) |
14 | 19 |
|
15 | 20 |
|
@@ -107,19 +112,41 @@ async def call_tool(self) -> list[Message]: |
107 | 112 | def render_for_completion(self) -> list[int]: |
108 | 113 | return render_for_completion(self.messages) |
109 | 114 |
|
110 | | - async def call_search_tool( |
111 | | - self, |
112 | | - tool_session: Tool, |
113 | | - last_msg: Message, |
114 | | - ) -> list[Message]: |
115 | | - return await tool_session.get_result(self) |
116 | | - |
117 | | - async def call_python_tool( |
118 | | - self, |
119 | | - tool_session: Tool, |
120 | | - last_msg: Message, |
121 | | - ) -> list[Message]: |
122 | | - return await tool_session.get_result(self) |
| 115 | + async def call_search_tool(self, tool_session: Union["ClientSession", |
| 116 | + Tool], |
| 117 | + last_msg: Message) -> list[Message]: |
| 118 | + if isinstance(tool_session, Tool): |
| 119 | + return await tool_session.get_result(self) |
| 120 | + tool_name = last_msg.recipient.split(".")[1] |
| 121 | + args = json.loads(last_msg.content[0].text) |
| 122 | + result = await tool_session.call_tool(tool_name, args) |
| 123 | + result_str = result.content[0].text |
| 124 | + content = TextContent(text=result_str) |
| 125 | + author = Author(role=Role.TOOL, name=last_msg.recipient) |
| 126 | + return [ |
| 127 | + Message(author=author, content=[content], recipient=Role.ASSISTANT) |
| 128 | + ] |
| 129 | + |
| 130 | + async def call_python_tool(self, tool_session: Union["ClientSession", |
| 131 | + Tool], |
| 132 | + last_msg: Message) -> list[Message]: |
| 133 | + if isinstance(tool_session, Tool): |
| 134 | + return await tool_session.get_result(self) |
| 135 | + param = { |
| 136 | + "code": last_msg.content[0].text, |
| 137 | + } |
| 138 | + result = await tool_session.call_tool("python", param) |
| 139 | + result_str = result.content[0].text |
| 140 | + |
| 141 | + content = TextContent(text=result_str) |
| 142 | + author = Author(role=Role.TOOL, name="python") |
| 143 | + |
| 144 | + return [ |
| 145 | + Message(author=author, |
| 146 | + content=[content], |
| 147 | + channel=last_msg.channel, |
| 148 | + recipient=Role.ASSISTANT) |
| 149 | + ] |
123 | 150 |
|
124 | 151 |
|
125 | 152 | class StreamingHarmonyContext(HarmonyContext): |
|
0 commit comments