diff --git a/litellm/proxy/management_endpoints/common_daily_activity.py b/litellm/proxy/management_endpoints/common_daily_activity.py index 02961748e7c..a4fbeb7e28f 100644 --- a/litellm/proxy/management_endpoints/common_daily_activity.py +++ b/litellm/proxy/management_endpoints/common_daily_activity.py @@ -474,16 +474,21 @@ def _build_aggregated_sql_query( start_date: str, end_date: str, model: Optional[str], - api_key: Optional[str], + api_key: Optional[Union[str, List[str]]], exclude_entity_ids: Optional[List[str]] = None, timezone_offset_minutes: Optional[int] = None, + include_entity_id: bool = False, ) -> Tuple[str, List[Any]]: """Build a parameterized SQL GROUP BY query for aggregated daily activity. Groups by (date, api_key, model, model_group, custom_llm_provider, mcp_namespaced_tool_name, endpoint) with SUMs on all metric columns. - The entity_id column is intentionally omitted from GROUP BY to collapse - rows across entities — this is where the biggest row reduction comes from. + + When include_entity_id is False (default), the entity_id column is omitted + from GROUP BY to collapse rows across entities. + + When include_entity_id is True, the entity_id column is included in both + SELECT and GROUP BY, preserving per-entity breakdown in the results. Returns: Tuple of (sql_query, params_list) ready for prisma_client.db.query_raw(). @@ -538,14 +543,24 @@ def _build_aggregated_sql_query( # Optional api_key filter if api_key: - sql_conditions.append(f"api_key = ${p}") - sql_params.append(api_key) - p += 1 + if isinstance(api_key, list): + placeholders = ", ".join(f"${p + i}" for i in range(len(api_key))) + sql_conditions.append(f"api_key IN ({placeholders})") + sql_params.extend(api_key) + p += len(api_key) + else: + sql_conditions.append(f"api_key = ${p}") + sql_params.append(api_key) + p += 1 where_clause = " AND ".join(sql_conditions) + entity_select = f'"{entity_id_field}",' if include_entity_id else "" + entity_group_by = f'"{entity_id_field}",' if include_entity_id else "" + sql_query = f""" SELECT + {entity_select} date, api_key, model, @@ -563,7 +578,7 @@ def _build_aggregated_sql_query( SUM(failed_requests)::bigint AS failed_requests FROM "{pg_table}" WHERE {where_clause} - GROUP BY date, api_key, model, model_group, custom_llm_provider, + GROUP BY {entity_group_by} date, api_key, model, model_group, custom_llm_provider, mcp_namespaced_tool_name, endpoint ORDER BY date DESC """ @@ -735,9 +750,10 @@ async def get_daily_activity_aggregated( start_date: Optional[str], end_date: Optional[str], model: Optional[str], - api_key: Optional[str], + api_key: Optional[Union[str, List[str]]], exclude_entity_ids: Optional[List[str]] = None, timezone_offset_minutes: Optional[int] = None, + include_entity_breakdown: bool = False, ) -> SpendAnalyticsPaginatedResponse: """Aggregated variant that returns the full result set (no pagination). @@ -745,6 +761,11 @@ async def get_daily_activity_aggregated( all individual rows into Python. This collapses rows across entities (users/teams/orgs), reducing ~150k rows to ~2-3k grouped rows. + When include_entity_breakdown is True, the entity_id column is included + in the GROUP BY so that per-entity breakdown data is preserved in the + response (e.g. per-team spend). This is needed for entity-specific views + like the team usage dashboard. + Matches the response model of the paginated endpoint so the UI does not need to transform. """ if prisma_client is None: @@ -770,6 +791,7 @@ async def get_daily_activity_aggregated( api_key=api_key, exclude_entity_ids=exclude_entity_ids, timezone_offset_minutes=timezone_offset_minutes, + include_entity_id=include_entity_breakdown, ) # Execute GROUP BY query — returns pre-aggregated dicts @@ -780,13 +802,11 @@ async def get_daily_activity_aggregated( # Convert dicts to objects for compatibility with _aggregate_spend_records records = [SimpleNamespace(**row) for row in rows] - # entity_id_field=None skips entity breakdown (entity dimension was - # collapsed by the GROUP BY, so per-entity data is not available) aggregated = await _aggregate_spend_records( prisma_client=prisma_client, records=records, - entity_id_field=None, - entity_metadata_field=None, + entity_id_field=entity_id_field if include_entity_breakdown else None, + entity_metadata_field=entity_metadata_field if include_entity_breakdown else None, ) return SpendAnalyticsPaginatedResponse( diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 80d50f31a17..5e7a0931b2a 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -77,8 +77,8 @@ _upsert_budget_and_membership, _user_has_admin_view, ) -from litellm.proxy.management_endpoints.tag_management_endpoints import ( - get_daily_activity, +from litellm.proxy.management_endpoints.common_daily_activity import ( + get_daily_activity_aggregated, ) from litellm.proxy.management_helpers.object_permission_utils import ( _set_object_permission, @@ -3890,22 +3890,27 @@ async def get_team_daily_activity( page: int = 1, page_size: int = 10, exclude_team_ids: Optional[str] = None, + timezone: Optional[int] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ Get daily activity for specific teams or all teams. + Uses SQL GROUP BY to aggregate all matching rows without pagination, + ensuring accurate total spend regardless of data volume. + Args: team_ids (Optional[str]): Comma-separated list of team IDs to filter by. If not provided, returns data for all teams. start_date (Optional[str]): Start date for the activity period (YYYY-MM-DD). end_date (Optional[str]): End date for the activity period (YYYY-MM-DD). model (Optional[str]): Filter by model name. api_key (Optional[str]): Filter by API key. - page (int): Page number for pagination. - page_size (int): Number of items per page. + page (int): Deprecated, kept for backward compatibility. All results are returned in a single page. + page_size (int): Deprecated, kept for backward compatibility. exclude_team_ids (Optional[str]): Comma-separated list of team IDs to exclude. + timezone (Optional[int]): Timezone offset in minutes from UTC (e.g., 480 for PST). Returns: - SpendAnalyticsPaginatedResponse: Paginated response containing daily activity data. + SpendAnalyticsPaginatedResponse: Response containing daily activity data with per-team breakdown. """ from litellm.proxy.proxy_server import ( prisma_client, @@ -4009,17 +4014,17 @@ async def get_team_daily_activity( if final_api_key_filter is None and user_api_keys is not None: final_api_key_filter = user_api_keys - return await get_daily_activity( + return await get_daily_activity_aggregated( prisma_client=prisma_client, table_name="litellm_dailyteamspend", entity_id_field="team_id", entity_id=team_ids_list, entity_metadata_field=team_alias_metadata, - exclude_entity_ids=exclude_team_ids_list, start_date=start_date, end_date=end_date, model=model, api_key=final_api_key_filter, - page=page, - page_size=page_size, + exclude_entity_ids=exclude_team_ids_list, + timezone_offset_minutes=timezone, + include_entity_breakdown=True, ) diff --git a/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py index b6ac974e2cf..0a2a7e0c432 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py @@ -5379,10 +5379,10 @@ async def test_get_team_daily_activity_non_admin_filters_by_user_api_keys( # Mock get_daily_activity to capture the api_key parameter with patch( - "litellm.proxy.management_endpoints.team_endpoints.get_daily_activity", + "litellm.proxy.management_endpoints.team_endpoints.get_daily_activity_aggregated", new_callable=AsyncMock, - ) as mock_get_daily_activity: - mock_get_daily_activity.return_value = MagicMock() + ) as mock_get_daily_activity_agg: + mock_get_daily_activity_agg.return_value = MagicMock() # Call the endpoint await get_team_daily_activity( @@ -5398,8 +5398,8 @@ async def test_get_team_daily_activity_non_admin_filters_by_user_api_keys( ) # Verify get_daily_activity was called with user's API keys as filter - mock_get_daily_activity.assert_called_once() - call_kwargs = mock_get_daily_activity.call_args[1] + mock_get_daily_activity_agg.assert_called_once() + call_kwargs = mock_get_daily_activity_agg.call_args[1] assert call_kwargs["api_key"] == ["user_key_1", "user_key_2"] assert call_kwargs["entity_id"] == [team_id] @@ -5464,10 +5464,10 @@ async def test_get_team_daily_activity_team_admin_sees_all_spend(mock_db_client) # Mock get_daily_activity to capture the api_key parameter with patch( - "litellm.proxy.management_endpoints.team_endpoints.get_daily_activity", + "litellm.proxy.management_endpoints.team_endpoints.get_daily_activity_aggregated", new_callable=AsyncMock, - ) as mock_get_daily_activity: - mock_get_daily_activity.return_value = MagicMock() + ) as mock_get_daily_activity_agg: + mock_get_daily_activity_agg.return_value = MagicMock() # Call the endpoint await get_team_daily_activity( @@ -5483,8 +5483,8 @@ async def test_get_team_daily_activity_team_admin_sees_all_spend(mock_db_client) ) # Verify get_daily_activity was called WITHOUT API key filtering - mock_get_daily_activity.assert_called_once() - call_kwargs = mock_get_daily_activity.call_args[1] + mock_get_daily_activity_agg.assert_called_once() + call_kwargs = mock_get_daily_activity_agg.call_args[1] assert call_kwargs["api_key"] is None assert call_kwargs["entity_id"] == [team_id] @@ -5553,10 +5553,10 @@ async def test_get_team_daily_activity_member_with_permission_sees_all_spend( # Mock get_daily_activity to capture the api_key parameter with patch( - "litellm.proxy.management_endpoints.team_endpoints.get_daily_activity", + "litellm.proxy.management_endpoints.team_endpoints.get_daily_activity_aggregated", new_callable=AsyncMock, - ) as mock_get_daily_activity: - mock_get_daily_activity.return_value = MagicMock() + ) as mock_get_daily_activity_agg: + mock_get_daily_activity_agg.return_value = MagicMock() # Call the endpoint await get_team_daily_activity( @@ -5572,8 +5572,8 @@ async def test_get_team_daily_activity_member_with_permission_sees_all_spend( ) # Verify get_daily_activity was called WITHOUT API key filtering - mock_get_daily_activity.assert_called_once() - call_kwargs = mock_get_daily_activity.call_args[1] + mock_get_daily_activity_agg.assert_called_once() + call_kwargs = mock_get_daily_activity_agg.call_args[1] assert call_kwargs["api_key"] is None assert call_kwargs["entity_id"] == [team_id] @@ -5652,10 +5652,10 @@ async def test_get_team_daily_activity_member_without_permission_filters_by_keys # Mock get_daily_activity to capture the api_key parameter with patch( - "litellm.proxy.management_endpoints.team_endpoints.get_daily_activity", + "litellm.proxy.management_endpoints.team_endpoints.get_daily_activity_aggregated", new_callable=AsyncMock, - ) as mock_get_daily_activity: - mock_get_daily_activity.return_value = MagicMock() + ) as mock_get_daily_activity_agg: + mock_get_daily_activity_agg.return_value = MagicMock() # Call the endpoint await get_team_daily_activity( @@ -5671,8 +5671,8 @@ async def test_get_team_daily_activity_member_without_permission_filters_by_keys ) # Verify get_daily_activity was called WITH API key filtering - mock_get_daily_activity.assert_called_once() - call_kwargs = mock_get_daily_activity.call_args[1] + mock_get_daily_activity_agg.assert_called_once() + call_kwargs = mock_get_daily_activity_agg.call_args[1] assert call_kwargs["api_key"] == ["user_key_abc", "user_key_def"] assert call_kwargs["entity_id"] == [team_id] @@ -5822,10 +5822,10 @@ async def test_get_team_daily_activity_non_admin_filters_by_user_api_keys( # Mock get_daily_activity to capture the api_key parameter with patch( - "litellm.proxy.management_endpoints.team_endpoints.get_daily_activity", + "litellm.proxy.management_endpoints.team_endpoints.get_daily_activity_aggregated", new_callable=AsyncMock, - ) as mock_get_daily_activity: - mock_get_daily_activity.return_value = MagicMock() + ) as mock_get_daily_activity_agg: + mock_get_daily_activity_agg.return_value = MagicMock() # Call the endpoint await get_team_daily_activity( @@ -5841,8 +5841,8 @@ async def test_get_team_daily_activity_non_admin_filters_by_user_api_keys( ) # Verify get_daily_activity was called with user's API keys as filter - mock_get_daily_activity.assert_called_once() - call_kwargs = mock_get_daily_activity.call_args[1] + mock_get_daily_activity_agg.assert_called_once() + call_kwargs = mock_get_daily_activity_agg.call_args[1] assert call_kwargs["api_key"] == ["user_key_1", "user_key_2"] assert call_kwargs["entity_id"] == [team_id] @@ -5907,10 +5907,10 @@ async def test_get_team_daily_activity_team_admin_sees_all_spend(mock_db_client) # Mock get_daily_activity to capture the api_key parameter with patch( - "litellm.proxy.management_endpoints.team_endpoints.get_daily_activity", + "litellm.proxy.management_endpoints.team_endpoints.get_daily_activity_aggregated", new_callable=AsyncMock, - ) as mock_get_daily_activity: - mock_get_daily_activity.return_value = MagicMock() + ) as mock_get_daily_activity_agg: + mock_get_daily_activity_agg.return_value = MagicMock() # Call the endpoint await get_team_daily_activity( @@ -5926,8 +5926,8 @@ async def test_get_team_daily_activity_team_admin_sees_all_spend(mock_db_client) ) # Verify get_daily_activity was called WITHOUT API key filtering - mock_get_daily_activity.assert_called_once() - call_kwargs = mock_get_daily_activity.call_args[1] + mock_get_daily_activity_agg.assert_called_once() + call_kwargs = mock_get_daily_activity_agg.call_args[1] assert call_kwargs["api_key"] is None assert call_kwargs["entity_id"] == [team_id] @@ -5939,6 +5939,56 @@ async def test_get_team_daily_activity_team_admin_sees_all_spend(mock_db_client) assert False, "API keys should not be fetched for team admin users" +@pytest.mark.asyncio +async def test_get_team_daily_activity_uses_aggregated_with_entity_breakdown( + mock_db_client, +): + """ + Test that /team/daily/activity calls get_daily_activity_aggregated + with include_entity_breakdown=True, timezone, and correct parameters. + """ + from litellm.proxy.management_endpoints.team_endpoints import ( + get_team_daily_activity, + ) + + user_api_key_dict = UserAPIKeyAuth( + user_id="admin_user", user_role=LitellmUserRoles.PROXY_ADMIN + ) + + # Mock the team table query for fetching team aliases + mock_db_client.db.litellm_teamtable.find_many = AsyncMock(return_value=[]) + + with patch( + "litellm.proxy.management_endpoints.team_endpoints.get_daily_activity_aggregated", + new_callable=AsyncMock, + ) as mock_get_daily_activity_agg: + mock_get_daily_activity_agg.return_value = MagicMock() + + await get_team_daily_activity( + team_ids="team_1,team_2", + start_date="2024-01-01", + end_date="2024-01-31", + model=None, + api_key=None, + page=1, + page_size=10, + exclude_team_ids="litellm-dashboard", + timezone=480, + user_api_key_dict=user_api_key_dict, + ) + + mock_get_daily_activity_agg.assert_called_once() + call_kwargs = mock_get_daily_activity_agg.call_args[1] + assert call_kwargs["table_name"] == "litellm_dailyteamspend" + assert call_kwargs["entity_id_field"] == "team_id" + assert call_kwargs["entity_id"] == ["team_1", "team_2"] + assert call_kwargs["exclude_entity_ids"] == ["litellm-dashboard"] + assert call_kwargs["start_date"] == "2024-01-01" + assert call_kwargs["end_date"] == "2024-01-31" + assert call_kwargs["timezone_offset_minutes"] == 480 + assert call_kwargs["include_entity_breakdown"] is True + + @pytest.mark.asyncio async def test_validate_and_populate_member_user_info_both_provided_match(): """