Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,7 @@ jobs:
- run: python ./tests/documentation_tests/test_circular_imports.py
- run: python ./tests/code_coverage_tests/prevent_key_leaks_in_exceptions.py
- run: python ./tests/code_coverage_tests/check_unsafe_enterprise_import.py
- run: python ./tests/code_coverage_tests/ban_copy_deepcopy_kwargs.py
- run: helm lint ./deploy/charts/litellm-helm

db_migration_disable_update_check:
Expand Down
57 changes: 51 additions & 6 deletions litellm/litellm_core_utils/core_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,22 @@


def safe_divide_seconds(
seconds: float,
denominator: float,
default: Optional[float] = None
seconds: float, denominator: float, default: Optional[float] = None
) -> Optional[float]:
"""
Safely divide seconds by denominator, handling zero division.

Args:
seconds: Time duration in seconds
denominator: The divisor (e.g., number of tokens)
default: Value to return if division by zero (defaults to None)

Returns:
The result of the division as a float (seconds per unit), or default if denominator is zero
"""
if denominator <= 0:
return default

return float(seconds / denominator)


Expand Down Expand Up @@ -203,3 +201,50 @@ def preserve_upstream_non_openai_attributes(
for key, value in original_chunk.model_dump().items():
if key not in expected_keys:
setattr(model_response, key, value)


def safe_deep_copy(data):
"""
Safe Deep Copy

The LiteLLM Request has some object that can-not be pickled / deep copied

Use this function to safely deep copy the LiteLLM Request
"""
import copy

import litellm

if litellm.safe_memory_mode is True:
return data

litellm_parent_otel_span: Optional[Any] = None
# Step 1: Remove the litellm_parent_otel_span
litellm_parent_otel_span = None
if isinstance(data, dict):
# remove litellm_parent_otel_span since this is not picklable
if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]:
litellm_parent_otel_span = data["metadata"].pop("litellm_parent_otel_span")
data["metadata"]["litellm_parent_otel_span"] = "placeholder"
if (
"litellm_metadata" in data
and "litellm_parent_otel_span" in data["litellm_metadata"]
):
litellm_parent_otel_span = data["litellm_metadata"].pop(
"litellm_parent_otel_span"
)
data["litellm_metadata"]["litellm_parent_otel_span"] = "placeholder"
new_data = copy.deepcopy(data)

# Step 2: re-add the litellm_parent_otel_span after doing a deep copy
if isinstance(data, dict) and litellm_parent_otel_span is not None:
if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]:
data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span
if (
"litellm_metadata" in data
and "litellm_parent_otel_span" in data["litellm_metadata"]
):
data["litellm_metadata"][
"litellm_parent_otel_span"
] = litellm_parent_otel_span
return new_data
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Deep Copy Fails Span Handling, Data Integrity

The safe_deep_copy function contains two logic errors:

  1. Incorrect Span Handling: If litellm_parent_otel_span exists in both metadata and litellm_metadata dictionaries, the single variable used to temporarily store the span is overwritten. This causes the first span to be lost, and the second span is then incorrectly restored to both locations in the original data object, leading to potential incorrect OTEL tracing behavior.
  2. Incomplete Deep Copy: The function replaces OTEL spans with "placeholder" strings before deep copying. However, it restores the original spans only to the original data object, while returning new_data which still contains these "placeholder" strings. This results in the returned deep copy being incomplete and potentially unusable for consumers expecting valid span objects.
Locations (1)
Fix in Cursor Fix in Web

4 changes: 2 additions & 2 deletions litellm/litellm_core_utils/fallback_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import uuid
from copy import deepcopy
from typing import Optional

import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import safe_deep_copy

from .asyncify import run_async_function

