Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 56 additions & 22 deletions server/polar/meter/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,16 +346,24 @@ async def get_quantities(
day_column = interval.sql_date_trunc(Event.timestamp)
truncated_timestamp = interval.sql_date_trunc(timestamp_column)

# Determine the appropriate SQL function for the running total calculation.
# For summable aggregations (count, sum), we can sum the daily values.
# For non-summable aggregations (max, min), we must use the same aggregation
# function to get the correct total (e.g., max of daily maxes = overall max).
# Note: avg and unique require special handling that's not implemented here -
# avg would need weighted averages, unique would need to avoid double counting.
if meter.aggregation.is_summable():
total_agg_func = AggregationFunction.sum.get_sql_function
else:
total_agg_func = meter.aggregation.func.get_sql_function
# Determine if we can use the optimized CTE path for the total calculation.
# - Summable aggregations (count, sum): sum of daily values = correct total
# - Max/Min: max/min of daily values = correct total
# - Avg/Unique: cannot be computed from daily aggregates, need direct query
use_optimized_total = (
meter.aggregation.is_summable()
or meter.aggregation.func
in (
AggregationFunction.max,
AggregationFunction.min,
)
)

if use_optimized_total:
if meter.aggregation.is_summable():
total_agg_func = AggregationFunction.sum.get_sql_function
else:
total_agg_func = meter.aggregation.func.get_sql_function

if customer_aggregation_function is not None:
daily_metrics = cte(
Expand All @@ -375,16 +383,29 @@ async def get_quantities(
).label("quantity"),
).group_by(daily_metrics.c.day)
)

if use_optimized_total:
total_column = func.coalesce(
total_agg_func(daily_aggregated.c.quantity).over(
order_by=timestamp_column
),
0,
)
else:
# For avg/unique: compute total directly over all events via subquery
# This is slower but necessary for correctness
total_subquery = (
select(meter.aggregation.get_sql_column(Event))
.where(and_(*event_clauses))
.scalar_subquery()
)
total_column = func.coalesce(total_subquery, 0)

