Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions server/polar/event/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
CustomerMeter,
Event,
EventClosure,
Meter,
MeterEvent,
Organization,
User,
UserOrganization,
Expand All @@ -59,6 +61,7 @@
StatisticsPeriod,
)
from .sorting import EventNamesSortProperty, EventSortProperty
from .system import SystemEvent

log: Logger = structlog.get_logger()

Expand Down Expand Up @@ -731,6 +734,8 @@ async def ingested(
if "_cost" in event.user_metadata:
organization_ids_for_revops.add(event.organization_id)

await self._create_meter_events(session, events)

await self._activate_matching_customer_meters(
session, repository, event_ids, customers
)
Expand All @@ -742,6 +747,57 @@ async def ingested(
organization_repository = OrganizationRepository.from_session(session)
await organization_repository.enable_revops(organization_ids_for_revops)

async def _create_meter_events(
self, session: AsyncSession, events: Sequence[Event]
) -> None:
if not events:
return

events_by_org: dict[uuid.UUID, list[Event]] = {}
for event in events:
events_by_org.setdefault(event.organization_id, []).append(event)

meter_repository = MeterRepository.from_session(session)
meter_event_rows: list[dict[str, Any]] = []

for org_id, org_events in events_by_org.items():
meters = await meter_repository.get_all(
meter_repository.get_base_statement().where(
Meter.organization_id == org_id,
Meter.archived_at.is_(None),
)
)

for event in org_events:
for meter in meters:
if self._event_matches_meter(event, meter):
meter_event_rows.append(
{
"meter_id": meter.id,
"event_id": event.id,
"customer_id": event.customer_id,
"external_customer_id": event.external_customer_id,
"organization_id": event.organization_id,
"ingested_at": event.ingested_at,
"timestamp": event.timestamp,
}
)

if meter_event_rows:
await session.execute(
insert(MeterEvent).values(meter_event_rows).on_conflict_do_nothing()
)

def _event_matches_meter(self, event: Event, meter: Meter) -> bool:
if (
event.source == EventSource.system
and event.name in (SystemEvent.meter_credited, SystemEvent.meter_reset)
and event.user_metadata.get("meter_id") == str(meter.id)
):
return True

return meter.filter.matches(event) and meter.aggregation.matches(event)

async def _activate_matching_customer_meters(
self,
session: AsyncSession,
Expand Down
19 changes: 18 additions & 1 deletion server/polar/meter/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from enum import StrEnum
from typing import Annotated, Any, Literal
from typing import TYPE_CHECKING, Annotated, Any, Literal

from pydantic import AfterValidator, BaseModel, Discriminator, TypeAdapter
from sqlalchemy import (
Expand All @@ -13,6 +15,9 @@
)
from sqlalchemy.dialects.postgresql import JSONB

if TYPE_CHECKING:
from polar.models import Event


class AggregationFunction(StrEnum):
cnt = "count" # `count` is a reserved keyword, so we use `cnt` as key
Expand Down Expand Up @@ -54,6 +59,9 @@ def is_summable(self) -> bool:
"""
return True

def matches(self, event: Event) -> bool:
return True


def _strip_metadata_prefix(value: str) -> str:
prefix = "metadata."
Expand Down Expand Up @@ -92,6 +100,12 @@ def is_summable(self) -> bool:
"""
return self.func == AggregationFunction.sum

def matches(self, event: Event) -> bool:
if self.property in ("name", "source", "timestamp"):
return True
value = event.user_metadata.get(self.property)
return isinstance(value, int | float)


class UniqueAggregation(BaseModel):
func: Literal[AggregationFunction.unique] = AggregationFunction.unique
Expand All @@ -112,6 +126,9 @@ def is_summable(self) -> bool:
"""
return False

def matches(self, event: Event) -> bool:
return True


_Aggregation = CountAggregation | PropertyAggregation | UniqueAggregation
Aggregation = Annotated[_Aggregation, Discriminator("func")]
Expand Down
48 changes: 46 additions & 2 deletions server/polar/meter/filter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from enum import StrEnum
from typing import Annotated, Any
from typing import TYPE_CHECKING, Annotated, Any

from annotated_types import Ge, Le, MaxLen
from pydantic import AfterValidator, BaseModel, ConfigDict
Expand All @@ -17,6 +19,9 @@
)
from sqlalchemy.dialects.postgresql import JSONB

if TYPE_CHECKING:
from polar.models import Event

# PostgreSQL int4 range limits
INT_MIN_VALUE = -2_147_483_648
INT_MAX_VALUE = 2_147_483_647
Expand Down Expand Up @@ -125,6 +130,39 @@ def _get_number_value(self) -> int:
return 1 if self.value else 0
return self.value

def matches(self, event: Event) -> bool:
if self.property == "name":
actual_value: Any = event.name
elif self.property == "source":
actual_value = event.source
elif self.property == "timestamp":
actual_value = int(event.timestamp.timestamp())
else:
actual_value = event.user_metadata.get(self.property)
if actual_value is None:
return False

return self._compare(actual_value, self.value)

def _compare(self, actual: Any, expected: str | int | bool) -> bool:
if self.operator == FilterOperator.eq:
return actual == expected
elif self.operator == FilterOperator.ne:
return actual != expected
elif self.operator == FilterOperator.gt:
return actual > expected
elif self.operator == FilterOperator.gte:
return actual >= expected
elif self.operator == FilterOperator.lt:
return actual < expected
elif self.operator == FilterOperator.lte:
return actual <= expected
elif self.operator == FilterOperator.like:
return str(expected) in str(actual)
elif self.operator == FilterOperator.not_like:
return str(expected) not in str(actual)
return False


class FilterConjunction(StrEnum):
and_ = "and"
Expand All @@ -133,7 +171,7 @@ class FilterConjunction(StrEnum):

class Filter(BaseModel):
conjunction: FilterConjunction
clauses: list["FilterClause | Filter"]
clauses: list[FilterClause | Filter]

model_config = ConfigDict(
# IMPORTANT: this ensures FastAPI doesn't generate `-Input` for output schemas
Expand All @@ -147,6 +185,12 @@ def get_sql_clause(self, model: type[Any]) -> ColumnExpressionArgument[bool]:
conjunction = and_ if self.conjunction == FilterConjunction.and_ else or_
return conjunction(*sql_clauses or (true(),))

def matches(self, event: Event) -> bool:
results = [clause.matches(event) for clause in self.clauses]
if self.conjunction == FilterConjunction.and_:
return all(results) if results else True
return any(results) if results else True


class FilterType(TypeDecorator[Any]):
impl = JSONB
Expand Down
62 changes: 62 additions & 0 deletions server/tests/event/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
from pydantic import ValidationError
from pytest_mock import MockerFixture
from sqlalchemy import func, select

from polar.auth.models import AuthSubject, is_user
from polar.event.repository import EventRepository
Expand All @@ -30,6 +31,7 @@
CustomerMeter,
EventType,
Meter,
MeterEvent,
Organization,
Product,
User,
Expand Down Expand Up @@ -1404,6 +1406,66 @@ async def test_activates_matching_customer_meter_with_external_customer_id(
await session.refresh(customer_meter)
assert customer_meter.activated_at is not None

async def test_creates_meter_events_for_matching_events(
self,
save_fixture: SaveFixture,
session: AsyncSession,
organization: Organization,
customer: Customer,
) -> None:
meter = Meter(
name="Test Meter",
organization=organization,
filter=Filter(
conjunction=FilterConjunction.and_,
clauses=[
FilterClause(
property="name", operator=FilterOperator.eq, value="api.request"
)
],
),
aggregation=PropertyAggregation(
func=AggregationFunction.sum, property="tokens"
),
)
await save_fixture(meter)

matching_event = await create_event(
save_fixture,
customer=customer,
organization=organization,
source=EventSource.user,
name="api.request",
metadata={"tokens": 100},
)
non_matching_event = await create_event(
save_fixture,
customer=customer,
organization=organization,
source=EventSource.user,
name="other.event",
metadata={"tokens": 50},
)

await event_service.ingested(
session, [matching_event.id, non_matching_event.id]
)

count_result = await session.execute(
select(func.count())
.select_from(MeterEvent)
.where(MeterEvent.meter_id == meter.id)
)
meter_events_count = count_result.scalar_one()
assert meter_events_count == 1

meter_event_result = await session.execute(
select(MeterEvent).where(MeterEvent.meter_id == meter.id)
)
meter_event = meter_event_result.scalar_one()
assert meter_event.event_id == matching_event.id
assert meter_event.customer_id == customer.id


@pytest.mark.asyncio
class TestSystemEvents:
Expand Down
53 changes: 53 additions & 0 deletions server/tests/meter/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
UniqueAggregation,
)
from polar.models import Event, Organization
from polar.models.event import EventSource
from polar.postgres import AsyncSession
from tests.fixtures.database import SaveFixture
from tests.fixtures.random_objects import create_event
Expand Down Expand Up @@ -189,3 +190,55 @@ async def test_unique(
)

assert await _get_aggregation_result(session, aggregation) == 3.0


class TestAggregationMatches:
def test_count_always_matches(self, organization: Organization) -> None:
agg = CountAggregation()
event = Event(
name="test",
organization_id=organization.id,
source=EventSource.user,
user_metadata={},
)
assert agg.matches(event) is True

def test_property_matches_number(self, organization: Organization) -> None:
agg = PropertyAggregation(func=AggregationFunction.sum, property="amount")
event = Event(
name="test",
organization_id=organization.id,
source=EventSource.user,
user_metadata={"amount": 100},
)
assert agg.matches(event) is True

def test_property_not_matches_string(self, organization: Organization) -> None:
agg = PropertyAggregation(func=AggregationFunction.sum, property="amount")
event = Event(
name="test",
organization_id=organization.id,
source=EventSource.user,
user_metadata={"amount": "invalid"},
)
assert agg.matches(event) is False

def test_property_not_matches_missing(self, organization: Organization) -> None:
agg = PropertyAggregation(func=AggregationFunction.sum, property="amount")
event = Event(
name="test",
organization_id=organization.id,
source=EventSource.user,
user_metadata={},
)
assert agg.matches(event) is False

def test_unique_always_matches(self, organization: Organization) -> None:
agg = UniqueAggregation(property="user_id")
event = Event(
name="test",
organization_id=organization.id,
source=EventSource.user,
user_metadata={},
)
assert agg.matches(event) is True
Loading