Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions litellm/llms/custom_httpx/aiohttp_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,17 @@ async def aclose(self) -> None:


class AiohttpTransport(httpx.AsyncBaseTransport):
def __init__(self, client: Union[ClientSession, Callable[[], ClientSession]]) -> None:
def __init__(self, client: Union[ClientSession, Callable[[], ClientSession]], _owns_session: bool = True) -> None:
self.client = client
self._owns_session = _owns_session

#########################################################
# Class variables for proxy settings
#########################################################
self.proxy_cache: Dict[str, Optional[str]] = {}

async def aclose(self) -> None:
if isinstance(self.client, ClientSession):
if self._owns_session and isinstance(self.client, ClientSession):
await self.client.close()


Expand All @@ -147,7 +148,11 @@ def __init__(
):
self.client = client
self._ssl_verify = ssl_verify # Store for per-request SSL override
super().__init__(client=client)
# If a pre-existing ClientSession is passed in, we don't own it
# and should not close it. If a factory is passed, we own sessions
# we create from it.
owns_session = not isinstance(client, ClientSession)
super().__init__(client=client, _owns_session=owns_session)
# Store the client factory for recreating sessions when needed
if callable(client):
self._client_factory = client
Expand All @@ -167,6 +172,7 @@ def _get_valid_client_session(self) -> ClientSession:
self.client = self._client_factory()
else:
self.client = ClientSession()
self._owns_session = True # We created this session, so we own it
# Don't return yet - check if the newly created session is valid

# Check if the session itself is closed
Expand All @@ -177,6 +183,7 @@ def _get_valid_client_session(self) -> ClientSession:
self.client = self._client_factory()
else:
self.client = ClientSession()
self._owns_session = True # We created this session, so we own it
return self.client

# Check if the existing session is still valid for the current event loop
Expand All @@ -186,30 +193,33 @@ def _get_valid_client_session(self) -> ClientSession:

# If session is from a different or closed loop, recreate it
if session_loop is None or session_loop != current_loop or session_loop.is_closed():
# Close old session to prevent leaks
# Close old session to prevent leaks (only if we own it)
old_session = self.client
try:
if not old_session.closed:
try:
asyncio.create_task(old_session.close())
except RuntimeError:
# Different event loop - can't schedule task, rely on GC
verbose_logger.debug("Old session from different loop, relying on GC")
except Exception as e:
verbose_logger.debug(f"Error closing old session: {e}")
if self._owns_session:
try:
if not old_session.closed:
try:
asyncio.create_task(old_session.close())
except RuntimeError:
# Different event loop - can't schedule task, rely on GC
verbose_logger.debug("Old session from different loop, relying on GC")
except Exception as e:
verbose_logger.debug(f"Error closing old session: {e}")

# Create a new session in the current event loop
if hasattr(self, "_client_factory") and callable(self._client_factory):
self.client = self._client_factory()
else:
self.client = ClientSession()
self._owns_session = True # We created this session, so we own it

except (RuntimeError, AttributeError):
# If we can't check the loop or session is invalid, recreate it
if hasattr(self, "_client_factory") and callable(self._client_factory):
self.client = self._client_factory()
else:
self.client = ClientSession()
self._owns_session = True # We created this session, so we own it

return self.client

Expand Down Expand Up @@ -306,6 +316,7 @@ async def handle_async_request(
self.client = self._client_factory()
else:
self.client = ClientSession()
self._owns_session = True # We created this session, so we own it
client_session = self.client

# Retry the request with the new session
Expand Down
48 changes: 44 additions & 4 deletions tests/test_litellm/llms/custom_httpx/test_aiohttp_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,12 @@ async def streaming_handler(request):
# but each chunk arrives quickly
response = web.StreamResponse()
await response.prepare(request)

# Send 5 chunks over 0.5 seconds total (0.1s between chunks)
for i in range(5):
await asyncio.sleep(0.05) # Less than sock_read timeout
await response.write(f"chunk{i}\n".encode())

await response.write_eof()
return response

Expand Down Expand Up @@ -436,12 +436,12 @@ def factory():
# This should succeed without timing out
response = await transport.handle_async_request(request)
assert response.status_code == 200

# Read the streaming response
chunks = []
async for chunk in response.aiter_bytes():
chunks.append(chunk)

# Verify we got all chunks
full_response = b"".join(chunks).decode()
assert "chunk0" in full_response
Expand Down Expand Up @@ -512,3 +512,43 @@ def factory():
assert counts["requests"] == 2 # First request failed, second succeeded
assert counts["sessions"] == 2 # Created 2 sessions for retry
assert response.status_code == 200


@pytest.mark.asyncio
async def test_shared_session_not_closed_by_transport():
"""Test that aclose() does NOT close a shared ClientSession passed to the transport (#21116)."""
import aiohttp

shared_session = aiohttp.ClientSession()
try:
transport = LiteLLMAiohttpTransport(client=shared_session)

# Transport should not own the session
assert transport._owns_session is False

# aclose() should NOT close the shared session
await transport.aclose()
assert not shared_session.closed, "Transport should not close a shared ClientSession it does not own"
finally:
await shared_session.close()


@pytest.mark.asyncio
async def test_factory_created_session_closed_by_transport():
"""Test that aclose() DOES close a session created from a factory."""
import aiohttp

def factory():
return aiohttp.ClientSession()

transport = LiteLLMAiohttpTransport(client=factory) # type: ignore

# Transport should own sessions it creates
assert transport._owns_session is True

# Force session creation
session = transport._get_valid_client_session()

# aclose() should close the session we created
await transport.aclose()
assert session.closed, "Transport should close a session it created from a factory"
Loading