From 58940533bb7926f7b22f12cd41ee88393fcbbd44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jesper=20Br=C3=A4nn?= Date: Wed, 7 Jan 2026 11:49:02 +0100 Subject: [PATCH] server/meter: update create_billing_entries compare with meter_events table MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We introduce the ability to ecompare billing entries fetched via meter_events with the previous functionality. This allows us to roll this out as a comparison initially. The ambition is that this will eliminate the expensive JSONB filter evaluation that was causing 703 second query times on large datasets. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- server/polar/meter/service.py | 107 +++++++++++++++++++++++++++-- server/tests/meter/test_service.py | 78 ++++++++++++++++++++- 2 files changed, 177 insertions(+), 8 deletions(-) diff --git a/server/polar/meter/service.py b/server/polar/meter/service.py index 897965156b..78e5aec6a8 100644 --- a/server/polar/meter/service.py +++ b/server/polar/meter/service.py @@ -3,6 +3,7 @@ from datetime import UTC, datetime from typing import Any +import logfire from sqlalchemy import ( ColumnElement, ColumnExpressionArgument, @@ -16,7 +17,7 @@ or_, select, ) -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import aliased, joinedload from polar.auth.models import AuthSubject, Organization, User from polar.billing_entry.repository import BillingEntryRepository @@ -35,6 +36,7 @@ Customer, Event, Meter, + MeterEvent, Product, ProductPriceMeteredUnit, SubscriptionProductPrice, @@ -466,9 +468,10 @@ async def _create_subscription_holder_billing_entry( BillingEntry.from_metered_event(customer, subscription_product_price, event) ) - async def create_billing_entries( + async def _create_billing_entries_old( self, session: AsyncSession, meter: Meter - ) -> Sequence[BillingEntry]: + ) -> tuple[Sequence[BillingEntry], Sequence[uuid.UUID], Event | None]: + """Old implementation using JSONB filters on events table.""" event_repository = EventRepository.from_session(session) statement = ( event_repository.get_base_statement() @@ -498,11 +501,12 @@ async def create_billing_entries( uuid.UUID, CustomerSubscriptionProductPrice | None ] = {} - entries: list[BillingEntry] = [] - updated_subscriptions: set[uuid.UUID] = set() + entries = [] + event_ids = [] last_event: Event | None = None async for event in event_repository.stream(statement): last_event = event + event_ids.append(event.id) customer = event.customer assert customer is not None @@ -530,11 +534,102 @@ async def create_billing_entries( customer_price.subscription_product_price, ) entries.append(entry) + + return entries, event_ids, last_event + + async def _create_billing_entries_new( + self, session: AsyncSession, meter: Meter + ) -> tuple[Sequence[uuid.UUID], Event | None]: + """New implementation using meter_events table.""" + event_repository = EventRepository.from_session(session) + BillableCustomer = aliased(Customer) + statement = ( + event_repository.get_base_statement() + .join(MeterEvent, MeterEvent.event_id == Event.id) + .join( + BillableCustomer, + or_( + BillableCustomer.id == MeterEvent.customer_id, + and_( + MeterEvent.customer_id.is_(None), + BillableCustomer.external_id == MeterEvent.external_customer_id, + BillableCustomer.organization_id == MeterEvent.organization_id, + ), + ), + ) + .where(MeterEvent.meter_id == meter.id) + .order_by(MeterEvent.ingested_at.asc()) + ) + last_billed_event = meter.last_billed_event + if last_billed_event is not None: + statement = statement.where( + MeterEvent.ingested_at > last_billed_event.ingested_at + ) + + event_ids = [] + last_event: Event | None = None + async for event in event_repository.stream(statement): + last_event = event + event_ids.append(event.id) + + return event_ids, last_event + + async def create_billing_entries( + self, session: AsyncSession, meter: Meter + ) -> Sequence[BillingEntry]: + last_billed_event = meter.last_billed_event + + with logfire.span("create_billing_entries.old", meter_id=str(meter.id)): + ( + entries, + old_event_ids, + old_last_event, + ) = await self._create_billing_entries_old(session, meter) + logfire.info( + "Old implementation completed", + event_count=len(old_event_ids), + last_event_id=str(old_last_event.id) if old_last_event else None, + ) + + with logfire.span("create_billing_entries.new", meter_id=str(meter.id)): + new_event_ids, new_last_event = await self._create_billing_entries_new( + session, meter + ) + logfire.info( + "New implementation completed", + event_count=len(new_event_ids), + last_event_id=str(new_last_event.id) if new_last_event else None, + ) + + old_set: set[uuid.UUID] = set(old_event_ids) + new_set: set[uuid.UUID] = set(new_event_ids) + if old_set != new_set: + only_in_old = old_set - new_set + only_in_new = new_set - old_set + logfire.error( + "Billing entries mismatch between old and new implementations", + meter_id=str(meter.id), + old_count=len(old_event_ids), + new_count=len(new_event_ids), + only_in_old_count=len(only_in_old), + only_in_new_count=len(only_in_new), + only_in_old=[str(e) for e in list(only_in_old)[:10]], + only_in_new=[str(e) for e in list(only_in_new)[:10]], + ) + else: + logfire.info( + "Billing entries match between implementations", + meter_id=str(meter.id), + count=len(old_event_ids), + ) + + updated_subscriptions: set[uuid.UUID] = set() + for entry in entries: if entry.subscription is not None: updated_subscriptions.add(entry.subscription.id) meter.last_billed_event = ( - last_event if last_event is not None else last_billed_event + old_last_event if old_last_event is not None else last_billed_event ) session.add(meter) diff --git a/server/tests/meter/test_service.py b/server/tests/meter/test_service.py index dc06e8aca5..19c1937a9a 100644 --- a/server/tests/meter/test_service.py +++ b/server/tests/meter/test_service.py @@ -10,6 +10,7 @@ from polar.auth.models import AuthSubject from polar.enums import SubscriptionRecurringInterval +from polar.event.service import event as event_service from polar.event.system import SystemEvent from polar.exceptions import PolarRequestValidationError from polar.kit.time_queries import TimeInterval @@ -27,6 +28,7 @@ Customer, Event, Meter, + MeterEvent, Organization, Product, Subscription, @@ -836,10 +838,10 @@ async def meter(save_fixture: SaveFixture, organization: Organization) -> Meter: @pytest_asyncio.fixture async def events( - save_fixture: SaveFixture, customer: Customer, meter: Meter + save_fixture: SaveFixture, session: AsyncSession, customer: Customer, meter: Meter ) -> list[Event]: timestamp = utc_now() - return [ + events = [ await create_event( save_fixture, timestamp=timestamp + timedelta(seconds=1), @@ -895,6 +897,8 @@ async def events( metadata={"units": 10, "meter_id": str(uuid.uuid4())}, ), ] + await event_service._create_meter_events(session, events) + return events @pytest_asyncio.fixture @@ -996,6 +1000,72 @@ async def test_last_billed_event( "subscription.update_meters", metered_subscription.id ) + async def test_external_customer_id_resolved_to_customer( + self, + enqueue_job_mock: AsyncMock, + save_fixture: SaveFixture, + session: AsyncSession, + organization: Organization, + ) -> None: + """Test that events with only external_customer_id are billed when customer exists.""" + customer = await create_customer( + save_fixture, + organization=organization, + external_id="ext_customer_123", + ) + + meter = await create_meter( + save_fixture, + organization=organization, + filter=Filter( + conjunction=FilterConjunction.and_, + clauses=[ + FilterClause( + property="name", + operator=FilterOperator.eq, + value=METER_TEST_EVENT, + ) + ], + ), + aggregation=CountAggregation(), + ) + + product = await create_product( + save_fixture, + organization=organization, + recurring_interval=SubscriptionRecurringInterval.month, + prices=[(meter, Decimal(100), None)], + ) + + subscription = await create_active_subscription( + save_fixture, product=product, customer=customer + ) + await session.refresh(subscription, ["subscription_product_prices"]) + + event = await create_event( + save_fixture, + organization=organization, + customer=None, + external_customer_id="ext_customer_123", + ) + + meter_event = MeterEvent( + meter_id=meter.id, + event_id=event.id, + customer_id=None, + external_customer_id="ext_customer_123", + organization_id=organization.id, + ingested_at=event.ingested_at, + timestamp=event.timestamp, + ) + await save_fixture(meter_event) + + entries = await meter_service.create_billing_entries(session, meter) + + assert len(entries) == 1 + assert entries[0].customer == customer + assert entries[0].subscription == subscription + @pytest.mark.asyncio class TestCreateBillingEntriesWithSeats: @@ -1083,6 +1153,7 @@ async def test_seat_holder_overage_charges_billing_manager( metadata={"tokens": 10, "model": "lite"}, ), ] + await event_service._create_meter_events(session, events) entries = await meter_service.create_billing_entries(session, meter) @@ -1172,6 +1243,7 @@ async def test_seat_holder_without_metered_pricing( metadata={"tokens": 20, "model": "lite"}, ), ] + await event_service._create_meter_events(session, events) entries = await meter_service.create_billing_entries(session, meter) @@ -1290,6 +1362,7 @@ async def test_multiple_seat_holders_same_subscription( metadata={"tokens": 15, "model": "lite"}, ), ] + await event_service._create_meter_events(session, events) entries = await meter_service.create_billing_entries(session, meter) @@ -1411,6 +1484,7 @@ async def test_billing_manager_is_seat_holder( metadata={"tokens": 15, "model": "lite"}, ), ] + await event_service._create_meter_events(session, events) entries = await meter_service.create_billing_entries(session, meter)