From eecbc5e2b5e6227264013c1439d9da3b8c78594b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 30 Jul 2025 10:35:05 -0700 Subject: [PATCH 01/14] fix(router.py): add acompletion_streaming_iterator inside router allows router to catch errors mid-stream for fallbacks Work for https://github.com/BerriAI/litellm/issues/6532 --- litellm/proxy/_new_secret_config.yaml | 13 +++++++++---- litellm/router.py | 21 +++++++++++++++++++-- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 8a63ed7b1f7..abac8253cd7 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,9 +1,14 @@ model_list: - - model_name: genai/test/* + - model_name: fake-anthropic-server litellm_params: - model: openai/* - api_base: https://api.openai.com + model: anthropic/my-fake-anthropic-server + api_base: http://localhost:8090 + api_key: os.environ/ANTHROPIC_API_KEY + - model_name: fake-openai-server + litellm_params: + model: openai/my-fake-openai-server + api_base: http://localhost:8090 api_key: os.environ/OPENAI_API_KEY litellm_settings: - check_provider_endpoint: true \ No newline at end of file + fallbacks: [{"fake-anthropic-server": ["fake-openai-server"]}] \ No newline at end of file diff --git a/litellm/router.py b/litellm/router.py index d02392c59f2..86cc2342914 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -23,6 +23,7 @@ from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, Callable, Dict, List, @@ -146,7 +147,7 @@ from litellm.types.utils import GenericBudgetConfigType, LiteLLMBatch from litellm.types.utils import ModelInfo from litellm.types.utils import ModelInfo as ModelMapInfo -from litellm.types.utils import StandardLoggingPayload +from litellm.types.utils import ModelResponseStream, StandardLoggingPayload from litellm.utils import ( CustomStreamWrapper, EmbeddingResponse, @@ -1078,9 +1079,22 @@ async def acompletion( ) raise e + async def _acompletion_streaming_iterator( + self, model_response: CustomStreamWrapper + ) -> AsyncGenerator[ModelResponseStream, None]: + """ + Helper to iterate over a streaming response. + + Catches errors for fallbacks + """ + async for item in model_response: + yield item + async def _acompletion( self, model: str, messages: List[Dict[str, str]], **kwargs - ) -> Union[ModelResponse, CustomStreamWrapper]: + ) -> Union[ + ModelResponse, CustomStreamWrapper, AsyncGenerator[ModelResponseStream, None] + ]: """ - Get an available deployment - call it with a semaphore over the call @@ -1199,6 +1213,9 @@ async def _acompletion( parent_otel_span=parent_otel_span, ) + if isinstance(response, CustomStreamWrapper): + return self._acompletion_streaming_iterator(model_response=response) + return response except litellm.Timeout as e: deployment_request_timeout_param = _timeout_debug_deployment_dict.get( From 298233cab30ec394f3d7a7387e6f20a5e9a8a030 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 30 Jul 2025 11:20:06 -0700 Subject: [PATCH 02/14] fix(router.py): working mid-stream fallbacks --- litellm/exceptions.py | 62 +++++++++++++++ .../litellm_core_utils/streaming_handler.py | 30 ++++--- litellm/router.py | 79 ++++++++++++++----- 3 files changed, 141 insertions(+), 30 deletions(-) diff --git a/litellm/exceptions.py b/litellm/exceptions.py index 9f3411143a6..153230518cc 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -829,3 +829,65 @@ def __init__( self.guardrail_name = guardrail_name self.message = f"Blocked entity detected: {entity_type} by Guardrail: {guardrail_name}. This entity is not allowed to be used in this request." super().__init__(self.message) + + +class MidStreamFallbackError(ServiceUnavailableError): # type: ignore + def __init__( + self, + message: str, + model: str, + llm_provider: str, + original_exception: Optional[Exception] = None, + response: Optional[httpx.Response] = None, + litellm_debug_info: Optional[str] = None, + max_retries: Optional[int] = None, + num_retries: Optional[int] = None, + generated_content: str = "", + is_pre_first_chunk: bool = False, + ): + self.status_code = 503 # Service Unavailable + self.message = f"litellm.MidStreamFallbackError: {message}" + self.model = model + self.llm_provider = llm_provider + self.original_exception = original_exception + self.litellm_debug_info = litellm_debug_info + self.max_retries = max_retries + self.num_retries = num_retries + self.generated_content = generated_content + self.is_pre_first_chunk = is_pre_first_chunk + + # Create a response if one wasn't provided + if response is None: + self.response = httpx.Response( + status_code=self.status_code, + request=httpx.Request( + method="POST", + url=f"https://{llm_provider}.com/v1/", + ), + ) + else: + self.response = response + + # Call the parent constructor + super().__init__( + message=self.message, + llm_provider=llm_provider, + model=model, + response=self.response, + litellm_debug_info=self.litellm_debug_info, + max_retries=self.max_retries, + num_retries=self.num_retries, + ) + + def __str__(self): + _message = self.message + if self.num_retries: + _message += f" LiteLLM Retried: {self.num_retries} times" + if self.max_retries: + _message += f", LiteLLM Max Retries: {self.max_retries}" + if self.original_exception: + _message += f" Original exception: {type(self.original_exception).__name__}: {str(self.original_exception)}" + return _message + + def __repr__(self): + return self.__str__() diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index 17f5a8d1deb..2e9e6770a1d 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -940,8 +940,8 @@ def _optional_combine_thinking_block_in_choices( and not self.sent_last_thinking_block and model_response.choices[0].delta.content ): - model_response.choices[0].delta.content = ( - "" + (model_response.choices[0].delta.content or "") + model_response.choices[0].delta.content = "" + ( + model_response.choices[0].delta.content or "" ) self.sent_last_thinking_block = True @@ -1841,13 +1841,25 @@ async def __anext__(self): # noqa: PLR0915 self.logging_obj.async_failure_handler(e, traceback_exception) # type: ignore ) ## Map to OpenAI Exception - raise exception_type( - model=self.model, - custom_llm_provider=self.custom_llm_provider, - original_exception=e, - completion_kwargs={}, - extra_kwargs={}, - ) + try: + exception_type( + model=self.model, + custom_llm_provider=self.custom_llm_provider, + original_exception=e, + completion_kwargs={}, + extra_kwargs={}, + ) + except Exception as e: + from litellm.exceptions import MidStreamFallbackError + + raise MidStreamFallbackError( + message=str(e), + model=self.model, + llm_provider=self.custom_llm_provider or "anthropic", + original_exception=e, + generated_content=self.response_uptil_now, + is_pre_first_chunk=not self.sent_first_chunk, + ) @staticmethod def _strip_sse_data_from_chunk(chunk: Optional[str]) -> Optional[str]: diff --git a/litellm/router.py b/litellm/router.py index 86cc2342914..87cae1cb458 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1080,20 +1080,55 @@ async def acompletion( raise e async def _acompletion_streaming_iterator( - self, model_response: CustomStreamWrapper - ) -> AsyncGenerator[ModelResponseStream, None]: + self, model_response: CustomStreamWrapper, input_kwargs: dict + ) -> AsyncGenerator[Optional[ModelResponseStream]]: """ Helper to iterate over a streaming response. Catches errors for fallbacks """ - async for item in model_response: - yield item + from litellm.exceptions import MidStreamFallbackError + + try: + async for item in model_response: + yield item + except MidStreamFallbackError as e: + # Prepare fallback request + fallback_kwargs = input_kwargs.copy() + fallback_kwargs["model"] = "gpt-3.5-turbo" + fallback_kwargs["messages"] = fallback_kwargs["messages"] + [ + { + "role": "user", + "content": f"The following is the beginning of an assistant's response. Continue from where it left off: {e.generated_content}", + } + ] + + # Ensure streaming is enabled for fallback + fallback_kwargs["stream"] = True + + try: + # Make fallback call + fallback_response = await litellm.acompletion(**fallback_kwargs) + + # If fallback returns a streaming response, iterate over it + if hasattr(fallback_response, "__aiter__"): + async for fallback_item in fallback_response: # type: ignore + yield fallback_item + else: + # If fallback returns a non-streaming response, yield None + yield None + + except Exception as fallback_error: + # If fallback also fails, log and re-raise original error + verbose_router_logger.error(f"Fallback also failed: {fallback_error}") + raise e # Re-raise the original error async def _acompletion( self, model: str, messages: List[Dict[str, str]], **kwargs ) -> Union[ - ModelResponse, CustomStreamWrapper, AsyncGenerator[ModelResponseStream, None] + ModelResponse, + CustomStreamWrapper, + AsyncGenerator[Optional[ModelResponseStream]], ]: """ - Get an available deployment @@ -1148,15 +1183,15 @@ async def _acompletion( ) self.total_calls[model_name] += 1 - _response = litellm.acompletion( - **{ - **data, - "messages": messages, - "caching": self.cache_responses, - "client": model_client, - **kwargs, - } - ) + input_kwargs = { + **data, + "messages": messages, + "caching": self.cache_responses, + "client": model_client, + **kwargs, + } + + _response = litellm.acompletion(**input_kwargs) logging_obj: Optional[LiteLLMLogging] = kwargs.get( "litellm_logging_obj", None @@ -1214,7 +1249,9 @@ async def _acompletion( ) if isinstance(response, CustomStreamWrapper): - return self._acompletion_streaming_iterator(model_response=response) + return self._acompletion_streaming_iterator( + model_response=response, input_kwargs=input_kwargs + ) return response except litellm.Timeout as e: @@ -3615,18 +3652,18 @@ async def async_function_with_fallbacks(self, *args, **kwargs): # noqa: PLR0915 if isinstance(e, litellm.ContextWindowExceededError): if context_window_fallbacks is not None: - fallback_model_group: Optional[List[str]] = ( + context_window_fallback_model_group: Optional[List[str]] = ( self._get_fallback_model_group_from_fallbacks( fallbacks=context_window_fallbacks, model_group=model_group, ) ) - if fallback_model_group is None: + if context_window_fallback_model_group is None: raise original_exception input_kwargs.update( { - "fallback_model_group": fallback_model_group, + "fallback_model_group": context_window_fallback_model_group, "original_model_group": original_model_group, } ) @@ -3651,18 +3688,18 @@ async def async_function_with_fallbacks(self, *args, **kwargs): # noqa: PLR0915 e.message += "\n{}".format(error_message) elif isinstance(e, litellm.ContentPolicyViolationError): if content_policy_fallbacks is not None: - fallback_model_group: Optional[List[str]] = ( + content_policy_fallback_model_group: Optional[List[str]] = ( self._get_fallback_model_group_from_fallbacks( fallbacks=content_policy_fallbacks, model_group=model_group, ) ) - if fallback_model_group is None: + if content_policy_fallback_model_group is None: raise original_exception input_kwargs.update( { - "fallback_model_group": fallback_model_group, + "fallback_model_group": content_policy_fallback_model_group, "original_model_group": original_model_group, } ) From 16bbb3841648a537265937ce3699b917b0656c4f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 30 Jul 2025 11:32:04 -0700 Subject: [PATCH 03/14] fix(router.py): more iterations --- litellm/router.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 87cae1cb458..84ba8c5da1f 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1081,7 +1081,7 @@ async def acompletion( async def _acompletion_streaming_iterator( self, model_response: CustomStreamWrapper, input_kwargs: dict - ) -> AsyncGenerator[Optional[ModelResponseStream]]: + ) -> AsyncGenerator[Optional[ModelResponseStream], None]: """ Helper to iterate over a streaming response. @@ -1094,12 +1094,16 @@ async def _acompletion_streaming_iterator( yield item except MidStreamFallbackError as e: # Prepare fallback request + litellm._turn_on_debug() fallback_kwargs = input_kwargs.copy() - fallback_kwargs["model"] = "gpt-3.5-turbo" + fallback_kwargs["model"] = "gpt-4o-mini" + fallback_kwargs["api_base"] = None + fallback_kwargs["api_key"] = None fallback_kwargs["messages"] = fallback_kwargs["messages"] + [ { - "role": "user", - "content": f"The following is the beginning of an assistant's response. Continue from where it left off: {e.generated_content}", + "role": "assistant", + "content": e.generated_content, + "prefix": True, } ] @@ -1128,7 +1132,7 @@ async def _acompletion( ) -> Union[ ModelResponse, CustomStreamWrapper, - AsyncGenerator[Optional[ModelResponseStream]], + AsyncGenerator[Optional[ModelResponseStream], None], ]: """ - Get an available deployment From 11a313cf8672e7d1a462a0240a6b1588aecfb2da Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 30 Jul 2025 12:13:16 -0700 Subject: [PATCH 04/14] fix(router.py): working mid-stream fallbacks with fallbacks set on router --- litellm/proxy/_new_secret_config.yaml | 6 +- litellm/router.py | 427 ++++++++++++++------------ 2 files changed, 238 insertions(+), 195 deletions(-) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index abac8253cd7..ca3acad6c24 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -9,6 +9,10 @@ model_list: model: openai/my-fake-openai-server api_base: http://localhost:8090 api_key: os.environ/OPENAI_API_KEY + - model_name: real-openai-server + litellm_params: + model: openai/gpt-4o-mini + api_key: os.environ/OPENAI_API_KEY litellm_settings: - fallbacks: [{"fake-anthropic-server": ["fake-openai-server"]}] \ No newline at end of file + fallbacks: [{"fake-anthropic-server": ["real-openai-server"]}] \ No newline at end of file diff --git a/litellm/router.py b/litellm/router.py index 84ba8c5da1f..9a2e3bc79fa 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1080,12 +1080,15 @@ async def acompletion( raise e async def _acompletion_streaming_iterator( - self, model_response: CustomStreamWrapper, input_kwargs: dict + self, + model_response: CustomStreamWrapper, + messages: List[Dict[str, str]], + initial_kwargs: dict, ) -> AsyncGenerator[Optional[ModelResponseStream], None]: """ Helper to iterate over a streaming response. - Catches errors for fallbacks + Catches errors for fallbacks using the router's fallback system """ from litellm.exceptions import MidStreamFallbackError @@ -1093,26 +1096,35 @@ async def _acompletion_streaming_iterator( async for item in model_response: yield item except MidStreamFallbackError as e: - # Prepare fallback request - litellm._turn_on_debug() - fallback_kwargs = input_kwargs.copy() - fallback_kwargs["model"] = "gpt-4o-mini" - fallback_kwargs["api_base"] = None - fallback_kwargs["api_key"] = None - fallback_kwargs["messages"] = fallback_kwargs["messages"] + [ - { - "role": "assistant", - "content": e.generated_content, - "prefix": True, - } - ] - - # Ensure streaming is enabled for fallback - fallback_kwargs["stream"] = True - try: - # Make fallback call - fallback_response = await litellm.acompletion(**fallback_kwargs) + # Use the router's fallback system + model_group = cast(str, initial_kwargs.get("model")) + fallbacks: Optional[List] = initial_kwargs.get( + "fallbacks", self.fallbacks + ) + context_window_fallbacks: Optional[List] = initial_kwargs.get( + "context_window_fallbacks", self.context_window_fallbacks + ) + content_policy_fallbacks: Optional[List] = initial_kwargs.get( + "content_policy_fallbacks", self.content_policy_fallbacks + ) + initial_kwargs["original_function"] = self._acompletion + initial_kwargs["messages"] = messages + self._update_kwargs_before_fallbacks( + model=model_group, kwargs=initial_kwargs + ) + fallback_response = ( + await self.async_function_with_fallbacks_common_utils( + e=e, + disable_fallbacks=False, + fallbacks=fallbacks, + context_window_fallbacks=context_window_fallbacks, + content_policy_fallbacks=content_policy_fallbacks, + model_group=model_group, + args=(), + kwargs=initial_kwargs, + ) + ) # If fallback returns a streaming response, iterate over it if hasattr(fallback_response, "__aiter__"): @@ -1145,9 +1157,12 @@ async def _acompletion( {} ) # this is a temporary dict to debug timeout issues try: + input_kwargs_for_streaming_fallback = kwargs.copy() + input_kwargs_for_streaming_fallback["model"] = model verbose_router_logger.debug( f"Inside _acompletion()- model: {model}; kwargs: {kwargs}" ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) start_time = time.time() deployment = await self.async_get_available_deployment( @@ -1254,7 +1269,9 @@ async def _acompletion( if isinstance(response, CustomStreamWrapper): return self._acompletion_streaming_iterator( - model_response=response, input_kwargs=input_kwargs + model_response=response, + messages=messages, + initial_kwargs=input_kwargs_for_streaming_fallback, ) return response @@ -3571,78 +3588,77 @@ async def _pass_through_assistants_endpoint_factory( #### [END] ASSISTANTS API #### - @tracer.wrap() - async def async_function_with_fallbacks(self, *args, **kwargs): # noqa: PLR0915 + async def async_function_with_fallbacks_common_utils( # noqa: PLR0915 + self, + e: Exception, + disable_fallbacks: Optional[bool], + fallbacks: Optional[List], + context_window_fallbacks: Optional[List], + content_policy_fallbacks: Optional[List], + model_group: Optional[str], + args: tuple, + kwargs: dict, + ): """ - Try calling the function_with_retries - If it fails after num_retries, fall back to another model group + Common utilities for async_function_with_fallbacks """ - model_group: Optional[str] = kwargs.get("model") - disable_fallbacks: Optional[bool] = kwargs.pop("disable_fallbacks", False) - fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks) - context_window_fallbacks: Optional[List] = kwargs.get( - "context_window_fallbacks", self.context_window_fallbacks - ) - content_policy_fallbacks: Optional[List] = kwargs.get( - "content_policy_fallbacks", self.content_policy_fallbacks - ) + verbose_router_logger.debug(f"Traceback{traceback.format_exc()}") + original_exception = e + fallback_model_group = None + original_model_group: Optional[str] = kwargs.get("model") # type: ignore + fallback_failure_exception_str = "" - mock_timeout = kwargs.pop("mock_timeout", None) + if disable_fallbacks is True or original_model_group is None: + raise e + + input_kwargs = { + "litellm_router": self, + "original_exception": original_exception, + **kwargs, + } + + if "max_fallbacks" not in input_kwargs: + input_kwargs["max_fallbacks"] = self.max_fallbacks + if "fallback_depth" not in input_kwargs: + input_kwargs["fallback_depth"] = 0 try: - self._handle_mock_testing_fallbacks( - kwargs=kwargs, - model_group=model_group, - fallbacks=fallbacks, - context_window_fallbacks=context_window_fallbacks, - content_policy_fallbacks=content_policy_fallbacks, - ) + verbose_router_logger.info("Trying to fallback b/w models") - if mock_timeout is not None: - response = await self.async_function_with_retries( - *args, **kwargs, mock_timeout=mock_timeout - ) - else: - response = await self.async_function_with_retries(*args, **kwargs) - verbose_router_logger.debug(f"Async Response: {response}") - response = add_fallback_headers_to_response( - response=response, - attempted_fallbacks=0, + # check if client-side fallbacks are used (e.g. fallbacks = ["gpt-3.5-turbo", "claude-3-haiku"] or fallbacks=[{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey, how's it going?"}]}] + is_non_standard_fallback_format = _check_non_standard_fallback_format( + fallbacks=fallbacks ) - return response - except Exception as e: - verbose_router_logger.debug(f"Traceback{traceback.format_exc()}") - original_exception = e - fallback_model_group = None - original_model_group: Optional[str] = kwargs.get("model") # type: ignore - fallback_failure_exception_str = "" - - if disable_fallbacks is True or original_model_group is None: - raise e - input_kwargs = { - "litellm_router": self, - "original_exception": original_exception, - **kwargs, - } + if is_non_standard_fallback_format: + input_kwargs.update( + { + "fallback_model_group": fallbacks, + "original_model_group": original_model_group, + } + ) - if "max_fallbacks" not in input_kwargs: - input_kwargs["max_fallbacks"] = self.max_fallbacks - if "fallback_depth" not in input_kwargs: - input_kwargs["fallback_depth"] = 0 + response = await run_async_fallback( + *args, + **input_kwargs, + ) - try: - verbose_router_logger.info("Trying to fallback b/w models") + return response - # check if client-side fallbacks are used (e.g. fallbacks = ["gpt-3.5-turbo", "claude-3-haiku"] or fallbacks=[{"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey, how's it going?"}]}] - is_non_standard_fallback_format = _check_non_standard_fallback_format( - fallbacks=fallbacks - ) + if isinstance(e, litellm.ContextWindowExceededError): + if context_window_fallbacks is not None: + context_window_fallback_model_group: Optional[List[str]] = ( + self._get_fallback_model_group_from_fallbacks( + fallbacks=context_window_fallbacks, + model_group=model_group, + ) + ) + if context_window_fallback_model_group is None: + raise original_exception - if is_non_standard_fallback_format: input_kwargs.update( { - "fallback_model_group": fallbacks, + "fallback_model_group": context_window_fallback_model_group, "original_model_group": original_model_group, } ) @@ -3651,107 +3667,34 @@ async def async_function_with_fallbacks(self, *args, **kwargs): # noqa: PLR0915 *args, **input_kwargs, ) - return response - if isinstance(e, litellm.ContextWindowExceededError): - if context_window_fallbacks is not None: - context_window_fallback_model_group: Optional[List[str]] = ( - self._get_fallback_model_group_from_fallbacks( - fallbacks=context_window_fallbacks, - model_group=model_group, - ) - ) - if context_window_fallback_model_group is None: - raise original_exception - - input_kwargs.update( - { - "fallback_model_group": context_window_fallback_model_group, - "original_model_group": original_model_group, - } - ) - - response = await run_async_fallback( - *args, - **input_kwargs, - ) - return response - - else: - error_message = "model={}. context_window_fallbacks={}. fallbacks={}.\n\nSet 'context_window_fallback' - https://docs.litellm.ai/docs/routing#fallbacks".format( - model_group, context_window_fallbacks, fallbacks - ) - verbose_router_logger.info( - msg="Got 'ContextWindowExceededError'. No context_window_fallback set. Defaulting \ - to fallbacks, if available.{}".format( - error_message - ) - ) - - e.message += "\n{}".format(error_message) - elif isinstance(e, litellm.ContentPolicyViolationError): - if content_policy_fallbacks is not None: - content_policy_fallback_model_group: Optional[List[str]] = ( - self._get_fallback_model_group_from_fallbacks( - fallbacks=content_policy_fallbacks, - model_group=model_group, - ) - ) - if content_policy_fallback_model_group is None: - raise original_exception - - input_kwargs.update( - { - "fallback_model_group": content_policy_fallback_model_group, - "original_model_group": original_model_group, - } - ) - - response = await run_async_fallback( - *args, - **input_kwargs, - ) - return response - else: - error_message = "model={}. content_policy_fallback={}. fallbacks={}.\n\nSet 'content_policy_fallback' - https://docs.litellm.ai/docs/routing#fallbacks".format( - model_group, content_policy_fallbacks, fallbacks - ) - verbose_router_logger.info( - msg="Got 'ContentPolicyViolationError'. No content_policy_fallback set. Defaulting \ - to fallbacks, if available.{}".format( - error_message - ) + else: + error_message = "model={}. context_window_fallbacks={}. fallbacks={}.\n\nSet 'context_window_fallback' - https://docs.litellm.ai/docs/routing#fallbacks".format( + model_group, context_window_fallbacks, fallbacks + ) + verbose_router_logger.info( + msg="Got 'ContextWindowExceededError'. No context_window_fallback set. Defaulting \ + to fallbacks, if available.{}".format( + error_message ) - - e.message += "\n{}".format(error_message) - if fallbacks is not None and model_group is not None: - verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") - ( - fallback_model_group, - generic_fallback_idx, - ) = get_fallback_model_group( - fallbacks=fallbacks, # if fallbacks = [{"gpt-3.5-turbo": ["claude-3-haiku"]}] - model_group=cast(str, model_group), ) - ## if none, check for generic fallback - if ( - fallback_model_group is None - and generic_fallback_idx is not None - ): - fallback_model_group = fallbacks[generic_fallback_idx]["*"] - if fallback_model_group is None: - verbose_router_logger.info( - f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" + e.message += "\n{}".format(error_message) + elif isinstance(e, litellm.ContentPolicyViolationError): + if content_policy_fallbacks is not None: + content_policy_fallback_model_group: Optional[List[str]] = ( + self._get_fallback_model_group_from_fallbacks( + fallbacks=content_policy_fallbacks, + model_group=model_group, ) - if hasattr(original_exception, "message"): - original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" # type: ignore + ) + if content_policy_fallback_model_group is None: raise original_exception input_kwargs.update( { - "fallback_model_group": fallback_model_group, + "fallback_model_group": content_policy_fallback_model_group, "original_model_group": original_model_group, } ) @@ -3760,36 +3703,132 @@ async def async_function_with_fallbacks(self, *args, **kwargs): # noqa: PLR0915 *args, **input_kwargs, ) - return response - except Exception as new_exception: - parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) - verbose_router_logger.error( - "litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format( - str(new_exception), - traceback.format_exc(), - await _async_get_cooldown_deployments_with_debug_info( - litellm_router_instance=self, - parent_otel_span=parent_otel_span, - ), + else: + error_message = "model={}. content_policy_fallback={}. fallbacks={}.\n\nSet 'content_policy_fallback' - https://docs.litellm.ai/docs/routing#fallbacks".format( + model_group, content_policy_fallbacks, fallbacks + ) + verbose_router_logger.info( + msg="Got 'ContentPolicyViolationError'. No content_policy_fallback set. Defaulting \ + to fallbacks, if available.{}".format( + error_message + ) ) - ) - fallback_failure_exception_str = str(new_exception) - if hasattr(original_exception, "message"): - # add the available fallbacks to the exception - original_exception.message += ". Received Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore - model_group, + e.message += "\n{}".format(error_message) + if fallbacks is not None and model_group is not None: + verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}") + ( fallback_model_group, + generic_fallback_idx, + ) = get_fallback_model_group( + fallbacks=fallbacks, # if fallbacks = [{"gpt-3.5-turbo": ["claude-3-haiku"]}] + model_group=cast(str, model_group), ) - if len(fallback_failure_exception_str) > 0: - original_exception.message += ( # type: ignore - "\nError doing the fallback: {}".format( - fallback_failure_exception_str - ) + ## if none, check for generic fallback + if fallback_model_group is None and generic_fallback_idx is not None: + fallback_model_group = fallbacks[generic_fallback_idx]["*"] + + if fallback_model_group is None: + verbose_router_logger.info( + f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" ) + if hasattr(original_exception, "message"): + original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" # type: ignore + raise original_exception + + input_kwargs.update( + { + "fallback_model_group": fallback_model_group, + "original_model_group": original_model_group, + } + ) - raise original_exception + response = await run_async_fallback( + *args, + **input_kwargs, + ) + + return response + except Exception as new_exception: + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) + verbose_router_logger.error( + "litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format( + str(new_exception), + traceback.format_exc(), + await _async_get_cooldown_deployments_with_debug_info( + litellm_router_instance=self, + parent_otel_span=parent_otel_span, + ), + ) + ) + fallback_failure_exception_str = str(new_exception) + + if hasattr(original_exception, "message"): + # add the available fallbacks to the exception + original_exception.message += ". Received Model Group={}\nAvailable Model Group Fallbacks={}".format( # type: ignore + model_group, + fallback_model_group, + ) + if len(fallback_failure_exception_str) > 0: + original_exception.message += ( # type: ignore + "\nError doing the fallback: {}".format( + fallback_failure_exception_str + ) + ) + + raise original_exception + + @tracer.wrap() + async def async_function_with_fallbacks(self, *args, **kwargs): + """ + Try calling the function_with_retries + If it fails after num_retries, fall back to another model group + """ + model_group: Optional[str] = kwargs.get("model") + disable_fallbacks: Optional[bool] = kwargs.pop("disable_fallbacks", False) + fallbacks: Optional[List] = kwargs.get("fallbacks", self.fallbacks) + context_window_fallbacks: Optional[List] = kwargs.get( + "context_window_fallbacks", self.context_window_fallbacks + ) + content_policy_fallbacks: Optional[List] = kwargs.get( + "content_policy_fallbacks", self.content_policy_fallbacks + ) + + mock_timeout = kwargs.pop("mock_timeout", None) + + try: + self._handle_mock_testing_fallbacks( + kwargs=kwargs, + model_group=model_group, + fallbacks=fallbacks, + context_window_fallbacks=context_window_fallbacks, + content_policy_fallbacks=content_policy_fallbacks, + ) + + if mock_timeout is not None: + response = await self.async_function_with_retries( + *args, **kwargs, mock_timeout=mock_timeout + ) + else: + response = await self.async_function_with_retries(*args, **kwargs) + verbose_router_logger.debug(f"Async Response: {response}") + response = add_fallback_headers_to_response( + response=response, + attempted_fallbacks=0, + ) + return response + except Exception as e: + return await self.async_function_with_fallbacks_common_utils( + e, + disable_fallbacks, + fallbacks, + context_window_fallbacks, + content_policy_fallbacks, + model_group, + args, + kwargs, + ) def _handle_mock_testing_fallbacks( self, From acbc1127b14dd4df8149da937e931c2c85786c15 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 30 Jul 2025 12:28:23 -0700 Subject: [PATCH 05/14] fix(router.py): pass prior content back in new request as assistant prefix message --- litellm/router.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/litellm/router.py b/litellm/router.py index 9a2e3bc79fa..f850a3d0e26 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1109,7 +1109,13 @@ async def _acompletion_streaming_iterator( "content_policy_fallbacks", self.content_policy_fallbacks ) initial_kwargs["original_function"] = self._acompletion - initial_kwargs["messages"] = messages + initial_kwargs["messages"] = messages + [ + { + "role": "assistant", + "content": e.generated_content, + "prefix": True, + } + ] self._update_kwargs_before_fallbacks( model=model_group, kwargs=initial_kwargs ) From 17732df4f4b64ab47ea61805ad57d660f2ae41d5 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 30 Jul 2025 12:49:42 -0700 Subject: [PATCH 06/14] fix(router.py): add a system prompt to help guide non-prefix supporting models to use the continued text correctly --- litellm/router.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/litellm/router.py b/litellm/router.py index f850a3d0e26..eab0f7bdff9 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1110,11 +1110,15 @@ async def _acompletion_streaming_iterator( ) initial_kwargs["original_function"] = self._acompletion initial_kwargs["messages"] = messages + [ + { + "role": "system", + "content": "You are a helpful assistant. You are given a message and you need to respond to it. You are also given a generated content. You need to respond to the message in continuation of the generated content. Do not repeat the same content. Your response should be in continuation of this text: ", + }, { "role": "assistant", "content": e.generated_content, "prefix": True, - } + }, ] self._update_kwargs_before_fallbacks( model=model_group, kwargs=initial_kwargs From a12554de87a55dfd5f149311341e14e4c60982c4 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 30 Jul 2025 12:54:32 -0700 Subject: [PATCH 07/14] fix(common_utils.py): support converting `prefix: true` for non-prefix supporting models --- .../prompt_templates/common_utils.py | 38 +++++++++++++++++++ ...ore_utils_prompt_templates_common_utils.py | 18 +++++++++ 2 files changed, 56 insertions(+) diff --git a/litellm/litellm_core_utils/prompt_templates/common_utils.py b/litellm/litellm_core_utils/prompt_templates/common_utils.py index 258601ff5a0..827d28598ec 100644 --- a/litellm/litellm_core_utils/prompt_templates/common_utils.py +++ b/litellm/litellm_core_utils/prompt_templates/common_utils.py @@ -822,3 +822,41 @@ def set_last_user_message( messages.reverse() messages.append({"role": "user", "content": content}) return messages + + +def convert_prefix_message_to_non_prefix_messages( + messages: List[AllMessageValues], +) -> List[AllMessageValues]: + """ + For models that don't support {prefix: true} in messages, we need to convert the prefix message to a non-prefix message. + + Use prompt: + + {"role": "assistant", "content": "value", "prefix": true} -> [ + { + "role": "system", + "content": "You are a helpful assistant. You are given a message and you need to respond to it. You are also given a generated content. You need to respond to the message in continuation of the generated content. Do not repeat the same content. Your response should be in continuation of this text: ", + }, + { + "role": "assistant", + "content": message["content"], + }, + ] + + do this in place + """ + new_messages: List[AllMessageValues] = [] + for message in messages: + if message.get("prefix"): + new_messages.append( + { + "role": "system", + "content": "You are a helpful assistant. You are given a message and you need to respond to it. You are also given a generated content. You need to respond to the message in continuation of the generated content. Do not repeat the same content. Your response should be in continuation of this text: ", + } + ) + new_messages.append( + {**{k: v for k, v in message.items() if k != "prefix"}} # type: ignore + ) + else: + new_messages.append(message) + return new_messages diff --git a/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_common_utils.py b/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_common_utils.py index 1d349a44e5c..980693aa73a 100644 --- a/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_common_utils.py +++ b/tests/test_litellm/litellm_core_utils/prompt_templates/test_litellm_core_utils_prompt_templates_common_utils.py @@ -125,3 +125,21 @@ def test_handle_any_messages_to_chat_completion_str_messages_conversion_complex( result = handle_any_messages_to_chat_completion_str_messages_conversion(message) assert len(result) == 1 assert result[0]["input"] == json.dumps(message) + + +def test_convert_prefix_message_to_non_prefix_messages(): + from litellm.litellm_core_utils.prompt_templates.common_utils import ( + convert_prefix_message_to_non_prefix_messages, + ) + + messages = [ + {"role": "assistant", "content": "value", "prefix": True}, + ] + result = convert_prefix_message_to_non_prefix_messages(messages) + assert result == [ + { + "role": "system", + "content": "You are a helpful assistant. You are given a message and you need to respond to it. You are also given a generated content. You need to respond to the message in continuation of the generated content. Do not repeat the same content. Your response should be in continuation of this text: ", + }, + {"role": "assistant", "content": "value"}, + ] From fd13348ff089d03406c138ae20996ee7d0679bd0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 30 Jul 2025 13:03:08 -0700 Subject: [PATCH 08/14] fix: reduce LOC in function --- litellm/router.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index eab0f7bdff9..85e275b3a3c 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1169,9 +1169,6 @@ async def _acompletion( try: input_kwargs_for_streaming_fallback = kwargs.copy() input_kwargs_for_streaming_fallback["model"] = model - verbose_router_logger.debug( - f"Inside _acompletion()- model: {model}; kwargs: {kwargs}" - ) parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs) start_time = time.time() From 8cfcdd605c08cd51d30de518cb4b6d9ef9193d8a Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 30 Jul 2025 13:15:58 -0700 Subject: [PATCH 09/14] test(test_router.py): add unit tests for new function --- tests/test_litellm/test_router.py | 257 ++++++++++++++++++++++++++++++ 1 file changed, 257 insertions(+) diff --git a/tests/test_litellm/test_router.py b/tests/test_litellm/test_router.py index cacdd51ee90..4d75f372e18 100644 --- a/tests/test_litellm/test_router.py +++ b/tests/test_litellm/test_router.py @@ -1064,3 +1064,260 @@ def test_router_get_model_access_groups_team_only_models(): model_name="gpt-3.5-turbo", team_id="team_1" ) assert list(access_groups.keys()) == ["default-models"] + + +@pytest.mark.asyncio +async def test_acompletion_streaming_iterator(): + """Test _acompletion_streaming_iterator for normal streaming and fallback behavior.""" + from unittest.mock import AsyncMock, MagicMock + + from litellm.exceptions import MidStreamFallbackError + from litellm.types.utils import ModelResponseStream + + # Helper class for creating async iterators + class AsyncIterator: + def __init__(self, items, error_after=None): + self.items = items + self.index = 0 + self.error_after = error_after + + def __aiter__(self): + return self + + async def __anext__(self): + if self.error_after is not None and self.index >= self.error_after: + raise self.error_after + if self.index >= len(self.items): + raise StopAsyncIteration + item = self.items[self.index] + self.index += 1 + return item + + # Set up router with fallback configuration + router = litellm.Router( + model_list=[ + { + "model_name": "gpt-4", + "litellm_params": {"model": "gpt-4", "api_key": "fake-key-1"}, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "gpt-3.5-turbo", "api_key": "fake-key-2"}, + }, + ], + fallbacks=[{"gpt-4": ["gpt-3.5-turbo"]}], + set_verbose=True, + ) + + # Test data + messages = [{"role": "user", "content": "Hello"}] + initial_kwargs = {"model": "gpt-4", "stream": True, "temperature": 0.7} + + # Test 1: Successful streaming (no errors) + print("\n=== Test 1: Successful streaming ===") + + # Mock successful streaming response + mock_chunks = [ + MagicMock(choices=[MagicMock(delta=MagicMock(content="Hello"))]), + MagicMock(choices=[MagicMock(delta=MagicMock(content=" there"))]), + MagicMock(choices=[MagicMock(delta=MagicMock(content="!"))]), + ] + + mock_response = AsyncIterator(mock_chunks) + + # Collect streamed chunks + collected_chunks = [] + async for chunk in router._acompletion_streaming_iterator( + model_response=mock_response, messages=messages, initial_kwargs=initial_kwargs + ): + collected_chunks.append(chunk) + + assert len(collected_chunks) == 3 + assert all(chunk in mock_chunks for chunk in collected_chunks) + print("✓ Successfully streamed all chunks") + + # Test 2: MidStreamFallbackError with fallback + print("\n=== Test 2: MidStreamFallbackError with fallback ===") + + # Create error that should trigger after first chunk + error = MidStreamFallbackError( + message="Connection lost", + model="gpt-4", + llm_provider="openai", + generated_content="Hello", + ) + + class AsyncIteratorWithError: + def __init__(self, items, error_after_index): + self.items = items + self.index = 0 + self.error_after_index = error_after_index + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index >= len(self.items): + raise StopAsyncIteration + if self.index == self.error_after_index: + raise error + item = self.items[self.index] + self.index += 1 + return item + + mock_error_response = AsyncIteratorWithError( + mock_chunks, 1 + ) # Error after first chunk + + # Mock the fallback response + fallback_chunks = [ + MagicMock(choices=[MagicMock(delta=MagicMock(content=" world"))]), + MagicMock(choices=[MagicMock(delta=MagicMock(content="!"))]), + ] + + mock_fallback_response = AsyncIterator(fallback_chunks) + + # Mock the fallback function + with patch.object( + router, + "async_function_with_fallbacks_common_utils", + return_value=mock_fallback_response, + ) as mock_fallback_utils: + + collected_chunks = [] + async for chunk in router._acompletion_streaming_iterator( + model_response=mock_error_response, + messages=messages, + initial_kwargs=initial_kwargs, + ): + collected_chunks.append(chunk) + + # Verify fallback was called + assert mock_fallback_utils.called + call_args = mock_fallback_utils.call_args + + # Check that generated content was added to messages + fallback_kwargs = call_args.kwargs["kwargs"] + modified_messages = fallback_kwargs["messages"] + + # Should have original message + system message + assistant message with prefix + assert len(modified_messages) == 3 + assert modified_messages[0] == {"role": "user", "content": "Hello"} + assert modified_messages[1]["role"] == "system" + assert "continuation" in modified_messages[1]["content"] + assert modified_messages[2]["role"] == "assistant" + assert modified_messages[2]["content"] == "Hello" + assert modified_messages[2]["prefix"] == True + + # Verify fallback parameters + assert call_args.kwargs["disable_fallbacks"] == False + assert call_args.kwargs["model_group"] == "gpt-4" + + # Should get original chunk + fallback chunks + assert len(collected_chunks) == 3 # 1 original + 2 fallback + print("✓ Fallback system called correctly with proper message modification") + + # Test 3: Fallback failure + print("\n=== Test 3: Fallback failure ===") + + mock_error_response_2 = AsyncIteratorWithError(mock_chunks, 1) # Same error pattern + + # Mock fallback failure + fallback_error = Exception("Fallback also failed") + with patch.object( + router, "async_function_with_fallbacks_common_utils", side_effect=fallback_error + ): + + collected_chunks = [] + original_error = None + + try: + async for chunk in router._acompletion_streaming_iterator( + model_response=mock_error_response_2, + messages=messages, + initial_kwargs=initial_kwargs, + ): + collected_chunks.append(chunk) + except MidStreamFallbackError as e: + original_error = e + + # Should re-raise original MidStreamFallbackError, not fallback error + assert original_error is not None + assert isinstance(original_error, MidStreamFallbackError) + assert original_error.generated_content == "Hello" + print("✓ Original error re-raised when fallback fails") + + print("\n=== All tests passed! ===") + + +@pytest.mark.asyncio +async def test_acompletion_streaming_iterator_edge_cases(): + """Test edge cases for _acompletion_streaming_iterator.""" + from unittest.mock import MagicMock + + from litellm.exceptions import MidStreamFallbackError + + router = litellm.Router( + model_list=[ + { + "model_name": "gpt-4", + "litellm_params": {"model": "gpt-4", "api_key": "fake-key"}, + } + ], + set_verbose=True, + ) + + messages = [{"role": "user", "content": "Test"}] + initial_kwargs = {"model": "gpt-4", "stream": True} + + # Test: Empty generated content + empty_error = MidStreamFallbackError( + message="Error", + model="gpt-4", + llm_provider="openai", + generated_content="", # Empty content + ) + + class AsyncIteratorImmediateError: + def __aiter__(self): + return self + + async def __anext__(self): + raise empty_error + + mock_response = AsyncIteratorImmediateError() + + # Mock empty fallback response using AsyncIterator + class EmptyAsyncIterator: + def __aiter__(self): + return self + + async def __anext__(self): + raise StopAsyncIteration + + mock_fallback_response = EmptyAsyncIterator() + + with patch.object( + router, + "async_function_with_fallbacks_common_utils", + return_value=mock_fallback_response, + ) as mock_fallback_utils: + + collected_chunks = [] + async for chunk in router._acompletion_streaming_iterator( + model_response=mock_response, + messages=messages, + initial_kwargs=initial_kwargs, + ): + collected_chunks.append(chunk) + + # Should still call fallback even with empty content + assert mock_fallback_utils.called + fallback_kwargs = mock_fallback_utils.call_args.kwargs["kwargs"] + modified_messages = fallback_kwargs["messages"] + + # Should have assistant message with empty content + assert modified_messages[2]["content"] == "" + print("✓ Handles empty generated content correctly") + + print("✓ Edge case tests passed!") From 0a1716491092cc4f2fa64ae5112ef95c1635b9e7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 30 Jul 2025 13:53:05 -0700 Subject: [PATCH 10/14] test: add basic unit test --- tests/test_litellm/test_router.py | 47 ++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/tests/test_litellm/test_router.py b/tests/test_litellm/test_router.py index 4d75f372e18..1837ab3c56f 100644 --- a/tests/test_litellm/test_router.py +++ b/tests/test_litellm/test_router.py @@ -2,7 +2,7 @@ import json import os import sys -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi.testclient import TestClient @@ -13,6 +13,7 @@ import litellm +from litellm.router_utils.fallback_event_handlers import run_async_fallback def test_update_kwargs_does_not_mutate_defaults_and_merges_metadata(): @@ -1321,3 +1322,47 @@ async def __anext__(self): print("✓ Handles empty generated content correctly") print("✓ Edge case tests passed!") + + +@pytest.mark.asyncio +async def test_async_function_with_fallbacks_common_utils(): + """Test the async_function_with_fallbacks_common_utils method""" + # Create a basic router for testing + router = litellm.Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + }, + } + ], + max_fallbacks=5, + ) + + # Test case 1: disable_fallbacks=True should raise original exception + test_exception = Exception("Test error") + with pytest.raises(Exception, match="Test error"): + await router.async_function_with_fallbacks_common_utils( + e=test_exception, + disable_fallbacks=True, + fallbacks=None, + context_window_fallbacks=None, + content_policy_fallbacks=None, + model_group="gpt-3.5-turbo", + args=(), + kwargs=MagicMock(), + ) + + # Test case 2: original_model_group=None should raise original exception + with pytest.raises(Exception, match="Test error"): + await router.async_function_with_fallbacks_common_utils( + e=test_exception, + disable_fallbacks=False, + fallbacks=None, + context_window_fallbacks=None, + content_policy_fallbacks=None, + model_group="gpt-3.5-turbo", + args=(), + kwargs={}, # No model key + ) From 2e89d50f95d7254c4b14422dadc472a16e8d627c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 30 Jul 2025 14:46:24 -0700 Subject: [PATCH 11/14] fix(router.py): ensure return type of fallback stream is compatible with CustomStreamWrapper prevent client code from breaking --- litellm/router.py | 137 +++++++++++------- .../test_router_batch_completion.py | 1 + 2 files changed, 82 insertions(+), 56 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 85e275b3a3c..b9feb8409c1 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1084,7 +1084,7 @@ async def _acompletion_streaming_iterator( model_response: CustomStreamWrapper, messages: List[Dict[str, str]], initial_kwargs: dict, - ) -> AsyncGenerator[Optional[ModelResponseStream], None]: + ) -> CustomStreamWrapper: """ Helper to iterate over a streaming response. @@ -1092,69 +1092,92 @@ async def _acompletion_streaming_iterator( """ from litellm.exceptions import MidStreamFallbackError - try: - async for item in model_response: - yield item - except MidStreamFallbackError as e: - try: - # Use the router's fallback system - model_group = cast(str, initial_kwargs.get("model")) - fallbacks: Optional[List] = initial_kwargs.get( - "fallbacks", self.fallbacks - ) - context_window_fallbacks: Optional[List] = initial_kwargs.get( - "context_window_fallbacks", self.context_window_fallbacks - ) - content_policy_fallbacks: Optional[List] = initial_kwargs.get( - "content_policy_fallbacks", self.content_policy_fallbacks + class FallbackStreamWrapper(CustomStreamWrapper): + def __init__(self, async_generator): + # Copy attributes from the original model_response + super().__init__( + completion_stream=async_generator, + model=getattr(model_response, "model", ""), + custom_llm_provider=getattr( + model_response, "custom_llm_provider", "" + ), + logging_obj=getattr(model_response, "logging_obj", None), ) - initial_kwargs["original_function"] = self._acompletion - initial_kwargs["messages"] = messages + [ - { - "role": "system", - "content": "You are a helpful assistant. You are given a message and you need to respond to it. You are also given a generated content. You need to respond to the message in continuation of the generated content. Do not repeat the same content. Your response should be in continuation of this text: ", - }, - { - "role": "assistant", - "content": e.generated_content, - "prefix": True, - }, - ] - self._update_kwargs_before_fallbacks( - model=model_group, kwargs=initial_kwargs - ) - fallback_response = ( - await self.async_function_with_fallbacks_common_utils( - e=e, - disable_fallbacks=False, - fallbacks=fallbacks, - context_window_fallbacks=context_window_fallbacks, - content_policy_fallbacks=content_policy_fallbacks, - model_group=model_group, - args=(), - kwargs=initial_kwargs, + self._async_generator = async_generator + + def __aiter__(self): + return self._async_generator + + async def __anext__(self): + return await self._async_generator.__anext__() + + async def stream_with_fallbacks(): + try: + async for item in model_response: + yield item + except MidStreamFallbackError as e: + try: + # Use the router's fallback system + model_group = cast(str, initial_kwargs.get("model")) + fallbacks: Optional[List] = initial_kwargs.get( + "fallbacks", self.fallbacks + ) + context_window_fallbacks: Optional[List] = initial_kwargs.get( + "context_window_fallbacks", self.context_window_fallbacks + ) + content_policy_fallbacks: Optional[List] = initial_kwargs.get( + "content_policy_fallbacks", self.content_policy_fallbacks + ) + initial_kwargs["original_function"] = self._acompletion + initial_kwargs["messages"] = messages + [ + { + "role": "system", + "content": "You are a helpful assistant. You are given a message and you need to respond to it. You are also given a generated content. You need to respond to the message in continuation of the generated content. Do not repeat the same content. Your response should be in continuation of this text: ", + }, + { + "role": "assistant", + "content": e.generated_content, + "prefix": True, + }, + ] + self._update_kwargs_before_fallbacks( + model=model_group, kwargs=initial_kwargs + ) + fallback_response = ( + await self.async_function_with_fallbacks_common_utils( + e=e, + disable_fallbacks=False, + fallbacks=fallbacks, + context_window_fallbacks=context_window_fallbacks, + content_policy_fallbacks=content_policy_fallbacks, + model_group=model_group, + args=(), + kwargs=initial_kwargs, + ) ) - ) - # If fallback returns a streaming response, iterate over it - if hasattr(fallback_response, "__aiter__"): - async for fallback_item in fallback_response: # type: ignore - yield fallback_item - else: - # If fallback returns a non-streaming response, yield None - yield None + # If fallback returns a streaming response, iterate over it + if hasattr(fallback_response, "__aiter__"): + async for fallback_item in fallback_response: # type: ignore + yield fallback_item + else: + # If fallback returns a non-streaming response, yield None + yield None - except Exception as fallback_error: - # If fallback also fails, log and re-raise original error - verbose_router_logger.error(f"Fallback also failed: {fallback_error}") - raise e # Re-raise the original error + except Exception as fallback_error: + # If fallback also fails, log and re-raise original error + verbose_router_logger.error( + f"Fallback also failed: {fallback_error}" + ) + raise e # Re-raise the original error + + return FallbackStreamWrapper(stream_with_fallbacks()) async def _acompletion( self, model: str, messages: List[Dict[str, str]], **kwargs ) -> Union[ ModelResponse, CustomStreamWrapper, - AsyncGenerator[Optional[ModelResponseStream], None], ]: """ - Get an available deployment @@ -1275,7 +1298,7 @@ async def _acompletion( ) if isinstance(response, CustomStreamWrapper): - return self._acompletion_streaming_iterator( + return await self._acompletion_streaming_iterator( model_response=response, messages=messages, initial_kwargs=input_kwargs_for_streaming_fallback, @@ -1659,7 +1682,8 @@ async def _async_completion_no_exceptions( Wrapper around self.acompletion that catches exceptions and returns them as a result """ try: - return await self.acompletion(model=model, messages=messages, stream=stream, **kwargs) # type: ignore + result = await self.acompletion(model=model, messages=messages, stream=stream, **kwargs) # type: ignore + return result except asyncio.CancelledError: verbose_router_logger.debug( "Received 'task.cancel'. Cancelling call w/ model={}.".format(model) @@ -1707,6 +1731,7 @@ async def check_response(task: asyncio.Task): ) for completed_task in done: result = await check_response(completed_task) + if result is not None: # Return the first successful result result._hidden_params["fastest_response_batch_completion"] = True diff --git a/tests/local_testing/test_router_batch_completion.py b/tests/local_testing/test_router_batch_completion.py index 6fedb82a553..534db7ed010 100644 --- a/tests/local_testing/test_router_batch_completion.py +++ b/tests/local_testing/test_router_batch_completion.py @@ -130,6 +130,7 @@ async def test_batch_completion_fastest_response_unit_test(): @pytest.mark.asyncio async def test_batch_completion_fastest_response_streaming(): litellm.set_verbose = True + litellm._turn_on_debug() router = litellm.Router( model_list=[ From b08cb138c80e50ea59d4f652f8f19420b705b00d Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 30 Jul 2025 14:51:31 -0700 Subject: [PATCH 12/14] fix: cleanup --- litellm/router.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index b9feb8409c1..468bb2c4ed0 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -1097,11 +1097,9 @@ def __init__(self, async_generator): # Copy attributes from the original model_response super().__init__( completion_stream=async_generator, - model=getattr(model_response, "model", ""), - custom_llm_provider=getattr( - model_response, "custom_llm_provider", "" - ), - logging_obj=getattr(model_response, "logging_obj", None), + model=model_response.model, + custom_llm_provider=model_response.custom_llm_provider, + logging_obj=model_response.logging_obj, ) self._async_generator = async_generator From a14262471572e782f77312525856120e5e9e9965 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 30 Jul 2025 15:01:46 -0700 Subject: [PATCH 13/14] test: update test --- tests/test_litellm/test_router.py | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/tests/test_litellm/test_router.py b/tests/test_litellm/test_router.py index 1837ab3c56f..4795fdc372c 100644 --- a/tests/test_litellm/test_router.py +++ b/tests/test_litellm/test_router.py @@ -1126,11 +1126,17 @@ async def __anext__(self): mock_response = AsyncIterator(mock_chunks) + setattr(mock_response, "model", "gpt-4") + setattr(mock_response, "custom_llm_provider", "openai") + setattr(mock_response, "logging_obj", MagicMock()) + + result = await router._acompletion_streaming_iterator( + model_response=mock_response, messages=messages, initial_kwargs=initial_kwargs + ) + # Collect streamed chunks collected_chunks = [] - async for chunk in router._acompletion_streaming_iterator( - model_response=mock_response, messages=messages, initial_kwargs=initial_kwargs - ): + async for chunk in result: collected_chunks.append(chunk) assert len(collected_chunks) == 3 @@ -1170,6 +1176,10 @@ async def __anext__(self): mock_chunks, 1 ) # Error after first chunk + setattr(mock_error_response, "model", "gpt-4") + setattr(mock_error_response, "custom_llm_provider", "openai") + setattr(mock_error_response, "logging_obj", MagicMock()) + # Mock the fallback response fallback_chunks = [ MagicMock(choices=[MagicMock(delta=MagicMock(content=" world"))]), @@ -1186,11 +1196,13 @@ async def __anext__(self): ) as mock_fallback_utils: collected_chunks = [] - async for chunk in router._acompletion_streaming_iterator( + result = await router._acompletion_streaming_iterator( model_response=mock_error_response, messages=messages, initial_kwargs=initial_kwargs, - ): + ) + + async for chunk in result: collected_chunks.append(chunk) # Verify fallback was called @@ -1231,13 +1243,18 @@ async def __anext__(self): collected_chunks = [] original_error = None + setattr(mock_error_response_2, "model", "gpt-4") + setattr(mock_error_response_2, "custom_llm_provider", "openai") + setattr(mock_error_response_2, "logging_obj", MagicMock()) try: - async for chunk in router._acompletion_streaming_iterator( + result = await router._acompletion_streaming_iterator( model_response=mock_error_response_2, messages=messages, initial_kwargs=initial_kwargs, - ): + ) + + async for chunk in result: collected_chunks.append(chunk) except MidStreamFallbackError as e: original_error = e From aadbbd5d783dcbfd98ef16cad43678a92101cea8 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 30 Jul 2025 15:03:42 -0700 Subject: [PATCH 14/14] fix: fix linting error --- litellm/router.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index 468bb2c4ed0..d85c5340605 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -23,7 +23,6 @@ from typing import ( TYPE_CHECKING, Any, - AsyncGenerator, Callable, Dict, List, @@ -147,7 +146,7 @@ from litellm.types.utils import GenericBudgetConfigType, LiteLLMBatch from litellm.types.utils import ModelInfo from litellm.types.utils import ModelInfo as ModelMapInfo -from litellm.types.utils import ModelResponseStream, StandardLoggingPayload +from litellm.types.utils import StandardLoggingPayload from litellm.utils import ( CustomStreamWrapper, EmbeddingResponse,