diff --git a/litellm/llms/custom_httpx/aiohttp_transport.py b/litellm/llms/custom_httpx/aiohttp_transport.py index fb98006c7e4..6cec1f4fe16 100644 --- a/litellm/llms/custom_httpx/aiohttp_transport.py +++ b/litellm/llms/custom_httpx/aiohttp_transport.py @@ -119,8 +119,13 @@ async def aclose(self) -> None: class AiohttpTransport(httpx.AsyncBaseTransport): - def __init__(self, client: Union[ClientSession, Callable[[], ClientSession]]) -> None: + def __init__( + self, + client: Union[ClientSession, Callable[[], ClientSession]], + owns_session: bool = True, + ) -> None: self.client = client + self._owns_session = owns_session ######################################################### # Class variables for proxy settings @@ -128,7 +133,7 @@ def __init__(self, client: Union[ClientSession, Callable[[], ClientSession]]) -> self.proxy_cache: Dict[str, Optional[str]] = {} async def aclose(self) -> None: - if isinstance(self.client, ClientSession): + if self._owns_session and isinstance(self.client, ClientSession): await self.client.close() @@ -144,10 +149,11 @@ def __init__( self, client: Union[ClientSession, Callable[[], ClientSession]], ssl_verify: Optional[Union[bool, ssl.SSLContext]] = None, + owns_session: bool = True, ): self.client = client self._ssl_verify = ssl_verify # Store for per-request SSL override - super().__init__(client=client) + super().__init__(client=client, owns_session=owns_session) # Store the client factory for recreating sessions when needed if callable(client): self._client_factory = client diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 5cf6efe5ba2..328097639e5 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -866,6 +866,7 @@ def _create_aiohttp_transport( return LiteLLMAiohttpTransport( client=shared_session, ssl_verify=ssl_for_transport, + owns_session=False, ) # Create new session only if none provided or existing one is invalid diff --git a/tests/test_litellm/llms/custom_httpx/test_aiohttp_transport.py b/tests/test_litellm/llms/custom_httpx/test_aiohttp_transport.py index 002fa81b9b5..6e2e60ba0dd 100644 --- a/tests/test_litellm/llms/custom_httpx/test_aiohttp_transport.py +++ b/tests/test_litellm/llms/custom_httpx/test_aiohttp_transport.py @@ -12,10 +12,42 @@ from litellm.llms.custom_httpx.aiohttp_transport import ( AiohttpResponseStream, + AiohttpTransport, LiteLLMAiohttpTransport, ) +@pytest.mark.asyncio +async def test_aclose_does_not_close_shared_session(): + """Test that aclose() does not close a session it does not own (shared session).""" + session = aiohttp.ClientSession() + try: + transport = LiteLLMAiohttpTransport(client=session, owns_session=False) + await transport.aclose() + assert not session.closed, "Shared session should not be closed by transport" + finally: + await session.close() + + +@pytest.mark.asyncio +async def test_aclose_closes_owned_session(): + """Test that aclose() closes a session it owns.""" + session = aiohttp.ClientSession() + transport = LiteLLMAiohttpTransport(client=session, owns_session=True) + await transport.aclose() + assert session.closed, "Owned session should be closed by transport" + + +@pytest.mark.asyncio +async def test_owns_session_defaults_to_true(): + """Test that owns_session defaults to True for backwards compatibility.""" + session = aiohttp.ClientSession() + transport = AiohttpTransport(client=session) + assert transport._owns_session is True + await transport.aclose() + assert session.closed + + class MockAiohttpResponse: """Mock aiohttp ClientResponse for testing"""