diff --git a/.circleci/config.yml b/.circleci/config.yml index d1c8fc07ec4..de4b15f9e22 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1343,6 +1343,7 @@ jobs: - run: python ./tests/documentation_tests/test_circular_imports.py - run: python ./tests/code_coverage_tests/prevent_key_leaks_in_exceptions.py - run: python ./tests/code_coverage_tests/check_unsafe_enterprise_import.py + - run: python ./tests/code_coverage_tests/ban_copy_deepcopy_kwargs.py - run: helm lint ./deploy/charts/litellm-helm db_migration_disable_update_check: diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index 86e7eb89a21..13a2e554f12 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -18,24 +18,22 @@ def safe_divide_seconds( - seconds: float, - denominator: float, - default: Optional[float] = None + seconds: float, denominator: float, default: Optional[float] = None ) -> Optional[float]: """ Safely divide seconds by denominator, handling zero division. - + Args: seconds: Time duration in seconds denominator: The divisor (e.g., number of tokens) default: Value to return if division by zero (defaults to None) - + Returns: The result of the division as a float (seconds per unit), or default if denominator is zero """ if denominator <= 0: return default - + return float(seconds / denominator) @@ -203,3 +201,50 @@ def preserve_upstream_non_openai_attributes( for key, value in original_chunk.model_dump().items(): if key not in expected_keys: setattr(model_response, key, value) + + +def safe_deep_copy(data): + """ + Safe Deep Copy + + The LiteLLM Request has some object that can-not be pickled / deep copied + + Use this function to safely deep copy the LiteLLM Request + """ + import copy + + import litellm + + if litellm.safe_memory_mode is True: + return data + + litellm_parent_otel_span: Optional[Any] = None + # Step 1: Remove the litellm_parent_otel_span + litellm_parent_otel_span = None + if isinstance(data, dict): + # remove litellm_parent_otel_span since this is not picklable + if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]: + litellm_parent_otel_span = data["metadata"].pop("litellm_parent_otel_span") + data["metadata"]["litellm_parent_otel_span"] = "placeholder" + if ( + "litellm_metadata" in data + and "litellm_parent_otel_span" in data["litellm_metadata"] + ): + litellm_parent_otel_span = data["litellm_metadata"].pop( + "litellm_parent_otel_span" + ) + data["litellm_metadata"]["litellm_parent_otel_span"] = "placeholder" + new_data = copy.deepcopy(data) + + # Step 2: re-add the litellm_parent_otel_span after doing a deep copy + if isinstance(data, dict) and litellm_parent_otel_span is not None: + if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]: + data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span + if ( + "litellm_metadata" in data + and "litellm_parent_otel_span" in data["litellm_metadata"] + ): + data["litellm_metadata"][ + "litellm_parent_otel_span" + ] = litellm_parent_otel_span + return new_data diff --git a/litellm/litellm_core_utils/fallback_utils.py b/litellm/litellm_core_utils/fallback_utils.py index d5610d5fddf..a5b0c85c816 100644 --- a/litellm/litellm_core_utils/fallback_utils.py +++ b/litellm/litellm_core_utils/fallback_utils.py @@ -1,9 +1,9 @@ import uuid -from copy import deepcopy from typing import Optional import litellm from litellm._logging import verbose_logger +from litellm.litellm_core_utils.core_helpers import safe_deep_copy from .asyncify import run_async_function @@ -41,7 +41,7 @@ async def async_completion_with_fallbacks(**kwargs): most_recent_exception_str: Optional[str] = None for fallback in fallbacks: try: - completion_kwargs = deepcopy(base_kwargs) + completion_kwargs = safe_deep_copy(base_kwargs) # Handle dictionary fallback configurations if isinstance(fallback, dict): model = fallback.pop("model", original_model) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 8a63ed7b1f7..25843ec8fe8 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,9 +1,8 @@ model_list: - - model_name: genai/test/* + - model_name: "gpt-4o-mini-openai" litellm_params: - model: openai/* - api_base: https://api.openai.com + model: gpt-4o-mini api_key: os.environ/OPENAI_API_KEY -litellm_settings: - check_provider_endpoint: true \ No newline at end of file +router_settings: + model_group_alias: {"gpt-4o": "gpt-4o-mini-openai"} \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a95a8a8835f..df1aa9c5353 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -272,9 +272,6 @@ def generate_feedback_box(): from litellm.proxy.management_endpoints.tag_management_endpoints import ( router as tag_management_router, ) -from litellm.proxy.management_endpoints.user_agent_analytics_endpoints import ( - router as user_agent_analytics_router, -) from litellm.proxy.management_endpoints.team_callback_endpoints import ( router as team_callback_router, ) @@ -287,6 +284,9 @@ def generate_feedback_box(): get_disabled_non_admin_personal_key_creation, ) from litellm.proxy.management_endpoints.ui_sso import router as ui_sso_router +from litellm.proxy.management_endpoints.user_agent_analytics_endpoints import ( + router as user_agent_analytics_router, +) from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update from litellm.proxy.middleware.prometheus_auth_middleware import PrometheusAuthMiddleware from litellm.proxy.openai_files_endpoints.files_endpoints import ( @@ -2213,7 +2213,9 @@ def _init_non_llm_configs(self, config: dict): litellm_settings = config.get("litellm_settings", {}) mcp_aliases = litellm_settings.get("mcp_aliases", None) - global_mcp_server_manager.load_servers_from_config(mcp_servers_config, mcp_aliases) + global_mcp_server_manager.load_servers_from_config( + mcp_servers_config, mcp_aliases + ) ## VECTOR STORES vector_store_registry_config = config.get("vector_store_registry", None) @@ -3246,7 +3248,6 @@ async def async_data_generator( "async_data_generator: received streaming chunk - {}".format(chunk) ) - ### CALL HOOKS ### - modify outgoing data chunk = await proxy_logging_obj.async_post_call_streaming_hook( user_api_key_dict=user_api_key_dict, @@ -3255,7 +3256,6 @@ async def async_data_generator( str_so_far=str_so_far, ) - if isinstance(chunk, (ModelResponse, ModelResponseStream)): response_str = litellm.get_response_string(response_obj=chunk) str_so_far += response_str diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index fb6ae851254..9bdece6148c 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -52,11 +52,6 @@ ModelResponseStream, Router, ) -from litellm.types.mcp import ( - MCPPreCallRequestObject, - MCPPreCallResponseObject, - MCPDuringCallResponseObject, -) from litellm._logging import verbose_proxy_logger from litellm._service_logger import ServiceLogging, ServiceTypes from litellm.caching.caching import DualCache, RedisCache @@ -93,6 +88,11 @@ from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup from litellm.secret_managers.main import str_to_bool from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES +from litellm.types.mcp import ( + MCPDuringCallResponseObject, + MCPPreCallRequestObject, + MCPPreCallResponseObject, +) from litellm.types.utils import CallTypes, LLMResponseTypes, LoggedLiteLLMParams if TYPE_CHECKING: @@ -118,33 +118,6 @@ def print_verbose(print_statement): print(f"LiteLLM Proxy: {print_statement}") # noqa -def safe_deep_copy(data): - """ - Safe Deep Copy - - The LiteLLM Request has some object that can-not be pickled / deep copied - - Use this function to safely deep copy the LiteLLM Request - """ - if litellm.safe_memory_mode is True: - return data - - litellm_parent_otel_span: Optional[Any] = None - # Step 1: Remove the litellm_parent_otel_span - litellm_parent_otel_span = None - if isinstance(data, dict): - # remove litellm_parent_otel_span since this is not picklable - if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]: - litellm_parent_otel_span = data["metadata"].pop("litellm_parent_otel_span") - new_data = copy.deepcopy(data) - - # Step 2: re-add the litellm_parent_otel_span after doing a deep copy - if isinstance(data, dict) and litellm_parent_otel_span is not None: - if "metadata" in data: - data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span - return new_data - - class InternalUsageCache: def __init__(self, dual_cache: DualCache): self.dual_cache: DualCache = dual_cache @@ -474,11 +447,11 @@ async def update_request_status( ) async def async_pre_mcp_tool_call_hook( - self, - kwargs: dict, - request_obj: Any, - start_time: datetime, - end_time: datetime, + self, + kwargs: dict, + request_obj: Any, + start_time: datetime, + end_time: datetime, ) -> Optional[Any]: """ Pre MCP Tool Call Hook @@ -489,7 +462,7 @@ async def async_pre_mcp_tool_call_hook( from litellm.types.mcp import MCPPreCallRequestObject, MCPPreCallResponseObject callbacks = self.get_combined_callback_list( - dynamic_success_callbacks=getattr(self, 'dynamic_success_callbacks', None), + dynamic_success_callbacks=getattr(self, "dynamic_success_callbacks", None), global_callbacks=litellm.success_callback, ) @@ -500,7 +473,7 @@ async def async_pre_mcp_tool_call_hook( arguments=kwargs.get("arguments", {}), server_name=kwargs.get("server_name"), user_api_key_auth=kwargs.get("user_api_key_auth"), - hidden_params=HiddenParams() + hidden_params=HiddenParams(), ) for callback in callbacks: @@ -537,10 +510,10 @@ def get_combined_callback_list( return global_callbacks return list(set(dynamic_success_callbacks + global_callbacks)) - - def _parse_pre_mcp_call_hook_response( - self, response: MCPPreCallResponseObject, original_request: MCPPreCallRequestObject + self, + response: MCPPreCallResponseObject, + original_request: MCPPreCallRequestObject, ) -> Dict[str, Any]: """ Parse the response from the pre_mcp_tool_call_hook @@ -551,18 +524,19 @@ def _parse_pre_mcp_call_hook_response( """ result = { "should_proceed": response.should_proceed, - "modified_arguments": response.modified_arguments or original_request.arguments, + "modified_arguments": response.modified_arguments + or original_request.arguments, "error_message": response.error_message, "hidden_params": response.hidden_params, } return result async def async_during_mcp_tool_call_hook( - self, - kwargs: dict, - request_obj: Any, - start_time: datetime, - end_time: datetime, + self, + kwargs: dict, + request_obj: Any, + start_time: datetime, + end_time: datetime, ) -> Optional[Any]: """ During MCP Tool Call Hook @@ -570,10 +544,13 @@ async def async_during_mcp_tool_call_hook( Use this for concurrent monitoring and validation during tool execution. """ from litellm.types.llms.base import HiddenParams - from litellm.types.mcp import MCPDuringCallResponseObject, MCPDuringCallRequestObject + from litellm.types.mcp import ( + MCPDuringCallRequestObject, + MCPDuringCallResponseObject, + ) callbacks = self.get_combined_callback_list( - dynamic_success_callbacks=getattr(self, 'dynamic_success_callbacks', None), + dynamic_success_callbacks=getattr(self, "dynamic_success_callbacks", None), global_callbacks=litellm.success_callback, ) @@ -584,7 +561,7 @@ async def async_during_mcp_tool_call_hook( arguments=kwargs.get("arguments", {}), server_name=kwargs.get("server_name"), start_time=start_time.timestamp() if start_time else None, - hidden_params=HiddenParams() + hidden_params=HiddenParams(), ) for callback in callbacks: @@ -603,7 +580,9 @@ async def async_during_mcp_tool_call_hook( # this allows for execution control decisions ###################################################################### if response is not None: - return self._parse_during_mcp_call_hook_response(response=response) + return self._parse_during_mcp_call_hook_response( + response=response + ) except Exception as e: verbose_proxy_logger.exception( "LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {}".format( @@ -613,7 +592,7 @@ async def async_during_mcp_tool_call_hook( return None def _parse_during_mcp_call_hook_response( - self, response: MCPDuringCallResponseObject + self, response: MCPDuringCallResponseObject ) -> Dict[str, Any]: """ Parse the response from the during_mcp_tool_call_hook @@ -1382,9 +1361,15 @@ def __init__( from prisma import Prisma # type: ignore except Exception as e: verbose_proxy_logger.error(f"Failed to import Prisma client: {e}") - verbose_proxy_logger.error("This usually means 'prisma generate' hasn't been run yet.") - verbose_proxy_logger.error("Please run 'prisma generate' to generate the Prisma client.") - raise Exception("Unable to find Prisma binaries. Please run 'prisma generate' first.") + verbose_proxy_logger.error( + "This usually means 'prisma generate' hasn't been run yet." + ) + verbose_proxy_logger.error( + "Please run 'prisma generate' to generate the Prisma client." + ) + raise Exception( + "Unable to find Prisma binaries. Please run 'prisma generate' first." + ) if http_client is not None: self.db = PrismaWrapper( original_prisma=Prisma(http=http_client), diff --git a/litellm/router.py b/litellm/router.py index d02392c59f2..52d48fc71ec 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -2914,7 +2914,9 @@ async def _acreate_file( ) async def create_file_for_deployment(deployment: dict) -> OpenAIFileObject: - kwargs_copy = copy.deepcopy(kwargs) + from litellm.litellm_core_utils.core_helpers import safe_deep_copy + + kwargs_copy = safe_deep_copy(kwargs) self._update_kwargs_with_deployment( deployment=deployment, kwargs=kwargs_copy, @@ -3165,6 +3167,8 @@ async def aretrieve_batch( async def try_retrieve_batch(model_name: DeploymentTypedDict): try: + from litellm.litellm_core_utils.core_helpers import safe_deep_copy + model = model_name["litellm_params"].get("model") data = model_name["litellm_params"].copy() custom_llm_provider = data.get("custom_llm_provider") @@ -3178,7 +3182,7 @@ async def try_retrieve_batch(model_name: DeploymentTypedDict): _, custom_llm_provider, _, _ = get_llm_provider( # type: ignore model=model ) - new_kwargs = copy.deepcopy(kwargs) + new_kwargs = safe_deep_copy(kwargs) self._update_kwargs_with_deployment( deployment=cast(dict, model_name), kwargs=new_kwargs, @@ -6008,6 +6012,7 @@ def get_settings(self): "context_window_fallbacks", "model_group_retry_policy", "retry_policy", + "model_group_alias", ] for var in vars_to_include: @@ -6037,6 +6042,7 @@ def update_settings(self, **kwargs): "fallbacks", "context_window_fallbacks", "model_group_retry_policy", + "model_group_alias", ] _int_settings = [ diff --git a/litellm/types/router.py b/litellm/types/router.py index 11076205339..864fdbf79b8 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -89,6 +89,7 @@ class UpdateRouterConfig(BaseModel): retry_after: Optional[float] = None fallbacks: Optional[List[dict]] = None context_window_fallbacks: Optional[List[dict]] = None + model_group_alias: Optional[Dict[str, Union[str, Dict]]] = {} model_config = ConfigDict(protected_namespaces=()) @@ -209,7 +210,6 @@ class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams): model_info: Optional[Dict] = None mock_response: Optional[Union[str, ModelResponse, Exception, Any]] = None - # auto-router params auto_router_config_path: Optional[str] = None auto_router_config: Optional[str] = None @@ -343,7 +343,7 @@ def __init__( if max_retries is not None and isinstance(max_retries, str): max_retries = int(max_retries) # cast to int args["max_retries"] = max_retries - super().__init__(**{ **args, **params }) + super().__init__(**{**args, **params}) def __contains__(self, key): # Define custom behavior for the 'in' operator @@ -776,9 +776,11 @@ def extract_bool_param(name: str) -> Optional[bool]: ), ) + class ModelGroupSettings(BaseModel): forward_client_headers_to_llm_api: Optional[List[str]] = None + class PreRoutingHookResponse(BaseModel): """ Response object from the pre-routing hook. @@ -787,5 +789,6 @@ class PreRoutingHookResponse(BaseModel): Add fields that you expect to be modified by the pre-routing hook. """ + model: str - messages: Optional[List[Dict[str, str]]] \ No newline at end of file + messages: Optional[List[Dict[str, str]]] diff --git a/litellm/utils.py b/litellm/utils.py index 94c5e1a1c83..9e9a112f304 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -681,7 +681,9 @@ def function_setup( # noqa: PLR0915 if add_breadcrumb: try: - details_to_log = copy.deepcopy(kwargs) + from litellm.litellm_core_utils.core_helpers import safe_deep_copy + + details_to_log = safe_deep_copy(kwargs) except Exception: details_to_log = kwargs diff --git a/tests/code_coverage_tests/ban_copy_deepcopy_kwargs.py b/tests/code_coverage_tests/ban_copy_deepcopy_kwargs.py new file mode 100644 index 00000000000..77c019c7ddb --- /dev/null +++ b/tests/code_coverage_tests/ban_copy_deepcopy_kwargs.py @@ -0,0 +1,142 @@ +import ast +import os + + +class CopyDeepcopyKwargsDetector(ast.NodeVisitor): + def __init__(self): + self.violations = [] + + def visit_Call(self, node): + # Check if this is a copy.deepcopy call + if self._is_copy_deepcopy_call(node): + # Check if any argument contains 'kwargs' in its name + for arg in node.args: + if self._is_kwargs_related(arg): + # Get line number and argument name for reporting + arg_name = self._get_arg_name(arg) + self.violations.append( + { + "line": node.lineno, + "arg_name": arg_name, + "full_call": ( + ast.unparse(node) + if hasattr(ast, "unparse") + else str(node) + ), + } + ) + + self.generic_visit(node) + + def _is_copy_deepcopy_call(self, node): + """Check if this is a copy.deepcopy() call""" + if isinstance(node.func, ast.Attribute): + # Case: copy.deepcopy() + if ( + isinstance(node.func.value, ast.Name) + and node.func.value.id == "copy" + and node.func.attr == "deepcopy" + ): + return True + elif isinstance(node.func, ast.Name): + # Case: deepcopy() (if imported as 'from copy import deepcopy') + if node.func.id == "deepcopy": + return True + return False + + def _is_kwargs_related(self, arg): + """Check if the argument is kwargs-related""" + if isinstance(arg, ast.Name): + # Direct variable names containing 'kwargs' + return "kwargs" in arg.id.lower() + elif isinstance(arg, ast.Subscript): + # Handle cases like kwargs['key'] + if isinstance(arg.value, ast.Name): + return "kwargs" in arg.value.id.lower() + elif isinstance(arg, ast.Attribute): + # Handle cases like self.kwargs + return "kwargs" in arg.attr.lower() + return False + + def _get_arg_name(self, arg): + """Get a readable name for the argument""" + if isinstance(arg, ast.Name): + return arg.id + elif isinstance(arg, ast.Subscript) and isinstance(arg.value, ast.Name): + return f"{arg.value.id}[...]" + elif isinstance(arg, ast.Attribute): + return f"...{arg.attr}" + else: + return "unknown_kwargs_variable" + + +def find_copy_deepcopy_kwargs_in_file(file_path): + """Find copy.deepcopy usage with kwargs in a single file""" + try: + with open(file_path, "r", encoding="utf-8") as file: + tree = ast.parse(file.read(), filename=file_path) + detector = CopyDeepcopyKwargsDetector() + detector.visit(tree) + return detector.violations + except Exception as e: + print(f"Error parsing {file_path}: {e}") + return [] + + +def find_copy_deepcopy_kwargs_in_directory(directory): + """Find copy.deepcopy usage with kwargs in all Python files in directory""" + violations = {} + + for root, _, files in os.walk(directory): + for file in files: + if file.endswith(".py"): + file_path = os.path.join(root, file) + print(f"Checking file: {file_path}") + file_violations = find_copy_deepcopy_kwargs_in_file(file_path) + if file_violations: + violations[file_path] = file_violations + + return violations + + +if __name__ == "__main__": + # Check for copy.deepcopy(kwargs) usage in the litellm directory + directory_path = "./litellm" + violations = find_copy_deepcopy_kwargs_in_directory(directory_path) + + print("\n" + "=" * 80) + print("COPY.DEEPCOPY KWARGS VIOLATIONS FOUND:") + print("=" * 80) + + if violations: + total_violations = 0 + for file_path, file_violations in violations.items(): + print(f"\nšŸ“ File: {file_path}") + for violation in file_violations: + total_violations += 1 + print( + f" āŒ Line {violation['line']}: copy.deepcopy({violation['arg_name']})" + ) + print(f" Full call: {violation['full_call']}") + + print(f"\n{'='*80}") + print(f"🚨 TOTAL VIOLATIONS: {total_violations}") + print("🚨 USE safe_deep_copy() INSTEAD OF copy.deepcopy() FOR KWARGS!") + print("🚨 Available imports:") + print(" - from litellm.proxy.utils import safe_deep_copy") + print(" - from litellm.litellm_core_utils.core_helpers import safe_deep_copy") + print("=" * 80) + + # Get first violation for the exception message + first_file = list(violations.keys())[0] + first_violation = violations[first_file][0] + + raise Exception( + f"🚨 Found {total_violations} copy.deepcopy(kwargs) violations! " + f"First violation: {first_file}:{first_violation['line']} - " + f"copy.deepcopy({first_violation['arg_name']}). " + f"Use safe_deep_copy() instead to handle non-serializable objects like OTEL spans." + ) + else: + print("āœ… No copy.deepcopy(kwargs) violations found!") + print("āœ… All kwargs copying appears to use safe_deep_copy() correctly.") diff --git a/ui/litellm-dashboard/src/components/generic_key_value_manager.tsx b/ui/litellm-dashboard/src/components/generic_key_value_manager.tsx new file mode 100644 index 00000000000..520893d14ca --- /dev/null +++ b/ui/litellm-dashboard/src/components/generic_key_value_manager.tsx @@ -0,0 +1,339 @@ +import React, { useState, useEffect, useCallback } from "react"; +import { + Card, + Title, + Text, + Table, + TableHead, + TableRow, + TableHeaderCell, + TableCell, + TableBody, +} from "@tremor/react"; +import { message, Input } from "antd"; +import { EditOutlined, DeleteOutlined, SaveOutlined, CloseOutlined } from "@ant-design/icons"; +import { ChevronDownIcon, ChevronRightIcon, PlusCircleIcon } from "@heroicons/react/outline"; + +interface KeyValueItem { + id?: string; + key: string; + value: string; +} + +interface GenericKeyValueManagerProps { + title: string; + description: string; + keyLabel: string; + valueLabel: string; + keyPlaceholder: string; + valuePlaceholder: string; + items: KeyValueItem[]; + onItemsChange: (items: KeyValueItem[]) => void; + onSave?: () => Promise; + showSaveButton?: boolean; + isCollapsible?: boolean; + defaultExpanded?: boolean; + configExample?: React.ReactNode; + additionalActions?: (item: KeyValueItem) => React.ReactNode; +} + +const GenericKeyValueManager: React.FC = ({ + title, + description, + keyLabel, + valueLabel, + keyPlaceholder, + valuePlaceholder, + items, + onItemsChange, + onSave, + showSaveButton = true, + isCollapsible = false, + defaultExpanded = true, + configExample, + additionalActions, +}) => { + const [newKey, setNewKey] = useState(""); + const [newValue, setNewValue] = useState(""); + const [editingItem, setEditingItem] = useState(null); + const [editingKey, setEditingKey] = useState(""); + const [editingValue, setEditingValue] = useState(""); + const [isExpanded, setIsExpanded] = useState(defaultExpanded); + + const generateId = () => Math.random().toString(36).substr(2, 9); + + const handleAddItem = useCallback(() => { + if (newKey.trim() && newValue.trim()) { + const newItem: KeyValueItem = { + id: generateId(), + key: newKey.trim(), + value: newValue.trim(), + }; + onItemsChange([...items, newItem]); + setNewKey(""); + setNewValue(""); + } else { + message.error(`Please provide both ${keyLabel.toLowerCase()} and ${valueLabel.toLowerCase()}`); + } + }, [newKey, newValue, items, onItemsChange, keyLabel, valueLabel]); + + const handleEditItem = useCallback((item: KeyValueItem) => { + setEditingItem({ ...item }); + setEditingKey(item.key); + setEditingValue(item.value); + }, []); + + const handleSaveEdit = useCallback(() => { + if (editingKey.trim() && editingValue.trim()) { + const updatedItems = items.map((item) => + item.id === editingItem?.id ? { ...item, key: editingKey.trim(), value: editingValue.trim() } : item + ); + onItemsChange(updatedItems); + setEditingItem(null); + setEditingKey(""); + setEditingValue(""); + } else { + message.error(`Please provide both ${keyLabel.toLowerCase()} and ${valueLabel.toLowerCase()}`); + } + }, [editingKey, editingValue, items, editingItem, onItemsChange, keyLabel, valueLabel]); + + const handleCancelEdit = useCallback(() => { + setEditingItem(null); + setEditingKey(""); + setEditingValue(""); + }, []); + + const handleDeleteItem = useCallback((id: string) => { + const updatedItems = items.filter((item) => item.id !== id); + onItemsChange(updatedItems); + }, [items, onItemsChange]); + + const handleSave = useCallback(async () => { + if (onSave) { + try { + await onSave(); + } catch (error) { + console.error("Failed to save:", error); + } + } + }, [onSave]); + + const ContentSection = useCallback(() => ( +
+ {/* Add New Item Section */} + + Add New {keyLabel} +
+
+ + setNewKey(e.target.value)} + placeholder={keyPlaceholder} + size="middle" + /> +
+
+ + setNewValue(e.target.value)} + placeholder={valuePlaceholder} + size="middle" + /> +
+
+ +
+
+
+ + {/* Manage Existing Items Section */} + +
+ Manage Existing {keyLabel}s + {showSaveButton && ( + + )} +
+ +
+
+ + + + {keyLabel} + {valueLabel} + Actions + + + + {items.map((item) => ( + + {editingItem && editingItem.id === item.id ? ( + <> + + setEditingKey(e.target.value)} + size="small" + /> + + + setEditingValue(e.target.value)} + size="small" + /> + + +
+ + +
+
+ + ) : ( + <> + + {item.key} + + + {item.value} + + +
+ {additionalActions && additionalActions(item)} + + +
+
+ + )} +
+ ))} + {items.length === 0 && ( + + + No {keyLabel.toLowerCase()}s added yet. Add a new {keyLabel.toLowerCase()} above. + + + )} +
+
+
+
+
+ + {/* Configuration Example */} + {configExample && ( + + Configuration Example + {configExample} + + )} +
+ ), [ + keyLabel, + valueLabel, + keyPlaceholder, + valuePlaceholder, + newKey, + newValue, + items, + editingItem, + editingKey, + editingValue, + showSaveButton, + configExample, + additionalActions, + handleAddItem, + handleSave, + handleEditItem, + handleSaveEdit, + handleCancelEdit, + handleDeleteItem, + ]); + + if (isCollapsible) { + return ( + +
setIsExpanded(!isExpanded)} + > +
+ {title} +

{description}

+
+
+ {isExpanded ? ( + + ) : ( + + )} +
+
+ + {isExpanded && ( +
+ +
+ )} +
+ ); + } + + return ( +
+
+ {title} + {description} +
+
+ +
+
+ ); +}; + +export default GenericKeyValueManager; \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/model_dashboard.tsx b/ui/litellm-dashboard/src/components/model_dashboard.tsx index 98cc3b14305..c66ec768449 100644 --- a/ui/litellm-dashboard/src/components/model_dashboard.tsx +++ b/ui/litellm-dashboard/src/components/model_dashboard.tsx @@ -73,6 +73,7 @@ import { ModelDataTable } from "./model_dashboard/table"; import { columns } from "./model_dashboard/columns"; import HealthCheckComponent from "./model_dashboard/HealthCheckComponent"; import PassThroughSettings from "./pass_through_settings"; +import ModelGroupAliasSettings from "./model_group_alias_settings"; import { all_admin_roles } from "@/utils/roles"; import { Table as TableInstance } from "@tanstack/react-table"; @@ -197,6 +198,9 @@ const ModelDashboard: React.FC = ({ const [credentialsList, setCredentialsList] = useState([]); + // Model Group Alias state + const [modelGroupAlias, setModelGroupAlias] = useState<{[key: string]: string}>({}); + // Add state for advanced settings visibility const [showAdvancedSettings, setShowAdvancedSettings] = useState(false); @@ -479,6 +483,8 @@ const ModelDashboard: React.FC = ({ } }; + + useEffect(() => { if (!accessToken || !token || !userRole || !userID) { return; @@ -646,6 +652,10 @@ const ModelDashboard: React.FC = ({ setModelGroupRetryPolicy(model_group_retry_policy); setGlobalRetryPolicy(router_settings.retry_policy); setDefaultRetry(default_retries); + + // Set model group alias + const model_group_alias = router_settings.model_group_alias || {}; + setModelGroupAlias(model_group_alias); } catch (error) { console.error("There was an error fetching the model data", error); } @@ -1095,6 +1105,9 @@ const ModelDashboard: React.FC = ({ {all_admin_roles.includes(userRole) && ( Model Retry Settings )} + {all_admin_roles.includes(userRole) && ( + Model Group Alias + )}
@@ -1859,6 +1872,13 @@ const ModelDashboard: React.FC = ({ Save + + + )} diff --git a/ui/litellm-dashboard/src/components/model_group_alias_settings.tsx b/ui/litellm-dashboard/src/components/model_group_alias_settings.tsx new file mode 100644 index 00000000000..b1131c36117 --- /dev/null +++ b/ui/litellm-dashboard/src/components/model_group_alias_settings.tsx @@ -0,0 +1,370 @@ +import React, { useState, useEffect } from "react"; +import { message } from "antd"; +import { PlusCircleIcon, PencilIcon, TrashIcon, ChevronDownIcon, ChevronRightIcon } from "@heroicons/react/outline"; +import { setCallbacksCall } from "./networking"; +import { + Card, + Title, + Text, + Table, + TableHead, + TableHeaderCell, + TableBody, + TableRow, + TableCell +} from "@tremor/react"; + +interface ModelGroupAliasSettingsProps { + accessToken: string; + initialModelGroupAlias?: { [key: string]: string }; + onAliasUpdate?: (updatedAlias: { [key: string]: string }) => void; +} + +interface AliasItem { + id: string; + aliasName: string; + targetModelGroup: string; +} + +const ModelGroupAliasSettings: React.FC = ({ + accessToken, + initialModelGroupAlias = {}, + onAliasUpdate, +}) => { + const [aliases, setAliases] = useState([]); + const [newAlias, setNewAlias] = useState({ aliasName: "", targetModelGroup: "" }); + const [editingAlias, setEditingAlias] = useState(null); + const [isExpanded, setIsExpanded] = useState(true); + + useEffect(() => { + // Convert object to array for display + const aliasArray = Object.entries(initialModelGroupAlias).map(([aliasName, targetModelGroup], index) => ({ + id: `${index}-${aliasName}`, + aliasName, + targetModelGroup, + })); + setAliases(aliasArray); + }, [initialModelGroupAlias]); + + const saveAliasesToBackend = async (updatedAliases: AliasItem[]) => { + if (!accessToken) { + console.error("Access token is missing"); + return false; + } + + try { + // Convert array back to object format + const aliasObject: { [key: string]: string } = {}; + updatedAliases.forEach(alias => { + aliasObject[alias.aliasName] = alias.targetModelGroup; + }); + + const payload = { + router_settings: { + model_group_alias: aliasObject, + }, + }; + + console.log("Saving model group alias:", aliasObject); + await setCallbacksCall(accessToken, payload); + + if (onAliasUpdate) { + onAliasUpdate(aliasObject); + } + + return true; + } catch (error) { + console.error("Failed to save model group alias settings:", error); + message.error("Failed to save model group alias settings"); + return false; + } + }; + + const handleAddAlias = async () => { + if (!newAlias.aliasName || !newAlias.targetModelGroup) { + message.error("Please provide both alias name and target model group"); + return; + } + + // Check for duplicate alias names + if (aliases.some(alias => alias.aliasName === newAlias.aliasName)) { + message.error("An alias with this name already exists"); + return; + } + + const newAliasObj: AliasItem = { + id: `${Date.now()}-${newAlias.aliasName}`, + aliasName: newAlias.aliasName, + targetModelGroup: newAlias.targetModelGroup, + }; + + const updatedAliases = [...aliases, newAliasObj]; + + if (await saveAliasesToBackend(updatedAliases)) { + setAliases(updatedAliases); + setNewAlias({ aliasName: "", targetModelGroup: "" }); + message.success("Alias added successfully"); + } + }; + + const handleEditAlias = (alias: AliasItem) => { + setEditingAlias({ ...alias }); + }; + + const handleUpdateAlias = async () => { + if (!editingAlias) return; + + if (!editingAlias.aliasName || !editingAlias.targetModelGroup) { + message.error("Please provide both alias name and target model group"); + return; + } + + // Check for duplicate alias names (excluding current alias) + if (aliases.some(alias => alias.id !== editingAlias.id && alias.aliasName === editingAlias.aliasName)) { + message.error("An alias with this name already exists"); + return; + } + + const updatedAliases = aliases.map(alias => + alias.id === editingAlias.id ? editingAlias : alias + ); + + if (await saveAliasesToBackend(updatedAliases)) { + setAliases(updatedAliases); + setEditingAlias(null); + message.success("Alias updated successfully"); + } + }; + + const handleCancelEdit = () => { + setEditingAlias(null); + }; + + const deleteAlias = async (aliasId: string) => { + const updatedAliases = aliases.filter(alias => alias.id !== aliasId); + + if (await saveAliasesToBackend(updatedAliases)) { + setAliases(updatedAliases); + message.success("Alias deleted successfully"); + } + }; + + // Convert current aliases to object for config example + const aliasObject = aliases.reduce((acc, alias) => { + acc[alias.aliasName] = alias.targetModelGroup; + return acc; + }, {} as { [key: string]: string }); + + return ( + +
setIsExpanded(!isExpanded)} + > +
+ Model Group Alias Settings +

Create aliases for your model groups to simplify API calls. For example, you can create an alias 'gpt-4o' that points to 'gpt-4o-mini-openai' model group.

+
+
+ {isExpanded ? ( + + ) : ( + + )} +
+
+ + {isExpanded && ( +
+
+ Add New Alias +
+
+ + + setNewAlias({ + ...newAlias, + aliasName: e.target.value, + }) + } + placeholder="e.g., gpt-4o" + className="w-full px-3 py-2 border border-gray-300 rounded-md text-sm" + /> +
+
+ + + setNewAlias({ + ...newAlias, + targetModelGroup: e.target.value, + }) + } + placeholder="e.g., gpt-4o-mini-openai" + className="w-full px-3 py-2 border border-gray-300 rounded-md text-sm" + /> +
+
+ +
+
+
+ + + Manage Existing Aliases + +
+
+ + + + + Alias Name + + + Target Model Group + + + Actions + + + + + {aliases.map((alias) => ( + + {editingAlias && editingAlias.id === alias.id ? ( + <> + + + setEditingAlias({ + ...editingAlias, + aliasName: e.target.value, + }) + } + className="w-full px-2 py-1 border border-gray-300 rounded-md text-sm" + /> + + + + setEditingAlias({ + ...editingAlias, + targetModelGroup: e.target.value, + }) + } + className="w-full px-2 py-1 border border-gray-300 rounded-md text-sm" + /> + + +
+ + +
+
+ + ) : ( + <> + + {alias.aliasName} + + + {alias.targetModelGroup} + + +
+ + +
+
+ + )} +
+ ))} + {aliases.length === 0 && ( + + + No aliases added yet. Add a new alias above. + + + )} +
+
+
+
+ + {/* Configuration Example */} + + Configuration Example + + Here's how your current aliases would look in the config.yaml: + +
+
+ router_settings: +
+   model_group_alias: + {Object.keys(aliasObject).length === 0 ? ( + +
+     # No aliases configured yet +
+ ) : ( + Object.entries(aliasObject).map(([key, value]) => ( + +
+     "{key}": "{value}" +
+ )) + )} +
+
+
+
+ )} +
+ ); +}; + +export default ModelGroupAliasSettings; \ No newline at end of file