diff --git a/litellm/router.py b/litellm/router.py index 8a1ac8c07f9..0b07d5ed8c1 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1408,15 +1408,6 @@ async def stream_with_fallbacks(): async for item in model_response: yield item except MidStreamFallbackError as e: - # Check if fallbacks are disabled by user - if initial_kwargs.get("disable_fallbacks", False): - verbose_router_logger.info( - "Mid stream fallback disabled by user, re-raising original error" - ) - if e.original_exception is not None: - raise e.original_exception - raise e - from litellm.main import stream_chunk_builder complete_response_object = stream_chunk_builder( diff --git a/tests/test_litellm/test_router.py b/tests/test_litellm/test_router.py index 6279e96305f..08ae804ea80 100644 --- a/tests/test_litellm/test_router.py +++ b/tests/test_litellm/test_router.py @@ -1171,191 +1171,6 @@ async def __anext__(self): print("✓ Edge case tests passed!") -@pytest.mark.asyncio -async def test_acompletion_streaming_disable_fallbacks_midstream(): - """Test that disable_fallbacks=True prevents mid-stream fallback attempts.""" - from unittest.mock import MagicMock - - from litellm.exceptions import MidStreamFallbackError - - # Set up router with fallback configuration - router = litellm.Router( - model_list=[ - { - "model_name": "gpt-4", - "litellm_params": {"model": "gpt-4", "api_key": "fake-key-1"}, - }, - { - "model_name": "gpt-3.5-turbo", - "litellm_params": {"model": "gpt-3.5-turbo", "api_key": "fake-key-2"}, - }, - ], - fallbacks=[{"gpt-4": ["gpt-3.5-turbo"]}], - set_verbose=True, - ) - - messages = [{"role": "user", "content": "Hello"}] - - # Test 1: disable_fallbacks=True with original_exception - print("\n=== Test 1: disable_fallbacks=True with original_exception ===") - - # Create an original exception to wrap - from litellm.llms.anthropic.common_utils import AnthropicError - - original_error = AnthropicError( - status_code=500, - message="An unexpected error occurred while processing the response", - ) - - # Create MidStreamFallbackError with original_exception - error_with_original = MidStreamFallbackError( - message="Connection lost", - model="gpt-4", - llm_provider="openai", - generated_content="Hello", - original_exception=original_error, - ) - - class AsyncIteratorWithError: - def __init__(self, items, error_after_index, error): - self.items = items - self.index = 0 - self.error_after_index = error_after_index - self.error = error - self.chunks = [] - self.model = "gpt-4" - self.custom_llm_provider = "openai" - self.logging_obj = MagicMock() - - def __aiter__(self): - return self - - async def __anext__(self): - if self.index >= len(self.items): - raise StopAsyncIteration - if self.index == self.error_after_index: - raise self.error - item = self.items[self.index] - self.index += 1 - self.chunks.append(item) - return item - - mock_chunks = [ - MagicMock(choices=[MagicMock(delta=MagicMock(content="Hello"))]), - ] - - mock_error_response = AsyncIteratorWithError( - mock_chunks, 1, error_with_original - ) # Error after first chunk - - initial_kwargs = {"model": "gpt-4", "stream": True, "disable_fallbacks": True} - - # Mock the fallback function to ensure it's NOT called - with patch.object( - router, - "async_function_with_fallbacks_common_utils", - return_value=MagicMock(), - ) as mock_fallback_utils: - with pytest.raises(AnthropicError, match="An unexpected error occurred"): - result = await router._acompletion_streaming_iterator( - model_response=mock_error_response, - messages=messages, - initial_kwargs=initial_kwargs, - ) - - async for chunk in result: - pass # Should not reach here; exception should be raised - - # Verify fallback was NOT called - mock_fallback_utils.assert_not_called() - print("✓ Original exception raised correctly when disable_fallbacks=True") - - # Test 2: disable_fallbacks=True without original_exception - print("\n=== Test 2: disable_fallbacks=True without original_exception ===") - - error_without_original = MidStreamFallbackError( - message="Connection lost", - model="gpt-4", - llm_provider="openai", - generated_content="Hello", - original_exception=None, - ) - - mock_error_response_2 = AsyncIteratorWithError( - mock_chunks, 1, error_without_original - ) - - with patch.object( - router, - "async_function_with_fallbacks_common_utils", - return_value=MagicMock(), - ) as mock_fallback_utils: - with pytest.raises(MidStreamFallbackError, match="Connection lost"): - result = await router._acompletion_streaming_iterator( - model_response=mock_error_response_2, - messages=messages, - initial_kwargs=initial_kwargs, - ) - - async for chunk in result: - pass # Should not reach here - - # Verify fallback was NOT called - mock_fallback_utils.assert_not_called() - print( - "✓ MidStreamFallbackError raised correctly when no original_exception and disable_fallbacks=True" - ) - - # Test 3: disable_fallbacks=False (default behavior - fallback should work) - print("\n=== Test 3: disable_fallbacks=False (fallback enabled) ===") - - error_for_fallback = MidStreamFallbackError( - message="Connection lost", - model="gpt-4", - llm_provider="openai", - generated_content="Hello", - ) - - mock_error_response_3 = AsyncIteratorWithError(mock_chunks, 1, error_for_fallback) - - # Mock successful fallback response - class EmptyAsyncIterator: - def __aiter__(self): - return self - - async def __anext__(self): - raise StopAsyncIteration - - mock_fallback_response = EmptyAsyncIterator() - - initial_kwargs_fallback_enabled = { - "model": "gpt-4", - "stream": True, - "disable_fallbacks": False, - } - - with patch.object( - router, - "async_function_with_fallbacks_common_utils", - return_value=mock_fallback_response, - ) as mock_fallback_utils: - collected_chunks = [] - result = await router._acompletion_streaming_iterator( - model_response=mock_error_response_3, - messages=messages, - initial_kwargs=initial_kwargs_fallback_enabled, - ) - - async for chunk in result: - collected_chunks.append(chunk) - - # Verify fallback WAS called - assert mock_fallback_utils.called - print("✓ Fallback called correctly when disable_fallbacks=False") - - print("\n=== All disable_fallbacks tests passed! ===") - - @pytest.mark.asyncio async def test_async_function_with_fallbacks_common_utils(): """Test the async_function_with_fallbacks_common_utils method"""