diff --git a/server/polar/meter/service.py b/server/polar/meter/service.py index fa71ad2e38..f965c26a89 100644 --- a/server/polar/meter/service.py +++ b/server/polar/meter/service.py @@ -339,6 +339,17 @@ async def get_quantities( day_column = interval.sql_date_trunc(Event.timestamp) truncated_timestamp = interval.sql_date_trunc(timestamp_column) + # Determine the appropriate SQL function for the running total calculation. + # For summable aggregations (count, sum), we can sum the daily values. + # For non-summable aggregations (max, min), we must use the same aggregation + # function to get the correct total (e.g., max of daily maxes = overall max). + # Note: avg and unique require special handling that's not implemented here - + # avg would need weighted averages, unique would need to avoid double counting. + if meter.aggregation.is_summable(): + total_agg_func = AggregationFunction.sum.get_sql_function + else: + total_agg_func = meter.aggregation.func.get_sql_function + if customer_aggregation_function is not None: daily_metrics = cte( select( @@ -362,7 +373,7 @@ async def get_quantities( timestamp_column.label("timestamp"), func.coalesce(daily_aggregated.c.quantity, 0).label("quantity"), func.coalesce( - func.sum(daily_aggregated.c.quantity).over( + total_agg_func(daily_aggregated.c.quantity).over( order_by=timestamp_column ), 0, @@ -391,7 +402,7 @@ async def get_quantities( timestamp_column.label("timestamp"), func.coalesce(daily_metrics.c.quantity, 0).label("quantity"), func.coalesce( - func.sum(daily_metrics.c.quantity).over( + total_agg_func(daily_metrics.c.quantity).over( order_by=timestamp_column ), 0, diff --git a/server/tests/meter/test_service.py b/server/tests/meter/test_service.py index 3c8202618b..dc06e8aca5 100644 --- a/server/tests/meter/test_service.py +++ b/server/tests/meter/test_service.py @@ -1,6 +1,7 @@ import uuid from datetime import timedelta from decimal import Decimal +from typing import Literal from unittest.mock import AsyncMock import pytest @@ -419,6 +420,118 @@ async def test_interval( assert result.total == 600 + @pytest.mark.parametrize( + ("aggregation_func", "expected_daily_quantities", "expected_total"), + [ + # For max: total should be max across all days with data (NULLs from empty days are ignored) + # Day 1: max(10, 20) = 20, Day 2: NULL (no events), Day 3: max(15, 5) = 15 + # Total: max(20, NULL, 15) = 20, NOT sum(20 + 0 + 15) = 35 + pytest.param( + AggregationFunction.max, [20, 0, 15], 20, id="max aggregation" + ), + # For min: total should be min across all days with data (NULLs from empty days are ignored) + # Day 1: min(10, 20) = 10, Day 2: NULL (no events), Day 3: min(15, 5) = 5 + # Total: min(10, NULL, 5) = 5 (SQL MIN ignores NULLs) + pytest.param(AggregationFunction.min, [10, 0, 5], 5, id="min aggregation"), + # For sum: total should be sum across all days (this is summable, so sum is correct) + pytest.param( + AggregationFunction.sum, [30, 0, 20], 50, id="sum aggregation" + ), + ], + ) + async def test_interval_non_summable_aggregation( + self, + aggregation_func: Literal[ + AggregationFunction.sum, + AggregationFunction.max, + AggregationFunction.min, + AggregationFunction.avg, + ], + expected_daily_quantities: list[int], + expected_total: int, + save_fixture: SaveFixture, + session: AsyncSession, + customer: Customer, + ) -> None: + """Test that total is computed correctly for non-summable aggregations over multiple days. + + Regression test for bug introduced in commit 668ea64 where the total was always + computed using SUM, even for non-summable aggregations like MAX, MIN, AVG. + """ + past_timestamp = utc_now() - timedelta(days=1) + today_timestamp = utc_now() + future_timestamp = utc_now() + timedelta(days=1) + + # Day 1 (past): two events with values 10 and 20 + await create_event( + save_fixture, + timestamp=past_timestamp, + organization=customer.organization, + customer=customer, + metadata={"tokens": 10, "model": "lite"}, + ) + await create_event( + save_fixture, + timestamp=past_timestamp, + organization=customer.organization, + customer=customer, + metadata={"tokens": 20, "model": "lite"}, + ) + + # Day 2 (today): no events + + # Day 3 (future): two events with values 15 and 5 + await create_event( + save_fixture, + timestamp=future_timestamp, + organization=customer.organization, + customer=customer, + metadata={"tokens": 15, "model": "lite"}, + ) + await create_event( + save_fixture, + timestamp=future_timestamp, + organization=customer.organization, + customer=customer, + metadata={"tokens": 5, "model": "lite"}, + ) + + meter = await create_meter( + save_fixture, + name="Token Usage", + filter=Filter( + conjunction=FilterConjunction.and_, + clauses=[ + FilterClause( + property="model", operator=FilterOperator.eq, value="lite" + ) + ], + ), + aggregation=PropertyAggregation(func=aggregation_func, property="tokens"), + organization=customer.organization, + ) + + result = await meter_service.get_quantities( + session, + meter, + customer_id=[customer.id], + start_timestamp=past_timestamp, + end_timestamp=future_timestamp, + interval=TimeInterval.day, + ) + + assert len(result.quantities) == 3 + + [yesterday_quantity, today_quantity, tomorrow_quantity] = result.quantities + + assert yesterday_quantity.quantity == expected_daily_quantities[0] + assert today_quantity.quantity == expected_daily_quantities[1] + assert tomorrow_quantity.quantity == expected_daily_quantities[2] + + # This is the key assertion - total should use the meter's aggregation, + # not always SUM + assert result.total == expected_total + @pytest.mark.parametrize( "property", [