diff --git a/server/polar/event/service.py b/server/polar/event/service.py index 9bf9ee29f1..97edfbdf18 100644 --- a/server/polar/event/service.py +++ b/server/polar/event/service.py @@ -39,6 +39,8 @@ CustomerMeter, Event, EventClosure, + Meter, + MeterEvent, Organization, User, UserOrganization, @@ -59,6 +61,7 @@ StatisticsPeriod, ) from .sorting import EventNamesSortProperty, EventSortProperty +from .system import SystemEvent log: Logger = structlog.get_logger() @@ -731,6 +734,8 @@ async def ingested( if "_cost" in event.user_metadata: organization_ids_for_revops.add(event.organization_id) + await self._create_meter_events(session, events) + await self._activate_matching_customer_meters( session, repository, event_ids, customers ) @@ -742,6 +747,57 @@ async def ingested( organization_repository = OrganizationRepository.from_session(session) await organization_repository.enable_revops(organization_ids_for_revops) + async def _create_meter_events( + self, session: AsyncSession, events: Sequence[Event] + ) -> None: + if not events: + return + + events_by_org: dict[uuid.UUID, list[Event]] = {} + for event in events: + events_by_org.setdefault(event.organization_id, []).append(event) + + meter_repository = MeterRepository.from_session(session) + meter_event_rows: list[dict[str, Any]] = [] + + for org_id, org_events in events_by_org.items(): + meters = await meter_repository.get_all( + meter_repository.get_base_statement().where( + Meter.organization_id == org_id, + Meter.archived_at.is_(None), + ) + ) + + for event in org_events: + for meter in meters: + if self._event_matches_meter(event, meter): + meter_event_rows.append( + { + "meter_id": meter.id, + "event_id": event.id, + "customer_id": event.customer_id, + "external_customer_id": event.external_customer_id, + "organization_id": event.organization_id, + "ingested_at": event.ingested_at, + "timestamp": event.timestamp, + } + ) + + if meter_event_rows: + await session.execute( + insert(MeterEvent).values(meter_event_rows).on_conflict_do_nothing() + ) + + def _event_matches_meter(self, event: Event, meter: Meter) -> bool: + if ( + event.source == EventSource.system + and event.name in (SystemEvent.meter_credited, SystemEvent.meter_reset) + and event.user_metadata.get("meter_id") == str(meter.id) + ): + return True + + return meter.filter.matches(event) and meter.aggregation.matches(event) + async def _activate_matching_customer_meters( self, session: AsyncSession, diff --git a/server/polar/meter/aggregation.py b/server/polar/meter/aggregation.py index cdc19033a3..4585a9220a 100644 --- a/server/polar/meter/aggregation.py +++ b/server/polar/meter/aggregation.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from enum import StrEnum -from typing import Annotated, Any, Literal +from typing import TYPE_CHECKING, Annotated, Any, Literal from pydantic import AfterValidator, BaseModel, Discriminator, TypeAdapter from sqlalchemy import ( @@ -13,6 +15,9 @@ ) from sqlalchemy.dialects.postgresql import JSONB +if TYPE_CHECKING: + from polar.models import Event + class AggregationFunction(StrEnum): cnt = "count" # `count` is a reserved keyword, so we use `cnt` as key @@ -54,6 +59,9 @@ def is_summable(self) -> bool: """ return True + def matches(self, event: Event) -> bool: + return True + def _strip_metadata_prefix(value: str) -> str: prefix = "metadata." @@ -92,6 +100,12 @@ def is_summable(self) -> bool: """ return self.func == AggregationFunction.sum + def matches(self, event: Event) -> bool: + if self.property in ("name", "source", "timestamp"): + return True + value = event.user_metadata.get(self.property) + return isinstance(value, int | float) + class UniqueAggregation(BaseModel): func: Literal[AggregationFunction.unique] = AggregationFunction.unique @@ -112,6 +126,9 @@ def is_summable(self) -> bool: """ return False + def matches(self, event: Event) -> bool: + return True + _Aggregation = CountAggregation | PropertyAggregation | UniqueAggregation Aggregation = Annotated[_Aggregation, Discriminator("func")] diff --git a/server/polar/meter/filter.py b/server/polar/meter/filter.py index b40ff05196..750506d78a 100644 --- a/server/polar/meter/filter.py +++ b/server/polar/meter/filter.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from enum import StrEnum -from typing import Annotated, Any +from typing import TYPE_CHECKING, Annotated, Any from annotated_types import Ge, Le, MaxLen from pydantic import AfterValidator, BaseModel, ConfigDict @@ -17,6 +19,9 @@ ) from sqlalchemy.dialects.postgresql import JSONB +if TYPE_CHECKING: + from polar.models import Event + # PostgreSQL int4 range limits INT_MIN_VALUE = -2_147_483_648 INT_MAX_VALUE = 2_147_483_647 @@ -125,6 +130,39 @@ def _get_number_value(self) -> int: return 1 if self.value else 0 return self.value + def matches(self, event: Event) -> bool: + if self.property == "name": + actual_value: Any = event.name + elif self.property == "source": + actual_value = event.source + elif self.property == "timestamp": + actual_value = int(event.timestamp.timestamp()) + else: + actual_value = event.user_metadata.get(self.property) + if actual_value is None: + return False + + return self._compare(actual_value, self.value) + + def _compare(self, actual: Any, expected: str | int | bool) -> bool: + if self.operator == FilterOperator.eq: + return actual == expected + elif self.operator == FilterOperator.ne: + return actual != expected + elif self.operator == FilterOperator.gt: + return actual > expected + elif self.operator == FilterOperator.gte: + return actual >= expected + elif self.operator == FilterOperator.lt: + return actual < expected + elif self.operator == FilterOperator.lte: + return actual <= expected + elif self.operator == FilterOperator.like: + return str(expected) in str(actual) + elif self.operator == FilterOperator.not_like: + return str(expected) not in str(actual) + return False + class FilterConjunction(StrEnum): and_ = "and" @@ -133,7 +171,7 @@ class FilterConjunction(StrEnum): class Filter(BaseModel): conjunction: FilterConjunction - clauses: list["FilterClause | Filter"] + clauses: list[FilterClause | Filter] model_config = ConfigDict( # IMPORTANT: this ensures FastAPI doesn't generate `-Input` for output schemas @@ -147,6 +185,12 @@ def get_sql_clause(self, model: type[Any]) -> ColumnExpressionArgument[bool]: conjunction = and_ if self.conjunction == FilterConjunction.and_ else or_ return conjunction(*sql_clauses or (true(),)) + def matches(self, event: Event) -> bool: + results = [clause.matches(event) for clause in self.clauses] + if self.conjunction == FilterConjunction.and_: + return all(results) if results else True + return any(results) if results else True + class FilterType(TypeDecorator[Any]): impl = JSONB diff --git a/server/tests/event/test_service.py b/server/tests/event/test_service.py index 0ef35ffade..a062820f53 100644 --- a/server/tests/event/test_service.py +++ b/server/tests/event/test_service.py @@ -7,6 +7,7 @@ import pytest from pydantic import ValidationError from pytest_mock import MockerFixture +from sqlalchemy import func, select from polar.auth.models import AuthSubject, is_user from polar.event.repository import EventRepository @@ -30,6 +31,7 @@ CustomerMeter, EventType, Meter, + MeterEvent, Organization, Product, User, @@ -1404,6 +1406,66 @@ async def test_activates_matching_customer_meter_with_external_customer_id( await session.refresh(customer_meter) assert customer_meter.activated_at is not None + async def test_creates_meter_events_for_matching_events( + self, + save_fixture: SaveFixture, + session: AsyncSession, + organization: Organization, + customer: Customer, + ) -> None: + meter = Meter( + name="Test Meter", + organization=organization, + filter=Filter( + conjunction=FilterConjunction.and_, + clauses=[ + FilterClause( + property="name", operator=FilterOperator.eq, value="api.request" + ) + ], + ), + aggregation=PropertyAggregation( + func=AggregationFunction.sum, property="tokens" + ), + ) + await save_fixture(meter) + + matching_event = await create_event( + save_fixture, + customer=customer, + organization=organization, + source=EventSource.user, + name="api.request", + metadata={"tokens": 100}, + ) + non_matching_event = await create_event( + save_fixture, + customer=customer, + organization=organization, + source=EventSource.user, + name="other.event", + metadata={"tokens": 50}, + ) + + await event_service.ingested( + session, [matching_event.id, non_matching_event.id] + ) + + count_result = await session.execute( + select(func.count()) + .select_from(MeterEvent) + .where(MeterEvent.meter_id == meter.id) + ) + meter_events_count = count_result.scalar_one() + assert meter_events_count == 1 + + meter_event_result = await session.execute( + select(MeterEvent).where(MeterEvent.meter_id == meter.id) + ) + meter_event = meter_event_result.scalar_one() + assert meter_event.event_id == matching_event.id + assert meter_event.customer_id == customer.id + @pytest.mark.asyncio class TestSystemEvents: diff --git a/server/tests/meter/test_aggregation.py b/server/tests/meter/test_aggregation.py index 1c526dad7d..57b5df729b 100644 --- a/server/tests/meter/test_aggregation.py +++ b/server/tests/meter/test_aggregation.py @@ -9,6 +9,7 @@ UniqueAggregation, ) from polar.models import Event, Organization +from polar.models.event import EventSource from polar.postgres import AsyncSession from tests.fixtures.database import SaveFixture from tests.fixtures.random_objects import create_event @@ -189,3 +190,55 @@ async def test_unique( ) assert await _get_aggregation_result(session, aggregation) == 3.0 + + +class TestAggregationMatches: + def test_count_always_matches(self, organization: Organization) -> None: + agg = CountAggregation() + event = Event( + name="test", + organization_id=organization.id, + source=EventSource.user, + user_metadata={}, + ) + assert agg.matches(event) is True + + def test_property_matches_number(self, organization: Organization) -> None: + agg = PropertyAggregation(func=AggregationFunction.sum, property="amount") + event = Event( + name="test", + organization_id=organization.id, + source=EventSource.user, + user_metadata={"amount": 100}, + ) + assert agg.matches(event) is True + + def test_property_not_matches_string(self, organization: Organization) -> None: + agg = PropertyAggregation(func=AggregationFunction.sum, property="amount") + event = Event( + name="test", + organization_id=organization.id, + source=EventSource.user, + user_metadata={"amount": "invalid"}, + ) + assert agg.matches(event) is False + + def test_property_not_matches_missing(self, organization: Organization) -> None: + agg = PropertyAggregation(func=AggregationFunction.sum, property="amount") + event = Event( + name="test", + organization_id=organization.id, + source=EventSource.user, + user_metadata={}, + ) + assert agg.matches(event) is False + + def test_unique_always_matches(self, organization: Organization) -> None: + agg = UniqueAggregation(property="user_id") + event = Event( + name="test", + organization_id=organization.id, + source=EventSource.user, + user_metadata={}, + ) + assert agg.matches(event) is True diff --git a/server/tests/meter/test_filter.py b/server/tests/meter/test_filter.py index 33cb15454d..ed0e9f4274 100644 --- a/server/tests/meter/test_filter.py +++ b/server/tests/meter/test_filter.py @@ -7,6 +7,7 @@ from polar.kit.utils import utc_now from polar.meter.filter import Filter, FilterClause, FilterConjunction, FilterOperator from polar.models import Event, Organization +from polar.models.event import EventSource from polar.postgres import AsyncSession from tests.fixtures.database import SaveFixture from tests.fixtures.random_objects import create_event @@ -292,3 +293,121 @@ async def test_number_comparisons_clause( assert len(matching_events) == 1 assert matching_events[0].id == events[1].id + + +class TestFilterClauseMatches: + def test_matches_name_eq(self, organization: Organization) -> None: + clause = FilterClause( + property="name", operator=FilterOperator.eq, value="test.event" + ) + event = Event( + name="test.event", + organization_id=organization.id, + source=EventSource.user, + user_metadata={}, + ) + assert clause.matches(event) is True + + def test_matches_name_ne(self, organization: Organization) -> None: + clause = FilterClause( + property="name", operator=FilterOperator.eq, value="other.event" + ) + event = Event( + name="test.event", + organization_id=organization.id, + source=EventSource.user, + user_metadata={}, + ) + assert clause.matches(event) is False + + def test_matches_metadata_string(self, organization: Organization) -> None: + clause = FilterClause( + property="category", operator=FilterOperator.eq, value="api" + ) + event = Event( + name="test", + organization_id=organization.id, + source=EventSource.user, + user_metadata={"category": "api"}, + ) + assert clause.matches(event) is True + + def test_matches_metadata_missing(self, organization: Organization) -> None: + clause = FilterClause( + property="category", operator=FilterOperator.eq, value="api" + ) + event = Event( + name="test", + organization_id=organization.id, + source=EventSource.user, + user_metadata={}, + ) + assert clause.matches(event) is False + + def test_matches_metadata_number_gt(self, organization: Organization) -> None: + clause = FilterClause(property="amount", operator=FilterOperator.gt, value=100) + event = Event( + name="test", + organization_id=organization.id, + source=EventSource.user, + user_metadata={"amount": 150}, + ) + assert clause.matches(event) is True + + +class TestFilterMatches: + def test_matches_and_conjunction_all_true(self, organization: Organization) -> None: + filter = Filter( + conjunction=FilterConjunction.and_, + clauses=[ + FilterClause(property="name", operator=FilterOperator.eq, value="test"), + FilterClause( + property="category", operator=FilterOperator.eq, value="api" + ), + ], + ) + event = Event( + name="test", + organization_id=organization.id, + source=EventSource.user, + user_metadata={"category": "api"}, + ) + assert filter.matches(event) is True + + def test_matches_and_conjunction_one_false( + self, organization: Organization + ) -> None: + filter = Filter( + conjunction=FilterConjunction.and_, + clauses=[ + FilterClause(property="name", operator=FilterOperator.eq, value="test"), + FilterClause( + property="category", operator=FilterOperator.eq, value="other" + ), + ], + ) + event = Event( + name="test", + organization_id=organization.id, + source=EventSource.user, + user_metadata={"category": "api"}, + ) + assert filter.matches(event) is False + + def test_matches_or_conjunction(self, organization: Organization) -> None: + filter = Filter( + conjunction=FilterConjunction.or_, + clauses=[ + FilterClause(property="name", operator=FilterOperator.eq, value="test"), + FilterClause( + property="name", operator=FilterOperator.eq, value="other" + ), + ], + ) + event = Event( + name="test", + organization_id=organization.id, + source=EventSource.user, + user_metadata={}, + ) + assert filter.matches(event) is True