@@ -94,16 +94,6 @@ def need_builtin_tool_call(self) -> bool:
9494 def render_for_completion (self ) -> list [int ]:
9595 pass
9696
97- @abstractmethod
98- async def init_tool_sessions (
99- self ,
100- tool_server : ToolServer | None ,
101- exit_stack : AsyncExitStack ,
102- request_id : str ,
103- mcp_tools : dict [str , Mcp ],
104- ) -> None :
105- pass
106-
10797 @abstractmethod
10898 async def __aenter__ (self ):
10999 pass
@@ -145,15 +135,6 @@ async def call_tool(self) -> list[Message]:
145135 def render_for_completion (self ) -> list [int ]:
146136 raise NotImplementedError ("Should not be called." )
147137
148- async def init_tool_sessions (
149- self ,
150- tool_server : ToolServer | None ,
151- exit_stack : AsyncExitStack ,
152- request_id : str ,
153- mcp_tools : dict [str , Mcp ],
154- ) -> None :
155- pass
156-
157138 async def __aenter__ (self ):
158139 return self
159140
@@ -170,13 +151,17 @@ def __init__(
170151 messages : list ,
171152 available_tools : list [str ],
172153 tool_server : Optional [ToolServer ],
154+ request_id : str ,
155+ mcp_tools : dict [str , Mcp ],
173156 ):
174157 self ._messages = messages
175158 self .finish_reason : str | None = None
176159 self .available_tools = available_tools
177160 self ._tool_sessions : dict [str , ClientSession | Tool ] = {}
178161 self .called_tools : set [str ] = set ()
179162 self ._tool_server = tool_server
163+ self .request_id = request_id
164+ self .mcp_tools = mcp_tools
180165 self ._async_exit_stack : Optional [AsyncExitStack ] = None
181166 self ._reference_count = 0
182167 self ._reference_count_lock = asyncio .Lock ()
@@ -328,18 +313,6 @@ def need_builtin_tool_call(self) -> bool:
328313 or recipient .startswith ("container." )
329314 )
330315
331- async def _get_tool_session (self , tool_name : str ) -> Union ["ClientSession" , Tool ]:
332- if tool_name not in self ._tool_sessions and self ._tool_server is not None :
333- assert self ._async_exit_stack is not None , (
334- "Async exit stack not set. Please report this issue."
335- )
336- self ._tool_sessions [
337- tool_name
338- ] = await self ._async_exit_stack .enter_async_context (
339- self ._tool_server .new_session (tool_name )
340- )
341- return self ._tool_sessions [tool_name ]
342-
343316 async def call_tool (self ) -> list [Message ]:
344317 if not self .messages :
345318 return []
@@ -363,6 +336,24 @@ async def call_tool(self) -> list[Message]:
363336 def render_for_completion (self ) -> list [int ]:
364337 return render_for_completion (self .messages )
365338
339+ async def _get_tool_session (self , tool_name : str ) -> Union ["ClientSession" , Tool ]:
340+ if tool_name not in self ._tool_sessions and self ._tool_server is not None :
341+ assert self ._async_exit_stack is not None , (
342+ "Async exit stack not set. Please report this issue."
343+ )
344+ tool_type = _map_tool_name_to_tool_type (tool_name )
345+ headers = (
346+ self .mcp_tools [tool_type ].headers
347+ if tool_type in self .mcp_tools
348+ else None
349+ )
350+ self ._tool_sessions [
351+ tool_name
352+ ] = await self ._async_exit_stack .enter_async_context (
353+ self ._tool_server .new_session (tool_name , self .request_id , headers )
354+ )
355+ return self ._tool_sessions [tool_name ]
356+
366357 async def call_search_tool (
367358 self , tool_session : Union ["ClientSession" , Tool ], last_msg : Message
368359 ) -> list [Message ]:
@@ -408,26 +399,6 @@ async def call_python_tool(
408399 )
409400 ]
410401
411- async def init_tool_sessions (
412- self ,
413- tool_server : ToolServer | None ,
414- exit_stack : AsyncExitStack ,
415- request_id : str ,
416- mcp_tools : dict [str , Mcp ],
417- ):
418- if tool_server :
419- for tool_name in self .available_tools :
420- if tool_name not in self ._tool_sessions :
421- tool_type = _map_tool_name_to_tool_type (tool_name )
422- headers = (
423- mcp_tools [tool_type ].headers if tool_type in mcp_tools else None
424- )
425- tool_session = await exit_stack .enter_async_context (
426- tool_server .new_session (tool_name , request_id , headers )
427- )
428- self ._tool_sessions [tool_name ] = tool_session
429- exit_stack .push_async_exit (self .cleanup_session )
430-
431402 async def call_container_tool (
432403 self , tool_session : Union ["ClientSession" , Tool ], last_msg : Message
433404 ) -> list [Message ]:
@@ -489,10 +460,11 @@ async def __aenter__(self):
489460 if self ._async_exit_stack is None :
490461 assert self ._reference_count == 1 , (
491462 "Reference count of exit stack should be "
463+ "1 when initializing exit stack."
492464 )
493- "1 when initializing exit stack."
494465 self ._async_exit_stack = AsyncExitStack ()
495466 await self ._async_exit_stack .__aenter__ ()
467+ self ._async_exit_stack .push_async_callback (self .cleanup_session )
496468 return self
497469
498470 async def __aexit__ (self , exc_type , exc , tb ):
0 commit comments