Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 101 additions & 6 deletions server/polar/meter/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import UTC, datetime
from typing import Any

import logfire
from sqlalchemy import (
ColumnElement,
ColumnExpressionArgument,
Expand All @@ -16,7 +17,7 @@
or_,
select,
)
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import aliased, joinedload

from polar.auth.models import AuthSubject, Organization, User
from polar.billing_entry.repository import BillingEntryRepository
Expand All @@ -35,6 +36,7 @@
Customer,
Event,
Meter,
MeterEvent,
Product,
ProductPriceMeteredUnit,
SubscriptionProductPrice,
Expand Down Expand Up @@ -466,9 +468,10 @@ async def _create_subscription_holder_billing_entry(
BillingEntry.from_metered_event(customer, subscription_product_price, event)
)

async def create_billing_entries(
async def _create_billing_entries_old(
self, session: AsyncSession, meter: Meter
) -> Sequence[BillingEntry]:
) -> tuple[Sequence[BillingEntry], Sequence[uuid.UUID], Event | None]:
"""Old implementation using JSONB filters on events table."""
event_repository = EventRepository.from_session(session)
statement = (
event_repository.get_base_statement()
Expand Down Expand Up @@ -498,11 +501,12 @@ async def create_billing_entries(
uuid.UUID, CustomerSubscriptionProductPrice | None
] = {}

entries: list[BillingEntry] = []
updated_subscriptions: set[uuid.UUID] = set()
entries = []
event_ids = []
last_event: Event | None = None
async for event in event_repository.stream(statement):
last_event = event
event_ids.append(event.id)
customer = event.customer
assert customer is not None

Expand Down Expand Up @@ -530,11 +534,102 @@ async def create_billing_entries(
customer_price.subscription_product_price,
)
entries.append(entry)

return entries, event_ids, last_event

async def _create_billing_entries_new(
self, session: AsyncSession, meter: Meter
) -> tuple[Sequence[uuid.UUID], Event | None]:
"""New implementation using meter_events table."""
event_repository = EventRepository.from_session(session)
BillableCustomer = aliased(Customer)
statement = (
event_repository.get_base_statement()
.join(MeterEvent, MeterEvent.event_id == Event.id)
.join(
BillableCustomer,
or_(
BillableCustomer.id == MeterEvent.customer_id,
and_(
MeterEvent.customer_id.is_(None),
BillableCustomer.external_id == MeterEvent.external_customer_id,
BillableCustomer.organization_id == MeterEvent.organization_id,
),
),
)
.where(MeterEvent.meter_id == meter.id)
.order_by(MeterEvent.ingested_at.asc())
)
last_billed_event = meter.last_billed_event
if last_billed_event is not None:
statement = statement.where(
MeterEvent.ingested_at > last_billed_event.ingested_at
)

event_ids = []
last_event: Event | None = None
async for event in event_repository.stream(statement):
last_event = event
event_ids.append(event.id)

return event_ids, last_event

async def create_billing_entries(
self, session: AsyncSession, meter: Meter
) -> Sequence[BillingEntry]:
last_billed_event = meter.last_billed_event

with logfire.span("create_billing_entries.old", meter_id=str(meter.id)):
(
entries,
old_event_ids,
old_last_event,
) = await self._create_billing_entries_old(session, meter)
logfire.info(
"Old implementation completed",
event_count=len(old_event_ids),
last_event_id=str(old_last_event.id) if old_last_event else None,
)

with logfire.span("create_billing_entries.new", meter_id=str(meter.id)):
new_event_ids, new_last_event = await self._create_billing_entries_new(
session, meter
)
logfire.info(
"New implementation completed",
event_count=len(new_event_ids),
last_event_id=str(new_last_event.id) if new_last_event else None,
)

old_set: set[uuid.UUID] = set(old_event_ids)
new_set: set[uuid.UUID] = set(new_event_ids)
if old_set != new_set:
only_in_old = old_set - new_set
only_in_new = new_set - old_set
logfire.error(
"Billing entries mismatch between old and new implementations",
meter_id=str(meter.id),
old_count=len(old_event_ids),
new_count=len(new_event_ids),
only_in_old_count=len(only_in_old),
only_in_new_count=len(only_in_new),
only_in_old=[str(e) for e in list(only_in_old)[:10]],
only_in_new=[str(e) for e in list(only_in_new)[:10]],
)
else:
logfire.info(
"Billing entries match between implementations",
meter_id=str(meter.id),
count=len(old_event_ids),
)

updated_subscriptions: set[uuid.UUID] = set()
for entry in entries:
if entry.subscription is not None:
updated_subscriptions.add(entry.subscription.id)

meter.last_billed_event = (
last_event if last_event is not None else last_billed_event
old_last_event if old_last_event is not None else last_billed_event
)
session.add(meter)

Expand Down
78 changes: 76 additions & 2 deletions server/tests/meter/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from polar.auth.models import AuthSubject
from polar.enums import SubscriptionRecurringInterval
from polar.event.service import event as event_service
from polar.event.system import SystemEvent
from polar.exceptions import PolarRequestValidationError
from polar.kit.time_queries import TimeInterval
Expand All @@ -27,6 +28,7 @@
Customer,
Event,
Meter,
MeterEvent,
Organization,
Product,
Subscription,
Expand Down Expand Up @@ -836,10 +838,10 @@ async def meter(save_fixture: SaveFixture, organization: Organization) -> Meter:

@pytest_asyncio.fixture
async def events(
save_fixture: SaveFixture, customer: Customer, meter: Meter
save_fixture: SaveFixture, session: AsyncSession, customer: Customer, meter: Meter
) -> list[Event]:
timestamp = utc_now()
return [
events = [
await create_event(
save_fixture,
timestamp=timestamp + timedelta(seconds=1),
Expand Down Expand Up @@ -895,6 +897,8 @@ async def events(
metadata={"units": 10, "meter_id": str(uuid.uuid4())},
),
]
await event_service._create_meter_events(session, events)
return events


@pytest_asyncio.fixture
Expand Down Expand Up @@ -996,6 +1000,72 @@ async def test_last_billed_event(
"subscription.update_meters", metered_subscription.id
)

async def test_external_customer_id_resolved_to_customer(
self,
enqueue_job_mock: AsyncMock,
save_fixture: SaveFixture,
session: AsyncSession,
organization: Organization,
) -> None:
"""Test that events with only external_customer_id are billed when customer exists."""
customer = await create_customer(
save_fixture,
organization=organization,
external_id="ext_customer_123",
)

meter = await create_meter(
save_fixture,
organization=organization,
filter=Filter(
conjunction=FilterConjunction.and_,
clauses=[
FilterClause(
property="name",
operator=FilterOperator.eq,
value=METER_TEST_EVENT,
)
],
),
aggregation=CountAggregation(),
)

product = await create_product(
save_fixture,
organization=organization,
recurring_interval=SubscriptionRecurringInterval.month,
prices=[(meter, Decimal(100), None)],
)

subscription = await create_active_subscription(
save_fixture, product=product, customer=customer
)
await session.refresh(subscription, ["subscription_product_prices"])

event = await create_event(
save_fixture,
organization=organization,
customer=None,
external_customer_id="ext_customer_123",
)

meter_event = MeterEvent(
meter_id=meter.id,
event_id=event.id,
customer_id=None,
external_customer_id="ext_customer_123",
organization_id=organization.id,
ingested_at=event.ingested_at,
timestamp=event.timestamp,
)
await save_fixture(meter_event)

entries = await meter_service.create_billing_entries(session, meter)

assert len(entries) == 1
assert entries[0].customer == customer
assert entries[0].subscription == subscription


@pytest.mark.asyncio
class TestCreateBillingEntriesWithSeats:
Expand Down Expand Up @@ -1083,6 +1153,7 @@ async def test_seat_holder_overage_charges_billing_manager(
metadata={"tokens": 10, "model": "lite"},
),
]
await event_service._create_meter_events(session, events)

entries = await meter_service.create_billing_entries(session, meter)

Expand Down Expand Up @@ -1172,6 +1243,7 @@ async def test_seat_holder_without_metered_pricing(
metadata={"tokens": 20, "model": "lite"},
),
]
await event_service._create_meter_events(session, events)

entries = await meter_service.create_billing_entries(session, meter)

Expand Down Expand Up @@ -1290,6 +1362,7 @@ async def test_multiple_seat_holders_same_subscription(
metadata={"tokens": 15, "model": "lite"},
),
]
await event_service._create_meter_events(session, events)

entries = await meter_service.create_billing_entries(session, meter)

Expand Down Expand Up @@ -1411,6 +1484,7 @@ async def test_billing_manager_is_seat_holder(
metadata={"tokens": 15, "model": "lite"},
),
]
await event_service._create_meter_events(session, events)

entries = await meter_service.create_billing_entries(session, meter)

Expand Down