Expand Down Expand Up @@ -41,7 +41,7 @@ async def async_completion_with_fallbacks(**kwargs):
most_recent_exception_str: Optional[str] = None
for fallback in fallbacks:
try:
completion_kwargs = deepcopy(base_kwargs)
completion_kwargs = safe_deep_copy(base_kwargs)
# Handle dictionary fallback configurations
if isinstance(fallback, dict):
model = fallback.pop("model", original_model)
Expand Down
9 changes: 4 additions & 5 deletions litellm/proxy/_new_secret_config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
model_list:
- model_name: genai/test/*
- model_name: "gpt-4o-mini-openai"
litellm_params:
model: openai/*
api_base: https://api.openai.com
model: gpt-4o-mini
api_key: os.environ/OPENAI_API_KEY

litellm_settings:
check_provider_endpoint: true
router_settings:
model_group_alias: {"gpt-4o": "gpt-4o-mini-openai"}
12 changes: 6 additions & 6 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,6 @@ def generate_feedback_box():
from litellm.proxy.management_endpoints.tag_management_endpoints import (
router as tag_management_router,
)
from litellm.proxy.management_endpoints.user_agent_analytics_endpoints import (
router as user_agent_analytics_router,
)
from litellm.proxy.management_endpoints.team_callback_endpoints import (
router as team_callback_router,
)
Expand All @@ -287,6 +284,9 @@ def generate_feedback_box():
get_disabled_non_admin_personal_key_creation,
)
from litellm.proxy.management_endpoints.ui_sso import router as ui_sso_router
from litellm.proxy.management_endpoints.user_agent_analytics_endpoints import (
router as user_agent_analytics_router,
)
from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update
from litellm.proxy.middleware.prometheus_auth_middleware import PrometheusAuthMiddleware
from litellm.proxy.openai_files_endpoints.files_endpoints import (
Expand Down Expand Up @@ -2213,7 +2213,9 @@ def _init_non_llm_configs(self, config: dict):
litellm_settings = config.get("litellm_settings", {})
mcp_aliases = litellm_settings.get("mcp_aliases", None)

global_mcp_server_manager.load_servers_from_config(mcp_servers_config, mcp_aliases)
global_mcp_server_manager.load_servers_from_config(
mcp_servers_config, mcp_aliases
)

## VECTOR STORES
vector_store_registry_config = config.get("vector_store_registry", None)
Expand Down Expand Up @@ -3246,7 +3248,6 @@ async def async_data_generator(
"async_data_generator: received streaming chunk - {}".format(chunk)
)


### CALL HOOKS ### - modify outgoing data
chunk = await proxy_logging_obj.async_post_call_streaming_hook(
user_api_key_dict=user_api_key_dict,
Expand All @@ -3255,7 +3256,6 @@ async def async_data_generator(
str_so_far=str_so_far,
)


if isinstance(chunk, (ModelResponse, ModelResponseStream)):
response_str = litellm.get_response_string(response_obj=chunk)
str_so_far += response_str
Expand Down
97 changes: 41 additions & 56 deletions litellm/proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@
ModelResponseStream,
Router,
)
from litellm.types.mcp import (
MCPPreCallRequestObject,
MCPPreCallResponseObject,
MCPDuringCallResponseObject,
)
from litellm._logging import verbose_proxy_logger
from litellm._service_logger import ServiceLogging, ServiceTypes
from litellm.caching.caching import DualCache, RedisCache
Expand Down Expand Up @@ -93,6 +88,11 @@
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
from litellm.secret_managers.main import str_to_bool
from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES
from litellm.types.mcp import (
MCPDuringCallResponseObject,
MCPPreCallRequestObject,
MCPPreCallResponseObject,
)
from litellm.types.utils import CallTypes, LLMResponseTypes, LoggedLiteLLMParams

if TYPE_CHECKING:
Expand All @@ -118,33 +118,6 @@ def print_verbose(print_statement):
print(f"LiteLLM Proxy: {print_statement}") # noqa


def safe_deep_copy(data):
"""
Safe Deep Copy

The LiteLLM Request has some object that can-not be pickled / deep copied

Use this function to safely deep copy the LiteLLM Request
"""
if litellm.safe_memory_mode is True:
return data

litellm_parent_otel_span: Optional[Any] = None
# Step 1: Remove the litellm_parent_otel_span
litellm_parent_otel_span = None
if isinstance(data, dict):
# remove litellm_parent_otel_span since this is not picklable
if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]:
litellm_parent_otel_span = data["metadata"].pop("litellm_parent_otel_span")
new_data = copy.deepcopy(data)

# Step 2: re-add the litellm_parent_otel_span after doing a deep copy
if isinstance(data, dict) and litellm_parent_otel_span is not None:
if "metadata" in data:
data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span
return new_data


class InternalUsageCache:
def __init__(self, dual_cache: DualCache):
self.dual_cache: DualCache = dual_cache
Expand Down Expand Up @@ -474,11 +447,11 @@ async def update_request_status(
)

async def async_pre_mcp_tool_call_hook(
self,
kwargs: dict,
request_obj: Any,
start_time: datetime,
end_time: datetime,
self,
kwargs: dict,
request_obj: Any,
start_time: datetime,
end_time: datetime,
) -> Optional[Any]:
"""
Pre MCP Tool Call Hook
Expand All @@ -489,7 +462,7 @@ async def async_pre_mcp_tool_call_hook(
from litellm.types.mcp import MCPPreCallRequestObject, MCPPreCallResponseObject

callbacks = self.get_combined_callback_list(
dynamic_success_callbacks=getattr(self, 'dynamic_success_callbacks', None),
dynamic_success_callbacks=getattr(self, "dynamic_success_callbacks", None),
global_callbacks=litellm.success_callback,
)

Expand All @@ -500,7 +473,7 @@ async def async_pre_mcp_tool_call_hook(
arguments=kwargs.get("arguments", {}),
server_name=kwargs.get("server_name"),
user_api_key_auth=kwargs.get("user_api_key_auth"),
hidden_params=HiddenParams()
hidden_params=HiddenParams(),
)

for callback in callbacks:
Expand Down Expand Up @@ -537,10 +510,10 @@ def get_combined_callback_list(
return global_callbacks
return list(set(dynamic_success_callbacks + global_callbacks))



def _parse_pre_mcp_call_hook_response(
self, response: MCPPreCallResponseObject, original_request: MCPPreCallRequestObject
self,
response: MCPPreCallResponseObject,
original_request: MCPPreCallRequestObject,
) -> Dict[str, Any]:
"""
Parse the response from the pre_mcp_tool_call_hook
Expand All @@ -551,29 +524,33 @@ def _parse_pre_mcp_call_hook_response(
"""
result = {
"should_proceed": response.should_proceed,
"modified_arguments": response.modified_arguments or original_request.arguments,
"modified_arguments": response.modified_arguments
or original_request.arguments,
"error_message": response.error_message,
"hidden_params": response.hidden_params,
}
return result

async def async_during_mcp_tool_call_hook(
self,
kwargs: dict,
request_obj: Any,
start_time: datetime,
end_time: datetime,
self,
kwargs: dict,
request_obj: Any,
start_time: datetime,
end_time: datetime,
) -> Optional[Any]:
"""
During MCP Tool Call Hook

Use this for concurrent monitoring and validation during tool execution.
"""
from litellm.types.llms.base import HiddenParams
from litellm.types.mcp import MCPDuringCallResponseObject, MCPDuringCallRequestObject
from litellm.types.mcp import (
MCPDuringCallRequestObject,
MCPDuringCallResponseObject,
)

callbacks = self.get_combined_callback_list(
dynamic_success_callbacks=getattr(self, 'dynamic_success_callbacks', None),
dynamic_success_callbacks=getattr(self, "dynamic_success_callbacks", None),
global_callbacks=litellm.success_callback,
)

Expand All @@ -584,7 +561,7 @@ async def async_during_mcp_tool_call_hook(
arguments=kwargs.get("arguments", {}),
server_name=kwargs.get("server_name"),
start_time=start_time.timestamp() if start_time else None,
hidden_params=HiddenParams()
hidden_params=HiddenParams(),
)

for callback in callbacks:
Expand All @@ -603,7 +580,9 @@ async def async_during_mcp_tool_call_hook(
# this allows for execution control decisions
######################################################################
if response is not None:
return self._parse_during_mcp_call_hook_response(response=response)
return self._parse_during_mcp_call_hook_response(
response=response
)
except Exception as e:
verbose_proxy_logger.exception(
"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {}".format(
Expand All @@ -613,7 +592,7 @@ async def async_during_mcp_tool_call_hook(
return None

def _parse_during_mcp_call_hook_response(
self, response: MCPDuringCallResponseObject
self, response: MCPDuringCallResponseObject
) -> Dict[str, Any]:
"""
Parse the response from the during_mcp_tool_call_hook
Expand Down Expand Up @@ -1382,9 +1361,15 @@ def __init__(
from prisma import Prisma # type: ignore
except Exception as e:
verbose_proxy_logger.error(f"Failed to import Prisma client: {e}")
verbose_proxy_logger.error("This usually means 'prisma generate' hasn't been run yet.")
verbose_proxy_logger.error("Please run 'prisma generate' to generate the Prisma client.")
raise Exception("Unable to find Prisma binaries. Please run 'prisma generate' first.")
verbose_proxy_logger.error(
"This usually means 'prisma generate' hasn't been run yet."
)
verbose_proxy_logger.error(
"Please run 'prisma generate' to generate the Prisma client."
)
raise Exception(
"Unable to find Prisma binaries. Please run 'prisma generate' first."
)
if http_client is not None:
self.db = PrismaWrapper(
original_prisma=Prisma(http=http_client),
Expand Down
Loading
Loading