Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions server/polar/auth/middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
from starlette.types import ASGIApp, Receive, Send
from starlette.types import Scope as ASGIScope

from polar.customer_session.service import customer_session as customer_session_service
from polar.customer_session.service import (
CUSTOMER_SESSION_TOKEN_PREFIX,
)
from polar.customer_session.service import (
customer_session as customer_session_service,
)
from polar.kit.crypto import validate_token_checksum
from polar.kit.utils import utc_now
from polar.logging import Logger
from polar.models import (
Expand All @@ -15,13 +21,19 @@
PersonalAccessToken,
UserSession,
)
from polar.oauth2.constants import is_registration_token_prefix
from polar.oauth2.constants import ACCESS_TOKEN_PREFIX, is_registration_token_prefix
from polar.oauth2.exception_handlers import OAuth2Error, oauth2_error_exception_handler
from polar.oauth2.exceptions import InvalidTokenError
from polar.oauth2.service.oauth2_token import oauth2_token as oauth2_token_service
from polar.organization_access_token.service import (
TOKEN_PREFIX as ORGANIZATION_ACCESS_TOKEN_PREFIX,
)
from polar.organization_access_token.service import (
organization_access_token as organization_access_token_service,
)
from polar.personal_access_token.service import (
TOKEN_PREFIX as PERSONAL_ACCESS_TOKEN_PREFIX,
)
from polar.personal_access_token.service import (
personal_access_token as personal_access_token_service,
)
Expand All @@ -36,6 +48,13 @@
log: Logger = structlog.get_logger(__name__)


def _has_valid_oauth2_token_checksum(token: str) -> bool:
for prefix in ACCESS_TOKEN_PREFIX.values():
if validate_token_checksum(token, prefix=prefix):
return True
return False


