From f9da5792cc4e96688515a740e1e49c4fe4fb005d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jesper=20Br=C3=A4nn?= Date: Fri, 12 Dec 2025 10:24:27 +0100 Subject: [PATCH] auth: Validate token checksums Log an error if the checksum is not matching even though the token was valid. --- server/polar/auth/middlewares.py | 66 +++++++++++++++++++- server/polar/kit/crypto.py | 25 ++++++++ server/tests/auth/__init__.py | 0 server/tests/auth/test_middlewares.py | 88 +++++++++++++++++++++++++++ server/tests/kit/test_crypto.py | 34 +++++++++++ 5 files changed, 211 insertions(+), 2 deletions(-) create mode 100644 server/tests/auth/__init__.py create mode 100644 server/tests/auth/test_middlewares.py create mode 100644 server/tests/kit/test_crypto.py diff --git a/server/polar/auth/middlewares.py b/server/polar/auth/middlewares.py index 7c6766b298..6748a21cf0 100644 --- a/server/polar/auth/middlewares.py +++ b/server/polar/auth/middlewares.py @@ -5,7 +5,13 @@ from starlette.types import ASGIApp, Receive, Send from starlette.types import Scope as ASGIScope -from polar.customer_session.service import customer_session as customer_session_service +from polar.customer_session.service import ( + CUSTOMER_SESSION_TOKEN_PREFIX, +) +from polar.customer_session.service import ( + customer_session as customer_session_service, +) +from polar.kit.crypto import validate_token_checksum from polar.kit.utils import utc_now from polar.logging import Logger from polar.models import ( @@ -15,13 +21,19 @@ PersonalAccessToken, UserSession, ) -from polar.oauth2.constants import is_registration_token_prefix +from polar.oauth2.constants import ACCESS_TOKEN_PREFIX, is_registration_token_prefix from polar.oauth2.exception_handlers import OAuth2Error, oauth2_error_exception_handler from polar.oauth2.exceptions import InvalidTokenError from polar.oauth2.service.oauth2_token import oauth2_token as oauth2_token_service +from polar.organization_access_token.service import ( + TOKEN_PREFIX as ORGANIZATION_ACCESS_TOKEN_PREFIX, +) from polar.organization_access_token.service import ( organization_access_token as organization_access_token_service, ) +from polar.personal_access_token.service import ( + TOKEN_PREFIX as PERSONAL_ACCESS_TOKEN_PREFIX, +) from polar.personal_access_token.service import ( personal_access_token as personal_access_token_service, ) @@ -36,6 +48,13 @@ log: Logger = structlog.get_logger(__name__) +def _has_valid_oauth2_token_checksum(token: str) -> bool: + for prefix in ACCESS_TOKEN_PREFIX.values(): + if validate_token_checksum(token, prefix=prefix): + return True + return False + + async def get_user_session( request: Request, session: AsyncSession ) -> UserSession | None: @@ -92,6 +111,30 @@ async def get_customer_session( return await customer_session_service.get_by_token(session, value) +def _get_token_prefix(token: str) -> str | None: + """Return the prefix if the token matches a known token type.""" + if token.startswith(CUSTOMER_SESSION_TOKEN_PREFIX): + return CUSTOMER_SESSION_TOKEN_PREFIX + if token.startswith(ORGANIZATION_ACCESS_TOKEN_PREFIX): + return ORGANIZATION_ACCESS_TOKEN_PREFIX + if token.startswith(PERSONAL_ACCESS_TOKEN_PREFIX): + return PERSONAL_ACCESS_TOKEN_PREFIX + for prefix in ACCESS_TOKEN_PREFIX.values(): + if token.startswith(prefix): + return prefix + return None + + +def _log_checksum_mismatch(token_type: str, prefix: str, token_length: int) -> None: + """Log an error when a valid token has an invalid checksum.""" + log.error( + "Valid token has invalid checksum", + token_type=token_type, + token_prefix=prefix, + token_length=token_length, + ) + + async def get_auth_subject( request: Request, session: AsyncSession ) -> AuthSubject[Subject]: @@ -102,6 +145,10 @@ async def get_auth_subject( customer_session = await get_customer_session(session, token) if customer_session: + if not validate_token_checksum(token, prefix=CUSTOMER_SESSION_TOKEN_PREFIX): + _log_checksum_mismatch( + "customer_session", CUSTOMER_SESSION_TOKEN_PREFIX, len(token) + ) return AuthSubject( customer_session.customer, {Scope.customer_portal_write}, @@ -110,6 +157,14 @@ async def get_auth_subject( organization_access_token = await get_organization_access_token(session, token) if organization_access_token: + if not validate_token_checksum( + token, prefix=ORGANIZATION_ACCESS_TOKEN_PREFIX + ): + _log_checksum_mismatch( + "organization_access_token", + ORGANIZATION_ACCESS_TOKEN_PREFIX, + len(token), + ) return AuthSubject( organization_access_token.organization, organization_access_token.scopes, @@ -118,10 +173,17 @@ async def get_auth_subject( oauth2_token = await get_oauth2_token(session, token) if oauth2_token: + if not _has_valid_oauth2_token_checksum(token): + prefix = _get_token_prefix(token) + _log_checksum_mismatch("oauth2_token", prefix or "unknown", len(token)) return AuthSubject(oauth2_token.sub, oauth2_token.scopes, oauth2_token) personal_access_token = await get_personal_access_token(session, token) if personal_access_token: + if not validate_token_checksum(token, prefix=PERSONAL_ACCESS_TOKEN_PREFIX): + _log_checksum_mismatch( + "personal_access_token", PERSONAL_ACCESS_TOKEN_PREFIX, len(token) + ) return AuthSubject( personal_access_token.user, personal_access_token.scopes, diff --git a/server/polar/kit/crypto.py b/server/polar/kit/crypto.py index 3268803f84..85143e4181 100644 --- a/server/polar/kit/crypto.py +++ b/server/polar/kit/crypto.py @@ -43,3 +43,28 @@ def generate_token_hash_pair(*, secret: str, prefix: str = "") -> tuple[str, str """ token = generate_token(prefix=prefix) return token, get_token_hash(token, secret=secret) + + +def validate_token_checksum(token: str, *, prefix: str) -> bool: + """ + Validate that a token has a valid CRC32 checksum. + + Tokens are structured as: {prefix}{37 random chars}{6 checksum chars} + Returns True if the checksum is valid, False otherwise. + """ + if not token.startswith(prefix): + return False + + token_without_prefix = token[len(prefix) :] + + # Token should be 37 random chars + 6 checksum chars = 43 chars + if len(token_without_prefix) != 43: + return False + + random_part = token_without_prefix[:37] + checksum_part = token_without_prefix[37:] + + expected_checksum = zlib.crc32(random_part.encode("utf-8")) & 0xFFFFFFFF + expected_checksum_base62 = _crc32_to_base62(expected_checksum) + + return checksum_part == expected_checksum_base62 diff --git a/server/tests/auth/__init__.py b/server/tests/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/server/tests/auth/test_middlewares.py b/server/tests/auth/test_middlewares.py new file mode 100644 index 0000000000..172b3239db --- /dev/null +++ b/server/tests/auth/test_middlewares.py @@ -0,0 +1,88 @@ +from datetime import timedelta +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture +from starlette.requests import Request + +from polar.auth.middlewares import get_auth_subject +from polar.config import settings +from polar.kit.crypto import generate_token, get_token_hash +from polar.kit.utils import utc_now +from polar.models import PersonalAccessToken, User +from polar.personal_access_token.service import TOKEN_PREFIX as PAT_PREFIX +from polar.postgres import AsyncSession +from tests.fixtures.database import SaveFixture + + +@pytest.fixture(autouse=True) +def enqueue_job_mock(mocker: MockerFixture) -> MagicMock: + return mocker.patch("polar.auth.middlewares.enqueue_job", autospec=True) + + +def _make_request_with_token(token: str) -> Request: + scope = { + "type": "http", + "headers": [(b"authorization", f"Bearer {token}".encode())], + } + return Request(scope) + + +@pytest.mark.asyncio +class TestGetAuthSubjectChecksumValidation: + async def test_valid_checksum_no_error_logged( + self, + save_fixture: SaveFixture, + session: AsyncSession, + user: User, + mocker: MockerFixture, + ) -> None: + log_mock = mocker.patch("polar.auth.middlewares.log") + + token = generate_token(prefix=PAT_PREFIX) + token_hash = get_token_hash(token, secret=settings.SECRET) + pat = PersonalAccessToken( + comment="Test", + token=token_hash, + user_id=user.id, + expires_at=utc_now() + timedelta(days=1), + scope="openid", + ) + await save_fixture(pat) + + request = _make_request_with_token(token) + auth_subject = await get_auth_subject(request, session) + + assert auth_subject.subject == user + log_mock.error.assert_not_called() + + async def test_invalid_checksum_logs_error( + self, + save_fixture: SaveFixture, + session: AsyncSession, + user: User, + mocker: MockerFixture, + ) -> None: + log_mock = mocker.patch("polar.auth.middlewares.log") + + legacy_token = PAT_PREFIX + "a" * 32 + token_hash = get_token_hash(legacy_token, secret=settings.SECRET) + pat = PersonalAccessToken( + comment="Legacy Token", + token=token_hash, + user_id=user.id, + expires_at=utc_now() + timedelta(days=1), + scope="openid", + ) + await save_fixture(pat) + + request = _make_request_with_token(legacy_token) + auth_subject = await get_auth_subject(request, session) + + assert auth_subject.subject == user + log_mock.error.assert_called_once_with( + "Valid token has invalid checksum", + token_type="personal_access_token", + token_prefix=PAT_PREFIX, + token_length=len(legacy_token), + ) diff --git a/server/tests/kit/test_crypto.py b/server/tests/kit/test_crypto.py new file mode 100644 index 0000000000..c70cd4917a --- /dev/null +++ b/server/tests/kit/test_crypto.py @@ -0,0 +1,34 @@ +from polar.kit.crypto import generate_token, validate_token_checksum + + +class TestValidateTokenChecksum: + def test_valid_checksum(self) -> None: + prefix = "polar_pat_" + token = generate_token(prefix=prefix) + assert validate_token_checksum(token, prefix=prefix) is True + + def test_invalid_checksum_tampered(self) -> None: + prefix = "polar_pat_" + token = generate_token(prefix=prefix) + tampered = token[:-1] + ("A" if token[-1] != "A" else "B") + assert validate_token_checksum(tampered, prefix=prefix) is False + + def test_wrong_prefix(self) -> None: + token = generate_token(prefix="polar_pat_") + assert validate_token_checksum(token, prefix="polar_oat_") is False + + def test_token_too_short(self) -> None: + prefix = "polar_pat_" + short_token = prefix + "a" * 32 + assert validate_token_checksum(short_token, prefix=prefix) is False + + def test_token_too_long(self) -> None: + prefix = "polar_pat_" + long_token = generate_token(prefix=prefix) + "extra" + assert validate_token_checksum(long_token, prefix=prefix) is False + + def test_different_prefixes(self) -> None: + prefixes = ["polar_pat_", "polar_oat_", "polar_cst_", "polar_at_u_"] + for prefix in prefixes: + token = generate_token(prefix=prefix) + assert validate_token_checksum(token, prefix=prefix) is True