diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 9dadffca351..2eb6cf65281 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -3749,7 +3749,7 @@ async def list_keys( else: admin_team_ids = None - if user_id is None and user_api_key_dict.user_role not in [ + if not user_id and user_api_key_dict.user_role not in [ LitellmUserRoles.PROXY_ADMIN.value, LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value, ]: diff --git a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py index 5720ff948a6..39f8d1cccb0 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py @@ -4140,6 +4140,72 @@ async def test_list_keys_with_invalid_status(): assert "deleted" in str(exc_info.value.message) +@pytest.mark.asyncio +async def test_list_keys_non_admin_user_id_auto_set(): + """ + Test that when a non-admin user calls list_keys with user_id=None, + the user_id is automatically set to the authenticated user's user_id. + """ + from unittest.mock import Mock, patch + + mock_prisma_client = AsyncMock() + + # Create a non-admin user with a user_id + test_user_id = "test-user-123" + mock_user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + user_id=test_user_id, + ) + + # Mock user info returned by validate_key_list_check + mock_user_info = LiteLLM_UserTable( + user_id=test_user_id, + user_email="test@example.com", + teams=[], + organization_memberships=[], + ) + + # Mock _list_key_helper to capture the user_id argument + mock_list_key_helper = AsyncMock(return_value={ + "keys": [], + "total_count": 0, + "current_page": 1, + "total_pages": 0, + }) + + # Mock prisma_client to be non-None + with patch("litellm.proxy.proxy_server.prisma_client", mock_prisma_client): + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints.validate_key_list_check", + return_value=mock_user_info, + ): + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints.get_admin_team_ids", + return_value=[], + ): + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints._list_key_helper", + mock_list_key_helper, + ): + mock_request = Mock() + + # Call list_keys with user_id=None + await list_keys( + request=mock_request, + user_api_key_dict=mock_user_api_key_dict, + user_id=None, # This should be auto-set to test_user_id + status=None, # Explicitly set status to None to avoid validation errors + ) + + # Verify that _list_key_helper was called with user_id set to the authenticated user's user_id + mock_list_key_helper.assert_called_once() + call_kwargs = mock_list_key_helper.call_args.kwargs + assert call_kwargs["user_id"] == test_user_id, ( + f"Expected user_id to be set to {test_user_id}, " + f"but got {call_kwargs.get('user_id')}" + ) + + @pytest.mark.asyncio async def test_generate_key_negative_max_budget(): """