statement = (
select(
timestamp_column.label("timestamp"),
func.coalesce(daily_aggregated.c.quantity, 0).label("quantity"),
func.coalesce(
total_agg_func(daily_aggregated.c.quantity).over(
order_by=timestamp_column
),
0,
).label("total"),
total_column.label("total"),
)
.select_from(
timestamp_series.join(
Expand All @@ -404,16 +425,29 @@ async def get_quantities(
.where(and_(*event_clauses))
.group_by(day_column)
)

if use_optimized_total:
total_column = func.coalesce(
total_agg_func(daily_metrics.c.quantity).over(
order_by=timestamp_column
),
0,
)
else:
# For avg/unique: compute total directly over all events via subquery
# This is slower but necessary for correctness
total_subquery = (
select(meter.aggregation.get_sql_column(Event))
.where(and_(*event_clauses))
.scalar_subquery()
)
total_column = func.coalesce(total_subquery, 0)

statement = (
select(
timestamp_column.label("timestamp"),
func.coalesce(daily_metrics.c.quantity, 0).label("quantity"),
func.coalesce(
total_agg_func(daily_metrics.c.quantity).over(
order_by=timestamp_column
),
0,
).label("total"),
total_column.label("total"),
)
.select_from(
timestamp_series.join(
Expand Down
156 changes: 156 additions & 0 deletions server/tests/meter/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
AggregationFunction,
CountAggregation,
PropertyAggregation,
UniqueAggregation,
)
from polar.meter.filter import Filter, FilterClause, FilterConjunction, FilterOperator
from polar.meter.schemas import MeterCreate, MeterUpdate
Expand Down Expand Up @@ -443,6 +444,14 @@ async def test_interval(
pytest.param(
AggregationFunction.sum, [30, 0, 20], 50, id="sum aggregation"
),
# For avg: total should be the true average across ALL events, not avg of daily avgs
# Day 1: avg(10, 20) = 15, Day 2: NULL (no events), Day 3: avg(15, 5) = 10
# True average: (10 + 20 + 15 + 5) / 4 = 12.5
# Wrong (avg of avgs): (15 + 10) / 2 = 12.5 -- happens to match in this case!
# Let's verify with the actual events: 4 events total
pytest.param(
AggregationFunction.avg, [15, 0, 10], 12.5, id="avg aggregation"
),
],
)
async def test_interval_non_summable_aggregation(
Expand Down Expand Up @@ -539,6 +548,153 @@ async def test_interval_non_summable_aggregation(
# not always SUM
assert result.total == expected_total

async def test_avg_aggregation_with_unequal_event_counts(
self,
save_fixture: SaveFixture,
session: AsyncSession,
customer: Customer,
) -> None:
"""Test that avg total is the true average, not average of daily averages.

With unequal event counts per day, avg of avgs differs from true avg:
- Day 1: 3 events [10, 10, 10] → daily avg = 10
- Day 2: 1 event [50] → daily avg = 50
- True average: (10 + 10 + 10 + 50) / 4 = 20
- Wrong (avg of daily avgs): (10 + 50) / 2 = 30
"""
past_timestamp = utc_now() - timedelta(days=1)
future_timestamp = utc_now() + timedelta(days=1)

# Day 1: 3 events with value 10 each
for _ in range(3):
await create_event(
save_fixture,
timestamp=past_timestamp,
organization=customer.organization,
customer=customer,
metadata={"tokens": 10, "model": "lite"},
)

# Day 2: 1 event with value 50
await create_event(
save_fixture,
timestamp=future_timestamp,
organization=customer.organization,
customer=customer,
metadata={"tokens": 50, "model": "lite"},
)

meter = await create_meter(
save_fixture,
name="Token Usage",
filter=Filter(
conjunction=FilterConjunction.and_,
clauses=[
FilterClause(
property="model", operator=FilterOperator.eq, value="lite"
)
],
),
aggregation=PropertyAggregation(
func=AggregationFunction.avg, property="tokens"
),
organization=customer.organization,
)

result = await meter_service.get_quantities(
session,
meter,
customer_id=[customer.id],
start_timestamp=past_timestamp,
end_timestamp=future_timestamp,
interval=TimeInterval.day,
timezone=ZoneInfo("UTC"),
)

assert len(result.quantities) == 3

[day1, day2, day3] = result.quantities
assert day1.quantity == 10 # avg(10, 10, 10) = 10
assert day2.quantity == 0 # no events
assert day3.quantity == 50 # avg(50) = 50

# True average: (10 + 10 + 10 + 50) / 4 = 20
# If we wrongly computed avg of daily avgs: (10 + 50) / 2 = 30
assert result.total == 20

async def test_unique_aggregation_across_days(
self,
save_fixture: SaveFixture,
session: AsyncSession,
customer: Customer,
) -> None:
"""Test that unique total counts distinct values across all days.

Same value appearing on multiple days should only be counted once:
- Day 1: user_ids ["a", "b", "c"] → daily unique = 3
- Day 2: user_ids ["b", "c", "d"] → daily unique = 3
- True unique across all days: ["a", "b", "c", "d"] = 4
- Wrong (sum of daily uniques): 3 + 3 = 6
"""
past_timestamp = utc_now() - timedelta(days=1)
future_timestamp = utc_now() + timedelta(days=1)

# Day 1: users a, b, c
for user_id in ["a", "b", "c"]:
await create_event(
save_fixture,
timestamp=past_timestamp,
organization=customer.organization,
customer=customer,
metadata={"user_id": user_id, "model": "lite"},
)

# Day 2: users b, c, d (b and c overlap with day 1)
for user_id in ["b", "c", "d"]:
await create_event(
save_fixture,
timestamp=future_timestamp,
organization=customer.organization,
customer=customer,
metadata={"user_id": user_id, "model": "lite"},
)

meter = await create_meter(
save_fixture,
name="Unique Users",
filter=Filter(
conjunction=FilterConjunction.and_,
clauses=[
FilterClause(
property="model", operator=FilterOperator.eq, value="lite"
)
],
),
aggregation=UniqueAggregation(property="user_id"),
organization=customer.organization,
)

result = await meter_service.get_quantities(
session,
meter,
customer_id=[customer.id],
start_timestamp=past_timestamp,
end_timestamp=future_timestamp,
interval=TimeInterval.day,
timezone=ZoneInfo("UTC"),
)

assert len(result.quantities) == 3

[day1, day2, day3] = result.quantities
assert day1.quantity == 3 # unique(a, b, c) = 3
assert day2.quantity == 0 # no events
assert day3.quantity == 3 # unique(b, c, d) = 3

# True unique: count(distinct a, b, c, d) = 4
# If we wrongly summed daily uniques: 3 + 3 = 6
assert result.total == 4

@pytest.mark.parametrize(
"property",
[
Expand Down