diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 045d2fd5f14..9ae95085f55 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -228,6 +228,7 @@ class KeyManagementRoutes(str, enum.Enum): KEY_BLOCK = "/key/block" KEY_UNBLOCK = "/key/unblock" KEY_BULK_UPDATE = "/key/bulk_update" + KEY_RESET_SPEND = "/key/{key_id}/reset_spend" # info and health routes KEY_INFO = "/key/info" @@ -987,6 +988,10 @@ class RegenerateKeyRequest(GenerateKeyRequest): new_master_key: Optional[str] = None +class ResetSpendRequest(LiteLLMPydanticObjectBase): + reset_to: float + + class KeyRequest(LiteLLMPydanticObjectBase): keys: Optional[List[str]] = None key_aliases: Optional[List[str]] = None diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 278971a91a5..d1840363009 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -3373,6 +3373,163 @@ async def regenerate_key_fn( raise handle_exception_on_proxy(e) +async def _check_proxy_or_team_admin_for_key( + key_in_db: LiteLLM_VerificationToken, + user_api_key_dict: UserAPIKeyAuth, + prisma_client: PrismaClient, + user_api_key_cache: DualCache, +) -> None: + if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value: + return + + if key_in_db.team_id is not None: + team_table = await get_team_object( + team_id=key_in_db.team_id, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + check_db_only=True, + ) + if team_table is not None: + if _is_user_team_admin( + user_api_key_dict=user_api_key_dict, + team_obj=team_table, + ): + return + + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={"error": "You must be a proxy admin or team admin to reset key spend"}, + ) + + +def _validate_reset_spend_value( + reset_to: Any, key_in_db: LiteLLM_VerificationToken +) -> float: + if not isinstance(reset_to, (int, float)): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "reset_to must be a float"}, + ) + + reset_to = float(reset_to) + + if reset_to < 0: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "reset_to must be >= 0"}, + ) + + current_spend = key_in_db.spend or 0.0 + if reset_to > current_spend: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": f"reset_to ({reset_to}) must be <= current spend ({current_spend})"}, + ) + + max_budget = key_in_db.max_budget + if key_in_db.litellm_budget_table is not None: + budget_max_budget = getattr(key_in_db.litellm_budget_table, "max_budget", None) + if budget_max_budget is not None: + if max_budget is None or budget_max_budget < max_budget: + max_budget = budget_max_budget + + if max_budget is not None and reset_to > max_budget: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": f"reset_to ({reset_to}) must be <= budget ({max_budget})"}, + ) + + return reset_to + + +@router.post( + "/key/{key:path}/reset_spend", + tags=["key management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def reset_key_spend_fn( + key: str, + data: ResetSpendRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +) -> Dict[str, Any]: + try: + from litellm.proxy.proxy_server import ( + hash_token, + prisma_client, + proxy_logging_obj, + user_api_key_cache, + ) + + if prisma_client is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"error": "DB not connected. prisma_client is None"}, + ) + + if "sk" not in key: + hashed_api_key = key + else: + hashed_api_key = hash_token(key) + + _key_in_db = await prisma_client.db.litellm_verificationtoken.find_unique( + where={"token": hashed_api_key}, + include={"litellm_budget_table": True}, + ) + if _key_in_db is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"error": f"Key {key} not found."}, + ) + + current_spend = _key_in_db.spend or 0.0 + reset_to = _validate_reset_spend_value(data.reset_to, _key_in_db) + + await _check_proxy_or_team_admin_for_key( + key_in_db=_key_in_db, + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + ) + + updated_key = await prisma_client.db.litellm_verificationtoken.update( + where={"token": hashed_api_key}, + data={"spend": reset_to}, + ) + + if updated_key is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"error": "Failed to update key spend"}, + ) + + await _delete_cache_key_object( + hashed_token=hashed_api_key, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + ) + + max_budget = updated_key.max_budget + budget_reset_at = updated_key.budget_reset_at + + return { + "key_hash": hashed_api_key, + "spend": reset_to, + "previous_spend": current_spend, + "max_budget": max_budget, + "budget_reset_at": budget_reset_at, + } + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception("Error resetting key spend: %s", e) + raise handle_exception_on_proxy(e) + + async def validate_key_list_check( user_api_key_dict: UserAPIKeyAuth, user_id: Optional[str], diff --git a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py index 3638fd7e2c9..e90fb277eed 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py @@ -19,10 +19,12 @@ LiteLLM_BudgetTable, LiteLLM_OrganizationTable, LiteLLM_TeamTableCachedObj, + LiteLLM_UserTable, LiteLLM_VerificationToken, LitellmUserRoles, Member, ProxyException, + ResetSpendRequest, UpdateKeyRequest, ) from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth @@ -37,6 +39,7 @@ _save_deleted_verification_token_records, _transform_verification_tokens_to_deleted_records, _validate_max_budget, + _validate_reset_spend_value, can_modify_verification_token, check_org_key_model_specific_limits, check_team_key_model_specific_limits, @@ -44,6 +47,8 @@ generate_key_helper_fn, list_keys, prepare_key_update_data, + reset_key_spend_fn, + validate_key_list_check, validate_key_team_change, ) from litellm.proxy.proxy_server import app @@ -4690,3 +4695,699 @@ async def test_bulk_update_keys_partial_failures(monkeypatch): assert response.successful_updates[0].key == "test-key-1" assert response.failed_updates[0].key == "non-existent-key" assert "Key not found" in response.failed_updates[0].failed_reason + + +@pytest.mark.parametrize( + "reset_to,key_spend,key_max_budget,budget_max_budget,expected_error", + [ + ("not_a_number", 100.0, None, None, "reset_to must be a float"), + (None, 100.0, None, None, "reset_to must be a float"), + ([], 100.0, None, None, "reset_to must be a float"), + ({}, 100.0, None, None, "reset_to must be a float"), + (-1.0, 100.0, None, None, "reset_to must be >= 0"), + (-0.1, 100.0, None, None, "reset_to must be >= 0"), + (101.0, 100.0, None, None, "reset_to (101.0) must be <= current spend (100.0)"), + (150.0, 100.0, None, None, "reset_to (150.0) must be <= current spend (100.0)"), + (50.0, 100.0, 30.0, None, "reset_to (50.0) must be <= budget (30.0)"), + ], +) +def test_validate_reset_spend_value_invalid( + reset_to, key_spend, key_max_budget, budget_max_budget, expected_error +): + key_in_db = LiteLLM_VerificationToken( + token="test-token", + user_id="test-user", + spend=key_spend, + max_budget=key_max_budget, + litellm_budget_table=LiteLLM_BudgetTable( + budget_id="test-budget", max_budget=budget_max_budget + ).dict() + if budget_max_budget is not None + else None, + ) + + with pytest.raises(HTTPException) as exc_info: + _validate_reset_spend_value(reset_to, key_in_db) + + assert exc_info.value.status_code == 400 + assert expected_error in str(exc_info.value.detail) + + +@pytest.mark.parametrize( + "reset_to,key_spend,key_max_budget,budget_max_budget", + [ + (0.0, 100.0, None, None), + (0, 100.0, None, None), + (50.0, 100.0, None, None), + (100.0, 100.0, None, None), + (25.0, 100.0, 50.0, None), + (0.0, 0.0, None, None), + (10.5, 50.0, 20.0, None), + ], +) +def test_validate_reset_spend_value_valid( + reset_to, key_spend, key_max_budget, budget_max_budget +): + key_in_db = LiteLLM_VerificationToken( + token="test-token", + user_id="test-user", + spend=key_spend, + max_budget=key_max_budget, + litellm_budget_table=LiteLLM_BudgetTable( + budget_id="test-budget", max_budget=budget_max_budget + ).dict() + if budget_max_budget is not None + else None, + ) + + result = _validate_reset_spend_value(reset_to, key_in_db) + assert result == float(reset_to) + + +def test_validate_reset_spend_value_no_budget_table(): + key_in_db = LiteLLM_VerificationToken( + token="test-token", + user_id="test-user", + spend=100.0, + max_budget=50.0, + litellm_budget_table=None, + ) + + result = _validate_reset_spend_value(25.0, key_in_db) + assert result == 25.0 + + +def test_validate_reset_spend_value_none_spend(): + key_in_db = LiteLLM_VerificationToken( + token="test-token", + user_id="test-user", + spend=0.0, + max_budget=None, + litellm_budget_table=None, + ) + + result = _validate_reset_spend_value(0.0, key_in_db) + assert result == 0.0 + + with pytest.raises(HTTPException) as exc_info: + _validate_reset_spend_value(1.0, key_in_db) + assert exc_info.value.status_code == 400 + assert "must be <= current spend" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_reset_key_spend_success(monkeypatch): + mock_prisma_client = MagicMock() + mock_user_api_key_cache = MagicMock() + mock_proxy_logging_obj = MagicMock() + + hashed_key = "hashed-test-key" + key_in_db = LiteLLM_VerificationToken( + token=hashed_key, + user_id="test-user", + spend=100.0, + max_budget=200.0, + litellm_budget_table=None, + ) + + updated_key = LiteLLM_VerificationToken( + token=hashed_key, + user_id="test-user", + spend=50.0, + max_budget=200.0, + budget_reset_at=None, + ) + + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=key_in_db + ) + mock_prisma_client.db.litellm_verificationtoken.update = AsyncMock( + return_value=updated_key + ) + + monkeypatch.setattr( + "litellm.proxy.proxy_server.prisma_client", mock_prisma_client + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj + ) + + with patch( + "litellm.proxy.proxy_server.hash_token" + ) as mock_hash_token, patch( + "litellm.proxy.management_endpoints.key_management_endpoints._check_proxy_or_team_admin_for_key" + ) as mock_check_admin, patch( + "litellm.proxy.management_endpoints.key_management_endpoints._delete_cache_key_object" + ) as mock_delete_cache: + mock_hash_token.return_value = hashed_key + mock_check_admin.return_value = None + mock_delete_cache.return_value = None + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-admin", + user_id="admin-user", + ) + + response = await reset_key_spend_fn( + key="sk-test-key", + data=ResetSpendRequest(reset_to=50.0), + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + assert response["spend"] == 50.0 + assert response["previous_spend"] == 100.0 + assert response["key_hash"] == hashed_key + assert response["max_budget"] == 200.0 + mock_prisma_client.db.litellm_verificationtoken.update.assert_called_once() + mock_delete_cache.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_reset_key_spend_success_team_admin(monkeypatch): + """Test that team admin can reset key spend for keys in their team.""" + mock_prisma_client = MagicMock() + mock_user_api_key_cache = MagicMock() + mock_proxy_logging_obj = MagicMock() + + hashed_key = "hashed-test-key" + team_id = "test-team-123" + key_in_db = LiteLLM_VerificationToken( + token=hashed_key, + user_id="test-user", + team_id=team_id, + spend=100.0, + max_budget=200.0, + litellm_budget_table=None, + ) + + updated_key = LiteLLM_VerificationToken( + token=hashed_key, + user_id="test-user", + team_id=team_id, + spend=50.0, + max_budget=200.0, + budget_reset_at=None, + ) + + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=key_in_db + ) + mock_prisma_client.db.litellm_verificationtoken.update = AsyncMock( + return_value=updated_key + ) + + # Set up team table with user as admin + team_table = LiteLLM_TeamTableCachedObj( + team_id=team_id, + team_alias="test-team", + tpm_limit=None, + rpm_limit=None, + max_budget=None, + spend=0.0, + models=[], + blocked=False, + members_with_roles=[ + Member(user_id="team-admin-user", role="admin"), + Member(user_id="test-user", role="user"), + ], + ) + + async def mock_get_team_object(*args, **kwargs): + return team_table + + monkeypatch.setattr( + "litellm.proxy.proxy_server.prisma_client", mock_prisma_client + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj + ) + monkeypatch.setattr( + "litellm.proxy.management_endpoints.key_management_endpoints.get_team_object", + mock_get_team_object, + ) + + with patch( + "litellm.proxy.proxy_server.hash_token" + ) as mock_hash_token, patch( + "litellm.proxy.management_endpoints.key_management_endpoints._delete_cache_key_object" + ) as mock_delete_cache: + mock_hash_token.return_value = hashed_key + mock_delete_cache.return_value = None + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + api_key="sk-team-admin", + user_id="team-admin-user", + ) + + response = await reset_key_spend_fn( + key="sk-test-key", + data=ResetSpendRequest(reset_to=50.0), + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + assert response["spend"] == 50.0 + assert response["previous_spend"] == 100.0 + assert response["key_hash"] == hashed_key + assert response["max_budget"] == 200.0 + mock_prisma_client.db.litellm_verificationtoken.update.assert_called_once() + mock_delete_cache.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_reset_key_spend_key_not_found(monkeypatch): + mock_prisma_client = MagicMock() + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=None + ) + + monkeypatch.setattr( + "litellm.proxy.proxy_server.prisma_client", mock_prisma_client + ) + + with patch("litellm.proxy.proxy_server.hash_token") as mock_hash_token: + mock_hash_token.return_value = "hashed-key" + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-admin", + user_id="admin-user", + ) + + with pytest.raises(HTTPException) as exc_info: + await reset_key_spend_fn( + key="sk-test-key", + data=ResetSpendRequest(reset_to=50.0), + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + assert exc_info.value.status_code == 404 + assert "Key not found" in str(exc_info.value.detail) or "Key sk-test-key not found" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_reset_key_spend_db_not_connected(monkeypatch): + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-admin", + user_id="admin-user", + ) + + with pytest.raises(HTTPException) as exc_info: + await reset_key_spend_fn( + key="sk-test-key", + data=ResetSpendRequest(reset_to=50.0), + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + assert exc_info.value.status_code == 500 + assert "DB not connected" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_reset_key_spend_validation_error(monkeypatch): + mock_prisma_client = MagicMock() + key_in_db = LiteLLM_VerificationToken( + token="hashed-key", + user_id="test-user", + spend=100.0, + max_budget=None, + litellm_budget_table=None, + ) + + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=key_in_db + ) + + monkeypatch.setattr( + "litellm.proxy.proxy_server.prisma_client", mock_prisma_client + ) + + with patch("litellm.proxy.proxy_server.hash_token") as mock_hash_token: + mock_hash_token.return_value = "hashed-key" + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-admin", + user_id="admin-user", + ) + + with pytest.raises(HTTPException) as exc_info: + await reset_key_spend_fn( + key="sk-test-key", + data=ResetSpendRequest(reset_to=150.0), + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + assert exc_info.value.status_code == 400 + assert "must be <= current spend" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_reset_key_spend_authorization_failure(monkeypatch): + mock_prisma_client = MagicMock() + mock_user_api_key_cache = MagicMock() + + hashed_key = "hashed-test-key" + key_in_db = LiteLLM_VerificationToken( + token=hashed_key, + user_id="test-user", + team_id="team-1", + spend=100.0, + max_budget=None, + litellm_budget_table=None, + ) + + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=key_in_db + ) + + monkeypatch.setattr( + "litellm.proxy.proxy_server.prisma_client", mock_prisma_client + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache + ) + + with patch("litellm.proxy.proxy_server.hash_token") as mock_hash_token, patch( + "litellm.proxy.management_endpoints.key_management_endpoints._check_proxy_or_team_admin_for_key" + ) as mock_check_admin: + mock_hash_token.return_value = hashed_key + mock_check_admin.side_effect = HTTPException( + status_code=403, detail={"error": "Not authorized"} + ) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + api_key="sk-user", + user_id="user-1", + ) + + with pytest.raises(HTTPException) as exc_info: + await reset_key_spend_fn( + key="sk-test-key", + data=ResetSpendRequest(reset_to=50.0), + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + assert exc_info.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_reset_key_spend_hashed_key(monkeypatch): + mock_prisma_client = MagicMock() + mock_user_api_key_cache = MagicMock() + mock_proxy_logging_obj = MagicMock() + + hashed_key = "already-hashed-key" + key_in_db = LiteLLM_VerificationToken( + token=hashed_key, + user_id="test-user", + spend=100.0, + max_budget=None, + litellm_budget_table=None, + ) + + updated_key = LiteLLM_VerificationToken( + token=hashed_key, + user_id="test-user", + spend=50.0, + max_budget=None, + budget_reset_at=None, + ) + + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=key_in_db + ) + mock_prisma_client.db.litellm_verificationtoken.update = AsyncMock( + return_value=updated_key + ) + + monkeypatch.setattr( + "litellm.proxy.proxy_server.prisma_client", mock_prisma_client + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache + ) + monkeypatch.setattr( + "litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj + ) + + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints._check_proxy_or_team_admin_for_key" + ) as mock_check_admin, patch( + "litellm.proxy.management_endpoints.key_management_endpoints._delete_cache_key_object" + ) as mock_delete_cache: + mock_check_admin.return_value = None + mock_delete_cache.return_value = None + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-admin", + user_id="admin-user", + ) + + response = await reset_key_spend_fn( + key=hashed_key, + data=ResetSpendRequest(reset_to=50.0), + user_api_key_dict=user_api_key_dict, + litellm_changed_by=None, + ) + + assert response["spend"] == 50.0 + mock_prisma_client.db.litellm_verificationtoken.find_unique.assert_called_once_with( + where={"token": hashed_key}, include={"litellm_budget_table": True} + ) + + +@pytest.mark.asyncio +async def test_validate_key_list_check_proxy_admin(): + mock_prisma_client = AsyncMock() + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + user_id="admin-user", + ) + + result = await validate_key_list_check( + user_api_key_dict=user_api_key_dict, + user_id=None, + team_id=None, + organization_id=None, + key_alias=None, + key_hash=None, + prisma_client=mock_prisma_client, + ) + + assert result is None + + +@pytest.mark.asyncio +async def test_validate_key_list_check_team_admin_success(): + mock_prisma_client = AsyncMock() + user_info = LiteLLM_UserTable( + user_id="test-user", + user_email="test@example.com", + teams=["team-1"], + organization_memberships=[], + ) + + mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock( + return_value=user_info + ) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + user_id="test-user", + ) + + result = await validate_key_list_check( + user_api_key_dict=user_api_key_dict, + user_id=None, + team_id="team-1", + organization_id=None, + key_alias=None, + key_hash=None, + prisma_client=mock_prisma_client, + ) + + assert result is not None + assert result.user_id == "test-user" + + +@pytest.mark.asyncio +async def test_validate_key_list_check_team_admin_fail(): + mock_prisma_client = AsyncMock() + user_info = LiteLLM_UserTable( + user_id="test-user", + user_email="test@example.com", + teams=["team-1"], + organization_memberships=[], + ) + + mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock( + return_value=user_info + ) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + user_id="test-user", + ) + + with pytest.raises(ProxyException) as exc_info: + await validate_key_list_check( + user_api_key_dict=user_api_key_dict, + user_id=None, + team_id="team-2", + organization_id=None, + key_alias=None, + key_hash=None, + prisma_client=mock_prisma_client, + ) + + assert exc_info.value.code == "403" or exc_info.value.code == 403 + assert "not authorized to check this team's keys" in exc_info.value.message + + +@pytest.mark.asyncio +async def test_validate_key_list_check_key_hash_authorized(): + mock_prisma_client = AsyncMock() + user_info = LiteLLM_UserTable( + user_id="test-user", + user_email="test@example.com", + teams=[], + organization_memberships=[], + ) + + key_info = LiteLLM_VerificationToken( + token="hashed-key", + user_id="test-user", + ) + + mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock( + return_value=user_info + ) + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=key_info + ) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + user_id="test-user", + ) + + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints._can_user_query_key_info" + ) as mock_can_query: + mock_can_query.return_value = True + + result = await validate_key_list_check( + user_api_key_dict=user_api_key_dict, + user_id=None, + team_id=None, + organization_id=None, + key_alias=None, + key_hash="hashed-key", + prisma_client=mock_prisma_client, + ) + + assert result is not None + assert result.user_id == "test-user" + + +@pytest.mark.asyncio +async def test_validate_key_list_check_key_hash_unauthorized(): + mock_prisma_client = AsyncMock() + user_info = LiteLLM_UserTable( + user_id="test-user", + user_email="test@example.com", + teams=[], + organization_memberships=[], + ) + + key_info = LiteLLM_VerificationToken( + token="hashed-key", + user_id="other-user", + ) + + mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock( + return_value=user_info + ) + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + return_value=key_info + ) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + user_id="test-user", + ) + + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints._can_user_query_key_info" + ) as mock_can_query: + mock_can_query.return_value = False + + with pytest.raises(HTTPException) as exc_info: + await validate_key_list_check( + user_api_key_dict=user_api_key_dict, + user_id=None, + team_id=None, + organization_id=None, + key_alias=None, + key_hash="hashed-key", + prisma_client=mock_prisma_client, + ) + + assert exc_info.value.status_code == 403 + assert "not allowed to access this key's info" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_validate_key_list_check_key_hash_not_found(): + mock_prisma_client = AsyncMock() + user_info = LiteLLM_UserTable( + user_id="test-user", + user_email="test@example.com", + teams=[], + organization_memberships=[], + ) + + mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock( + return_value=user_info + ) + mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock( + side_effect=Exception("Key not found") + ) + + user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + user_id="test-user", + ) + + with pytest.raises(ProxyException) as exc_info: + await validate_key_list_check( + user_api_key_dict=user_api_key_dict, + user_id=None, + team_id=None, + organization_id=None, + key_alias=None, + key_hash="non-existent-key", + prisma_client=mock_prisma_client, + ) + + assert exc_info.value.code == "403" or exc_info.value.code == 403 + assert "Key Hash not found" in exc_info.value.message