Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions litellm/litellm_core_utils/streaming_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import traceback
from typing import Any, Callable, Dict, List, Optional, Union, cast

import anyio
import httpx
from pydantic import BaseModel

Expand Down Expand Up @@ -156,6 +157,27 @@ def __iter__(self):
def __aiter__(self):
return self

async def aclose(self):
if self.completion_stream is not None:
stream_to_close = self.completion_stream
self.completion_stream = None
# Shield from anyio cancellation so cleanup awaits can complete.
# Without this, CancelledError is thrown into every await during
# task group cancellation, preventing HTTP connection release.
with anyio.CancelScope(shield=True):
try:
if hasattr(stream_to_close, "aclose"):
await stream_to_close.aclose()
elif hasattr(stream_to_close, "close"):
result = stream_to_close.close()
if result is not None:
await result
except BaseException as e:
verbose_logger.debug(
"CustomStreamWrapper.aclose: error closing completion_stream: %s",
e,
)

def check_send_stream_usage(self, stream_options: Optional[dict]):
return (
stream_options is not None
Expand Down
2 changes: 1 addition & 1 deletion litellm/llms/custom_httpx/aiohttp_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ async def handle_async_request(
return httpx.Response(
status_code=response.status,
headers=response.headers,
content=AiohttpResponseStream(response),
stream=AiohttpResponseStream(response),
request=request,
)

Expand Down
3 changes: 3 additions & 0 deletions litellm/llms/openai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ def completion( # type: ignore # noqa: PLR0915
organization=organization,
drop_params=drop_params,
stream_options=stream_options,
shared_session=shared_session,
)
else:
return self.acompletion(
Expand Down Expand Up @@ -1063,6 +1064,7 @@ async def async_streaming(
headers=None,
drop_params: Optional[bool] = None,
stream_options: Optional[dict] = None,
shared_session: Optional["ClientSession"] = None,
):
response = None
data = provider_config.transform_request(
Expand All @@ -1087,6 +1089,7 @@ async def async_streaming(
max_retries=max_retries,
organization=organization,
client=client,
shared_session=shared_session,
)
## LOGGING
logging_obj.pre_call(
Expand Down
15 changes: 15 additions & 0 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import anyio
import asyncio
import copy
import enum
Expand Down Expand Up @@ -5168,6 +5169,20 @@ async def async_data_generator(
)
error_returned = json.dumps({"error": proxy_exception.to_dict()})
yield f"data: {error_returned}\n\n"
finally:
# Close the response stream to release the underlying HTTP connection
# back to the connection pool. This prevents pool exhaustion when
# clients disconnect mid-stream.
# Shield from cancellation so the close awaits can complete.
with anyio.CancelScope(shield=True):
if hasattr(response, "aclose"):
try:
await response.aclose()
except BaseException as e:
verbose_proxy_logger.debug(
"async_data_generator: error closing response stream: %s",
e,
)


def select_data_generator(
Expand Down
26 changes: 26 additions & 0 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
cast,
)

import anyio
import httpx
import openai
from openai import AsyncOpenAI
Expand Down Expand Up @@ -1561,6 +1562,7 @@ async def __anext__(self):
return await self._async_generator.__anext__()

async def stream_with_fallbacks():
fallback_response = None # Track for cleanup in finally
try:
async for item in model_response:
yield item
Expand Down Expand Up @@ -1659,6 +1661,30 @@ async def stream_with_fallbacks():
f"Fallback also failed: {fallback_error}"
)
raise fallback_error
finally:
# Close the underlying streams to release HTTP connections
# back to the connection pool when the generator is closed
# (e.g. on client disconnect).
# Shield from anyio cancellation so the awaits can complete.
with anyio.CancelScope(shield=True):
if hasattr(model_response, "aclose"):
try:
await model_response.aclose()
except BaseException as e:
verbose_router_logger.debug(
"stream_with_fallbacks: error closing model_response: %s",
e,
)
if fallback_response is not None and hasattr(
fallback_response, "aclose"
):
try:
await fallback_response.aclose()
except BaseException as e:
verbose_router_logger.debug(
"stream_with_fallbacks: error closing fallback_response: %s",
e,
)

return FallbackStreamWrapper(stream_with_fallbacks())

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pydantic = "^2.5.0"
jsonschema = ">=4.23.0,<5.0.0"
numpydoc = {version = "*", optional = true} # used in utils.py

uvicorn = {version = "^0.31.1", optional = true}
uvicorn = {version = ">=0.32.1,<1.0.0", optional = true}
uvloop = {version = "^0.21.0", optional = true, markers="sys_platform != 'win32'"}
gunicorn = {version = "^23.0.0", optional = true}
fastapi = {version = ">=0.120.1", optional = true}
Expand Down
49 changes: 48 additions & 1 deletion tests/test_litellm/litellm_core_utils/test_streaming_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import sys
import time
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import pytest

Expand Down Expand Up @@ -1185,3 +1185,50 @@ def test_is_chunk_non_empty_with_valid_tool_calls(
)
is True
)


@pytest.mark.asyncio
async def test_custom_stream_wrapper_aclose():
"""Test that aclose() delegates to the underlying completion_stream's aclose()"""
mock_stream = AsyncMock()
mock_stream.aclose = AsyncMock()

wrapper = CustomStreamWrapper(
completion_stream=mock_stream,
model=None,
logging_obj=MagicMock(),
custom_llm_provider=None,
)

await wrapper.aclose()
mock_stream.aclose.assert_awaited_once()


@pytest.mark.asyncio
async def test_custom_stream_wrapper_aclose_no_underlying():
"""Test that aclose() is safe when completion_stream has no aclose method"""
mock_stream = MagicMock(spec=[]) # No aclose attribute

wrapper = CustomStreamWrapper(
completion_stream=mock_stream,
model=None,
logging_obj=MagicMock(),
custom_llm_provider=None,
)

# Should not raise
await wrapper.aclose()


@pytest.mark.asyncio
async def test_custom_stream_wrapper_aclose_none_stream():
"""Test that aclose() is safe when completion_stream is None"""
wrapper = CustomStreamWrapper(
completion_stream=None,
model=None,
logging_obj=MagicMock(),
custom_llm_provider=None,
)

# Should not raise
await wrapper.aclose()
151 changes: 151 additions & 0 deletions tests/test_litellm/proxy/test_proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3519,6 +3519,157 @@ def test_invitation_endpoints_non_admin_denied(
assert "not allowed" in str(error_content).lower()


@pytest.mark.asyncio
async def test_async_data_generator_cleanup_on_early_exit():
"""
Test that async_data_generator calls response.aclose() in the finally block
when the generator is abandoned mid-stream (client disconnect).
"""
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import async_data_generator
from litellm.proxy.utils import ProxyLogging

mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_request_data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "test"}],
}

mock_chunks = [
{"choices": [{"delta": {"content": "Hello"}}]},
{"choices": [{"delta": {"content": " world"}}]},
{"choices": [{"delta": {"content": " more"}}]},
]

mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)

