Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions server/polar/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from polar.benefit.grant.endpoints import router as benefit_grants_router
from polar.checkout.endpoints import router as checkout_router
from polar.checkout_link.endpoints import router as checkout_link_router
from polar.cli.endpoints import router as cli_router
from polar.custom_field.endpoints import router as custom_field_router
from polar.customer.endpoints import router as customer_router
from polar.customer_meter.endpoints import router as customer_meter_router
Expand Down Expand Up @@ -106,6 +107,8 @@
router.include_router(dispute_router)
# /checkouts
router.include_router(checkout_router)
# /cli
router.include_router(cli_router)
# /files
router.include_router(files_router)
# /metrics
Expand Down
1 change: 1 addition & 0 deletions server/polar/auth/middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,4 +186,5 @@ async def __call__(self, scope: ASGIScope, receive: Receive, send: Send) -> None
with logfire.set_baggage(**auth_subject.log_context):
log.info("Authenticated subject", **auth_subject.log_context)
set_sentry_user(auth_subject)
# Other scope types (lifespan, etc.)
await self.app(scope, receive, send)
Empty file added server/polar/cli/__init__.py
Empty file.
19 changes: 19 additions & 0 deletions server/polar/cli/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Annotated

from fastapi import Depends

from polar.auth.dependencies import Authenticator
from polar.auth.models import AuthSubject
from polar.auth.scope import Scope
from polar.models.organization import Organization

_CLIRead = Authenticator(
required_scopes={
Scope.web_read,
Scope.web_write,
Scope.webhooks_read,
Scope.webhooks_write,
},
allowed_subjects={Organization},
)
CLIRead = Annotated[AuthSubject[Organization], Depends(_CLIRead)]
113 changes: 113 additions & 0 deletions server/polar/cli/endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import base64
import json
from collections.abc import AsyncGenerator
from typing import Any

import structlog
from fastapi import Depends, Request
from sse_starlette.sse import EventSourceResponse
from standardwebhooks.webhooks import Webhook as StandardWebhook

from polar.cli import auth
from polar.cli.listener import mark_active, mark_inactive
from polar.eventstream.endpoints import subscribe
from polar.eventstream.service import Receivers
from polar.kit.utils import utc_now
from polar.openapi import APITag
from polar.postgres import AsyncSession, get_db_session
from polar.redis import Redis, get_redis
from polar.routing import APIRouter

log = structlog.get_logger()

router = APIRouter(prefix="/cli", tags=["cli", APITag.private])


async def transform_webhook_events(
organization_id: str, event_stream: AsyncGenerator[Any, Any]
) -> AsyncGenerator[Any, Any]:
"""
Transform webhook events before sending to CLI client.
Adds signed headers using organization_id as the secret.
"""
async for message in event_stream:
try:
event = json.loads(message)

# Check if this is a webhook event
if event.get("key") == "webhook.created":
payload_data = event.get("payload", {})
webhook_payload = payload_data.get("payload")
webhook_event_id = payload_data.get("webhook_event_id")

if webhook_payload and webhook_event_id:
ts = utc_now()

secret = str(organization_id).replace("-", "")

# Use organization_id as the signing secret
b64secret = base64.b64encode(secret.encode("utf-8")).decode("utf-8")

# Sign the payload
wh = StandardWebhook(b64secret)
signature = wh.sign(webhook_event_id, ts, webhook_payload)

# Add signed headers to the event
event["headers"] = {
"user-agent": "polar.sh webhooks",
"content-type": "application/json",
"webhook-id": webhook_event_id,
"webhook-timestamp": str(int(ts.timestamp())),
"webhook-signature": signature,
}

event["payload"]["payload"] = json.loads(webhook_payload)

yield json.dumps(event)
continue
except (json.JSONDecodeError, KeyError) as e:
log.warning("Failed to transform webhook event", error=str(e))

# Yield original message if not a webhook event or if transformation failed
yield message


@router.get("/listen")
async def listen(
request: Request,
auth_subject: auth.CLIRead,
redis: Redis = Depends(get_redis),
session: AsyncSession = Depends(get_db_session),
) -> EventSourceResponse:
org_id = auth_subject.subject.id

await mark_active(redis, org_id)

async def refresh_listener() -> None:
await mark_active(redis, org_id)

receivers = Receivers(organization_id=org_id)
event_stream = subscribe(
redis, receivers.get_channels(), request, on_iteration=refresh_listener
)
transformed_stream = transform_webhook_events(str(org_id), event_stream)

async def first_event_wrapper() -> AsyncGenerator[str]:
secret = str(org_id).replace("-", "")

# Send a first event announcing connection established
yield json.dumps(
{
"key": "connected",
"ts": str(utc_now()),
"secret": secret,
}
)

try:
async for message in transformed_stream:
yield message
finally:
await mark_inactive(redis, org_id)

return EventSourceResponse(first_event_wrapper())
22 changes: 22 additions & 0 deletions server/polar/cli/listener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from uuid import UUID

from polar.redis import Redis

LISTENER_KEY_PREFIX = "cli:listening"
LISTENER_TTL_SECONDS = 30


def _key(org_id: UUID) -> str:
return f"{LISTENER_KEY_PREFIX}:{org_id}"


async def mark_active(redis: Redis, org_id: UUID) -> None:
await redis.set(_key(org_id), "1", ex=LISTENER_TTL_SECONDS)


async def mark_inactive(redis: Redis, org_id: UUID) -> None:
await redis.delete(_key(org_id))


