Skip to content

Commit b610210

Browse files
committed
rebase main
Signed-off-by: wuhang <[email protected]>
1 parent e72ca86 commit b610210

File tree

2 files changed

+39
-54
lines changed

2 files changed

+39
-54
lines changed

vllm/entrypoints/context.py

Lines changed: 24 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,6 @@ def need_builtin_tool_call(self) -> bool:
9494
def render_for_completion(self) -> list[int]:
9595
pass
9696

97-
@abstractmethod
98-
async def init_tool_sessions(
99-
self,
100-
tool_server: ToolServer | None,
101-
exit_stack: AsyncExitStack,
102-
request_id: str,
103-
mcp_tools: dict[str, Mcp],
104-
) -> None:
105-
pass
106-
10797
@abstractmethod
10898
async def __aenter__(self):
10999
pass
@@ -145,15 +135,6 @@ async def call_tool(self) -> list[Message]:
145135
def render_for_completion(self) -> list[int]:
146136
raise NotImplementedError("Should not be called.")
147137

148-
async def init_tool_sessions(
149-
self,
150-
tool_server: ToolServer | None,
151-
exit_stack: AsyncExitStack,
152-
request_id: str,
153-
mcp_tools: dict[str, Mcp],
154-
) -> None:
155-
pass
156-
157138
async def __aenter__(self):
158139
return self
159140

@@ -170,13 +151,17 @@ def __init__(
170151
messages: list,
171152
available_tools: list[str],
172153
tool_server: Optional[ToolServer],
154+
request_id: str,
155+
mcp_tools: dict[str, Mcp],
173156
):
174157
self._messages = messages
175158
self.finish_reason: str | None = None
176159
self.available_tools = available_tools
177160
self._tool_sessions: dict[str, ClientSession | Tool] = {}
178161
self.called_tools: set[str] = set()
179162
self._tool_server = tool_server
163+
self.request_id = request_id
164+
self.mcp_tools = mcp_tools
180165
self._async_exit_stack: Optional[AsyncExitStack] = None
181166
self._reference_count = 0
182167
self._reference_count_lock = asyncio.Lock()
@@ -328,18 +313,6 @@ def need_builtin_tool_call(self) -> bool:
328313
or recipient.startswith("container.")
329314
)
330315

331-
async def _get_tool_session(self, tool_name: str) -> Union["ClientSession", Tool]:
332-
if tool_name not in self._tool_sessions and self._tool_server is not None:
333-
assert self._async_exit_stack is not None, (
334-
"Async exit stack not set. Please report this issue."
335-
)
336-
self._tool_sessions[
337-
tool_name
338-
] = await self._async_exit_stack.enter_async_context(
339-
self._tool_server.new_session(tool_name)
340-
)
341-
return self._tool_sessions[tool_name]
342-
343316
async def call_tool(self) -> list[Message]:
344317
if not self.messages:
345318
return []
@@ -363,6 +336,24 @@ async def call_tool(self) -> list[Message]:
363336
def render_for_completion(self) -> list[int]:
364337
return render_for_completion(self.messages)
365338

339+
async def _get_tool_session(self, tool_name: str) -> Union["ClientSession", Tool]:
340+
if tool_name not in self._tool_sessions and self._tool_server is not None:
341+
assert self._async_exit_stack is not None, (
342+
"Async exit stack not set. Please report this issue."
343+
)
344+
tool_type = _map_tool_name_to_tool_type(tool_name)
345+
headers = (
346+
self.mcp_tools[tool_type].headers
347+
if tool_type in self.mcp_tools
348+
else None
349+
)
350+
self._tool_sessions[
351+
tool_name
352+
] = await self._async_exit_stack.enter_async_context(
353+
self._tool_server.new_session(tool_name, self.request_id, headers)
354+
)
355+
return self._tool_sessions[tool_name]
356+
366357
async def call_search_tool(
367358
self, tool_session: Union["ClientSession", Tool], last_msg: Message
368359
) -> list[Message]:
@@ -408,26 +399,6 @@ async def call_python_tool(
408399
)
409400
]
410401

411-
async def init_tool_sessions(
412-
self,
413-
tool_server: ToolServer | None,
414-
exit_stack: AsyncExitStack,
415-
request_id: str,
416-
mcp_tools: dict[str, Mcp],
417-
):
418-
if tool_server:
419-
for tool_name in self.available_tools:
420-
if tool_name not in self._tool_sessions:
421-
tool_type = _map_tool_name_to_tool_type(tool_name)
422-
headers = (
423-
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
424-
)
425-
tool_session = await exit_stack.enter_async_context(
426-
tool_server.new_session(tool_name, request_id, headers)
427-
)
428-
self._tool_sessions[tool_name] = tool_session
429-
exit_stack.push_async_exit(self.cleanup_session)
430-
431402
async def call_container_tool(
432403
self, tool_session: Union["ClientSession", Tool], last_msg: Message
433404
) -> list[Message]:
@@ -489,10 +460,11 @@ async def __aenter__(self):
489460
if self._async_exit_stack is None:
490461
assert self._reference_count == 1, (
491462
"Reference count of exit stack should be "
463+
"1 when initializing exit stack."
492464
)
493-
"1 when initializing exit stack."
494465
self._async_exit_stack = AsyncExitStack()
495466
await self._async_exit_stack.__aenter__()
467+
self._async_exit_stack.push_async_callback(self.cleanup_session)
496468
return self
497469

498470
async def __aexit__(self, exc_type, exc, tb):

vllm/entrypoints/openai/serving_responses.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,15 +349,28 @@ async def create_responses(
349349
else await self._get_trace_headers(raw_request.headers)
350350
)
351351

352+
mcp_tools = {
353+
tool.server_label: tool
354+
for tool in request.tools
355+
if tool.type == "mcp"
356+
}
352357
context: ConversationContext
353358
if self.use_harmony:
354359
if request.stream:
355360
context = StreamingHarmonyContext(
356-
messages, available_tools, self.tool_server
361+
messages,
362+
available_tools,
363+
self.tool_server,
364+
request.request_id,
365+
mcp_tools,
357366
)
358367
else:
359368
context = HarmonyContext(
360-
messages, available_tools, self.tool_server
369+
messages,
370+
available_tools,
371+
self.tool_server,
372+
request.request_id,
373+
mcp_tools,
361374
)
362375
else:
363376
context = SimpleContext()

0 commit comments

Comments
 (0)