diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 7a6752fbff8..1f739a60b40 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -8,6 +8,7 @@ import traceback from typing import Any, Callable, Dict, List, Optional, Union, cast +import anyio import httpx from pydantic import BaseModel @@ -156,6 +157,27 @@ def __iter__(self): def __aiter__(self): return self + async def aclose(self): + if self.completion_stream is not None: + stream_to_close = self.completion_stream + self.completion_stream = None + # Shield from anyio cancellation so cleanup awaits can complete. + # Without this, CancelledError is thrown into every await during + # task group cancellation, preventing HTTP connection release. + with anyio.CancelScope(shield=True): + try: + if hasattr(stream_to_close, "aclose"): + await stream_to_close.aclose() + elif hasattr(stream_to_close, "close"): + result = stream_to_close.close() + if result is not None: + await result + except BaseException as e: + verbose_logger.debug( + "CustomStreamWrapper.aclose: error closing completion_stream: %s", + e, + ) + def check_send_stream_usage(self, stream_options: Optional[dict]): return ( stream_options is not None diff --git a/litellm/llms/custom_httpx/aiohttp_transport.py b/litellm/llms/custom_httpx/aiohttp_transport.py index 6cec1f4fe16..60f34a2a825 100644 --- a/litellm/llms/custom_httpx/aiohttp_transport.py +++ b/litellm/llms/custom_httpx/aiohttp_transport.py @@ -330,7 +330,7 @@ async def handle_async_request( return httpx.Response( status_code=response.status, headers=response.headers, - content=AiohttpResponseStream(response), + stream=AiohttpResponseStream(response), request=request, ) diff --git a/litellm/llms/openai/openai.py b/litellm/llms/openai/openai.py index da87852dff5..c7524925bd0 100644 --- a/litellm/llms/openai/openai.py +++ b/litellm/llms/openai/openai.py @@ -693,6 +693,7 @@ def completion( # type: ignore # noqa: PLR0915 organization=organization, drop_params=drop_params, stream_options=stream_options, + shared_session=shared_session, ) else: return self.acompletion( @@ -1063,6 +1064,7 @@ async def async_streaming( headers=None, drop_params: Optional[bool] = None, stream_options: Optional[dict] = None, + shared_session: Optional["ClientSession"] = None, ): response = None data = provider_config.transform_request( @@ -1087,6 +1089,7 @@ async def async_streaming( max_retries=max_retries, organization=organization, client=client, + shared_session=shared_session, ) ## LOGGING logging_obj.pre_call( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 75d7fcc4f3a..c1181dd52c2 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1,3 +1,4 @@ +import anyio import asyncio import copy import enum @@ -5168,6 +5169,20 @@ async def async_data_generator( ) error_returned = json.dumps({"error": proxy_exception.to_dict()}) yield f"data: {error_returned}\n\n" + finally: + # Close the response stream to release the underlying HTTP connection + # back to the connection pool. This prevents pool exhaustion when + # clients disconnect mid-stream. + # Shield from cancellation so the close awaits can complete. + with anyio.CancelScope(shield=True): + if hasattr(response, "aclose"): + try: + await response.aclose() + except BaseException as e: + verbose_proxy_logger.debug( + "async_data_generator: error closing response stream: %s", + e, + ) def select_data_generator( diff --git a/litellm/router.py b/litellm/router.py index 2858a6a32f2..022b707df25 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -33,6 +33,7 @@ cast, ) +import anyio import httpx import openai from openai import AsyncOpenAI @@ -1561,6 +1562,7 @@ async def __anext__(self): return await self._async_generator.__anext__() async def stream_with_fallbacks(): + fallback_response = None # Track for cleanup in finally try: async for item in model_response: yield item @@ -1659,6 +1661,30 @@ async def stream_with_fallbacks(): f"Fallback also failed: {fallback_error}" ) raise fallback_error + finally: + # Close the underlying streams to release HTTP connections + # back to the connection pool when the generator is closed + # (e.g. on client disconnect). + # Shield from anyio cancellation so the awaits can complete. + with anyio.CancelScope(shield=True): + if hasattr(model_response, "aclose"): + try: + await model_response.aclose() + except BaseException as e: + verbose_router_logger.debug( + "stream_with_fallbacks: error closing model_response: %s", + e, + ) + if fallback_response is not None and hasattr( + fallback_response, "aclose" + ): + try: + await fallback_response.aclose() + except BaseException as e: + verbose_router_logger.debug( + "stream_with_fallbacks: error closing fallback_response: %s", + e, + ) return FallbackStreamWrapper(stream_with_fallbacks()) diff --git a/pyproject.toml b/pyproject.toml index e1eb60f3db2..469a892fbc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ pydantic = "^2.5.0" jsonschema = ">=4.23.0,<5.0.0" numpydoc = {version = "*", optional = true} # used in utils.py -uvicorn = {version = "^0.31.1", optional = true} +uvicorn = {version = ">=0.32.1,<1.0.0", optional = true} uvloop = {version = "^0.21.0", optional = true, markers="sys_platform != 'win32'"} gunicorn = {version = "^23.0.0", optional = true} fastapi = {version = ">=0.120.1", optional = true} diff --git a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py index ec2f528a35d..73ecaa20a2d 100644 --- a/tests/test_litellm/litellm_core_utils/test_streaming_handler.py +++ b/tests/test_litellm/litellm_core_utils/test_streaming_handler.py @@ -2,7 +2,7 @@ import os import sys import time -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest @@ -1185,3 +1185,50 @@ def test_is_chunk_non_empty_with_valid_tool_calls( ) is True ) + + +@pytest.mark.asyncio +async def test_custom_stream_wrapper_aclose(): + """Test that aclose() delegates to the underlying completion_stream's aclose()""" + mock_stream = AsyncMock() + mock_stream.aclose = AsyncMock() + + wrapper = CustomStreamWrapper( + completion_stream=mock_stream, + model=None, + logging_obj=MagicMock(), + custom_llm_provider=None, + ) + + await wrapper.aclose() + mock_stream.aclose.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_custom_stream_wrapper_aclose_no_underlying(): + """Test that aclose() is safe when completion_stream has no aclose method""" + mock_stream = MagicMock(spec=[]) # No aclose attribute + + wrapper = CustomStreamWrapper( + completion_stream=mock_stream, + model=None, + logging_obj=MagicMock(), + custom_llm_provider=None, + ) + + # Should not raise + await wrapper.aclose() + + +@pytest.mark.asyncio +async def test_custom_stream_wrapper_aclose_none_stream(): + """Test that aclose() is safe when completion_stream is None""" + wrapper = CustomStreamWrapper( + completion_stream=None, + model=None, + logging_obj=MagicMock(), + custom_llm_provider=None, + ) + + # Should not raise + await wrapper.aclose() diff --git a/tests/test_litellm/proxy/test_proxy_server.py b/tests/test_litellm/proxy/test_proxy_server.py index ab414db3569..5f54c151d83 100644 --- a/tests/test_litellm/proxy/test_proxy_server.py +++ b/tests/test_litellm/proxy/test_proxy_server.py @@ -3519,6 +3519,157 @@ def test_invitation_endpoints_non_admin_denied( assert "not allowed" in str(error_content).lower() +@pytest.mark.asyncio +async def test_async_data_generator_cleanup_on_early_exit(): + """ + Test that async_data_generator calls response.aclose() in the finally block + when the generator is abandoned mid-stream (client disconnect). + """ + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy.proxy_server import async_data_generator + from litellm.proxy.utils import ProxyLogging + + mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) + mock_request_data = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "test"}], + } + + mock_chunks = [ + {"choices": [{"delta": {"content": "Hello"}}]}, + {"choices": [{"delta": {"content": " world"}}]}, + {"choices": [{"delta": {"content": " more"}}]}, + ] + + mock_proxy_logging_obj = MagicMock(spec=ProxyLogging) + + async def mock_streaming_iterator(*args, **kwargs): + for chunk in mock_chunks: + yield chunk + + mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = ( + mock_streaming_iterator + ) + mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock( + side_effect=lambda **kwargs: kwargs.get("response") + ) + mock_proxy_logging_obj.post_call_failure_hook = AsyncMock() + + # Create a mock response with aclose + mock_response = MagicMock() + mock_response.aclose = AsyncMock() + + with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj): + # Consume only the first chunk then abandon the generator (simulates client disconnect) + gen = async_data_generator( + mock_response, mock_user_api_key_dict, mock_request_data + ) + first_chunk = await gen.__anext__() + assert first_chunk.startswith("data: ") + + # Close the generator early (simulates what ASGI does on client disconnect) + await gen.aclose() + + # Verify aclose was called on the response to release the HTTP connection + mock_response.aclose.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_async_data_generator_cleanup_on_normal_completion(): + """ + Test that async_data_generator calls response.aclose() even on normal completion. + """ + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy.proxy_server import async_data_generator + from litellm.proxy.utils import ProxyLogging + + mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) + mock_request_data = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "test"}], + } + + mock_chunks = [ + {"choices": [{"delta": {"content": "Hello"}}]}, + ] + + mock_proxy_logging_obj = MagicMock(spec=ProxyLogging) + + async def mock_streaming_iterator(*args, **kwargs): + for chunk in mock_chunks: + yield chunk + + mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = ( + mock_streaming_iterator + ) + mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock( + side_effect=lambda **kwargs: kwargs.get("response") + ) + mock_proxy_logging_obj.post_call_failure_hook = AsyncMock() + + mock_response = MagicMock() + mock_response.aclose = AsyncMock() + + with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj): + yielded_data = [] + async for data in async_data_generator( + mock_response, mock_user_api_key_dict, mock_request_data + ): + yielded_data.append(data) + + # Should have completed normally with [DONE] + assert any("[DONE]" in d for d in yielded_data) + # aclose should still be called via finally block + mock_response.aclose.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_async_data_generator_cleanup_on_midstream_error(): + """ + Test that async_data_generator calls response.aclose() via finally block + even when an exception occurs mid-stream. + """ + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy.proxy_server import async_data_generator + from litellm.proxy.utils import ProxyLogging + + mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) + mock_request_data = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "test"}], + } + + mock_proxy_logging_obj = MagicMock(spec=ProxyLogging) + + async def mock_streaming_iterator_with_error(*args, **kwargs): + yield {"choices": [{"delta": {"content": "Hello"}}]} + raise RuntimeError("upstream connection reset") + + mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = ( + mock_streaming_iterator_with_error + ) + mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock( + side_effect=lambda **kwargs: kwargs.get("response") + ) + mock_proxy_logging_obj.post_call_failure_hook = AsyncMock() + + mock_response = MagicMock() + mock_response.aclose = AsyncMock() + + with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj): + yielded_data = [] + async for data in async_data_generator( + mock_response, mock_user_api_key_dict, mock_request_data + ): + yielded_data.append(data) + + # Should have yielded data chunk and then an error chunk + assert len(yielded_data) >= 2 + assert any("error" in d for d in yielded_data) + # aclose must still be called via finally block despite the error + mock_response.aclose.assert_awaited_once() + + # ============================================================================ # store_model_in_db DB Config Override Tests # ============================================================================ diff --git a/tests/test_litellm/test_streaming_connection_cleanup.py b/tests/test_litellm/test_streaming_connection_cleanup.py new file mode 100644 index 00000000000..677046bc66c --- /dev/null +++ b/tests/test_litellm/test_streaming_connection_cleanup.py @@ -0,0 +1,391 @@ +""" +Regression tests for streaming connection pool leak fix. +""" + +import asyncio +import os +import sys +from unittest.mock import MagicMock, patch + +import anyio +import httpx +import pytest + +sys.path.insert(0, os.path.abspath("../..")) + +from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper +from litellm.llms.custom_httpx.aiohttp_transport import ( + AiohttpResponseStream, + LiteLLMAiohttpTransport, +) + + +# ── aiohttp transport layer tests ────────────────────────────── + + +@pytest.mark.asyncio +async def test_aiohttp_transport_response_uses_stream_not_content(): + """handle_async_request must use stream= so aclose() propagates to AiohttpResponseStream.""" + + class FakeSession: + closed = False + + def __init__(self): + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = None + + def request(self, **kwargs): + class Resp: + status = 200 + headers = {} + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + @property + def content(self): + class C: + async def iter_chunked(self, size): + yield b"data" + + return C() + + return Resp() + + transport = LiteLLMAiohttpTransport(client=lambda: FakeSession()) # type: ignore + response = await transport.handle_async_request( + httpx.Request("GET", "http://example.com") + ) + + assert isinstance(response.stream, AiohttpResponseStream) + + +@pytest.mark.asyncio +async def test_aiohttp_response_stream_aclose_releases_connection(): + """AiohttpResponseStream.aclose() must call __aexit__ on the aiohttp response.""" + aexit_called = False + + class MockResponse: + status = 200 + headers = {} + + @property + def content(self): + class C: + async def iter_chunked(self, size): + yield b"data" + + return C() + + async def __aexit__(self, *args): + nonlocal aexit_called + aexit_called = True + + stream = AiohttpResponseStream(MockResponse()) # type: ignore + await stream.aclose() + assert aexit_called + + +# ── CustomStreamWrapper.aclose() tests ───────────────────────── + + +@pytest.mark.asyncio +async def test_aclose_falls_back_to_close(): + """OpenAI's AsyncStream has close() but not aclose(). Must fall back.""" + close_called = False + + class FakeAsyncStream: + async def close(self): + nonlocal close_called + close_called = True + + wrapper = CustomStreamWrapper( + completion_stream=FakeAsyncStream(), + model=None, + logging_obj=MagicMock(), + custom_llm_provider=None, + ) + + await wrapper.aclose() + assert close_called + + +@pytest.mark.asyncio +async def test_aclose_prefers_aclose_over_close(): + """When both aclose() and close() exist, aclose() should be preferred.""" + aclose_called = False + close_called = False + + class FakeStream: + async def aclose(self): + nonlocal aclose_called + aclose_called = True + + async def close(self): + nonlocal close_called + close_called = True + + wrapper = CustomStreamWrapper( + completion_stream=FakeStream(), + model=None, + logging_obj=MagicMock(), + custom_llm_provider=None, + ) + + await wrapper.aclose() + assert aclose_called + assert not close_called + + +@pytest.mark.asyncio +async def test_aclose_completes_under_cancellation(): + """aclose() must shield cleanup from CancelledError so streams actually close.""" + aclose_completed = False + + class SlowCloseStream: + async def aclose(self): + await anyio.sleep(0) + nonlocal aclose_completed + aclose_completed = True + + wrapper = CustomStreamWrapper( + completion_stream=SlowCloseStream(), + model=None, + logging_obj=MagicMock(), + custom_llm_provider=None, + ) + + with anyio.CancelScope() as scope: + scope.cancel() + await wrapper.aclose() + + assert aclose_completed + + +# ── Router stream_with_fallbacks cleanup tests ────────────────── + + +@pytest.mark.asyncio +async def test_stream_with_fallbacks_closes_stream_on_generator_close(): + """Closing the FallbackStreamWrapper must aclose() the underlying model_response + via stream_with_fallbacks' finally block.""" + from litellm.router import Router + + stream_closed = False + + class FakeStream(CustomStreamWrapper): + def __init__(self): + super().__init__( + completion_stream=None, + model="test-model", + logging_obj=MagicMock(), + custom_llm_provider="openai", + ) + self._items = ["chunk1", "chunk2", "chunk3"] + self._index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self._index >= len(self._items): + raise StopAsyncIteration + item = self._items[self._index] + self._index += 1 + return item + + async def aclose(self): + nonlocal stream_closed + stream_closed = True + + router = Router( + model_list=[ + { + "model_name": "test-model", + "litellm_params": { + "model": "openai/test", + "api_key": "fake", + }, + } + ] + ) + + fake_stream = FakeStream() + + # Call _acompletion_streaming_iterator directly so we go through + # stream_with_fallbacks and its finally block + result = await router._acompletion_streaming_iterator( + model_response=fake_stream, + messages=[{"role": "user", "content": "hi"}], + initial_kwargs={"model": "test-model"}, + ) + + # Consume one chunk then close (simulates client disconnect) + async for _ in result: + break + await result.aclose() + + assert stream_closed, "model_response stream was not closed by stream_with_fallbacks finally block" + + +@pytest.mark.asyncio +async def test_stream_with_fallbacks_closes_stream_on_normal_completion(): + """stream_with_fallbacks must aclose() model_response even on normal completion.""" + from litellm.router import Router + + stream_closed = False + + class FakeStream(CustomStreamWrapper): + def __init__(self): + super().__init__( + completion_stream=None, + model="test-model", + logging_obj=MagicMock(), + custom_llm_provider="openai", + ) + self._items = ["chunk1"] + self._index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self._index >= len(self._items): + raise StopAsyncIteration + item = self._items[self._index] + self._index += 1 + return item + + async def aclose(self): + nonlocal stream_closed + stream_closed = True + + router = Router( + model_list=[ + { + "model_name": "test-model", + "litellm_params": { + "model": "openai/test", + "api_key": "fake", + }, + } + ] + ) + + fake_stream = FakeStream() + + result = await router._acompletion_streaming_iterator( + model_response=fake_stream, + messages=[{"role": "user", "content": "hi"}], + initial_kwargs={"model": "test-model"}, + ) + + # Exhaust the stream fully + async for _ in result: + pass + await result.aclose() + + assert stream_closed, "model_response stream was not closed after normal completion" + + +@pytest.mark.asyncio +async def test_stream_with_fallbacks_closes_both_on_fallback_disconnect(): + """When a fallback is triggered and the client disconnects during fallback + iteration, both model_response and fallback_response must be closed.""" + from litellm.exceptions import MidStreamFallbackError + from litellm.router import Router + + model_closed = False + fallback_closed = False + + class FakeModelStream(CustomStreamWrapper): + """Stream that raises MidStreamFallbackError immediately to trigger fallback.""" + + def __init__(self): + super().__init__( + completion_stream=None, + model="test-model", + logging_obj=MagicMock(), + custom_llm_provider="openai", + ) + self.chunks = [] + + def __aiter__(self): + return self + + async def __anext__(self): + raise MidStreamFallbackError( + message="test mid-stream error", + model="test-model", + llm_provider="openai", + generated_content="", + ) + + async def aclose(self): + nonlocal model_closed + model_closed = True + + class FakeFallbackStream: + """Fallback stream that yields chunks.""" + + def __init__(self): + self._items = ["fb1", "fb2", "fb3"] + self._index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self._index >= len(self._items): + raise StopAsyncIteration + item = self._items[self._index] + self._index += 1 + return item + + async def aclose(self): + nonlocal fallback_closed + fallback_closed = True + + router = Router( + model_list=[ + { + "model_name": "test-model", + "litellm_params": { + "model": "openai/test", + "api_key": "fake", + }, + } + ] + ) + + fake_model_stream = FakeModelStream() + fake_fallback_stream = FakeFallbackStream() + + # Mock async_function_with_fallbacks_common_utils to return the fallback stream + # instead of actually calling through the full fallback machinery + with patch.object( + router, + "async_function_with_fallbacks_common_utils", + return_value=fake_fallback_stream, + ): + result = await router._acompletion_streaming_iterator( + model_response=fake_model_stream, + messages=[{"role": "user", "content": "hi"}], + initial_kwargs={ + "model": "test-model", + "fallbacks": ["other-model"], + }, + ) + + # Consume one fallback chunk then close (simulates client disconnect) + async for _ in result: + break + await result.aclose() + + assert model_closed, "model_response stream was not closed" + assert fallback_closed, "fallback_response stream was not closed"