diff --git a/litellm/llms/custom_httpx/aiohttp_transport.py b/litellm/llms/custom_httpx/aiohttp_transport.py index 1b03ec47643..30abf57f533 100644 --- a/litellm/llms/custom_httpx/aiohttp_transport.py +++ b/litellm/llms/custom_httpx/aiohttp_transport.py @@ -119,8 +119,9 @@ async def aclose(self) -> None: class AiohttpTransport(httpx.AsyncBaseTransport): - def __init__(self, client: Union[ClientSession, Callable[[], ClientSession]]) -> None: + def __init__(self, client: Union[ClientSession, Callable[[], ClientSession]], _owns_session: bool = True) -> None: self.client = client + self._owns_session = _owns_session ######################################################### # Class variables for proxy settings @@ -128,7 +129,7 @@ def __init__(self, client: Union[ClientSession, Callable[[], ClientSession]]) -> self.proxy_cache: Dict[str, Optional[str]] = {} async def aclose(self) -> None: - if isinstance(self.client, ClientSession): + if self._owns_session and isinstance(self.client, ClientSession): await self.client.close() @@ -147,7 +148,11 @@ def __init__( ): self.client = client self._ssl_verify = ssl_verify # Store for per-request SSL override - super().__init__(client=client) + # If a pre-existing ClientSession is passed in, we don't own it + # and should not close it. If a factory is passed, we own sessions + # we create from it. + owns_session = not isinstance(client, ClientSession) + super().__init__(client=client, _owns_session=owns_session) # Store the client factory for recreating sessions when needed if callable(client): self._client_factory = client @@ -167,6 +172,7 @@ def _get_valid_client_session(self) -> ClientSession: self.client = self._client_factory() else: self.client = ClientSession() + self._owns_session = True # We created this session, so we own it # Don't return yet - check if the newly created session is valid # Check if the session itself is closed @@ -177,6 +183,7 @@ def _get_valid_client_session(self) -> ClientSession: self.client = self._client_factory() else: self.client = ClientSession() + self._owns_session = True # We created this session, so we own it return self.client # Check if the existing session is still valid for the current event loop @@ -186,23 +193,25 @@ def _get_valid_client_session(self) -> ClientSession: # If session is from a different or closed loop, recreate it if session_loop is None or session_loop != current_loop or session_loop.is_closed(): - # Close old session to prevent leaks + # Close old session to prevent leaks (only if we own it) old_session = self.client - try: - if not old_session.closed: - try: - asyncio.create_task(old_session.close()) - except RuntimeError: - # Different event loop - can't schedule task, rely on GC - verbose_logger.debug("Old session from different loop, relying on GC") - except Exception as e: - verbose_logger.debug(f"Error closing old session: {e}") + if self._owns_session: + try: + if not old_session.closed: + try: + asyncio.create_task(old_session.close()) + except RuntimeError: + # Different event loop - can't schedule task, rely on GC + verbose_logger.debug("Old session from different loop, relying on GC") + except Exception as e: + verbose_logger.debug(f"Error closing old session: {e}") # Create a new session in the current event loop if hasattr(self, "_client_factory") and callable(self._client_factory): self.client = self._client_factory() else: self.client = ClientSession() + self._owns_session = True # We created this session, so we own it except (RuntimeError, AttributeError): # If we can't check the loop or session is invalid, recreate it @@ -210,6 +219,7 @@ def _get_valid_client_session(self) -> ClientSession: self.client = self._client_factory() else: self.client = ClientSession() + self._owns_session = True # We created this session, so we own it return self.client @@ -306,6 +316,7 @@ async def handle_async_request( self.client = self._client_factory() else: self.client = ClientSession() + self._owns_session = True # We created this session, so we own it client_session = self.client # Retry the request with the new session diff --git a/tests/test_litellm/llms/custom_httpx/test_aiohttp_transport.py b/tests/test_litellm/llms/custom_httpx/test_aiohttp_transport.py index 002fa81b9b5..13d4e1f11cc 100644 --- a/tests/test_litellm/llms/custom_httpx/test_aiohttp_transport.py +++ b/tests/test_litellm/llms/custom_httpx/test_aiohttp_transport.py @@ -396,12 +396,12 @@ async def streaming_handler(request): # but each chunk arrives quickly response = web.StreamResponse() await response.prepare(request) - + # Send 5 chunks over 0.5 seconds total (0.1s between chunks) for i in range(5): await asyncio.sleep(0.05) # Less than sock_read timeout await response.write(f"chunk{i}\n".encode()) - + await response.write_eof() return response @@ -436,12 +436,12 @@ def factory(): # This should succeed without timing out response = await transport.handle_async_request(request) assert response.status_code == 200 - + # Read the streaming response chunks = [] async for chunk in response.aiter_bytes(): chunks.append(chunk) - + # Verify we got all chunks full_response = b"".join(chunks).decode() assert "chunk0" in full_response @@ -512,3 +512,43 @@ def factory(): assert counts["requests"] == 2 # First request failed, second succeeded assert counts["sessions"] == 2 # Created 2 sessions for retry assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_shared_session_not_closed_by_transport(): + """Test that aclose() does NOT close a shared ClientSession passed to the transport (#21116).""" + import aiohttp + + shared_session = aiohttp.ClientSession() + try: + transport = LiteLLMAiohttpTransport(client=shared_session) + + # Transport should not own the session + assert transport._owns_session is False + + # aclose() should NOT close the shared session + await transport.aclose() + assert not shared_session.closed, "Transport should not close a shared ClientSession it does not own" + finally: + await shared_session.close() + + +@pytest.mark.asyncio +async def test_factory_created_session_closed_by_transport(): + """Test that aclose() DOES close a session created from a factory.""" + import aiohttp + + def factory(): + return aiohttp.ClientSession() + + transport = LiteLLMAiohttpTransport(client=factory) # type: ignore + + # Transport should own sessions it creates + assert transport._owns_session is True + + # Force session creation + session = transport._get_valid_client_session() + + # aclose() should close the session we created + await transport.aclose() + assert session.closed, "Transport should close a session it created from a factory"