Skip to content

Commit 21f2e30

Browse files
Resolve cancel scope error in MCP session cleanup with lifetime task (#931)
Previously, session cleanup was calling client.__aexit__() from a different task context than where __aenter__() was called, violating anyio's CancelScope requirement that enter and exit must happen in the same task. This caused "Attempted to exit cancel scope in a different task" errors during session cleanup. Solution: - Introduce a per-client lifetime task that manages the entire client lifecycle - The lifetime task enters the client context (async with client:) and waits for a stop_event signal before exiting - Session cleanup now signals the stop_event and waits for the lifetime task to complete, ensuring __aexit__ runs in the correct task context This ensures proper cancel scope handling and prevents resource leaks while maintaining thread-safe session management. ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/NVIDIA/NeMo-Agent-Toolkit/blob/develop/docs/source/resources/contributing.md). - We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. - Any contribution which contains commits that are not Signed-Off will not be accepted. - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. ## Summary by CodeRabbit * **Bug Fixes** * Sessions now shut down cleanly (stop signaling and fallback cleanup), reducing hangs and timeouts; inactive sessions are reliably removed. * **Refactor** * Session lifecycle reworked to include explicit stop events and managed background lifetime tasks for more robust session management. * **Chores** * Initialization and cleanup wiring improved to avoid orphaned tasks and lower resource use during idle/long runs. * **Tests** * Test harness updated with session cleanup helpers to ensure per-session tasks are torn down and improve isolation. Authors: - Anuradha Karuppiah (https://github.com/AnuradhaKaruppiah) Approvers: - Yuchen Zhang (https://github.com/yczhang-nv) URL: #931
1 parent 5d11923 commit 21f2e30

File tree

2 files changed

+415
-14
lines changed

2 files changed

+415
-14
lines changed

packages/nvidia_nat_mcp/src/nat/plugins/mcp/client_impl.py

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ class SessionData:
4444
ref_count: int = 0
4545
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
4646

47+
# lifetime task to respect task boundaries
48+
stop_event: asyncio.Event = field(default_factory=asyncio.Event)
49+
lifetime_task: asyncio.Task | None = None
50+
4751

4852
class MCPFunctionGroup(FunctionGroup):
4953
"""
@@ -202,7 +206,7 @@ async def _cleanup_inactive_sessions(self, max_age: timedelta | None = None):
202206
if max_age is None:
203207
max_age = self._client_config.session_idle_timeout if self._client_config else timedelta(hours=1)
204208

205-
to_close: list[tuple[str, MCPBaseClient]] = []
209+
to_close: list[tuple[str, SessionData]] = []
206210

207211
async with self._session_rwlock.writer:
208212
current_time = datetime.now()
@@ -222,7 +226,7 @@ async def _cleanup_inactive_sessions(self, max_age: timedelta | None = None):
222226
session_data = self._sessions[session_id]
223227
# Close the client connection
224228
if session_data:
225-
to_close.append((session_id, session_data.client))
229+
to_close.append((session_id, session_data))
226230
except Exception as e:
227231
logger.warning("Error cleaning up session client %s: %s", truncate_session_id(session_id), e)
228232
finally:
@@ -232,11 +236,22 @@ async def _cleanup_inactive_sessions(self, max_age: timedelta | None = None):
232236
logger.info(" Total sessions: %d", len(self._sessions))
233237

234238
# Close sessions outside the writer lock to avoid deadlock
235-
for session_id, client in to_close:
239+
for session_id, sdata in to_close:
236240
try:
237-
logger.info("Cleaning up inactive session client: %s", truncate_session_id(session_id))
238-
await client.__aexit__(None, None, None)
239-
logger.info("Cleaned up inactive session client: %s", truncate_session_id(session_id))
241+
if sdata.stop_event and sdata.lifetime_task:
242+
if not sdata.lifetime_task.done():
243+
# Instead of directly exiting the task, set the stop event
244+
# and wait for the task to exit. This ensures the cancel scope
245+
# is entered and exited in the same task.
246+
sdata.stop_event.set()
247+
await sdata.lifetime_task # __aexit__ runs in that task
248+
else:
249+
logger.debug("Session client %s lifetime task already done", truncate_session_id(session_id))
250+
else:
251+
# add fallback to ensure we clean up the client
252+
logger.warning("Session client %s lifetime task not found, cleaning up client",
253+
truncate_session_id(session_id))
254+
await sdata.client.__aexit__(None, None, None)
240255
except Exception as e:
241256
logger.warning("Error cleaning up session client %s: %s", truncate_session_id(session_id), e)
242257

@@ -284,10 +299,14 @@ async def _get_session_client(self, session_id: str) -> MCPBaseClient:
284299

285300
# Create session client lazily
286301
logger.info("Creating new MCP client for session: %s", truncate_session_id(session_id))
287-
session_client = await self._create_session_client(session_id)
288-
289-
# Create session data with all components
290-
session_data = SessionData(client=session_client, last_activity=datetime.now(), ref_count=0)
302+
session_client, stop_event, lifetime_task = await self._create_session_client(session_id)
303+
session_data = SessionData(
304+
client=session_client,
305+
last_activity=datetime.now(),
306+
ref_count=0,
307+
stop_event=stop_event,
308+
lifetime_task=lifetime_task,
309+
)
291310

292311
# Cache the session data
293312
self._sessions[session_id] = session_data
@@ -325,7 +344,7 @@ async def _session_usage_context(self, session_id: str):
325344
sdata.ref_count -= 1
326345
sdata.last_activity = datetime.now()
327346

328-
async def _create_session_client(self, session_id: str) -> MCPBaseClient:
347+
async def _create_session_client(self, session_id: str) -> tuple[MCPBaseClient, asyncio.Event, asyncio.Task]:
329348
"""Create a new MCP client instance for the session."""
330349
from nat.plugins.mcp.client_base import MCPStreamableHTTPClient
331350

@@ -348,11 +367,50 @@ async def _create_session_client(self, session_id: str) -> MCPBaseClient:
348367
# per-user sessions are only supported for streamable-http transport
349368
raise ValueError(f"Unsupported transport: {config.server.transport}")
350369

351-
# Initialize the client
352-
await client.__aenter__()
370+
ready = asyncio.Event()
371+
stop_event = asyncio.Event()
372+
373+
async def _lifetime():
374+
"""
375+
Create a lifetime task to respect task boundaries and ensure the
376+
cancel scope is entered and exited in the same task.
377+
"""
378+
try:
379+
async with client:
380+
ready.set()
381+
await stop_event.wait()
382+
except Exception:
383+
ready.set() # Ensure we don't hang the waiter
384+
raise
385+
386+
task = asyncio.create_task(_lifetime(), name=f"mcp-session-{truncate_session_id(session_id)}")
387+
388+
# Wait for initialization with timeout to prevent infinite hangs
389+
timeout = config.tool_call_timeout.total_seconds() if config else 300
390+
try:
391+
await asyncio.wait_for(ready.wait(), timeout=timeout)
392+
except TimeoutError:
393+
task.cancel()
394+
try:
395+
await task
396+
except asyncio.CancelledError:
397+
pass
398+
logger.error("Session client initialization timed out after %ds for %s",
399+
timeout,
400+
truncate_session_id(session_id))
401+
raise RuntimeError(f"Session client initialization timed out after {timeout}s")
402+
403+
# Check if initialization failed before ready was set
404+
if task.done():
405+
try:
406+
await task # Re-raise exception if the task failed
407+
except Exception as e:
408+
logger.error("Failed to initialize session client for %s: %s", truncate_session_id(session_id), e)
409+
raise RuntimeError(f"Failed to initialize session client: {e}") from e
353410

354411
logger.info("Created session client for session: %s", truncate_session_id(session_id))
355-
return client
412+
# NOTE: caller will place client into SessionData and attach stop_event/task
413+
return client, stop_event, task
356414

357415

358416
def mcp_session_tool_function(tool, function_group: MCPFunctionGroup):

0 commit comments

Comments
 (0)