async def mock_streaming_iterator(*args, **kwargs):
for chunk in mock_chunks:
yield chunk

mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = (
mock_streaming_iterator
)
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
side_effect=lambda **kwargs: kwargs.get("response")
)
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()

# Create a mock response with aclose
mock_response = MagicMock()
mock_response.aclose = AsyncMock()

with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
# Consume only the first chunk then abandon the generator (simulates client disconnect)
gen = async_data_generator(
mock_response, mock_user_api_key_dict, mock_request_data
)
first_chunk = await gen.__anext__()
assert first_chunk.startswith("data: ")

# Close the generator early (simulates what ASGI does on client disconnect)
await gen.aclose()

# Verify aclose was called on the response to release the HTTP connection
mock_response.aclose.assert_awaited_once()


@pytest.mark.asyncio
async def test_async_data_generator_cleanup_on_normal_completion():
"""
Test that async_data_generator calls response.aclose() even on normal completion.
"""
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import async_data_generator
from litellm.proxy.utils import ProxyLogging

mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_request_data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "test"}],
}

mock_chunks = [
{"choices": [{"delta": {"content": "Hello"}}]},
]

mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)

async def mock_streaming_iterator(*args, **kwargs):
for chunk in mock_chunks:
yield chunk

mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = (
mock_streaming_iterator
)
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
side_effect=lambda **kwargs: kwargs.get("response")
)
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()

mock_response = MagicMock()
mock_response.aclose = AsyncMock()

with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
yielded_data = []
async for data in async_data_generator(
mock_response, mock_user_api_key_dict, mock_request_data
):
yielded_data.append(data)

# Should have completed normally with [DONE]
assert any("[DONE]" in d for d in yielded_data)
# aclose should still be called via finally block
mock_response.aclose.assert_awaited_once()


@pytest.mark.asyncio
async def test_async_data_generator_cleanup_on_midstream_error():
"""
Test that async_data_generator calls response.aclose() via finally block
even when an exception occurs mid-stream.
"""
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.proxy_server import async_data_generator
from litellm.proxy.utils import ProxyLogging

mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
mock_request_data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "test"}],
}

mock_proxy_logging_obj = MagicMock(spec=ProxyLogging)

async def mock_streaming_iterator_with_error(*args, **kwargs):
yield {"choices": [{"delta": {"content": "Hello"}}]}
raise RuntimeError("upstream connection reset")

mock_proxy_logging_obj.async_post_call_streaming_iterator_hook = (
mock_streaming_iterator_with_error
)
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
side_effect=lambda **kwargs: kwargs.get("response")
)
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()

mock_response = MagicMock()
mock_response.aclose = AsyncMock()

with patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj):
yielded_data = []
async for data in async_data_generator(
mock_response, mock_user_api_key_dict, mock_request_data
):
yielded_data.append(data)

# Should have yielded data chunk and then an error chunk
assert len(yielded_data) >= 2
assert any("error" in d for d in yielded_data)
# aclose must still be called via finally block despite the error
mock_response.aclose.assert_awaited_once()


# ============================================================================
# store_model_in_db DB Config Override Tests
# ============================================================================
Expand Down
Loading
Loading