async def get_user_session(
request: Request, session: AsyncSession
) -> UserSession | None:
Expand Down Expand Up @@ -92,6 +111,30 @@ async def get_customer_session(
return await customer_session_service.get_by_token(session, value)


def _get_token_prefix(token: str) -> str | None:
"""Return the prefix if the token matches a known token type."""
if token.startswith(CUSTOMER_SESSION_TOKEN_PREFIX):
return CUSTOMER_SESSION_TOKEN_PREFIX
if token.startswith(ORGANIZATION_ACCESS_TOKEN_PREFIX):
return ORGANIZATION_ACCESS_TOKEN_PREFIX
if token.startswith(PERSONAL_ACCESS_TOKEN_PREFIX):
return PERSONAL_ACCESS_TOKEN_PREFIX
for prefix in ACCESS_TOKEN_PREFIX.values():
if token.startswith(prefix):
return prefix
return None


def _log_checksum_mismatch(token_type: str, prefix: str, token_length: int) -> None:
"""Log an error when a valid token has an invalid checksum."""
log.error(
"Valid token has invalid checksum",
token_type=token_type,
token_prefix=prefix,
token_length=token_length,
)


async def get_auth_subject(
request: Request, session: AsyncSession
) -> AuthSubject[Subject]:
Expand All @@ -102,6 +145,10 @@ async def get_auth_subject(

customer_session = await get_customer_session(session, token)
if customer_session:
if not validate_token_checksum(token, prefix=CUSTOMER_SESSION_TOKEN_PREFIX):
_log_checksum_mismatch(
"customer_session", CUSTOMER_SESSION_TOKEN_PREFIX, len(token)
)
return AuthSubject(
customer_session.customer,
{Scope.customer_portal_write},
Expand All @@ -110,6 +157,14 @@ async def get_auth_subject(

organization_access_token = await get_organization_access_token(session, token)
if organization_access_token:
if not validate_token_checksum(
token, prefix=ORGANIZATION_ACCESS_TOKEN_PREFIX
):
_log_checksum_mismatch(
"organization_access_token",
ORGANIZATION_ACCESS_TOKEN_PREFIX,
len(token),
)
return AuthSubject(
organization_access_token.organization,
organization_access_token.scopes,
Expand All @@ -118,10 +173,17 @@ async def get_auth_subject(

oauth2_token = await get_oauth2_token(session, token)
if oauth2_token:
if not _has_valid_oauth2_token_checksum(token):
prefix = _get_token_prefix(token)
_log_checksum_mismatch("oauth2_token", prefix or "unknown", len(token))
return AuthSubject(oauth2_token.sub, oauth2_token.scopes, oauth2_token)

personal_access_token = await get_personal_access_token(session, token)
if personal_access_token:
if not validate_token_checksum(token, prefix=PERSONAL_ACCESS_TOKEN_PREFIX):
_log_checksum_mismatch(
"personal_access_token", PERSONAL_ACCESS_TOKEN_PREFIX, len(token)
)
return AuthSubject(
personal_access_token.user,
personal_access_token.scopes,
Expand Down
25 changes: 25 additions & 0 deletions server/polar/kit/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,28 @@ def generate_token_hash_pair(*, secret: str, prefix: str = "") -> tuple[str, str
"""
token = generate_token(prefix=prefix)
return token, get_token_hash(token, secret=secret)


def validate_token_checksum(token: str, *, prefix: str) -> bool:
"""
Validate that a token has a valid CRC32 checksum.

Tokens are structured as: {prefix}{37 random chars}{6 checksum chars}
Returns True if the checksum is valid, False otherwise.
"""
if not token.startswith(prefix):
return False

token_without_prefix = token[len(prefix) :]

# Token should be 37 random chars + 6 checksum chars = 43 chars
if len(token_without_prefix) != 43:
return False

random_part = token_without_prefix[:37]
checksum_part = token_without_prefix[37:]

expected_checksum = zlib.crc32(random_part.encode("utf-8")) & 0xFFFFFFFF
expected_checksum_base62 = _crc32_to_base62(expected_checksum)

return checksum_part == expected_checksum_base62
Empty file added server/tests/auth/__init__.py
Empty file.
88 changes: 88 additions & 0 deletions server/tests/auth/test_middlewares.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from datetime import timedelta
from unittest.mock import MagicMock

import pytest
from pytest_mock import MockerFixture
from starlette.requests import Request

from polar.auth.middlewares import get_auth_subject
from polar.config import settings
from polar.kit.crypto import generate_token, get_token_hash
from polar.kit.utils import utc_now
from polar.models import PersonalAccessToken, User
from polar.personal_access_token.service import TOKEN_PREFIX as PAT_PREFIX
from polar.postgres import AsyncSession
from tests.fixtures.database import SaveFixture


@pytest.fixture(autouse=True)
def enqueue_job_mock(mocker: MockerFixture) -> MagicMock:
return mocker.patch("polar.auth.middlewares.enqueue_job", autospec=True)


def _make_request_with_token(token: str) -> Request:
scope = {
"type": "http",
"headers": [(b"authorization", f"Bearer {token}".encode())],
}
return Request(scope)


@pytest.mark.asyncio
class TestGetAuthSubjectChecksumValidation:
async def test_valid_checksum_no_error_logged(
self,
save_fixture: SaveFixture,
session: AsyncSession,
user: User,
mocker: MockerFixture,
) -> None:
log_mock = mocker.patch("polar.auth.middlewares.log")

token = generate_token(prefix=PAT_PREFIX)
token_hash = get_token_hash(token, secret=settings.SECRET)
pat = PersonalAccessToken(
comment="Test",
token=token_hash,
user_id=user.id,
expires_at=utc_now() + timedelta(days=1),
scope="openid",
)
await save_fixture(pat)

request = _make_request_with_token(token)
auth_subject = await get_auth_subject(request, session)

assert auth_subject.subject == user
log_mock.error.assert_not_called()

async def test_invalid_checksum_logs_error(
self,
save_fixture: SaveFixture,
session: AsyncSession,
user: User,
mocker: MockerFixture,
) -> None:
log_mock = mocker.patch("polar.auth.middlewares.log")

legacy_token = PAT_PREFIX + "a" * 32
token_hash = get_token_hash(legacy_token, secret=settings.SECRET)
pat = PersonalAccessToken(
comment="Legacy Token",
token=token_hash,
user_id=user.id,
expires_at=utc_now() + timedelta(days=1),
scope="openid",
)
await save_fixture(pat)

request = _make_request_with_token(legacy_token)
auth_subject = await get_auth_subject(request, session)

assert auth_subject.subject == user
log_mock.error.assert_called_once_with(
"Valid token has invalid checksum",
token_type="personal_access_token",
token_prefix=PAT_PREFIX,
token_length=len(legacy_token),
)
34 changes: 34 additions & 0 deletions server/tests/kit/test_crypto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from polar.kit.crypto import generate_token, validate_token_checksum


class TestValidateTokenChecksum:
def test_valid_checksum(self) -> None:
prefix = "polar_pat_"
token = generate_token(prefix=prefix)
assert validate_token_checksum(token, prefix=prefix) is True

def test_invalid_checksum_tampered(self) -> None:
prefix = "polar_pat_"
token = generate_token(prefix=prefix)
tampered = token[:-1] + ("A" if token[-1] != "A" else "B")
assert validate_token_checksum(tampered, prefix=prefix) is False

def test_wrong_prefix(self) -> None:
token = generate_token(prefix="polar_pat_")
assert validate_token_checksum(token, prefix="polar_oat_") is False

def test_token_too_short(self) -> None:
prefix = "polar_pat_"
short_token = prefix + "a" * 32
assert validate_token_checksum(short_token, prefix=prefix) is False

def test_token_too_long(self) -> None:
prefix = "polar_pat_"
long_token = generate_token(prefix=prefix) + "extra"
assert validate_token_checksum(long_token, prefix=prefix) is False

def test_different_prefixes(self) -> None:
prefixes = ["polar_pat_", "polar_oat_", "polar_cst_", "polar_at_u_"]
for prefix in prefixes:
token = generate_token(prefix=prefix)
assert validate_token_checksum(token, prefix=prefix) is True