async def has_active_listener(redis: Redis, org_id: UUID) -> bool:
return await redis.exists(_key(org_id)) > 0
6 changes: 5 additions & 1 deletion server/polar/eventstream/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any

import structlog
Expand Down Expand Up @@ -51,6 +51,7 @@ async def subscribe(
redis: Redis,
channels: list[str],
request: Request,
on_iteration: Callable[[], Awaitable[None]] | None = None,
) -> AsyncGenerator[Any, Any]:
async with redis.pubsub() as pubsub:
await pubsub.subscribe(*channels)
Expand All @@ -60,6 +61,9 @@ async def subscribe(
await pubsub.close()
break

if on_iteration is not None:
await on_iteration()

try:
message = await pubsub.get_message(
ignore_subscribe_messages=True,
Expand Down
46 changes: 46 additions & 0 deletions server/polar/webhook/eventstream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from enum import StrEnum
from uuid import UUID

from polar.cli.listener import has_active_listener
from polar.eventstream.service import publish
from polar.kit.utils import generate_uuid
from polar.redis import Redis, create_redis


class WebhookEvent(StrEnum):
webhook_created = "webhook.created"


_check_redis: Redis | None = None


def _get_check_redis() -> Redis:
global _check_redis
if _check_redis is None:
_check_redis = create_redis("app")
return _check_redis


async def publish_webhook_event(
organization_id: UUID,
payload: str,
) -> None:
"""
Publish a webhook event to the eventstream for CLI listeners.

Args:
organization_id: The organization to publish the event for
payload: The raw JSON string of the webhook payload
"""
redis = _get_check_redis()
if not await has_active_listener(redis, organization_id):
return

await publish(
WebhookEvent.webhook_created,
{
"webhook_event_id": str(generate_uuid()),
"payload": payload,
},
organization_id=organization_id,
)
13 changes: 10 additions & 3 deletions server/polar/webhook/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@
from polar.user_organization.service import (
user_organization as user_organization_service,
)
from polar.worker import enqueue_job

from .repository import (
from polar.webhook.repository import (
WebhookDeliveryRepository,
WebhookEndpointRepository,
WebhookEventRepository,
)
from polar.worker import enqueue_job

from .eventstream import publish_webhook_event
from .schemas import WebhookEndpointCreate, WebhookEndpointUpdate
from .webhooks import SkipEvent, UnsupportedTarget, WebhookPayloadTypeAdapter

Expand Down Expand Up @@ -719,6 +720,12 @@ async def send(
{"type": event, "timestamp": now, "data": data}
)

# Publish to eventstream for CLI listeners, regardless of webhook endpoints
await publish_webhook_event(
organization_id=target.id,
payload=payload.get_raw_payload(),
)

events: list[WebhookEvent] = []
for endpoint in await self._get_event_target_endpoints(
session, event=event, target=target
Expand Down
59 changes: 59 additions & 0 deletions server/polar/webhook/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,62 @@ async def webhook_event_archive() -> None:
return await webhook_service.archive_events(
session, older_than=utc_now() - settings.WEBHOOK_EVENT_RETENTION_PERIOD
)


@actor(actor_name="webhook_event.publish", priority=TaskPriority.MEDIUM)
async def webhook_event_publish(webhook_event_id: UUID, organization_id: UUID) -> None:
"""
Publish a webhook event to the eventstream for CLI listeners.

Args:
webhook_event_id: ID of the webhook event to publish
organization_id: ID of the organization (used as signing secret)
"""
async with AsyncSessionMaker() as session:
return await _webhook_event_publish(
session, webhook_event_id=webhook_event_id, organization_id=organization_id
)


async def _webhook_event_publish(
session: AsyncSession, *, webhook_event_id: UUID, organization_id: UUID
) -> None:
from polar.eventstream.service import publish
from polar.webhook.eventstream import WebhookEvent

repository = WebhookEventRepository.from_session(session)
event = await repository.get_by_id(
webhook_event_id, options=repository.get_eager_options()
)
if event is None:
log.warning(
"Webhook event not found for eventstream publishing",
webhook_event_id=webhook_event_id,
)
return

if event.payload is None:
log.debug(
"Webhook event has no payload, skipping eventstream publish",
webhook_event_id=webhook_event_id,
)
return

try:
# Publish raw webhook event data to eventstream
# The CLI endpoint will apply transformation (headers, signing) when sending to client
await publish(
WebhookEvent.webhook_created,
{
"webhook_event_id": str(event.id),
"payload": event.payload,
},
organization_id=organization_id,
)
except Exception as e:
log.warning(
"Failed to publish webhook to eventstream",
webhook_event_id=webhook_event_id,
organization_id=organization_id,
error=str(e),
)
Empty file added server/tests/cli/__init__.py
Empty file.
25 changes: 25 additions & 0 deletions server/tests/cli/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from collections.abc import AsyncIterator
from unittest.mock import AsyncMock

import pytest
import pytest_asyncio

from polar.kit.db.postgres import AsyncSession


@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
async def initialize_test_database() -> AsyncIterator[None]:
"""Override: CLI tests don't need the database."""
yield


@pytest_asyncio.fixture
async def session() -> AsyncIterator[AsyncSession]:
"""Override: provide a mock session so patch_middlewares doesn't hit Postgres."""
yield AsyncMock(spec=AsyncSession)


@pytest.fixture(autouse=True)
def patch_middlewares() -> None:
"""Override: CLI tests don't need worker middleware patching."""
pass
Loading