diff --git a/litellm/exceptions.py b/litellm/exceptions.py index c2443626b8d..f5fcded5134 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -16,6 +16,21 @@ from litellm.types.utils import LiteLLMCommonStrings +_MINIMAL_ERROR_RESPONSE: Optional[httpx.Response] = None + + +def _get_minimal_error_response() -> httpx.Response: + """Get a cached minimal httpx.Response object for error cases.""" + global _MINIMAL_ERROR_RESPONSE + if _MINIMAL_ERROR_RESPONSE is None: + _MINIMAL_ERROR_RESPONSE = httpx.Response( + status_code=400, + request=httpx.Request( + method="GET", url="https://litellm.ai" + ), + ) + return _MINIMAL_ERROR_RESPONSE + class AuthenticationError(openai.AuthenticationError): # type: ignore def __init__( @@ -127,16 +142,15 @@ def __init__( self.litellm_debug_info = litellm_debug_info self.max_retries = max_retries self.num_retries = num_retries - _response_headers = ( - getattr(response, "headers", None) if response is not None else None - ) - self.response = httpx.Response( - status_code=self.status_code, - headers=_response_headers, - request=httpx.Request( - method="GET", url="https://litellm.ai" - ), # mock request object - ) + if ( + response is not None + and isinstance(response, httpx.Response) + and hasattr(response, "request") + and response.request is not None + ): + self.response = response + else: + self.response = _get_minimal_error_response() super().__init__( self.message, response=self.response, body=body ) # Call the base class constructor with the parameters it needs diff --git a/litellm/litellm_core_utils/get_litellm_params.py b/litellm/litellm_core_utils/get_litellm_params.py index 0d35cfa3140..e290101f8bf 100644 --- a/litellm/litellm_core_utils/get_litellm_params.py +++ b/litellm/litellm_core_utils/get_litellm_params.py @@ -93,8 +93,9 @@ def get_litellm_params( "text_completion": text_completion, "azure_ad_token_provider": azure_ad_token_provider, "user_continue_message": user_continue_message, - "base_model": base_model - or _get_base_model_from_litellm_call_metadata(metadata=metadata), + "base_model": base_model or ( + _get_base_model_from_litellm_call_metadata(metadata=metadata) if metadata else None + ), "litellm_trace_id": litellm_trace_id, "litellm_session_id": litellm_session_id, "hf_model_name": hf_model_name, diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index 807b66faec1..718773a1b16 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -1,7 +1,5 @@ from typing import Optional, Tuple -import httpx - import litellm from litellm.constants import REPLICATE_MODEL_NAME_WITH_ID_LENGTH from litellm.llms.openai_like.json_loader import JSONProviderRegistry @@ -453,11 +451,7 @@ def get_llm_provider( # noqa: PLR0915 raise litellm.exceptions.BadRequestError( # type: ignore message=error_str, model=model, - response=httpx.Response( - status_code=400, - content=error_str, - request=httpx.Request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore - ), + response=None, llm_provider="", ) if api_base is not None and not isinstance(api_base, str): @@ -481,11 +475,7 @@ def get_llm_provider( # noqa: PLR0915 raise litellm.exceptions.BadRequestError( # type: ignore message=f"GetLLMProvider Exception - {str(e)}\n\noriginal model: {model}", model=model, - response=httpx.Response( - status_code=400, - content=error_str, - request=httpx.Request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore - ), + response=None, llm_provider="", ) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 572c797aa55..d3bcfe8200e 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -325,12 +325,12 @@ def __init__( messages = new_messages self.model = model - self.messages = copy.deepcopy(messages) + self.messages = copy.deepcopy(messages) if messages is not None else None self.stream = stream self.start_time = start_time # log the call start time self.call_type = call_type self.litellm_call_id = litellm_call_id - self.litellm_trace_id: str = litellm_trace_id or str(uuid.uuid4()) + self.litellm_trace_id: str = litellm_trace_id if litellm_trace_id else str(uuid.uuid4()) self.function_id = function_id self.streaming_chunks: List[Any] = [] # for generating complete stream response self.sync_streaming_chunks: List[ diff --git a/litellm/utils.py b/litellm/utils.py index 06eceefe0cc..11389dd1af9 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -771,7 +771,8 @@ def function_setup( # noqa: PLR0915 function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None ## LAZY LOAD COROUTINE CHECKER ## - get_coroutine_checker = getattr(sys.modules[__name__], "get_coroutine_checker") + get_coroutine_checker_fn = getattr(sys.modules[__name__], "get_coroutine_checker") + coroutine_checker = get_coroutine_checker_fn() ## DYNAMIC CALLBACKS ## dynamic_callbacks: Optional[ @@ -825,7 +826,7 @@ def function_setup( # noqa: PLR0915 if len(litellm.input_callback) > 0: removed_async_items = [] for index, callback in enumerate(litellm.input_callback): # type: ignore - if get_coroutine_checker().is_async_callable(callback): + if coroutine_checker.is_async_callable(callback): litellm._async_input_callback.append(callback) removed_async_items.append(index) @@ -835,7 +836,7 @@ def function_setup( # noqa: PLR0915 if len(litellm.success_callback) > 0: removed_async_items = [] for index, callback in enumerate(litellm.success_callback): # type: ignore - if get_coroutine_checker().is_async_callable(callback): + if coroutine_checker.is_async_callable(callback): litellm.logging_callback_manager.add_litellm_async_success_callback( callback ) @@ -860,7 +861,7 @@ def function_setup( # noqa: PLR0915 if len(litellm.failure_callback) > 0: removed_async_items = [] for index, callback in enumerate(litellm.failure_callback): # type: ignore - if get_coroutine_checker().is_async_callable(callback): + if coroutine_checker.is_async_callable(callback): litellm.logging_callback_manager.add_litellm_async_failure_callback( callback ) @@ -893,7 +894,7 @@ def function_setup( # noqa: PLR0915 removed_async_items = [] for index, callback in enumerate(kwargs["success_callback"]): if ( - get_coroutine_checker().is_async_callable(callback) + coroutine_checker.is_async_callable(callback) or callback == "dynamodb" or callback == "s3" ):