Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,15 +1408,6 @@ async def stream_with_fallbacks():
async for item in model_response:
yield item
except MidStreamFallbackError as e:
# Check if fallbacks are disabled by user
if initial_kwargs.get("disable_fallbacks", False):
verbose_router_logger.info(
"Mid stream fallback disabled by user, re-raising original error"
)
if e.original_exception is not None:
raise e.original_exception
raise e

from litellm.main import stream_chunk_builder

complete_response_object = stream_chunk_builder(
Expand Down
185 changes: 0 additions & 185 deletions tests/test_litellm/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,191 +1171,6 @@ async def __anext__(self):
print("✓ Edge case tests passed!")


@pytest.mark.asyncio
async def test_acompletion_streaming_disable_fallbacks_midstream():
"""Test that disable_fallbacks=True prevents mid-stream fallback attempts."""
from unittest.mock import MagicMock

from litellm.exceptions import MidStreamFallbackError

# Set up router with fallback configuration
router = litellm.Router(
model_list=[
{
"model_name": "gpt-4",
"litellm_params": {"model": "gpt-4", "api_key": "fake-key-1"},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {"model": "gpt-3.5-turbo", "api_key": "fake-key-2"},
},
],
fallbacks=[{"gpt-4": ["gpt-3.5-turbo"]}],
set_verbose=True,
)

messages = [{"role": "user", "content": "Hello"}]

# Test 1: disable_fallbacks=True with original_exception
print("\n=== Test 1: disable_fallbacks=True with original_exception ===")

# Create an original exception to wrap
from litellm.llms.anthropic.common_utils import AnthropicError

original_error = AnthropicError(
status_code=500,
message="An unexpected error occurred while processing the response",
)

# Create MidStreamFallbackError with original_exception
error_with_original = MidStreamFallbackError(
message="Connection lost",
model="gpt-4",
llm_provider="openai",
generated_content="Hello",
original_exception=original_error,
)

class AsyncIteratorWithError:
def __init__(self, items, error_after_index, error):
self.items = items
self.index = 0
self.error_after_index = error_after_index
self.error = error
self.chunks = []
self.model = "gpt-4"
self.custom_llm_provider = "openai"
self.logging_obj = MagicMock()

def __aiter__(self):
return self

async def __anext__(self):
if self.index >= len(self.items):
raise StopAsyncIteration
if self.index == self.error_after_index:
raise self.error
item = self.items[self.index]
self.index += 1
self.chunks.append(item)
return item

mock_chunks = [
MagicMock(choices=[MagicMock(delta=MagicMock(content="Hello"))]),
]

mock_error_response = AsyncIteratorWithError(
mock_chunks, 1, error_with_original
) # Error after first chunk

initial_kwargs = {"model": "gpt-4", "stream": True, "disable_fallbacks": True}

# Mock the fallback function to ensure it's NOT called
with patch.object(
router,
"async_function_with_fallbacks_common_utils",
return_value=MagicMock(),
) as mock_fallback_utils:
with pytest.raises(AnthropicError, match="An unexpected error occurred"):
result = await router._acompletion_streaming_iterator(
model_response=mock_error_response,
messages=messages,
initial_kwargs=initial_kwargs,
)

async for chunk in result:
pass # Should not reach here; exception should be raised

# Verify fallback was NOT called
mock_fallback_utils.assert_not_called()
print("✓ Original exception raised correctly when disable_fallbacks=True")

# Test 2: disable_fallbacks=True without original_exception
print("\n=== Test 2: disable_fallbacks=True without original_exception ===")

error_without_original = MidStreamFallbackError(
message="Connection lost",
model="gpt-4",
llm_provider="openai",
generated_content="Hello",
original_exception=None,
)

mock_error_response_2 = AsyncIteratorWithError(
mock_chunks, 1, error_without_original
)

with patch.object(
router,
"async_function_with_fallbacks_common_utils",
return_value=MagicMock(),
) as mock_fallback_utils:
with pytest.raises(MidStreamFallbackError, match="Connection lost"):
result = await router._acompletion_streaming_iterator(
model_response=mock_error_response_2,
messages=messages,
initial_kwargs=initial_kwargs,
)

async for chunk in result:
pass # Should not reach here

# Verify fallback was NOT called
mock_fallback_utils.assert_not_called()
print(
"✓ MidStreamFallbackError raised correctly when no original_exception and disable_fallbacks=True"
)

# Test 3: disable_fallbacks=False (default behavior - fallback should work)
print("\n=== Test 3: disable_fallbacks=False (fallback enabled) ===")

error_for_fallback = MidStreamFallbackError(
message="Connection lost",
model="gpt-4",
llm_provider="openai",
generated_content="Hello",
)

mock_error_response_3 = AsyncIteratorWithError(mock_chunks, 1, error_for_fallback)

# Mock successful fallback response
class EmptyAsyncIterator:
def __aiter__(self):
return self

async def __anext__(self):
raise StopAsyncIteration

mock_fallback_response = EmptyAsyncIterator()

initial_kwargs_fallback_enabled = {
"model": "gpt-4",
"stream": True,
"disable_fallbacks": False,
}

with patch.object(
router,
"async_function_with_fallbacks_common_utils",
return_value=mock_fallback_response,
) as mock_fallback_utils:
collected_chunks = []
result = await router._acompletion_streaming_iterator(
model_response=mock_error_response_3,
messages=messages,
initial_kwargs=initial_kwargs_fallback_enabled,
)

async for chunk in result:
collected_chunks.append(chunk)

# Verify fallback WAS called
assert mock_fallback_utils.called
print("✓ Fallback called correctly when disable_fallbacks=False")

print("\n=== All disable_fallbacks tests passed! ===")


@pytest.mark.asyncio
async def test_async_function_with_fallbacks_common_utils():
"""Test the async_function_with_fallbacks_common_utils method"""
Expand Down
Loading