diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index c854d81ec71..f1f2c259f17 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -227,6 +227,7 @@ class KeyManagementRoutes(str, enum.Enum): KEY_REGENERATE_WITH_PATH_PARAM = "/key/{key_id}/regenerate" KEY_BLOCK = "/key/block" KEY_UNBLOCK = "/key/unblock" + KEY_BULK_UPDATE = "/key/bulk_update" # info and health routes KEY_INFO = "/key/info" @@ -494,6 +495,7 @@ class LiteLLMRoutes(enum.Enum): KeyManagementRoutes.KEY_LIST.value, KeyManagementRoutes.KEY_BLOCK.value, KeyManagementRoutes.KEY_UNBLOCK.value, + KeyManagementRoutes.KEY_BULK_UPDATE.value, ] management_routes = [ diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index ab87e862ea6..380e8bddc99 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -37,6 +37,13 @@ ) from litellm.proxy._types import * from litellm.proxy._types import LiteLLM_VerificationToken +from litellm.types.proxy.management_endpoints.key_management_endpoints import ( + BulkUpdateKeyRequest, + BulkUpdateKeyRequestItem, + BulkUpdateKeyResponse, + FailedKeyUpdate, + SuccessfulKeyUpdate, +) from litellm.proxy.auth.auth_checks import ( _cache_key_object, _delete_cache_key_object, @@ -1438,6 +1445,211 @@ def is_different_team( return data.team_id != existing_key_row.team_id +def _validate_max_budget(max_budget: Optional[float]) -> None: + """ + Validate that max_budget is not negative. + + Args: + max_budget: The max_budget value to validate + + Raises: + HTTPException: If max_budget is negative + """ + if max_budget is not None and max_budget < 0: + raise HTTPException( + status_code=400, + detail={ + "error": f"max_budget cannot be negative. Received: {max_budget}" + }, + ) + + +async def _get_and_validate_existing_key( + token: str, prisma_client: Optional[PrismaClient] +) -> LiteLLM_VerificationToken: + """ + Get existing key from database and validate it exists. + + Args: + token: The key token to look up + prisma_client: Prisma client instance + + Returns: + LiteLLM_VerificationToken: The existing key row + + Raises: + HTTPException: If key is not found + """ + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": "Database not connected"}, + ) + + existing_key_row = await prisma_client.get_data( + token=token, + table_name="key", + query_type="find_unique", + ) + + if existing_key_row is None: + raise HTTPException( + status_code=404, + detail={"error": f"Key not found: {token}"}, + ) + + return existing_key_row + + +async def _process_single_key_update( + key_update_item: BulkUpdateKeyRequestItem, + user_api_key_dict: UserAPIKeyAuth, + litellm_changed_by: Optional[str], + prisma_client: Optional[PrismaClient], + user_api_key_cache: DualCache, + proxy_logging_obj: Any, + llm_router: Optional[Router], +) -> Dict[str, Any]: + """ + Process a single key update with all validations and checks. + + This function encapsulates all the logic for updating a single key, + including validation, permission checks, team checks, and database updates. + + Args: + key_update_item: The key update request item + user_api_key_dict: The authenticated user's API key info + litellm_changed_by: Optional header for tracking who made the change + prisma_client: Prisma client instance + user_api_key_cache: User API key cache + proxy_logging_obj: Proxy logging object + llm_router: LLM router instance + + Returns: + Dict containing the updated key information + + Raises: + HTTPException: For various validation and permission errors + """ + # Validate max_budget + _validate_max_budget(key_update_item.max_budget) + + # Get and validate existing key + existing_key_row = await _get_and_validate_existing_key( + token=key_update_item.key, + prisma_client=prisma_client, + ) + + # Check team member permissions + await TeamMemberPermissionChecks.can_team_member_execute_key_management_endpoint( + user_api_key_dict=user_api_key_dict, + route=KeyManagementRoutes.KEY_UPDATE, + prisma_client=prisma_client, + existing_key_row=existing_key_row, + user_api_key_cache=user_api_key_cache, + ) + + # Create UpdateKeyRequest from BulkUpdateKeyRequestItem + update_key_request = UpdateKeyRequest( + key=key_update_item.key, + budget_id=key_update_item.budget_id, + max_budget=key_update_item.max_budget, + team_id=key_update_item.team_id, + tags=key_update_item.tags, + ) + + # Get team object and check team limits if team_id is provided + team_obj: Optional[LiteLLM_TeamTableCachedObj] = None + if update_key_request.team_id is not None: + team_obj = await get_team_object( + team_id=update_key_request.team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + check_db_only=True, + ) + + if team_obj is not None and prisma_client is not None: + await _check_team_key_limits( + team_table=team_obj, + data=update_key_request, + prisma_client=prisma_client, + ) + + # Validate team change if team is being changed + if is_different_team( + data=update_key_request, existing_key_row=existing_key_row + ): + if llm_router is None: + raise HTTPException( + status_code=400, + detail={ + "error": "LLM router not found. Please set it up by passing in a valid config.yaml or adding models via the UI." + }, + ) + if team_obj is None: + raise HTTPException( + status_code=500, + detail={ + "error": "Team object not found for team change validation" + }, + ) + validate_key_team_change( + key=existing_key_row, + team=team_obj, + change_initiated_by=user_api_key_dict, + llm_router=llm_router, + ) + + # Prepare update data + non_default_values = await prepare_key_update_data( + data=update_key_request, existing_key_row=existing_key_row + ) + + # Update key in database + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": "Database not connected"}, + ) + + _data = {**non_default_values, "token": key_update_item.key} + response = await prisma_client.update_data( + token=key_update_item.key, data=_data + ) + + # Delete cache + await _delete_cache_key_object( + hashed_token=hash_token(key_update_item.key), + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + # Trigger async hook + asyncio.create_task( + KeyManagementEventHooks.async_key_updated_hook( + data=update_key_request, + existing_key_row=existing_key_row, + response=response, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=litellm_changed_by, + ) + ) + + if response is None: + raise ValueError("Failed to update key got response = None") + + # Extract and format updated key info + updated_key_info = response.get("data", {}) + if hasattr(updated_key_info, "model_dump"): + updated_key_info = updated_key_info.model_dump() + elif hasattr(updated_key_info, "dict"): + updated_key_info = updated_key_info.dict() + + updated_key_info.pop("token", None) + + return updated_key_info + + @router.post( "/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)] ) @@ -1684,6 +1896,167 @@ async def update_key_fn( ) +@router.post( + "/key/bulk_update", + tags=["key management"], + dependencies=[Depends(user_api_key_auth)], + response_model=BulkUpdateKeyResponse, +) +@management_endpoint_wrapper +async def bulk_update_keys( + data: BulkUpdateKeyRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Bulk update multiple keys at once. + + This endpoint allows updating multiple keys in a single request. Each key update + is processed independently - if some updates fail, others will still succeed. + + Parameters: + - keys: List[BulkUpdateKeyRequestItem] - List of key update requests, each containing: + - key: str - The key identifier (token) to update + - budget_id: Optional[str] - Budget ID associated with the key + - max_budget: Optional[float] - Max budget for key + - team_id: Optional[str] - Team ID associated with key + - tags: Optional[List[str]] - Tags for organizing keys + + Returns: + - total_requested: int - Total number of keys requested for update + - successful_updates: List[SuccessfulKeyUpdate] - List of successfully updated keys with their updated info + - failed_updates: List[FailedKeyUpdate] - List of failed updates with key_info and failed_reason + + Example request: + ```bash + curl --location 'http://0.0.0.0:4000/key/bulk_update' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "keys": [ + { + "key": "sk-1234", + "max_budget": 100.0, + "team_id": "team-123", + "tags": ["production", "api"] + }, + { + "key": "sk-5678", + "budget_id": "budget-456", + "tags": ["staging"] + } + ] + }' + ``` + """ + from litellm.proxy.proxy_server import ( + llm_router, + prisma_client, + proxy_logging_obj, + user_api_key_cache, + ) + + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value: + raise HTTPException( + status_code=403, + detail={ + "error": "Only proxy admins can perform bulk key updates" + }, + ) + + if prisma_client is None: + raise HTTPException( + status_code=500, + detail={"error": "Database not connected"}, + ) + + if not data.keys: + raise HTTPException( + status_code=400, + detail={"error": "No keys provided for update"}, + ) + + MAX_BATCH_SIZE = 500 + if len(data.keys) > MAX_BATCH_SIZE: + raise HTTPException( + status_code=400, + detail={ + "error": f"Maximum {MAX_BATCH_SIZE} keys can be updated at once. Found {len(data.keys)} keys." + }, + ) + + successful_updates: List[SuccessfulKeyUpdate] = [] + failed_updates: List[FailedKeyUpdate] = [] + + for key_update_item in data.keys: + try: + # Process single key update using reusable function + updated_key_info = await _process_single_key_update( + key_update_item=key_update_item, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=litellm_changed_by, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + llm_router=llm_router, + ) + + successful_updates.append( + SuccessfulKeyUpdate( + key=key_update_item.key, + key_info=updated_key_info, + ) + ) + + except Exception as e: + verbose_proxy_logger.exception( + f"Failed to update key {key_update_item.key}: {e}" + ) + + if isinstance(e, HTTPException): + error_detail = e.detail + if isinstance(error_detail, dict): + error_message = error_detail.get("error", str(e)) + else: + error_message = str(error_detail) + else: + error_message = str(e) + + key_info = None + try: + existing_key_row = await prisma_client.get_data( + token=key_update_item.key, + table_name="key", + query_type="find_unique", + ) + if existing_key_row is not None: + if hasattr(existing_key_row, "model_dump"): + key_info = existing_key_row.model_dump() + elif hasattr(existing_key_row, "dict"): + key_info = existing_key_row.dict() + if key_info: + key_info.pop("token", None) + except Exception: + pass + + failed_updates.append( + FailedKeyUpdate( + key=key_update_item.key, + key_info=key_info, + failed_reason=error_message, + ) + ) + + return BulkUpdateKeyResponse( + total_requested=len(data.keys), + successful_updates=successful_updates, + failed_updates=failed_updates, + ) + + def validate_key_team_change( key: LiteLLM_VerificationToken, team: LiteLLM_TeamTable, diff --git a/litellm/types/proxy/management_endpoints/key_management_endpoints.py b/litellm/types/proxy/management_endpoints/key_management_endpoints.py new file mode 100644 index 00000000000..b1d25455d18 --- /dev/null +++ b/litellm/types/proxy/management_endpoints/key_management_endpoints.py @@ -0,0 +1,42 @@ +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel + + +class BulkUpdateKeyRequestItem(BaseModel): + """Individual key update request item""" + + key: str # Key identifier (token) + budget_id: Optional[str] = None # Budget ID associated with the key + max_budget: Optional[float] = None # Max budget for key + team_id: Optional[str] = None # Team ID associated with key + tags: Optional[List[str]] = None # Tags for organizing keys + + +class BulkUpdateKeyRequest(BaseModel): + """Request for bulk key updates""" + + keys: List[BulkUpdateKeyRequestItem] + + +class SuccessfulKeyUpdate(BaseModel): + """Successfully updated key with its updated information""" + + key: str + key_info: Dict[str, Any] + + +class FailedKeyUpdate(BaseModel): + """Failed key update with reason""" + + key: str + key_info: Optional[Dict[str, Any]] = None + failed_reason: str + + +class BulkUpdateKeyResponse(BaseModel): + """Response for bulk key update operations""" + + total_requested: int + successful_updates: List[SuccessfulKeyUpdate] + failed_updates: List[FailedKeyUpdate] diff --git a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py index 7d31f762096..a57378e579c 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py @@ -30,10 +30,13 @@ _check_org_key_limits, _check_team_key_limits, _common_key_generation_helper, + _get_and_validate_existing_key, _list_key_helper, _persist_deleted_verification_tokens, + _process_single_key_update, _save_deleted_verification_token_records, _transform_verification_tokens_to_deleted_records, + _validate_max_budget, can_modify_verification_token, check_org_key_model_specific_limits, check_team_key_model_specific_limits, @@ -4223,3 +4226,467 @@ async def test_update_key_with_router_settings(monkeypatch): # Verify router_settings can be deserialized and matches input deserialized_settings = json.loads(result["router_settings"]) assert deserialized_settings == router_settings_data + + +@pytest.mark.asyncio +async def test_validate_max_budget(): + """ + Test _validate_max_budget helper function. + + Tests: + 1. Positive max_budget should pass + 2. Zero max_budget should pass + 3. Negative max_budget should raise HTTPException + 4. None max_budget should pass + """ + from fastapi import HTTPException + + # Test Case 1: Positive max_budget should pass + try: + _validate_max_budget(100.0) + _validate_max_budget(0.0) + except HTTPException: + pytest.fail("_validate_max_budget raised HTTPException for valid values") + + # Test Case 2: None max_budget should pass + try: + _validate_max_budget(None) + except HTTPException: + pytest.fail("_validate_max_budget raised HTTPException for None") + + # Test Case 3: Negative max_budget should raise HTTPException + with pytest.raises(HTTPException) as exc_info: + _validate_max_budget(-10.0) + + assert exc_info.value.status_code == 400 + assert "max_budget cannot be negative" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_get_and_validate_existing_key(): + """ + Test _get_and_validate_existing_key helper function. + + Tests: + 1. Successfully retrieve existing key + 2. Key not found raises HTTPException + 3. Database not connected raises HTTPException + """ + from fastapi import HTTPException + + # Test Case 1: Successfully retrieve existing key + mock_prisma_client = AsyncMock() + mock_key = LiteLLM_VerificationToken( + token="test-key-123", + user_id="user-123", + models=["gpt-4"], + team_id=None, + ) + mock_prisma_client.get_data = AsyncMock(return_value=mock_key) + + result = await _get_and_validate_existing_key( + token="test-key-123", + prisma_client=mock_prisma_client, + ) + + assert result == mock_key + mock_prisma_client.get_data.assert_called_once_with( + token="test-key-123", + table_name="key", + query_type="find_unique", + ) + + # Test Case 2: Key not found raises HTTPException + mock_prisma_client.get_data = AsyncMock(return_value=None) + + with pytest.raises(HTTPException) as exc_info: + await _get_and_validate_existing_key( + token="non-existent-key", + prisma_client=mock_prisma_client, + ) + + assert exc_info.value.status_code == 404 + assert "Key not found" in str(exc_info.value.detail) + + # Test Case 3: Database not connected raises HTTPException + with pytest.raises(HTTPException) as exc_info: + await _get_and_validate_existing_key( + token="test-key-123", + prisma_client=None, + ) + + assert exc_info.value.status_code == 500 + assert "Database not connected" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_process_single_key_update(): + """ + Test _process_single_key_update helper function. + + Tests successful key update with all validations passing. + """ + from litellm.types.proxy.management_endpoints.key_management_endpoints import ( + BulkUpdateKeyRequestItem, + ) + + # Setup mocks + mock_prisma_client = AsyncMock() + mock_user_api_key_cache = MagicMock() + mock_proxy_logging_obj = MagicMock() + mock_llm_router = MagicMock() + + # Mock existing key + existing_key = LiteLLM_VerificationToken( + token="test-key-123", + user_id="user-123", + models=["gpt-4"], + team_id=None, + max_budget=None, + tags=None, + ) + + # Mock updated key response + updated_key_data = { + "user_id": "user-123", + "models": ["gpt-4"], + "team_id": None, + "max_budget": 100.0, + "tags": ["production"], + } + + mock_prisma_client.get_data = AsyncMock(return_value=existing_key) + mock_updated_key_obj = MagicMock() + mock_updated_key_obj.model_dump.return_value = updated_key_data + mock_prisma_client.update_data = AsyncMock( + return_value={"data": mock_updated_key_obj} + ) + + # Mock prepare_key_update_data + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints.prepare_key_update_data" + ) as mock_prepare: + mock_prepare.return_value = {"max_budget": 100.0, "tags": ["production"]} + + # Mock TeamMemberPermissionChecks + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints.TeamMemberPermissionChecks.can_team_member_execute_key_management_endpoint" + ) as mock_permission_check: + mock_permission_check.return_value = None + + # Mock _delete_cache_key_object + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints._delete_cache_key_object" + ) as mock_delete_cache: + mock_delete_cache.return_value = None + + # Mock hash_token (imported from litellm.proxy._types) + with patch( + "litellm.proxy._types.hash_token" + ) as mock_hash: + mock_hash.return_value = "hashed-test-key-123" + + # Mock KeyManagementEventHooks + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints.KeyManagementEventHooks.async_key_updated_hook" + ): + # Create update request + key_update_item = BulkUpdateKeyRequestItem( + key="test-key-123", + max_budget=100.0, + tags=["production"], + ) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-admin", + user_id="admin-user", + ) + + # Call the function + result = await _process_single_key_update( + key_update_item=key_update_item, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + prisma_client=mock_prisma_client, + user_api_key_cache=mock_user_api_key_cache, + proxy_logging_obj=mock_proxy_logging_obj, + llm_router=mock_llm_router, + ) + + # Verify results + assert result is not None + assert "token" not in result # Token should be removed + assert result.get("max_budget") == 100.0 + assert result.get("tags") == ["production"] + + # Verify mocks were called + mock_prisma_client.get_data.assert_called_once() + mock_prisma_client.update_data.assert_called_once() + mock_delete_cache.assert_called_once() + + +@pytest.mark.asyncio +async def test_bulk_update_keys_success(monkeypatch): + """ + Test /key/bulk_update endpoint with successful updates. + + Tests: + 1. Multiple keys updated successfully + 2. Response contains correct counts and data + """ + from litellm.types.proxy.management_endpoints.key_management_endpoints import ( + BulkUpdateKeyRequest, + BulkUpdateKeyRequestItem, + ) + from litellm.proxy.management_endpoints.key_management_endpoints import ( + bulk_update_keys, + ) + from litellm.proxy.proxy_server import ( + llm_router, + prisma_client, + proxy_logging_obj, + user_api_key_cache, + ) + + # Setup mocks + mock_prisma_client = AsyncMock() + mock_user_api_key_cache = MagicMock() + mock_proxy_logging_obj = MagicMock() + mock_llm_router = MagicMock() + + # Mock existing keys + existing_key_1 = LiteLLM_VerificationToken( + token="test-key-1", + user_id="user-123", + models=["gpt-4"], + team_id=None, + max_budget=None, + ) + existing_key_2 = LiteLLM_VerificationToken( + token="test-key-2", + user_id="user-123", + models=["gpt-3.5-turbo"], + team_id=None, + max_budget=50.0, + ) + + # Mock updated key responses + updated_key_1_data = { + "user_id": "user-123", + "models": ["gpt-4"], + "max_budget": 100.0, + "tags": ["production"], + } + updated_key_2_data = { + "user_id": "user-123", + "models": ["gpt-3.5-turbo"], + "max_budget": 200.0, + "tags": ["staging"], + } + + mock_prisma_client.get_data = AsyncMock( + side_effect=[existing_key_1, existing_key_2] + ) + mock_updated_key_1_obj = MagicMock() + mock_updated_key_1_obj.model_dump.return_value = updated_key_1_data + mock_updated_key_2_obj = MagicMock() + mock_updated_key_2_obj.model_dump.return_value = updated_key_2_data + mock_prisma_client.update_data = AsyncMock( + side_effect=[ + {"data": mock_updated_key_1_obj}, + {"data": mock_updated_key_2_obj}, + ] + ) + + # Patch dependencies + monkeypatch.setattr( + "litellm.proxy.proxy_server.prisma_client", mock_prisma_client + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj + ) + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", mock_llm_router) + + # Mock helper functions + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints.prepare_key_update_data" + ) as mock_prepare: + mock_prepare.side_effect = [ + {"max_budget": 100.0, "tags": ["production"]}, + {"max_budget": 200.0, "tags": ["staging"]}, + ] + + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints.TeamMemberPermissionChecks.can_team_member_execute_key_management_endpoint" + ): + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints._delete_cache_key_object" + ): + with patch( + "litellm.proxy._types.hash_token" + ) as mock_hash: + mock_hash.side_effect = ["hashed-key-1", "hashed-key-2"] + + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints.KeyManagementEventHooks.async_key_updated_hook" + ): + # Create request + request_data = BulkUpdateKeyRequest( + keys=[ + BulkUpdateKeyRequestItem( + key="test-key-1", + max_budget=100.0, + tags=["production"], + ), + BulkUpdateKeyRequestItem( + key="test-key-2", + max_budget=200.0, + tags=["staging"], + ), + ] + ) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-admin", + user_id="admin-user", + ) + + # Call endpoint + response = await bulk_update_keys( + data=request_data, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + # Verify response + assert response.total_requested == 2 + assert len(response.successful_updates) == 2 + assert len(response.failed_updates) == 0 + assert response.successful_updates[0].key == "test-key-1" + assert response.successful_updates[1].key == "test-key-2" + + +@pytest.mark.asyncio +async def test_bulk_update_keys_partial_failures(monkeypatch): + """ + Test /key/bulk_update endpoint with partial failures. + + Tests: + 1. Some keys update successfully, others fail + 2. Response contains both successful and failed updates + 3. Failed updates include error messages + """ + from litellm.types.proxy.management_endpoints.key_management_endpoints import ( + BulkUpdateKeyRequest, + BulkUpdateKeyRequestItem, + ) + from litellm.proxy.management_endpoints.key_management_endpoints import ( + bulk_update_keys, + ) + + # Setup mocks + mock_prisma_client = AsyncMock() + mock_user_api_key_cache = MagicMock() + mock_proxy_logging_obj = MagicMock() + mock_llm_router = MagicMock() + + # Mock existing keys + existing_key_1 = LiteLLM_VerificationToken( + token="test-key-1", + user_id="user-123", + models=["gpt-4"], + team_id=None, + max_budget=None, + ) + + # Mock updated key response for successful update + updated_key_1_data = { + "user_id": "user-123", + "models": ["gpt-4"], + "max_budget": 100.0, + "tags": ["production"], + } + + # First key exists, second key doesn't exist + mock_prisma_client.get_data = AsyncMock( + side_effect=[existing_key_1, None] # Second key not found + ) + mock_updated_key_1_obj = MagicMock() + mock_updated_key_1_obj.model_dump.return_value = updated_key_1_data + mock_prisma_client.update_data = AsyncMock( + return_value={"data": mock_updated_key_1_obj} + ) + + # Patch dependencies + monkeypatch.setattr( + "litellm.proxy.proxy_server.prisma_client", mock_prisma_client + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj + ) + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", mock_llm_router) + + # Mock helper functions + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints.prepare_key_update_data" + ) as mock_prepare: + mock_prepare.return_value = {"max_budget": 100.0, "tags": ["production"]} + + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints.TeamMemberPermissionChecks.can_team_member_execute_key_management_endpoint" + ): + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints._delete_cache_key_object" + ): + with patch( + "litellm.proxy._types.hash_token" + ) as mock_hash: + mock_hash.return_value = "hashed-key-1" + + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints.KeyManagementEventHooks.async_key_updated_hook" + ): + # Create request with one valid and one invalid key + request_data = BulkUpdateKeyRequest( + keys=[ + BulkUpdateKeyRequestItem( + key="test-key-1", + max_budget=100.0, + tags=["production"], + ), + BulkUpdateKeyRequestItem( + key="non-existent-key", + max_budget=200.0, + tags=["staging"], + ), + ] + ) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-admin", + user_id="admin-user", + ) + + # Call endpoint + response = await bulk_update_keys( + data=request_data, + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + # Verify response + assert response.total_requested == 2 + assert len(response.successful_updates) == 1 + assert len(response.failed_updates) == 1 + assert response.successful_updates[0].key == "test-key-1" + assert response.failed_updates[0].key == "non-existent-key" + assert "Key not found" in response.failed_updates[0].failed_reason