Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions litellm/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@

from litellm.types.utils import LiteLLMCommonStrings

_MINIMAL_ERROR_RESPONSE: Optional[httpx.Response] = None


def _get_minimal_error_response() -> httpx.Response:
"""Get a cached minimal httpx.Response object for error cases."""
global _MINIMAL_ERROR_RESPONSE
if _MINIMAL_ERROR_RESPONSE is None:
_MINIMAL_ERROR_RESPONSE = httpx.Response(
status_code=400,
request=httpx.Request(
method="GET", url="https://litellm.ai"
),
)
return _MINIMAL_ERROR_RESPONSE


class AuthenticationError(openai.AuthenticationError): # type: ignore
def __init__(
Expand Down Expand Up @@ -127,16 +142,15 @@ def __init__(
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
_response_headers = (
getattr(response, "headers", None) if response is not None else None
)
self.response = httpx.Response(
status_code=self.status_code,
headers=_response_headers,
request=httpx.Request(
method="GET", url="https://litellm.ai"
), # mock request object
)
if (
response is not None
and isinstance(response, httpx.Response)
and hasattr(response, "request")
and response.request is not None
):
self.response = response
else:
self.response = _get_minimal_error_response()
super().__init__(
self.message, response=self.response, body=body
) # Call the base class constructor with the parameters it needs
Expand Down
5 changes: 3 additions & 2 deletions litellm/litellm_core_utils/get_litellm_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ def get_litellm_params(
"text_completion": text_completion,
"azure_ad_token_provider": azure_ad_token_provider,
"user_continue_message": user_continue_message,
"base_model": base_model
or _get_base_model_from_litellm_call_metadata(metadata=metadata),
"base_model": base_model or (
_get_base_model_from_litellm_call_metadata(metadata=metadata) if metadata else None
),
"litellm_trace_id": litellm_trace_id,
"litellm_session_id": litellm_session_id,
"hf_model_name": hf_model_name,
Expand Down
14 changes: 2 additions & 12 deletions litellm/litellm_core_utils/get_llm_provider_logic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Optional, Tuple

import httpx

import litellm
from litellm.constants import REPLICATE_MODEL_NAME_WITH_ID_LENGTH
from litellm.llms.openai_like.json_loader import JSONProviderRegistry
Expand Down Expand Up @@ -453,11 +451,7 @@ def get_llm_provider( # noqa: PLR0915
raise litellm.exceptions.BadRequestError( # type: ignore
message=error_str,
model=model,
response=httpx.Response(
status_code=400,
content=error_str,
request=httpx.Request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore
),
response=None,
llm_provider="",
)
if api_base is not None and not isinstance(api_base, str):
Expand All @@ -481,11 +475,7 @@ def get_llm_provider( # noqa: PLR0915
raise litellm.exceptions.BadRequestError( # type: ignore
message=f"GetLLMProvider Exception - {str(e)}\n\noriginal model: {model}",
model=model,
response=httpx.Response(
status_code=400,
content=error_str,
request=httpx.Request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore
),
response=None,
llm_provider="",
)

Expand Down
4 changes: 2 additions & 2 deletions litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,12 @@ def __init__(
messages = new_messages

self.model = model
self.messages = copy.deepcopy(messages)
self.messages = copy.deepcopy(messages) if messages is not None else None
self.stream = stream
self.start_time = start_time # log the call start time
self.call_type = call_type
self.litellm_call_id = litellm_call_id
self.litellm_trace_id: str = litellm_trace_id or str(uuid.uuid4())
self.litellm_trace_id: str = litellm_trace_id if litellm_trace_id else str(uuid.uuid4())
self.function_id = function_id
self.streaming_chunks: List[Any] = [] # for generating complete stream response
self.sync_streaming_chunks: List[
Expand Down
11 changes: 6 additions & 5 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,8 @@ def function_setup( # noqa: PLR0915
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None

## LAZY LOAD COROUTINE CHECKER ##
get_coroutine_checker = getattr(sys.modules[__name__], "get_coroutine_checker")
get_coroutine_checker_fn = getattr(sys.modules[__name__], "get_coroutine_checker")
coroutine_checker = get_coroutine_checker_fn()

## DYNAMIC CALLBACKS ##
dynamic_callbacks: Optional[
Expand Down Expand Up @@ -825,7 +826,7 @@ def function_setup( # noqa: PLR0915
if len(litellm.input_callback) > 0:
removed_async_items = []
for index, callback in enumerate(litellm.input_callback): # type: ignore
if get_coroutine_checker().is_async_callable(callback):
if coroutine_checker.is_async_callable(callback):
litellm._async_input_callback.append(callback)
removed_async_items.append(index)

Expand All @@ -835,7 +836,7 @@ def function_setup( # noqa: PLR0915
if len(litellm.success_callback) > 0:
removed_async_items = []
for index, callback in enumerate(litellm.success_callback): # type: ignore
if get_coroutine_checker().is_async_callable(callback):
if coroutine_checker.is_async_callable(callback):
litellm.logging_callback_manager.add_litellm_async_success_callback(
callback
)
Expand All @@ -860,7 +861,7 @@ def function_setup( # noqa: PLR0915
if len(litellm.failure_callback) > 0:
removed_async_items = []
for index, callback in enumerate(litellm.failure_callback): # type: ignore
if get_coroutine_checker().is_async_callable(callback):
if coroutine_checker.is_async_callable(callback):
litellm.logging_callback_manager.add_litellm_async_failure_callback(
callback
)
Expand Down Expand Up @@ -893,7 +894,7 @@ def function_setup( # noqa: PLR0915
removed_async_items = []
for index, callback in enumerate(kwargs["success_callback"]):
if (
get_coroutine_checker().is_async_callable(callback)
coroutine_checker.is_async_callable(callback)
or callback == "dynamodb"
or callback == "s3"
):
Expand Down
Loading