From 0236dab1f411914358a2290050e14cc36462bb34 Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Fri, 12 Dec 2025 12:09:26 +0100 Subject: [PATCH 01/18] add simple listen-handler --- server/polar/api.py | 3 +++ server/polar/cli/__init__.py | 0 server/polar/cli/endpoints.py | 25 +++++++++++++++++++++++++ server/uv.lock | 2 ++ 4 files changed, 30 insertions(+) create mode 100644 server/polar/cli/__init__.py create mode 100644 server/polar/cli/endpoints.py diff --git a/server/polar/api.py b/server/polar/api.py index a1b53af0dc..78ab683bd3 100644 --- a/server/polar/api.py +++ b/server/polar/api.py @@ -6,6 +6,7 @@ from polar.benefit.grant.endpoints import router as benefit_grants_router from polar.checkout.endpoints import router as checkout_router from polar.checkout_link.endpoints import router as checkout_link_router +from polar.cli.endpoints import router as cli_router from polar.custom_field.endpoints import router as custom_field_router from polar.customer.endpoints import router as customer_router from polar.customer_meter.endpoints import router as customer_meter_router @@ -106,6 +107,8 @@ router.include_router(dispute_router) # /checkouts router.include_router(checkout_router) +# /cli +router.include_router(cli_router) # /files router.include_router(files_router) # /metrics diff --git a/server/polar/cli/__init__.py b/server/polar/cli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/server/polar/cli/endpoints.py b/server/polar/cli/endpoints.py new file mode 100644 index 0000000000..d06d3f205d --- /dev/null +++ b/server/polar/cli/endpoints.py @@ -0,0 +1,25 @@ +import asyncio +from datetime import UTC, datetime + +from fastapi import WebSocket, WebSocketDisconnect + +from polar.routing import APIRouter + +router = APIRouter(prefix="/cli", tags=["cli"]) + +@router.websocket("/listen") +async def listen(websocket: WebSocket) -> None: + await websocket.accept() + + try: + while True: + message = { + "timestamp": datetime.now(UTC).isoformat(), + "status": "ok", + "message": "Periodic update from Polar CLI" + } + + await websocket.send_json(message) + await asyncio.sleep(30) + except WebSocketDisconnect: + pass diff --git a/server/uv.lock b/server/uv.lock index 290bfced27..02aed9887c 100644 --- a/server/uv.lock +++ b/server/uv.lock @@ -1939,6 +1939,7 @@ dependencies = [ { name = "taskipy" }, { name = "typer" }, { name = "uvicorn", extra = ["standard"] }, + { name = "websockets" }, ] [package.dev-dependencies] @@ -2024,6 +2025,7 @@ requires-dist = [ { name = "taskipy", specifier = ">=1.10.3" }, { name = "typer", specifier = ">=0.12.5" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.31.1" }, + { name = "websockets", specifier = ">=15.0.1" }, ] [package.metadata.requires-dev] From aaaa620657d76e38a245525f00452c3cec9f8d15 Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Fri, 12 Dec 2025 16:54:14 +0100 Subject: [PATCH 02/18] server: experimental cli listen support --- server/polar/auth/dependencies.py | 9 +- server/polar/auth/middlewares.py | 110 ++++++++++++++++++++---- server/polar/cli/endpoints.py | 124 ++++++++++++++++++++++++---- server/polar/postgres.py | 10 ++- server/polar/redis.py | 4 +- server/polar/webhook/eventstream.py | 46 +++++++++++ server/polar/webhook/service.py | 18 +++- 7 files changed, 277 insertions(+), 44 deletions(-) create mode 100644 server/polar/webhook/eventstream.py diff --git a/server/polar/auth/dependencies.py b/server/polar/auth/dependencies.py index 5654087355..4e431242cb 100644 --- a/server/polar/auth/dependencies.py +++ b/server/polar/auth/dependencies.py @@ -2,9 +2,10 @@ from inspect import Parameter, Signature from typing import Annotated, Any -from fastapi import Depends, Request, Security +from fastapi import Depends, Security from fastapi.security import HTTPBearer, OpenIdConnect from makefun import with_signature +from starlette.requests import HTTPConnection from polar.auth.scope import RESERVED_SCOPES, Scope from polar.exceptions import Unauthorized @@ -74,7 +75,7 @@ def _get_auth_subject_factory( Parameter( name="request", kind=Parameter.POSITIONAL_OR_KEYWORD, - annotation=Request, + annotation=HTTPConnection, ) ] if User in allowed_subjects or Organization in allowed_subjects: @@ -121,7 +122,9 @@ def _get_auth_subject_factory( signature = Signature(parameters) @with_signature(signature) - async def get_auth_subject(request: Request, **kwargs: Any) -> AuthSubject[Subject]: + async def get_auth_subject( + request: HTTPConnection, **kwargs: Any + ) -> AuthSubject[Subject]: try: return request.state.auth_subject except AttributeError as e: diff --git a/server/polar/auth/middlewares.py b/server/polar/auth/middlewares.py index 7870ee554b..b6e58a946c 100644 --- a/server/polar/auth/middlewares.py +++ b/server/polar/auth/middlewares.py @@ -60,6 +60,18 @@ def get_bearer_token(request: Request) -> str | None: return value +def get_bearer_token_from_websocket(scope: ASGIScope) -> str | None: + """Extract bearer token from WebSocket connection (headers or query params).""" + # Try to get token from Authorization header + headers = dict(scope.get("headers", [])) + authorization = headers.get(b"authorization", b"").decode("utf-8") + scheme, value = get_authorization_scheme_param(authorization) + if scheme and value and scheme.lower() == "bearer" and value.isascii(): + return value + + return None + + async def get_oauth2_token(session: AsyncSession, value: str) -> OAuth2Token | None: return await oauth2_token_service.get_by_access_token(session, value) @@ -163,27 +175,89 @@ async def get_auth_subject( return AuthSubject(Anonymous(), set(), None) -class AuthSubjectMiddleware: - def __init__(self, app: ASGIApp) -> None: - self.app = app +async def get_auth_subject_from_websocket( + scope: ASGIScope, session: AsyncSession +) -> AuthSubject[Subject]: + """Get auth subject from WebSocket connection.""" + token = get_bearer_token_from_websocket(scope) + if token is not None: + if is_registration_token_prefix(token): + return AuthSubject(Anonymous(), set(), None) - async def __call__(self, scope: ASGIScope, receive: Receive, send: Send) -> None: - if scope["type"] != "http": - await self.app(scope, receive, send) - return + customer_session = await get_customer_session(session, token) + if customer_session: + return AuthSubject( + customer_session.customer, + {Scope.customer_portal_write}, + customer_session, + ) + + organization_access_token = await get_organization_access_token(session, token) + if organization_access_token: + return AuthSubject( + organization_access_token.organization, + organization_access_token.scopes, + organization_access_token, + ) - session: AsyncSession = scope["state"]["async_session"] - request = Request(scope) + oauth2_token = await get_oauth2_token(session, token) + if oauth2_token: + return AuthSubject(oauth2_token.sub, oauth2_token.scopes, oauth2_token) - try: - auth_subject = await get_auth_subject(request, session) - except OAuth2Error as e: - response = await oauth2_error_exception_handler(request, e) - return await response(scope, receive, send) + personal_access_token = await get_personal_access_token(session, token) + if personal_access_token: + return AuthSubject( + personal_access_token.user, + personal_access_token.scopes, + personal_access_token, + ) + + raise InvalidTokenError() + + # WebSockets don't support user sessions (cookies), so return Anonymous + return AuthSubject(Anonymous(), set(), None) - scope["state"]["auth_subject"] = auth_subject - with logfire.set_baggage(**auth_subject.log_context): - log.info("Authenticated subject", **auth_subject.log_context) - set_sentry_user(auth_subject) +class AuthSubjectMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: ASGIScope, receive: Receive, send: Send) -> None: + if scope["type"] == "http": + session: AsyncSession = scope["state"]["async_session"] + request = Request(scope) + + try: + auth_subject = await get_auth_subject(request, session) + except OAuth2Error as e: + response = await oauth2_error_exception_handler(request, e) + return await response(scope, receive, send) + + scope["state"]["auth_subject"] = auth_subject + + with logfire.set_baggage(**auth_subject.log_context): + log.info("Authenticated subject", **auth_subject.log_context) + set_sentry_user(auth_subject) + await self.app(scope, receive, send) + + elif scope["type"] == "websocket": + session: AsyncSession = scope["state"]["async_session"] + + try: + auth_subject = await get_auth_subject_from_websocket(scope, session) + except OAuth2Error as e: + # For WebSocket, we can't return an HTTP response + # The error will be handled when the connection is accepted + log.warning("WebSocket authentication failed", error=str(e)) + auth_subject = AuthSubject(Anonymous(), set(), None) + + scope["state"]["auth_subject"] = auth_subject + + with logfire.set_baggage(**auth_subject.log_context): + log.info("Authenticated WebSocket subject", **auth_subject.log_context) + set_sentry_user(auth_subject) + await self.app(scope, receive, send) + + else: + # Other scope types (lifespan, etc.) await self.app(scope, receive, send) diff --git a/server/polar/cli/endpoints.py b/server/polar/cli/endpoints.py index d06d3f205d..e6ba88a1c7 100644 --- a/server/polar/cli/endpoints.py +++ b/server/polar/cli/endpoints.py @@ -1,25 +1,121 @@ -import asyncio -from datetime import UTC, datetime +import json -from fastapi import WebSocket, WebSocketDisconnect +import structlog +from fastapi import Depends, WebSocket, WebSocketDisconnect +from polar.auth.models import Organization, User, is_anonymous +from polar.auth.scope import Scope +from polar.eventstream.service import Receivers +from polar.postgres import AsyncSession, get_db_session +from polar.redis import Redis, create_redis, get_redis from polar.routing import APIRouter router = APIRouter(prefix="/cli", tags=["cli"]) +log = structlog.get_logger() + + @router.websocket("/listen") -async def listen(websocket: WebSocket) -> None: +async def listen( + websocket: WebSocket, + session: AsyncSession = Depends(get_db_session), + redis: Redis = Depends(get_redis), +) -> None: + """ + WebSocket endpoint that listens to webhook events for authenticated organizations. + Clients must authenticate using a token in query parameters (?token=...) or Authorization header. + The organization is inferred from the authenticated subject. + """ + + # Get auth_subject from WebSocket state (set by AuthSubjectMiddleware) + try: + auth_subject = websocket.state.auth_subject + except AttributeError: + await websocket.close(code=4001, reason="Authentication required") + return + + # Check if authenticated + if is_anonymous(auth_subject): + await websocket.close(code=4001, reason="Authentication required") + return + + # Verify required scopes + required_scopes = { + Scope.web_read, + Scope.web_write, + Scope.webhooks_read, + Scope.webhooks_write, + } + if not (auth_subject.scopes & required_scopes): + await websocket.close(code=4003, reason="Insufficient permissions") + return + + # Check subject type + if not isinstance(auth_subject.subject, (User, Organization)): + await websocket.close(code=4002, reason="Invalid subject type") + return + + # Get organization ID + if isinstance(auth_subject.subject, Organization): + organization_id = auth_subject.subject.id + elif isinstance(auth_subject.subject, User): + from polar.user_organization.service import ( + user_organization as user_organization_service, + ) + + user_organizations = await user_organization_service.list_by_user_id( + session, auth_subject.subject.id + ) + if not user_organizations: + await websocket.close(code=4003, reason="User has no organizations") + return + organization_id = user_organizations[0].organization_id + else: + await websocket.close(code=4002, reason="Invalid subject type") + return + await websocket.accept() + redis = create_redis("app") + + # Use eventstream channel format + receivers = Receivers(organization_id=organization_id) + channels = receivers.get_channels() try: - while True: - message = { - "timestamp": datetime.now(UTC).isoformat(), - "status": "ok", - "message": "Periodic update from Polar CLI" - } - - await websocket.send_json(message) - await asyncio.sleep(30) + pubsub = redis.pubsub() + await pubsub.subscribe(*channels) + + log.info( + "WebSocket client subscribed to eventstream channels", + organization_id=str(organization_id), + channels=channels, + ) + + async for message in pubsub.listen(): + if message["type"] == "message": + try: + event_data = json.loads(message["data"]) + # Extract webhook payload from eventstream event + if "payload" in event_data and "key" in event_data: + # This is an eventstream event, send the nested payload + await websocket.send_json(event_data["payload"]) + else: + # Fallback for any non-eventstream format + await websocket.send_json(event_data) + except json.JSONDecodeError as e: + log.warning("Failed to decode event message", error=str(e)) + except Exception as e: + log.warning("Failed to send event to client", error=str(e)) + break + except WebSocketDisconnect: - pass + log.info( + "WebSocket client disconnected from eventstream", + organization_id=str(organization_id), + ) + except Exception as e: + log.error("WebSocket error", error=str(e), organization_id=str(organization_id)) + finally: + await pubsub.unsubscribe(*channels) + await pubsub.close() + await redis.close() diff --git a/server/polar/postgres.py b/server/polar/postgres.py index ae2e2f2488..579e98c33f 100644 --- a/server/polar/postgres.py +++ b/server/polar/postgres.py @@ -1,7 +1,7 @@ from collections.abc import AsyncGenerator from typing import Literal -from fastapi import Request +from starlette.requests import HTTPConnection from starlette.types import ASGIApp, Receive, Scope, Send from polar.config import settings @@ -72,11 +72,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) -async def get_db_sessionmaker(request: Request) -> AsyncSessionMaker: +async def get_db_sessionmaker(request: HTTPConnection) -> AsyncSessionMaker: return request.state.async_sessionmaker -async def get_db_session(request: Request) -> AsyncGenerator[AsyncSession]: +async def get_db_session(request: HTTPConnection) -> AsyncGenerator[AsyncSession]: try: session = request.state.async_session except AttributeError as e: @@ -94,7 +94,9 @@ async def get_db_session(request: Request) -> AsyncGenerator[AsyncSession]: await session.commit() -async def get_db_read_session(request: Request) -> AsyncGenerator[AsyncReadSession]: +async def get_db_read_session( + request: HTTPConnection, +) -> AsyncGenerator[AsyncReadSession]: sessionmaker: AsyncReadSessionMaker = request.state.async_read_sessionmaker async with sessionmaker() as session: yield session diff --git a/server/polar/redis.py b/server/polar/redis.py index 7f9e4afd74..c184ce7f44 100644 --- a/server/polar/redis.py +++ b/server/polar/redis.py @@ -1,10 +1,10 @@ from typing import TYPE_CHECKING, Literal import redis.asyncio as _async_redis -from fastapi import Request from redis import ConnectionError, RedisError, TimeoutError from redis.asyncio.retry import Retry from redis.backoff import default_backoff +from starlette.requests import HTTPConnection from polar.config import settings @@ -32,7 +32,7 @@ def create_redis(process_name: ProcessName) -> Redis: ) -async def get_redis(request: Request) -> Redis: +async def get_redis(request: HTTPConnection) -> Redis: return request.state.redis diff --git a/server/polar/webhook/eventstream.py b/server/polar/webhook/eventstream.py new file mode 100644 index 0000000000..51e254e519 --- /dev/null +++ b/server/polar/webhook/eventstream.py @@ -0,0 +1,46 @@ +from datetime import datetime +from enum import StrEnum +from typing import Any, TypedDict +from uuid import UUID + +from polar.eventstream.service import publish + + +class WebhookEvent(StrEnum): + webhook_created = "webhook.created" + + +class WebhookEventPayload(TypedDict): + type: str + timestamp: str + organization_id: str + payload: dict[str, Any] + + +async def publish_webhook_event( + organization_id: UUID, + event_type: str, + timestamp: datetime, + payload: dict[str, Any], +) -> None: + """ + Publish a webhook event to the eventstream for an organization. + + Args: + organization_id: The organization to publish the event to + event_type: The webhook event type (e.g., 'checkout.created') + timestamp: The timestamp of the event + payload: The webhook payload + """ + event_payload: dict[str, Any] = { + "type": event_type, + "timestamp": timestamp.isoformat(), + "organization_id": str(organization_id), + "payload": payload, + } + + await publish( + WebhookEvent.webhook_created, + event_payload, + organization_id=organization_id, + ) diff --git a/server/polar/webhook/service.py b/server/polar/webhook/service.py index 6c74c52baa..ec7d9333ad 100644 --- a/server/polar/webhook/service.py +++ b/server/polar/webhook/service.py @@ -49,13 +49,14 @@ from polar.user_organization.service import ( user_organization as user_organization_service, ) -from polar.worker import enqueue_job - -from .repository import ( +from polar.webhook.eventstream import publish_webhook_event +from polar.webhook.repository import ( WebhookDeliveryRepository, WebhookEndpointRepository, WebhookEventRepository, ) +from polar.worker import enqueue_job + from .schemas import WebhookEndpointCreate, WebhookEndpointUpdate from .webhooks import SkipEvent, UnsupportedTarget, WebhookPayloadTypeAdapter @@ -742,6 +743,17 @@ async def send( except SkipEvent: continue + # Publish webhook event to eventstream for CLI listeners + try: + await publish_webhook_event( + organization_id=target.id, + event_type=event, + timestamp=now, + payload=payload.model_dump(mode="json"), + ) + except Exception as e: + log.warning("Failed to publish webhook to eventstream", error=str(e)) + return events async def archive_events( From 860f23eadd7db8f1d6ece08c121b6dc4e11ab431 Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Mon, 15 Dec 2025 09:59:58 +0100 Subject: [PATCH 03/18] make cli api router private --- server/polar/cli/endpoints.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/polar/cli/endpoints.py b/server/polar/cli/endpoints.py index e6ba88a1c7..8eb4417475 100644 --- a/server/polar/cli/endpoints.py +++ b/server/polar/cli/endpoints.py @@ -6,11 +6,12 @@ from polar.auth.models import Organization, User, is_anonymous from polar.auth.scope import Scope from polar.eventstream.service import Receivers +from polar.openapi import APITag from polar.postgres import AsyncSession, get_db_session from polar.redis import Redis, create_redis, get_redis from polar.routing import APIRouter -router = APIRouter(prefix="/cli", tags=["cli"]) +router = APIRouter(prefix="/cli", tags=["cli", APITag.private]) log = structlog.get_logger() From f299fd2b58a5d925d22cde8c8305d8195d5d0ca1 Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Mon, 15 Dec 2025 10:38:27 +0100 Subject: [PATCH 04/18] adjust comment --- server/polar/auth/middlewares.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/polar/auth/middlewares.py b/server/polar/auth/middlewares.py index b6e58a946c..87f1a12ac2 100644 --- a/server/polar/auth/middlewares.py +++ b/server/polar/auth/middlewares.py @@ -61,7 +61,7 @@ def get_bearer_token(request: Request) -> str | None: def get_bearer_token_from_websocket(scope: ASGIScope) -> str | None: - """Extract bearer token from WebSocket connection (headers or query params).""" + """Extract bearer token from WebSocket connection.""" # Try to get token from Authorization header headers = dict(scope.get("headers", [])) authorization = headers.get(b"authorization", b"").decode("utf-8") From 78c0b5b70e178f37617376b2478794ebabc1cbeb Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Mon, 15 Dec 2025 10:40:42 +0100 Subject: [PATCH 05/18] fix session variable --- server/polar/auth/middlewares.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/server/polar/auth/middlewares.py b/server/polar/auth/middlewares.py index 87f1a12ac2..f6c3fe4ffd 100644 --- a/server/polar/auth/middlewares.py +++ b/server/polar/auth/middlewares.py @@ -223,8 +223,9 @@ def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: ASGIScope, receive: Receive, send: Send) -> None: + session: AsyncSession = scope["state"]["async_session"] + if scope["type"] == "http": - session: AsyncSession = scope["state"]["async_session"] request = Request(scope) try: @@ -241,8 +242,6 @@ async def __call__(self, scope: ASGIScope, receive: Receive, send: Send) -> None await self.app(scope, receive, send) elif scope["type"] == "websocket": - session: AsyncSession = scope["state"]["async_session"] - try: auth_subject = await get_auth_subject_from_websocket(scope, session) except OAuth2Error as e: From eae3ed15cfeafc235e156181c990ece8af28e2f3 Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Tue, 16 Dec 2025 16:42:17 +0100 Subject: [PATCH 06/18] server: fix session in middleware --- server/polar/auth/middlewares.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/server/polar/auth/middlewares.py b/server/polar/auth/middlewares.py index f6c3fe4ffd..03a21acc79 100644 --- a/server/polar/auth/middlewares.py +++ b/server/polar/auth/middlewares.py @@ -223,9 +223,9 @@ def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: ASGIScope, receive: Receive, send: Send) -> None: - session: AsyncSession = scope["state"]["async_session"] - if scope["type"] == "http": + session: AsyncSession = scope["state"]["async_session"] + request = Request(scope) try: @@ -242,6 +242,8 @@ async def __call__(self, scope: ASGIScope, receive: Receive, send: Send) -> None await self.app(scope, receive, send) elif scope["type"] == "websocket": + session: AsyncSession = scope["state"]["async_session"] + try: auth_subject = await get_auth_subject_from_websocket(scope, session) except OAuth2Error as e: From a76d227dfae3e55ea826bf9bc62207ef9245b69d Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Wed, 17 Dec 2025 11:19:17 +0100 Subject: [PATCH 07/18] use sse instead of websockets for cli listen --- server/polar/auth/middlewares.py | 105 ++++---------------------- server/polar/cli/auth.py | 19 +++++ server/polar/cli/endpoints.py | 125 ++++--------------------------- 3 files changed, 50 insertions(+), 199 deletions(-) create mode 100644 server/polar/cli/auth.py diff --git a/server/polar/auth/middlewares.py b/server/polar/auth/middlewares.py index 03a21acc79..58b51b2747 100644 --- a/server/polar/auth/middlewares.py +++ b/server/polar/auth/middlewares.py @@ -60,18 +60,6 @@ def get_bearer_token(request: Request) -> str | None: return value -def get_bearer_token_from_websocket(scope: ASGIScope) -> str | None: - """Extract bearer token from WebSocket connection.""" - # Try to get token from Authorization header - headers = dict(scope.get("headers", [])) - authorization = headers.get(b"authorization", b"").decode("utf-8") - scheme, value = get_authorization_scheme_param(authorization) - if scheme and value and scheme.lower() == "bearer" and value.isascii(): - return value - - return None - - async def get_oauth2_token(session: AsyncSession, value: str) -> OAuth2Token | None: return await oauth2_token_service.get_by_access_token(session, value) @@ -175,90 +163,29 @@ async def get_auth_subject( return AuthSubject(Anonymous(), set(), None) -async def get_auth_subject_from_websocket( - scope: ASGIScope, session: AsyncSession -) -> AuthSubject[Subject]: - """Get auth subject from WebSocket connection.""" - token = get_bearer_token_from_websocket(scope) - if token is not None: - if is_registration_token_prefix(token): - return AuthSubject(Anonymous(), set(), None) - - customer_session = await get_customer_session(session, token) - if customer_session: - return AuthSubject( - customer_session.customer, - {Scope.customer_portal_write}, - customer_session, - ) - - organization_access_token = await get_organization_access_token(session, token) - if organization_access_token: - return AuthSubject( - organization_access_token.organization, - organization_access_token.scopes, - organization_access_token, - ) - - oauth2_token = await get_oauth2_token(session, token) - if oauth2_token: - return AuthSubject(oauth2_token.sub, oauth2_token.scopes, oauth2_token) - - personal_access_token = await get_personal_access_token(session, token) - if personal_access_token: - return AuthSubject( - personal_access_token.user, - personal_access_token.scopes, - personal_access_token, - ) - - raise InvalidTokenError() - - # WebSockets don't support user sessions (cookies), so return Anonymous - return AuthSubject(Anonymous(), set(), None) - - class AuthSubjectMiddleware: def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: ASGIScope, receive: Receive, send: Send) -> None: - if scope["type"] == "http": - session: AsyncSession = scope["state"]["async_session"] - - request = Request(scope) - - try: - auth_subject = await get_auth_subject(request, session) - except OAuth2Error as e: - response = await oauth2_error_exception_handler(request, e) - return await response(scope, receive, send) - - scope["state"]["auth_subject"] = auth_subject - - with logfire.set_baggage(**auth_subject.log_context): - log.info("Authenticated subject", **auth_subject.log_context) - set_sentry_user(auth_subject) - await self.app(scope, receive, send) + if scope["type"] != "http": + await self.app(scope, receive, send) + return - elif scope["type"] == "websocket": - session: AsyncSession = scope["state"]["async_session"] + session: AsyncSession = scope["state"]["async_session"] + request = Request(scope) - try: - auth_subject = await get_auth_subject_from_websocket(scope, session) - except OAuth2Error as e: - # For WebSocket, we can't return an HTTP response - # The error will be handled when the connection is accepted - log.warning("WebSocket authentication failed", error=str(e)) - auth_subject = AuthSubject(Anonymous(), set(), None) + try: + auth_subject = await get_auth_subject(request, session) + except OAuth2Error as e: + response = await oauth2_error_exception_handler(request, e) + return await response(scope, receive, send) - scope["state"]["auth_subject"] = auth_subject + scope["state"]["auth_subject"] = auth_subject - with logfire.set_baggage(**auth_subject.log_context): - log.info("Authenticated WebSocket subject", **auth_subject.log_context) - set_sentry_user(auth_subject) - await self.app(scope, receive, send) + with logfire.set_baggage(**auth_subject.log_context): + log.info("Authenticated subject", **auth_subject.log_context) + set_sentry_user(auth_subject) - else: - # Other scope types (lifespan, etc.) - await self.app(scope, receive, send) + # Other scope types (lifespan, etc.) + await self.app(scope, receive, send) diff --git a/server/polar/cli/auth.py b/server/polar/cli/auth.py new file mode 100644 index 0000000000..60241c39d5 --- /dev/null +++ b/server/polar/cli/auth.py @@ -0,0 +1,19 @@ +from typing import Annotated + +from fastapi import Depends + +from polar.auth.dependencies import Authenticator +from polar.auth.models import AuthSubject +from polar.auth.scope import Scope +from polar.models.organization import Organization + +_CLIRead = Authenticator( + required_scopes={ + Scope.web_read, + Scope.web_write, + Scope.webhooks_read, + Scope.webhooks_write, + }, + allowed_subjects={Organization}, +) +CLIRead = Annotated[AuthSubject[Organization], Depends(_CLIRead)] diff --git a/server/polar/cli/endpoints.py b/server/polar/cli/endpoints.py index 8eb4417475..8e1aa0c320 100644 --- a/server/polar/cli/endpoints.py +++ b/server/polar/cli/endpoints.py @@ -1,122 +1,27 @@ -import json - import structlog -from fastapi import Depends, WebSocket, WebSocketDisconnect +from fastapi import Depends, Request +from sse_starlette.sse import EventSourceResponse -from polar.auth.models import Organization, User, is_anonymous -from polar.auth.scope import Scope +from polar.cli import auth +from polar.eventstream.endpoints import subscribe from polar.eventstream.service import Receivers from polar.openapi import APITag from polar.postgres import AsyncSession, get_db_session -from polar.redis import Redis, create_redis, get_redis +from polar.redis import Redis, get_redis from polar.routing import APIRouter -router = APIRouter(prefix="/cli", tags=["cli", APITag.private]) - log = structlog.get_logger() -@router.websocket("/listen") -async def listen( - websocket: WebSocket, - session: AsyncSession = Depends(get_db_session), - redis: Redis = Depends(get_redis), -) -> None: - """ - WebSocket endpoint that listens to webhook events for authenticated organizations. - Clients must authenticate using a token in query parameters (?token=...) or Authorization header. - The organization is inferred from the authenticated subject. - """ - - # Get auth_subject from WebSocket state (set by AuthSubjectMiddleware) - try: - auth_subject = websocket.state.auth_subject - except AttributeError: - await websocket.close(code=4001, reason="Authentication required") - return - - # Check if authenticated - if is_anonymous(auth_subject): - await websocket.close(code=4001, reason="Authentication required") - return - - # Verify required scopes - required_scopes = { - Scope.web_read, - Scope.web_write, - Scope.webhooks_read, - Scope.webhooks_write, - } - if not (auth_subject.scopes & required_scopes): - await websocket.close(code=4003, reason="Insufficient permissions") - return - - # Check subject type - if not isinstance(auth_subject.subject, (User, Organization)): - await websocket.close(code=4002, reason="Invalid subject type") - return - - # Get organization ID - if isinstance(auth_subject.subject, Organization): - organization_id = auth_subject.subject.id - elif isinstance(auth_subject.subject, User): - from polar.user_organization.service import ( - user_organization as user_organization_service, - ) - - user_organizations = await user_organization_service.list_by_user_id( - session, auth_subject.subject.id - ) - if not user_organizations: - await websocket.close(code=4003, reason="User has no organizations") - return - organization_id = user_organizations[0].organization_id - else: - await websocket.close(code=4002, reason="Invalid subject type") - return - - await websocket.accept() - redis = create_redis("app") - - # Use eventstream channel format - receivers = Receivers(organization_id=organization_id) - channels = receivers.get_channels() - - try: - pubsub = redis.pubsub() - await pubsub.subscribe(*channels) - - log.info( - "WebSocket client subscribed to eventstream channels", - organization_id=str(organization_id), - channels=channels, - ) +router = APIRouter(prefix="/cli", tags=["cli", APITag.private]) - async for message in pubsub.listen(): - if message["type"] == "message": - try: - event_data = json.loads(message["data"]) - # Extract webhook payload from eventstream event - if "payload" in event_data and "key" in event_data: - # This is an eventstream event, send the nested payload - await websocket.send_json(event_data["payload"]) - else: - # Fallback for any non-eventstream format - await websocket.send_json(event_data) - except json.JSONDecodeError as e: - log.warning("Failed to decode event message", error=str(e)) - except Exception as e: - log.warning("Failed to send event to client", error=str(e)) - break - except WebSocketDisconnect: - log.info( - "WebSocket client disconnected from eventstream", - organization_id=str(organization_id), - ) - except Exception as e: - log.error("WebSocket error", error=str(e), organization_id=str(organization_id)) - finally: - await pubsub.unsubscribe(*channels) - await pubsub.close() - await redis.close() +@router.get("/listen") +async def listen( + request: Request, + auth_subject: auth.CLIRead, + redis: Redis = Depends(get_redis), + session: AsyncSession = Depends(get_db_session), +) -> EventSourceResponse: + receivers = Receivers(organization_id=auth_subject.subject.id) + return EventSourceResponse(subscribe(redis, receivers.get_channels(), request)) From ff312371c051bb461c339d2bd44b7521f83fa525 Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Wed, 17 Dec 2025 11:22:53 +0100 Subject: [PATCH 08/18] revert typeds --- server/polar/auth/dependencies.py | 9 +++------ server/polar/auth/middlewares.py | 5 ++--- server/polar/postgres.py | 8 ++++---- server/polar/redis.py | 4 ++-- 4 files changed, 11 insertions(+), 15 deletions(-) diff --git a/server/polar/auth/dependencies.py b/server/polar/auth/dependencies.py index 4e431242cb..5654087355 100644 --- a/server/polar/auth/dependencies.py +++ b/server/polar/auth/dependencies.py @@ -2,10 +2,9 @@ from inspect import Parameter, Signature from typing import Annotated, Any -from fastapi import Depends, Security +from fastapi import Depends, Request, Security from fastapi.security import HTTPBearer, OpenIdConnect from makefun import with_signature -from starlette.requests import HTTPConnection from polar.auth.scope import RESERVED_SCOPES, Scope from polar.exceptions import Unauthorized @@ -75,7 +74,7 @@ def _get_auth_subject_factory( Parameter( name="request", kind=Parameter.POSITIONAL_OR_KEYWORD, - annotation=HTTPConnection, + annotation=Request, ) ] if User in allowed_subjects or Organization in allowed_subjects: @@ -122,9 +121,7 @@ def _get_auth_subject_factory( signature = Signature(parameters) @with_signature(signature) - async def get_auth_subject( - request: HTTPConnection, **kwargs: Any - ) -> AuthSubject[Subject]: + async def get_auth_subject(request: Request, **kwargs: Any) -> AuthSubject[Subject]: try: return request.state.auth_subject except AttributeError as e: diff --git a/server/polar/auth/middlewares.py b/server/polar/auth/middlewares.py index 58b51b2747..53a8ab3fb3 100644 --- a/server/polar/auth/middlewares.py +++ b/server/polar/auth/middlewares.py @@ -186,6 +186,5 @@ async def __call__(self, scope: ASGIScope, receive: Receive, send: Send) -> None with logfire.set_baggage(**auth_subject.log_context): log.info("Authenticated subject", **auth_subject.log_context) set_sentry_user(auth_subject) - - # Other scope types (lifespan, etc.) - await self.app(scope, receive, send) + # Other scope types (lifespan, etc.) + await self.app(scope, receive, send) diff --git a/server/polar/postgres.py b/server/polar/postgres.py index 579e98c33f..7c77623116 100644 --- a/server/polar/postgres.py +++ b/server/polar/postgres.py @@ -1,7 +1,7 @@ from collections.abc import AsyncGenerator from typing import Literal -from starlette.requests import HTTPConnection +from fastapi import Request from starlette.types import ASGIApp, Receive, Scope, Send from polar.config import settings @@ -72,11 +72,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) -async def get_db_sessionmaker(request: HTTPConnection) -> AsyncSessionMaker: +async def get_db_sessionmaker(request: Request) -> AsyncSessionMaker: return request.state.async_sessionmaker -async def get_db_session(request: HTTPConnection) -> AsyncGenerator[AsyncSession]: +async def get_db_session(request: Request) -> AsyncGenerator[AsyncSession]: try: session = request.state.async_session except AttributeError as e: @@ -95,7 +95,7 @@ async def get_db_session(request: HTTPConnection) -> AsyncGenerator[AsyncSession async def get_db_read_session( - request: HTTPConnection, + request: Request, ) -> AsyncGenerator[AsyncReadSession]: sessionmaker: AsyncReadSessionMaker = request.state.async_read_sessionmaker async with sessionmaker() as session: diff --git a/server/polar/redis.py b/server/polar/redis.py index c184ce7f44..7f9e4afd74 100644 --- a/server/polar/redis.py +++ b/server/polar/redis.py @@ -1,10 +1,10 @@ from typing import TYPE_CHECKING, Literal import redis.asyncio as _async_redis +from fastapi import Request from redis import ConnectionError, RedisError, TimeoutError from redis.asyncio.retry import Retry from redis.backoff import default_backoff -from starlette.requests import HTTPConnection from polar.config import settings @@ -32,7 +32,7 @@ def create_redis(process_name: ProcessName) -> Redis: ) -async def get_redis(request: HTTPConnection) -> Redis: +async def get_redis(request: Request) -> Redis: return request.state.redis From ee447f24705aa3541aafed33bea2384eb6fab23e Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Wed, 17 Dec 2025 11:23:58 +0100 Subject: [PATCH 09/18] remove uv sync --- server/uv.lock | 2 -- 1 file changed, 2 deletions(-) diff --git a/server/uv.lock b/server/uv.lock index 02aed9887c..290bfced27 100644 --- a/server/uv.lock +++ b/server/uv.lock @@ -1939,7 +1939,6 @@ dependencies = [ { name = "taskipy" }, { name = "typer" }, { name = "uvicorn", extra = ["standard"] }, - { name = "websockets" }, ] [package.dev-dependencies] @@ -2025,7 +2024,6 @@ requires-dist = [ { name = "taskipy", specifier = ">=1.10.3" }, { name = "typer", specifier = ">=0.12.5" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.31.1" }, - { name = "websockets", specifier = ">=15.0.1" }, ] [package.metadata.requires-dev] From 35330b1ec8ee68e50c1c2675ed98370c97038a3c Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Wed, 17 Dec 2025 13:11:50 +0100 Subject: [PATCH 10/18] adjust sse for webhooks --- server/polar/webhook/eventstream.py | 13 ++++--------- server/polar/webhook/service.py | 26 ++++++++++++++++++++++---- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/server/polar/webhook/eventstream.py b/server/polar/webhook/eventstream.py index 51e254e519..0d01719f03 100644 --- a/server/polar/webhook/eventstream.py +++ b/server/polar/webhook/eventstream.py @@ -1,4 +1,4 @@ -from datetime import datetime +from collections.abc import Mapping from enum import StrEnum from typing import Any, TypedDict from uuid import UUID @@ -19,24 +19,19 @@ class WebhookEventPayload(TypedDict): async def publish_webhook_event( organization_id: UUID, - event_type: str, - timestamp: datetime, payload: dict[str, Any], + headers: Mapping[str, str], ) -> None: """ Publish a webhook event to the eventstream for an organization. Args: - organization_id: The organization to publish the event to - event_type: The webhook event type (e.g., 'checkout.created') - timestamp: The timestamp of the event + headers: The headers of the event payload: The webhook payload """ event_payload: dict[str, Any] = { - "type": event_type, - "timestamp": timestamp.isoformat(), - "organization_id": str(organization_id), "payload": payload, + "headers": headers, } await publish( diff --git a/server/polar/webhook/service.py b/server/polar/webhook/service.py index ec7d9333ad..3c105e4bb0 100644 --- a/server/polar/webhook/service.py +++ b/server/polar/webhook/service.py @@ -1,6 +1,7 @@ +import base64 import datetime import json -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import Literal, cast, overload from uuid import UUID @@ -8,6 +9,7 @@ from sqlalchemy import CursorResult, String, desc, func, or_, select, text, update from sqlalchemy import cast as sql_cast from sqlalchemy.orm import joinedload +from standardwebhooks.webhooks import Webhook as StandardWebhook from polar.auth.models import AuthSubject from polar.checkout.eventstream import CheckoutEvent, publish_checkout_event @@ -745,11 +747,27 @@ async def send( # Publish webhook event to eventstream for CLI listeners try: + ts = utc_now() + + b64secret = base64.b64encode( + event.webhook_endpoint.secret.encode("utf-8") + ).decode("utf-8") + + # Sign the payload + wh = StandardWebhook(b64secret) + signature = wh.sign(str(event.id), ts, event.payload) + + headers: Mapping[str, str] = { + "user-agent": "polar.sh webhooks", + "content-type": "application/json", + "webhook-id": str(event.id), + "webhook-timestamp": str(int(ts.timestamp())), + "webhook-signature": signature, + } await publish_webhook_event( + payload=event.payload, + headers=headers, organization_id=target.id, - event_type=event, - timestamp=now, - payload=payload.model_dump(mode="json"), ) except Exception as e: log.warning("Failed to publish webhook to eventstream", error=str(e)) From 07b70cd34e173c22fa59bb2214ce6feb72bec35b Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Thu, 18 Dec 2025 22:20:06 +0100 Subject: [PATCH 11/18] server: refactor webhook listener to support secrets --- server/polar/cli/endpoints.py | 76 ++++++++++++++++++++++++++++++++- server/polar/webhook/service.py | 38 +++-------------- server/polar/webhook/tasks.py | 59 +++++++++++++++++++++++++ 3 files changed, 141 insertions(+), 32 deletions(-) diff --git a/server/polar/cli/endpoints.py b/server/polar/cli/endpoints.py index 8e1aa0c320..f74210dd80 100644 --- a/server/polar/cli/endpoints.py +++ b/server/polar/cli/endpoints.py @@ -1,10 +1,17 @@ +import base64 +import json +from collections.abc import AsyncGenerator +from typing import Any + import structlog from fastapi import Depends, Request from sse_starlette.sse import EventSourceResponse +from standardwebhooks.webhooks import Webhook as StandardWebhook from polar.cli import auth from polar.eventstream.endpoints import subscribe from polar.eventstream.service import Receivers +from polar.kit.utils import utc_now from polar.openapi import APITag from polar.postgres import AsyncSession, get_db_session from polar.redis import Redis, get_redis @@ -16,6 +23,55 @@ router = APIRouter(prefix="/cli", tags=["cli", APITag.private]) +async def transform_webhook_events( + organization_id: str, event_stream: AsyncGenerator[Any, Any] +) -> AsyncGenerator[Any, Any]: + """ + Transform webhook events before sending to CLI client. + Adds signed headers using organization_id as the secret. + """ + async for message in event_stream: + try: + event = json.loads(message) + + # Check if this is a webhook event + if event.get("key") == "webhook.created": + payload_data = event.get("payload", {}) + webhook_payload = payload_data.get("payload") + webhook_event_id = payload_data.get("webhook_event_id") + + if webhook_payload and webhook_event_id: + ts = utc_now() + + # Use organization_id as the signing secret + b64secret = base64.b64encode( + organization_id.encode("utf-8") + ).decode("utf-8") + + # Sign the payload + wh = StandardWebhook(b64secret) + signature = wh.sign(webhook_event_id, ts, webhook_payload) + + # Add signed headers to the event + event["headers"] = { + "user-agent": "polar.sh webhooks", + "content-type": "application/json", + "webhook-id": webhook_event_id, + "webhook-timestamp": str(int(ts.timestamp())), + "webhook-signature": signature, + } + + event["payload"]["payload"] = json.loads(webhook_payload) + + yield json.dumps(event) + continue + except (json.JSONDecodeError, KeyError) as e: + log.warning("Failed to transform webhook event", error=str(e)) + + # Yield original message if not a webhook event or if transformation failed + yield message + + @router.get("/listen") async def listen( request: Request, @@ -24,4 +80,22 @@ async def listen( session: AsyncSession = Depends(get_db_session), ) -> EventSourceResponse: receivers = Receivers(organization_id=auth_subject.subject.id) - return EventSourceResponse(subscribe(redis, receivers.get_channels(), request)) + event_stream = subscribe(redis, receivers.get_channels(), request) + transformed_stream = transform_webhook_events( + str(auth_subject.subject.id), event_stream + ) + + async def first_event_wrapper(): + # Send a first event announcing connection established + yield json.dumps( + { + "key": "connected", + "ts": str(utc_now()), + "secret": str(auth_subject.subject.id), + } + ) + + async for message in transformed_stream: + yield message + + return EventSourceResponse(first_event_wrapper()) diff --git a/server/polar/webhook/service.py b/server/polar/webhook/service.py index 3c105e4bb0..b1002bff51 100644 --- a/server/polar/webhook/service.py +++ b/server/polar/webhook/service.py @@ -1,7 +1,6 @@ -import base64 import datetime import json -from collections.abc import Mapping, Sequence +from collections.abc import Sequence from typing import Literal, cast, overload from uuid import UUID @@ -9,7 +8,6 @@ from sqlalchemy import CursorResult, String, desc, func, or_, select, text, update from sqlalchemy import cast as sql_cast from sqlalchemy.orm import joinedload -from standardwebhooks.webhooks import Webhook as StandardWebhook from polar.auth.models import AuthSubject from polar.checkout.eventstream import CheckoutEvent, publish_checkout_event @@ -51,7 +49,6 @@ from polar.user_organization.service import ( user_organization as user_organization_service, ) -from polar.webhook.eventstream import publish_webhook_event from polar.webhook.repository import ( WebhookDeliveryRepository, WebhookEndpointRepository, @@ -738,6 +735,12 @@ async def send( events.append(event_type) await session.flush() enqueue_job("webhook_event.send", webhook_event_id=event_type.id) + # Publish webhook event to eventstream for CLI listeners + enqueue_job( + "webhook_event.publish", + webhook_event_id=event_type.id, + organization_id=target.id, + ) except UnsupportedTarget as e: # Log the error but do not raise to not fail the whole request log.error(e.message) @@ -745,33 +748,6 @@ async def send( except SkipEvent: continue - # Publish webhook event to eventstream for CLI listeners - try: - ts = utc_now() - - b64secret = base64.b64encode( - event.webhook_endpoint.secret.encode("utf-8") - ).decode("utf-8") - - # Sign the payload - wh = StandardWebhook(b64secret) - signature = wh.sign(str(event.id), ts, event.payload) - - headers: Mapping[str, str] = { - "user-agent": "polar.sh webhooks", - "content-type": "application/json", - "webhook-id": str(event.id), - "webhook-timestamp": str(int(ts.timestamp())), - "webhook-signature": signature, - } - await publish_webhook_event( - payload=event.payload, - headers=headers, - organization_id=target.id, - ) - except Exception as e: - log.warning("Failed to publish webhook to eventstream", error=str(e)) - return events async def archive_events( diff --git a/server/polar/webhook/tasks.py b/server/polar/webhook/tasks.py index 4af4691c7d..b5acde7908 100644 --- a/server/polar/webhook/tasks.py +++ b/server/polar/webhook/tasks.py @@ -208,3 +208,62 @@ async def webhook_event_archive() -> None: return await webhook_service.archive_events( session, older_than=utc_now() - settings.WEBHOOK_EVENT_RETENTION_PERIOD ) + + +@actor(actor_name="webhook_event.publish", priority=TaskPriority.MEDIUM) +async def webhook_event_publish(webhook_event_id: UUID, organization_id: UUID) -> None: + """ + Publish a webhook event to the eventstream for CLI listeners. + + Args: + webhook_event_id: ID of the webhook event to publish + organization_id: ID of the organization (used as signing secret) + """ + async with AsyncSessionMaker() as session: + return await _webhook_event_publish( + session, webhook_event_id=webhook_event_id, organization_id=organization_id + ) + + +async def _webhook_event_publish( + session: AsyncSession, *, webhook_event_id: UUID, organization_id: UUID +) -> None: + from polar.eventstream.service import publish + from polar.webhook.eventstream import WebhookEvent + + repository = WebhookEventRepository.from_session(session) + event = await repository.get_by_id( + webhook_event_id, options=repository.get_eager_options() + ) + if event is None: + log.warning( + "Webhook event not found for eventstream publishing", + webhook_event_id=webhook_event_id, + ) + return + + if event.payload is None: + log.debug( + "Webhook event has no payload, skipping eventstream publish", + webhook_event_id=webhook_event_id, + ) + return + + try: + # Publish raw webhook event data to eventstream + # The CLI endpoint will apply transformation (headers, signing) when sending to client + await publish( + WebhookEvent.webhook_created, + { + "webhook_event_id": str(event.id), + "payload": event.payload, + }, + organization_id=organization_id, + ) + except Exception as e: + log.warning( + "Failed to publish webhook to eventstream", + webhook_event_id=webhook_event_id, + organization_id=organization_id, + error=str(e), + ) From 4ff8b603a2715c8f09d937173235922468d3b05b Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Fri, 19 Dec 2025 13:54:30 +0100 Subject: [PATCH 12/18] cli: fix --- server/polar/cli/endpoints.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/server/polar/cli/endpoints.py b/server/polar/cli/endpoints.py index f74210dd80..8175cac3f2 100644 --- a/server/polar/cli/endpoints.py +++ b/server/polar/cli/endpoints.py @@ -19,7 +19,6 @@ log = structlog.get_logger() - router = APIRouter(prefix="/cli", tags=["cli", APITag.private]) @@ -43,10 +42,10 @@ async def transform_webhook_events( if webhook_payload and webhook_event_id: ts = utc_now() + secret = str(organization_id).replace("-", "") + # Use organization_id as the signing secret - b64secret = base64.b64encode( - organization_id.encode("utf-8") - ).decode("utf-8") + b64secret = base64.b64encode(secret.encode("utf-8")).decode("utf-8") # Sign the payload wh = StandardWebhook(b64secret) @@ -86,12 +85,14 @@ async def listen( ) async def first_event_wrapper(): + secret = str(auth_subject.subject.id).replace("-", "") + # Send a first event announcing connection established yield json.dumps( { "key": "connected", "ts": str(utc_now()), - "secret": str(auth_subject.subject.id), + "secret": secret, } ) From 916a8c480a605698c65a71df343183ce846928c1 Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Thu, 5 Feb 2026 10:41:31 +0100 Subject: [PATCH 13/18] make sure to send webhook events regardless of configured endpoints --- server/polar/webhook/eventstream.py | 29 +++++++++-------------------- server/polar/webhook/service.py | 13 +++++++------ 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/server/polar/webhook/eventstream.py b/server/polar/webhook/eventstream.py index 0d01719f03..36fa7028c0 100644 --- a/server/polar/webhook/eventstream.py +++ b/server/polar/webhook/eventstream.py @@ -1,41 +1,30 @@ -from collections.abc import Mapping from enum import StrEnum -from typing import Any, TypedDict from uuid import UUID from polar.eventstream.service import publish +from polar.kit.utils import generate_uuid class WebhookEvent(StrEnum): webhook_created = "webhook.created" -class WebhookEventPayload(TypedDict): - type: str - timestamp: str - organization_id: str - payload: dict[str, Any] - - async def publish_webhook_event( organization_id: UUID, - payload: dict[str, Any], - headers: Mapping[str, str], + payload: str, ) -> None: """ - Publish a webhook event to the eventstream for an organization. + Publish a webhook event to the eventstream for CLI listeners. Args: - headers: The headers of the event - payload: The webhook payload + organization_id: The organization to publish the event for + payload: The raw JSON string of the webhook payload """ - event_payload: dict[str, Any] = { - "payload": payload, - "headers": headers, - } - await publish( WebhookEvent.webhook_created, - event_payload, + { + "webhook_event_id": str(generate_uuid()), + "payload": payload, + }, organization_id=organization_id, ) diff --git a/server/polar/webhook/service.py b/server/polar/webhook/service.py index b1002bff51..acd3fcd419 100644 --- a/server/polar/webhook/service.py +++ b/server/polar/webhook/service.py @@ -56,6 +56,7 @@ ) from polar.worker import enqueue_job +from .eventstream import publish_webhook_event from .schemas import WebhookEndpointCreate, WebhookEndpointUpdate from .webhooks import SkipEvent, UnsupportedTarget, WebhookPayloadTypeAdapter @@ -719,6 +720,12 @@ async def send( {"type": event, "timestamp": now, "data": data} ) + # Publish to eventstream for CLI listeners, regardless of webhook endpoints + await publish_webhook_event( + organization_id=target.id, + payload=payload.get_raw_payload(), + ) + events: list[WebhookEvent] = [] for endpoint in await self._get_event_target_endpoints( session, event=event, target=target @@ -735,12 +742,6 @@ async def send( events.append(event_type) await session.flush() enqueue_job("webhook_event.send", webhook_event_id=event_type.id) - # Publish webhook event to eventstream for CLI listeners - enqueue_job( - "webhook_event.publish", - webhook_event_id=event_type.id, - organization_id=target.id, - ) except UnsupportedTarget as e: # Log the error but do not raise to not fail the whole request log.error(e.message) From a27f3ca49d8094d77afb52988fba161e8cdf91c9 Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Thu, 5 Feb 2026 11:06:49 +0100 Subject: [PATCH 14/18] fix typing --- server/polar/cli/endpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/polar/cli/endpoints.py b/server/polar/cli/endpoints.py index 8175cac3f2..846bfc2110 100644 --- a/server/polar/cli/endpoints.py +++ b/server/polar/cli/endpoints.py @@ -84,7 +84,7 @@ async def listen( str(auth_subject.subject.id), event_stream ) - async def first_event_wrapper(): + async def first_event_wrapper() -> AsyncGenerator[str]: secret = str(auth_subject.subject.id).replace("-", "") # Send a first event announcing connection established From c88a5b7fb903bea03cb950ba1ffb49b0d09ebbad Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Thu, 5 Feb 2026 13:57:50 +0100 Subject: [PATCH 15/18] fix listen endpoint --- server/polar/cli/endpoints.py | 25 ++++++++++++++++++------- server/polar/cli/listener.py | 22 ++++++++++++++++++++++ server/polar/eventstream/endpoints.py | 6 +++++- server/polar/webhook/eventstream.py | 16 ++++++++++++++++ 4 files changed, 61 insertions(+), 8 deletions(-) create mode 100644 server/polar/cli/listener.py diff --git a/server/polar/cli/endpoints.py b/server/polar/cli/endpoints.py index 846bfc2110..a144beed0a 100644 --- a/server/polar/cli/endpoints.py +++ b/server/polar/cli/endpoints.py @@ -9,6 +9,7 @@ from standardwebhooks.webhooks import Webhook as StandardWebhook from polar.cli import auth +from polar.cli.listener import mark_active, mark_inactive from polar.eventstream.endpoints import subscribe from polar.eventstream.service import Receivers from polar.kit.utils import utc_now @@ -78,14 +79,21 @@ async def listen( redis: Redis = Depends(get_redis), session: AsyncSession = Depends(get_db_session), ) -> EventSourceResponse: - receivers = Receivers(organization_id=auth_subject.subject.id) - event_stream = subscribe(redis, receivers.get_channels(), request) - transformed_stream = transform_webhook_events( - str(auth_subject.subject.id), event_stream + org_id = auth_subject.subject.id + + await mark_active(redis, org_id) + + async def refresh_listener() -> None: + await mark_active(redis, org_id) + + receivers = Receivers(organization_id=org_id) + event_stream = subscribe( + redis, receivers.get_channels(), request, on_iteration=refresh_listener ) + transformed_stream = transform_webhook_events(str(org_id), event_stream) async def first_event_wrapper() -> AsyncGenerator[str]: - secret = str(auth_subject.subject.id).replace("-", "") + secret = str(org_id).replace("-", "") # Send a first event announcing connection established yield json.dumps( @@ -96,7 +104,10 @@ async def first_event_wrapper() -> AsyncGenerator[str]: } ) - async for message in transformed_stream: - yield message + try: + async for message in transformed_stream: + yield message + finally: + await mark_inactive(redis, org_id) return EventSourceResponse(first_event_wrapper()) diff --git a/server/polar/cli/listener.py b/server/polar/cli/listener.py new file mode 100644 index 0000000000..449d4ec84a --- /dev/null +++ b/server/polar/cli/listener.py @@ -0,0 +1,22 @@ +from uuid import UUID + +from polar.redis import Redis + +LISTENER_KEY_PREFIX = "cli:listening" +LISTENER_TTL_SECONDS = 30 + + +def _key(org_id: UUID) -> str: + return f"{LISTENER_KEY_PREFIX}:{org_id}" + + +async def mark_active(redis: Redis, org_id: UUID) -> None: + await redis.set(_key(org_id), "1", ex=LISTENER_TTL_SECONDS) + + +async def mark_inactive(redis: Redis, org_id: UUID) -> None: + await redis.delete(_key(org_id)) + + +async def has_active_listener(redis: Redis, org_id: UUID) -> bool: + return await redis.exists(_key(org_id)) > 0 diff --git a/server/polar/eventstream/endpoints.py b/server/polar/eventstream/endpoints.py index f5581df725..0c461883fb 100644 --- a/server/polar/eventstream/endpoints.py +++ b/server/polar/eventstream/endpoints.py @@ -1,5 +1,5 @@ import asyncio -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Any import structlog @@ -51,6 +51,7 @@ async def subscribe( redis: Redis, channels: list[str], request: Request, + on_iteration: Callable[[], Awaitable[None]] | None = None, ) -> AsyncGenerator[Any, Any]: async with redis.pubsub() as pubsub: await pubsub.subscribe(*channels) @@ -60,6 +61,9 @@ async def subscribe( await pubsub.close() break + if on_iteration is not None: + await on_iteration() + try: message = await pubsub.get_message( ignore_subscribe_messages=True, diff --git a/server/polar/webhook/eventstream.py b/server/polar/webhook/eventstream.py index 36fa7028c0..6300cdf48e 100644 --- a/server/polar/webhook/eventstream.py +++ b/server/polar/webhook/eventstream.py @@ -1,14 +1,26 @@ from enum import StrEnum from uuid import UUID +from polar.cli.listener import has_active_listener from polar.eventstream.service import publish from polar.kit.utils import generate_uuid +from polar.redis import Redis, create_redis class WebhookEvent(StrEnum): webhook_created = "webhook.created" +_check_redis: Redis | None = None + + +def _get_check_redis() -> Redis: + global _check_redis + if _check_redis is None: + _check_redis = create_redis("app") + return _check_redis + + async def publish_webhook_event( organization_id: UUID, payload: str, @@ -20,6 +32,10 @@ async def publish_webhook_event( organization_id: The organization to publish the event for payload: The raw JSON string of the webhook payload """ + redis = _get_check_redis() + if not await has_active_listener(redis, organization_id): + return + await publish( WebhookEvent.webhook_created, { From 6336e761f81006fb2319c9a93280f6dad6699ac4 Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Thu, 5 Feb 2026 14:48:29 +0100 Subject: [PATCH 16/18] add tests to cli-listen --- server/polar/postgres.py | 4 +- server/tests/cli/__init__.py | 0 server/tests/cli/conftest.py | 25 ++ server/tests/cli/test_endpoints.py | 228 +++++++++++++++++++ server/tests/cli/test_listener.py | 85 +++++++ server/tests/cli/test_webhook_eventstream.py | 72 ++++++ server/tests/eventstream/__init__.py | 0 server/tests/eventstream/conftest.py | 25 ++ server/tests/eventstream/test_subscribe.py | 104 +++++++++ server/tests/fixtures/redis.py | 11 + 10 files changed, 551 insertions(+), 3 deletions(-) create mode 100644 server/tests/cli/__init__.py create mode 100644 server/tests/cli/conftest.py create mode 100644 server/tests/cli/test_endpoints.py create mode 100644 server/tests/cli/test_listener.py create mode 100644 server/tests/cli/test_webhook_eventstream.py create mode 100644 server/tests/eventstream/__init__.py create mode 100644 server/tests/eventstream/conftest.py create mode 100644 server/tests/eventstream/test_subscribe.py diff --git a/server/polar/postgres.py b/server/polar/postgres.py index 7c77623116..ae2e2f2488 100644 --- a/server/polar/postgres.py +++ b/server/polar/postgres.py @@ -94,9 +94,7 @@ async def get_db_session(request: Request) -> AsyncGenerator[AsyncSession]: await session.commit() -async def get_db_read_session( - request: Request, -) -> AsyncGenerator[AsyncReadSession]: +async def get_db_read_session(request: Request) -> AsyncGenerator[AsyncReadSession]: sessionmaker: AsyncReadSessionMaker = request.state.async_read_sessionmaker async with sessionmaker() as session: yield session diff --git a/server/tests/cli/__init__.py b/server/tests/cli/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/server/tests/cli/conftest.py b/server/tests/cli/conftest.py new file mode 100644 index 0000000000..cd2183a319 --- /dev/null +++ b/server/tests/cli/conftest.py @@ -0,0 +1,25 @@ +from collections.abc import AsyncIterator +from unittest.mock import AsyncMock + +import pytest +import pytest_asyncio + +from polar.kit.db.postgres import AsyncSession + + +@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True) +async def initialize_test_database() -> AsyncIterator[None]: + """Override: CLI tests don't need the database.""" + yield + + +@pytest_asyncio.fixture +async def session() -> AsyncIterator[AsyncSession]: + """Override: provide a mock session so patch_middlewares doesn't hit Postgres.""" + yield AsyncMock(spec=AsyncSession) + + +@pytest.fixture(autouse=True) +def patch_middlewares() -> None: + """Override: CLI tests don't need worker middleware patching.""" + pass diff --git a/server/tests/cli/test_endpoints.py b/server/tests/cli/test_endpoints.py new file mode 100644 index 0000000000..95ab4241a2 --- /dev/null +++ b/server/tests/cli/test_endpoints.py @@ -0,0 +1,228 @@ +import json +import uuid +from collections.abc import AsyncGenerator +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from pytest_mock import MockerFixture + +from polar.cli.listener import LISTENER_KEY_PREFIX, has_active_listener +from polar.redis import Redis + + +@pytest.fixture +def org_id() -> uuid.UUID: + return uuid.uuid4() + + +@pytest.fixture +def mock_subscribe(mocker: MockerFixture) -> MagicMock: + """Mock subscribe to yield controlled messages, capturing on_iteration.""" + + async def fake_subscribe( + redis: Any, + channels: list[str], + request: Any, + on_iteration: Any = None, + ) -> AsyncGenerator[Any, Any]: + # Call on_iteration once to simulate a loop tick + if on_iteration is not None: + await on_iteration() + yield json.dumps({"key": "some.event", "payload": {}}) + + return mocker.patch( + "polar.cli.endpoints.subscribe", + side_effect=fake_subscribe, + ) + + +@pytest.fixture +def mock_auth_subject(org_id: uuid.UUID) -> AsyncMock: + subject = AsyncMock() + subject.subject.id = org_id + return subject + + +@pytest.fixture +def mock_request() -> AsyncMock: + return AsyncMock() + + +@pytest.mark.asyncio +class TestListenEndpoint: + async def test_mark_active_on_connect( + self, + redis: Redis, + org_id: uuid.UUID, + mock_subscribe: MagicMock, + mock_auth_subject: AsyncMock, + mock_request: AsyncMock, + ) -> None: + """mark_active is called immediately when listen() is invoked.""" + from polar.cli.endpoints import listen + + response = await listen( + request=mock_request, + auth_subject=mock_auth_subject, + redis=redis, + session=AsyncMock(), + ) + + # Before consuming the generator, the key should already be set + assert await has_active_listener(redis, org_id) is True + + # Consume and close the generator to clean up + gen = response.body_iterator + async for _ in gen: + break + await gen.aclose() + + async def test_mark_inactive_on_disconnect( + self, + redis: Redis, + org_id: uuid.UUID, + mock_subscribe: MagicMock, + mock_auth_subject: AsyncMock, + mock_request: AsyncMock, + ) -> None: + """mark_inactive is called when the stream generator finishes.""" + from polar.cli.endpoints import listen + + response = await listen( + request=mock_request, + auth_subject=mock_auth_subject, + redis=redis, + session=AsyncMock(), + ) + + # Fully consume the generator (simulates client disconnect) + gen = response.body_iterator + async for _ in gen: + pass + + # After generator is exhausted, key should be deleted + assert await has_active_listener(redis, org_id) is False + + async def test_mark_inactive_called_on_subscribe_error( + self, + redis: Redis, + org_id: uuid.UUID, + mock_auth_subject: AsyncMock, + mock_request: AsyncMock, + mocker: MockerFixture, + ) -> None: + """mark_inactive is called even when subscribe raises an exception.""" + + async def error_subscribe( + redis: Any, + channels: list[str], + request: Any, + on_iteration: Any = None, + ) -> AsyncGenerator[Any, Any]: + yield json.dumps({"key": "event.1", "payload": {}}) + raise ConnectionError("connection lost") + + mocker.patch( + "polar.cli.endpoints.subscribe", + side_effect=error_subscribe, + ) + + from polar.cli.endpoints import listen + + response = await listen( + request=mock_request, + auth_subject=mock_auth_subject, + redis=redis, + session=AsyncMock(), + ) + + assert await has_active_listener(redis, org_id) is True + + # Consume the generator — the error in subscribe triggers the finally block + gen = response.body_iterator + with pytest.raises(ConnectionError): + async for _ in gen: + pass + + assert await has_active_listener(redis, org_id) is False + + async def test_refresh_ttl_via_on_iteration( + self, + redis: Redis, + org_id: uuid.UUID, + mock_auth_subject: AsyncMock, + mock_request: AsyncMock, + mocker: MockerFixture, + ) -> None: + """The on_iteration callback passed to subscribe refreshes the TTL.""" + captured_callback: list[Any] = [] + + async def capturing_subscribe( + redis: Any, + channels: list[str], + request: Any, + on_iteration: Any = None, + ) -> AsyncGenerator[Any, Any]: + captured_callback.append(on_iteration) + # Don't yield anything — we just want to capture the callback + return + yield # make it an async generator + + mocker.patch( + "polar.cli.endpoints.subscribe", + side_effect=capturing_subscribe, + ) + + from polar.cli.endpoints import listen + + response = await listen( + request=mock_request, + auth_subject=mock_auth_subject, + redis=redis, + session=AsyncMock(), + ) + + # Consume generator to trigger subscribe call + async for _ in response.body_iterator: + pass + + assert len(captured_callback) == 1 + assert captured_callback[0] is not None + + # Simulate TTL nearly expired + key = f"{LISTENER_KEY_PREFIX}:{org_id}" + await redis.expire(key, 2) + assert await redis.ttl(key) <= 2 + + # Call the captured callback — should refresh TTL + await captured_callback[0]() + assert await redis.ttl(key) > 2 + + async def test_first_event_is_connected( + self, + redis: Redis, + org_id: uuid.UUID, + mock_subscribe: MagicMock, + mock_auth_subject: AsyncMock, + mock_request: AsyncMock, + ) -> None: + """First SSE event is a 'connected' event with the secret.""" + from polar.cli.endpoints import listen + + response = await listen( + request=mock_request, + auth_subject=mock_auth_subject, + redis=redis, + session=AsyncMock(), + ) + + gen = response.body_iterator + first_event = await gen.__anext__() + data = json.loads(first_event) + + assert data["key"] == "connected" + assert data["secret"] == str(org_id).replace("-", "") + assert "ts" in data + + await gen.aclose() diff --git a/server/tests/cli/test_listener.py b/server/tests/cli/test_listener.py new file mode 100644 index 0000000000..e1515c3949 --- /dev/null +++ b/server/tests/cli/test_listener.py @@ -0,0 +1,85 @@ +import uuid + +import pytest + +from polar.cli.listener import ( + LISTENER_KEY_PREFIX, + LISTENER_TTL_SECONDS, + has_active_listener, + mark_active, + mark_inactive, +) +from polar.redis import Redis + + +@pytest.mark.asyncio +class TestMarkActive: + async def test_sets_key(self, redis: Redis) -> None: + org_id = uuid.uuid4() + await mark_active(redis, org_id) + + assert await redis.exists(f"{LISTENER_KEY_PREFIX}:{org_id}") > 0 + + async def test_sets_ttl(self, redis: Redis) -> None: + org_id = uuid.uuid4() + await mark_active(redis, org_id) + + ttl = await redis.ttl(f"{LISTENER_KEY_PREFIX}:{org_id}") + assert 0 < ttl <= LISTENER_TTL_SECONDS + + async def test_refreshes_ttl(self, redis: Redis) -> None: + org_id = uuid.uuid4() + key = f"{LISTENER_KEY_PREFIX}:{org_id}" + + await mark_active(redis, org_id) + # Simulate time passing by lowering TTL manually + await redis.expire(key, 5) + assert await redis.ttl(key) <= 5 + + # Refresh should restore full TTL + await mark_active(redis, org_id) + ttl = await redis.ttl(key) + assert ttl > 5 + + +@pytest.mark.asyncio +class TestMarkInactive: + async def test_deletes_key(self, redis: Redis) -> None: + org_id = uuid.uuid4() + await mark_active(redis, org_id) + assert await redis.exists(f"{LISTENER_KEY_PREFIX}:{org_id}") > 0 + + await mark_inactive(redis, org_id) + assert await redis.exists(f"{LISTENER_KEY_PREFIX}:{org_id}") == 0 + + async def test_noop_when_not_set(self, redis: Redis) -> None: + org_id = uuid.uuid4() + # Should not raise + await mark_inactive(redis, org_id) + + +@pytest.mark.asyncio +class TestHasActiveListener: + async def test_returns_true_when_active(self, redis: Redis) -> None: + org_id = uuid.uuid4() + await mark_active(redis, org_id) + assert await has_active_listener(redis, org_id) is True + + async def test_returns_false_when_inactive(self, redis: Redis) -> None: + org_id = uuid.uuid4() + assert await has_active_listener(redis, org_id) is False + + async def test_returns_false_after_mark_inactive(self, redis: Redis) -> None: + org_id = uuid.uuid4() + await mark_active(redis, org_id) + await mark_inactive(redis, org_id) + assert await has_active_listener(redis, org_id) is False + + async def test_independent_per_org(self, redis: Redis) -> None: + org_a = uuid.uuid4() + org_b = uuid.uuid4() + + await mark_active(redis, org_a) + + assert await has_active_listener(redis, org_a) is True + assert await has_active_listener(redis, org_b) is False diff --git a/server/tests/cli/test_webhook_eventstream.py b/server/tests/cli/test_webhook_eventstream.py new file mode 100644 index 0000000000..80e048d32d --- /dev/null +++ b/server/tests/cli/test_webhook_eventstream.py @@ -0,0 +1,72 @@ +import uuid +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from polar.redis import Redis +from polar.webhook.eventstream import publish_webhook_event + + +@pytest.fixture +def publish_mock(mocker: MockerFixture) -> MagicMock: + return mocker.patch("polar.webhook.eventstream.publish") + + +@pytest.mark.asyncio +class TestPublishWebhookEvent: + async def test_skips_when_no_listener( + self, publish_mock: MagicMock, redis: Redis + ) -> None: + org_id = uuid.uuid4() + + await publish_webhook_event(org_id, '{"type": "test"}') + + publish_mock.assert_not_called() + + async def test_publishes_when_listener_active( + self, publish_mock: MagicMock, redis: Redis + ) -> None: + from polar.cli.listener import mark_active + + org_id = uuid.uuid4() + await mark_active(redis, org_id) + + await publish_webhook_event(org_id, '{"type": "test"}') + + publish_mock.assert_called_once() + call_kwargs = publish_mock.call_args + assert call_kwargs[0][0] == "webhook.created" + assert call_kwargs[1]["organization_id"] == org_id + payload = call_kwargs[0][1] + assert payload["payload"] == '{"type": "test"}' + assert "webhook_event_id" in payload + + async def test_skips_after_listener_disconnects( + self, publish_mock: MagicMock, redis: Redis + ) -> None: + from polar.cli.listener import mark_active, mark_inactive + + org_id = uuid.uuid4() + await mark_active(redis, org_id) + await mark_inactive(redis, org_id) + + await publish_webhook_event(org_id, '{"type": "test"}') + + publish_mock.assert_not_called() + + async def test_only_publishes_to_listening_org( + self, publish_mock: MagicMock, redis: Redis + ) -> None: + from polar.cli.listener import mark_active + + listening_org = uuid.uuid4() + other_org = uuid.uuid4() + + await mark_active(redis, listening_org) + + await publish_webhook_event(other_org, '{"type": "test"}') + publish_mock.assert_not_called() + + await publish_webhook_event(listening_org, '{"type": "test"}') + publish_mock.assert_called_once() diff --git a/server/tests/eventstream/__init__.py b/server/tests/eventstream/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/server/tests/eventstream/conftest.py b/server/tests/eventstream/conftest.py new file mode 100644 index 0000000000..c9b0d49372 --- /dev/null +++ b/server/tests/eventstream/conftest.py @@ -0,0 +1,25 @@ +from collections.abc import AsyncIterator +from unittest.mock import AsyncMock + +import pytest +import pytest_asyncio + +from polar.kit.db.postgres import AsyncSession + + +@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True) +async def initialize_test_database() -> AsyncIterator[None]: + """Override: eventstream subscribe tests don't need the database.""" + yield + + +@pytest_asyncio.fixture +async def session() -> AsyncIterator[AsyncSession]: + """Override: provide a mock session so patch_middlewares doesn't hit Postgres.""" + yield AsyncMock(spec=AsyncSession) + + +@pytest.fixture(autouse=True) +def patch_middlewares() -> None: + """Override: eventstream tests don't need worker middleware patching.""" + pass diff --git a/server/tests/eventstream/test_subscribe.py b/server/tests/eventstream/test_subscribe.py new file mode 100644 index 0000000000..3415df2a3c --- /dev/null +++ b/server/tests/eventstream/test_subscribe.py @@ -0,0 +1,104 @@ +import asyncio +from unittest.mock import AsyncMock + +import pytest +from pytest_mock import MockerFixture + +from polar.eventstream.endpoints import subscribe +from polar.redis import Redis + + +@pytest.fixture(autouse=True) +def _no_uvicorn_exit(mocker: MockerFixture) -> None: + """Ensure _uvicorn_should_exit always returns False during tests.""" + mocker.patch( + "polar.eventstream.endpoints._uvicorn_should_exit", + return_value=False, + ) + + +def _make_request(disconnect_after: int) -> AsyncMock: + """Create a mock request that disconnects after N iterations.""" + request = AsyncMock() + request.is_disconnected = AsyncMock( + side_effect=[False] * disconnect_after + [True] + ) + return request + + +@pytest.mark.asyncio +class TestSubscribeOnIteration: + async def test_on_iteration_called_each_loop(self, redis: Redis) -> None: + """on_iteration callback is invoked on every loop iteration.""" + channel = "test:on_iter" + iterations = 3 + callback = AsyncMock() + request = _make_request(disconnect_after=iterations) + + async for _ in subscribe(redis, [channel], request, on_iteration=callback): + pass + + assert callback.call_count == iterations + + async def test_on_iteration_none_does_not_error(self, redis: Redis) -> None: + """subscribe works without an on_iteration callback.""" + channel = "test:no_cb" + request = _make_request(disconnect_after=1) + + # Should complete without error + messages = [] + async for msg in subscribe(redis, [channel], request): + messages.append(msg) + + # No messages expected — just verifying no crash + assert isinstance(messages, list) + + async def test_yields_published_messages(self, redis: Redis) -> None: + """Messages published to the channel are yielded by subscribe.""" + channel = "test:msgs" + callback = AsyncMock() + + # Allow enough iterations to receive messages + disconnect + request = _make_request(disconnect_after=4) + + async def publish_after_subscribe() -> None: + # Small delay to let subscribe register + await asyncio.sleep(0.1) + await redis.publish(channel, "msg1") + await asyncio.sleep(0.05) + await redis.publish(channel, "msg2") + + task = asyncio.create_task(publish_after_subscribe()) + + messages = [] + async for msg in subscribe( + redis, [channel], request, on_iteration=callback + ): + messages.append(msg) + + await task + + # fakeredis returns bytes since decode_responses is not set + assert b"msg1" in messages or "msg1" in messages + assert b"msg2" in messages or "msg2" in messages + assert callback.call_count >= 1 + + async def test_on_iteration_called_before_get_message( + self, redis: Redis + ) -> None: + """on_iteration is called before waiting for messages, not after.""" + channel = "test:order" + call_order: list[str] = [] + + async def track_callback() -> None: + call_order.append("callback") + + request = _make_request(disconnect_after=1) + + async for _ in subscribe( + redis, [channel], request, on_iteration=track_callback + ): + call_order.append("message") + + # callback should have been called even though no messages arrived + assert "callback" in call_order diff --git a/server/tests/fixtures/redis.py b/server/tests/fixtures/redis.py index 65604d8f41..d9651195e5 100644 --- a/server/tests/fixtures/redis.py +++ b/server/tests/fixtures/redis.py @@ -1,7 +1,9 @@ from collections.abc import AsyncIterator +import pytest import pytest_asyncio from fakeredis import FakeAsyncRedis +from pytest_mock import MockerFixture from polar.redis import Redis @@ -9,3 +11,12 @@ @pytest_asyncio.fixture(autouse=True) async def redis() -> AsyncIterator[Redis]: yield FakeAsyncRedis() + + +@pytest.fixture(autouse=True) +def patch_webhook_eventstream_redis(mocker: MockerFixture, redis: Redis) -> None: + """Ensure publish_webhook_event uses fakeredis instead of a real connection.""" + mocker.patch( + "polar.webhook.eventstream._get_check_redis", + return_value=redis, + ) From ad6a785cae7709c2a5da97d86df40929ea49bd92 Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Thu, 5 Feb 2026 14:58:29 +0100 Subject: [PATCH 17/18] fix linting --- server/tests/eventstream/test_subscribe.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/server/tests/eventstream/test_subscribe.py b/server/tests/eventstream/test_subscribe.py index 3415df2a3c..fa2ce8e9aa 100644 --- a/server/tests/eventstream/test_subscribe.py +++ b/server/tests/eventstream/test_subscribe.py @@ -20,9 +20,7 @@ def _no_uvicorn_exit(mocker: MockerFixture) -> None: def _make_request(disconnect_after: int) -> AsyncMock: """Create a mock request that disconnects after N iterations.""" request = AsyncMock() - request.is_disconnected = AsyncMock( - side_effect=[False] * disconnect_after + [True] - ) + request.is_disconnected = AsyncMock(side_effect=[False] * disconnect_after + [True]) return request @@ -71,9 +69,7 @@ async def publish_after_subscribe() -> None: task = asyncio.create_task(publish_after_subscribe()) messages = [] - async for msg in subscribe( - redis, [channel], request, on_iteration=callback - ): + async for msg in subscribe(redis, [channel], request, on_iteration=callback): messages.append(msg) await task @@ -83,9 +79,7 @@ async def publish_after_subscribe() -> None: assert b"msg2" in messages or "msg2" in messages assert callback.call_count >= 1 - async def test_on_iteration_called_before_get_message( - self, redis: Redis - ) -> None: + async def test_on_iteration_called_before_get_message(self, redis: Redis) -> None: """on_iteration is called before waiting for messages, not after.""" channel = "test:order" call_order: list[str] = [] From 9480492306c21ca16610ebaab2d0a274f02c37fd Mon Sep 17 00:00:00 2001 From: Emil Widlund Date: Thu, 5 Feb 2026 15:14:35 +0100 Subject: [PATCH 18/18] fix linter --- server/tests/cli/test_endpoints.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/server/tests/cli/test_endpoints.py b/server/tests/cli/test_endpoints.py index 95ab4241a2..5b7d656d32 100644 --- a/server/tests/cli/test_endpoints.py +++ b/server/tests/cli/test_endpoints.py @@ -72,11 +72,9 @@ async def test_mark_active_on_connect( # Before consuming the generator, the key should already be set assert await has_active_listener(redis, org_id) is True - # Consume and close the generator to clean up - gen = response.body_iterator - async for _ in gen: + # Consume the generator to clean up + async for _ in response.body_iterator: break - await gen.aclose() async def test_mark_inactive_on_disconnect( self, @@ -217,12 +215,10 @@ async def test_first_event_is_connected( session=AsyncMock(), ) - gen = response.body_iterator - first_event = await gen.__anext__() + first_event = await anext(aiter(response.body_iterator)) + assert isinstance(first_event, str) data = json.loads(first_event) assert data["key"] == "connected" assert data["secret"] == str(org_id).replace("-", "") assert "ts" in data - - await gen.aclose()