diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index 7e19e3787545..bf97b6ed1882 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -260,7 +260,11 @@ def need_builtin_tool_call(self) -> bool: last_message = self.parser.chat_completion_messages[-1]["content"][-1] if isinstance(last_message, FunctionCall): # HACK: figure out which tools are MCP tools - if last_message.name == "code_interpreter" or last_message.name == "python": + if ( + last_message.name == "code_interpreter" + or last_message.name == "python" + or last_message.name == "web_search_preview" + ): return True return False @@ -290,6 +294,34 @@ async def call_python_tool( return [message] + async def call_search_tool( + self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall + ) -> list[Message]: + self.called_tools.add("browser") + if isinstance(tool_session, Tool): + return await tool_session.get_result(self) + if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY: + try: + args = json.loads(last_msg.arguments) + except json.JSONDecodeError as e: + return _create_json_parse_error_messages(last_msg, e) + else: + args = json.loads(last_msg.arguments) + result = await tool_session.call_tool("search", args) + result_str = result.content[0].text + + content = TextContent(text=result_str) + # author = Author(role=Role.TOOL, name="python") + + message = CustomChatCompletionMessageParam( + role="tool", + content=[ + ChatCompletionContentPartTextParam(text=content, type="text") + ], # TODO: why is this nested? + ) + + return [message] + async def call_tool(self) -> list[CustomChatCompletionMessageParam]: if not self.parser.chat_completion_messages: return [] @@ -299,6 +331,10 @@ async def call_tool(self) -> list[CustomChatCompletionMessageParam]: return await self.call_python_tool( self._tool_sessions["python"], last_tool_request ) + elif last_tool_request.name == "web_search_preview": + return await self.call_search_tool( + self._tool_sessions["browser"], last_tool_request + ) # recipient = last_message.name == "code_interpreter" # if recipient is not None and recipient.startswith("python"): # return await self.call_python_tool(self._tool_sessions["python"], last_tool_request)