Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ SECRET_KEY="" # To generate: python -c "import os; print(os.urandom(32).hex())"
HOST_URI="127.0.0.1:8000"
ENV="development" # change to "production" in production

# CORS — allowed origins for private routes (auth, oauth, dashboard)
# Public API routes (/api/v1/*) always allow all origins.
# Set to your frontend domain(s) in production. Leave empty to block all cross-origin
# requests to private routes (safe default).
# Example: CORS_PRIVATE_ORIGINS='["https://spoo.me","chrome-extension://your-ext-id"]'
CORS_PRIVATE_ORIGINS='[]'

# Logging Configuration
LOG_LEVEL=DEBUG # DEBUG, INFO, WARNING, ERROR, CRITICAL
LOG_FORMAT=console # json (prod) or console (dev)
Expand Down
54 changes: 50 additions & 4 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
configure_openapi,
)
from middleware.rate_limiter import limiter
from middleware.security import MaxContentLengthMiddleware, configure_cors
from middleware.security import (
MaxContentLengthMiddleware,
SecurityHeadersMiddleware,
configure_cors,
)
from repositories.indexes import ensure_indexes
from routes.api_v1 import router as api_v1_router
from routes.auth_routes import router as auth_router
Expand Down Expand Up @@ -120,6 +124,45 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:

await ensure_indexes(app.state.db)

# Warn if session secret is missing when auth is enabled
if settings.jwt and not settings.secret_key:
log.warning(
"secret_key_empty",
detail="SECRET_KEY is empty — session cookies are unsigned. "
"Set a strong SECRET_KEY when auth/OAuth is enabled.",
)

# Warn if CORS private origins not configured in production
if settings.is_production and not settings.cors_private_origins:
log.warning(
"cors_private_origins_empty",
detail="CORS_PRIVATE_ORIGINS is empty — auth/oauth/dashboard routes "
"will reject all cross-origin requests. Set to your frontend domain(s).",
)

# Warn if JWT config is weak (auth is optional, so don't crash)
if settings.jwt and bool(settings.jwt.jwt_private_key) != bool(
settings.jwt.jwt_public_key
):
log.warning(
"jwt_rsa_half_configured",
detail="Only one of JWT_PRIVATE_KEY / JWT_PUBLIC_KEY is set — "
"both are required for RS256. Falling back to HS256.",
)

if settings.jwt and not settings.jwt.use_rs256:
if not settings.jwt.jwt_secret:
log.warning(
"jwt_config_insecure",
detail="RS256 keys not set and JWT_SECRET is empty — tokens can be forged. "
"Set JWT_PRIVATE_KEY + JWT_PUBLIC_KEY or a strong JWT_SECRET.",
)
elif len(settings.jwt.jwt_secret) < 32:
log.warning(
"jwt_secret_weak",
detail="JWT_SECRET is shorter than 32 characters — consider using RS256 keys or a longer secret.",
)

yield

# ── Shutdown ─────────────────────────────────────────────────────────
Expand Down Expand Up @@ -164,13 +207,16 @@ async def docs(request: Request):
# ── Middleware (registered in reverse execution order) ────────────────
# 1. Session — outermost, needed by Authlib OAuth for state storage
app.add_middleware(SessionMiddleware, secret_key=settings.secret_key)
# 2. CORS — must wrap everything
# 2. Security headers — must be outer so HSTS/CSP/nosniff apply to
# all responses including CORS preflights (204) and body-limit (413)
app.add_middleware(SecurityHeadersMiddleware, hsts_enabled=settings.is_production)
# 3. CORS
configure_cors(app, settings)
# 3. Body size limit
# 4. Body size limit
app.add_middleware(
MaxContentLengthMiddleware, max_content_length=settings.max_content_length
)
# 4. Request logging — innermost, logs all requests with request_id
# 5. Request logging — innermost, logs all requests with request_id
app.add_middleware(RequestLoggingMiddleware)

# ── Error handlers + rate limiter ────────────────────────────────────
Expand Down
10 changes: 8 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,14 @@ class AppSettings(BaseSettings):
app_url: str = "https://spoo.me"
app_name: str = "spoo.me"

# CORS — default matches current behaviour: all origins, credentials allowed
cors_origins: list[str] = ["*"]
# CORS — public API routes allow all origins (no credentials).
# Private routes (auth, oauth, dashboard) require explicit origin allowlist.
cors_origins: list[str] = ["*"] # deprecated — kept for backward compat
cors_private_origins: list[str] = []

# Device auth flow - allowed redirect URIs for third-party clients
# (e.g., mobile apps, desktop apps). Browser extensions don't need this.
device_auth_redirect_uris: list[str] = []

# Request body size limit (bytes); 1 MB default
max_content_length: int = 1_048_576
Expand Down
7 changes: 6 additions & 1 deletion infrastructure/oauth_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,10 @@ def init_oauth(settings: Any) -> tuple[OAuth | None, dict[str, Any]]:


def generate_oauth_state(
provider: str, action: OAuthAction = OAuthAction.LOGIN, user_id: str | None = None
provider: str,
action: OAuthAction = OAuthAction.LOGIN,
user_id: str | None = None,
next_url: str | None = None,
) -> str:
"""Generate a URL-safe state string for CSRF protection."""
action_str = action.value if isinstance(action, OAuthAction) else action
Expand All @@ -161,6 +164,8 @@ def generate_oauth_state(
]
if user_id:
parts.append(f"user_id={user_id}")
if next_url:
parts.append(f"next={next_url}")
return "&".join(parts)


Expand Down
4 changes: 4 additions & 0 deletions middleware/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class Limits:
PASSWORD_RESET_REQUEST = "3 per hour"
PASSWORD_RESET_CONFIRM = "5 per hour"

# Device auth flow (extensions, apps, CLIs)
DEVICE_AUTH = "10 per minute"
DEVICE_TOKEN = "10 per minute"

# OAuth
OAUTH_INIT = "10 per minute"
OAUTH_CALLBACK = "20 per minute"
Expand Down
119 changes: 110 additions & 9 deletions middleware/security.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,131 @@
"""
Security middleware — CORS configuration and request body size limit.
Security middleware — CORS configuration, security headers, and request body size limit.
"""

from __future__ import annotations

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
from starlette.responses import PlainTextResponse, Response

from config import AppSettings

# ── Path-based CORS classification ──────────────────────────────────────────

_PRIVATE_PREFIXES = ("/auth", "/oauth", "/dashboard")
_PUBLIC_PREFIXES = ("/api/v1", "/auth/device", "/stats", "/export", "/metric")

_ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
_ALLOWED_HEADERS = "Authorization, Content-Type, Accept, X-Request-ID"


def _classify_path(path: str) -> str:
"""Return 'public', 'private', or 'none' based on the request path."""
for prefix in _PUBLIC_PREFIXES:
if path == prefix or path.startswith(prefix + "/"):
return "public"
for prefix in _PRIVATE_PREFIXES:
if path == prefix or path.startswith(prefix + "/"):
return "private"
# Legacy root shortener endpoint (POST /) is public API
if path == "/":
return "public"
return "none"


class SplitCORSMiddleware(BaseHTTPMiddleware):
"""Apply different CORS policies based on the request path.

- Public routes (``/api/v1/*``, legacy API): ``Access-Control-Allow-Origin: *``,
no credentials.
- Private routes (``/auth/*``, ``/oauth/*``, ``/dashboard/*``): origin checked
against *private_origins* allowlist, credentials allowed.
- All other routes: no CORS headers.
"""

def __init__(self, app, *, private_origins: list[str]) -> None:
super().__init__(app)
self._private_origins: set[str] = {o.rstrip("/") for o in private_origins}

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
origin = request.headers.get("origin")
path = request.url.path
classification = _classify_path(path)

# Handle preflight — only intercept routes that need CORS
if request.method == "OPTIONS" and origin and classification != "none":
return self._preflight(origin, classification)

response = await call_next(request)

if origin and classification != "none":
self._set_cors_headers(response, origin, classification)

return response

def _preflight(self, origin: str, classification: str) -> Response:
"""Return a 204 preflight response with appropriate CORS headers."""
response = PlainTextResponse("", status_code=204)
if classification != "none":
self._set_cors_headers(response, origin, classification)
response.headers["Access-Control-Allow-Methods"] = _ALLOWED_METHODS
response.headers["Access-Control-Allow-Headers"] = _ALLOWED_HEADERS
response.headers["Access-Control-Max-Age"] = "86400"
return response

def _set_cors_headers(
self, response: Response, origin: str, classification: str
) -> None:
"""Set CORS headers based on route classification."""
if classification == "public":
response.headers["Access-Control-Allow-Origin"] = "*"
elif classification == "private" and origin in self._private_origins:
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Vary"] = "Origin"


def configure_cors(app: FastAPI, settings: AppSettings) -> None:
"""Add CORS middleware from settings. Defaults to allow-all with credentials."""
"""Add split CORS middleware — different policies for public vs private routes."""
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
SplitCORSMiddleware,
private_origins=settings.cors_private_origins,
)


class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Set standard security headers on all responses."""

def __init__(self, app, *, hsts_enabled: bool = True) -> None:
super().__init__(app)
self.hsts_enabled = hsts_enabled

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
response = await call_next(request)
if self.hsts_enabled:
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains"
)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Permissions-Policy"] = (
"camera=(), microphone=(), geolocation=()"
)
response.headers["Content-Security-Policy-Report-Only"] = (
"default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: https:; font-src 'self'; connect-src 'self'; "
"frame-ancestors 'none'"
)
return response


class MaxContentLengthMiddleware(BaseHTTPMiddleware):
"""Reject requests whose Content-Length exceeds the configured limit."""

Expand Down
4 changes: 4 additions & 0 deletions repositories/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,9 @@ async def ensure_indexes(db: AsyncDatabase) -> None:
await tokens_col.create_index([("token_hash", 1)])
await tokens_col.create_index([("token_type", 1)])
await tokens_col.create_index([("expires_at", 1)], expireAfterSeconds=0)
await tokens_col.create_index(
[("user_id", 1), ("token_type", 1), ("used_at", 1), ("created_at", -1)],
name="ix_latest_unused_by_user",
)

log.info("mongodb_indexes_ensured")
82 changes: 75 additions & 7 deletions repositories/token_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,10 @@ async def create(self, token_data: dict) -> ObjectId:
)
raise

async def find_by_hash(
async def find_by_hash_and_type(
self, token_hash: str, token_type: str
) -> VerificationTokenDoc | None:
"""
Find a non-used token by its SHA-256 hash and type.

Only returns tokens where ``used_at`` is None (not yet consumed).
"""
"""Find a non-used token by its SHA-256 hash and type."""
try:
doc = await self._col.find_one(
{
Expand All @@ -56,7 +52,35 @@ async def find_by_hash(
return VerificationTokenDoc.from_mongo(doc)
except PyMongoError as exc:
log.error(
"token_repo_find_by_hash_failed",
"token_repo_find_by_hash_and_type_failed",
token_type=token_type,
error=str(exc),
error_type=type(exc).__name__,
)
raise

async def consume_by_hash(
self, token_hash: str, token_type: str
) -> VerificationTokenDoc | None:
"""Atomically find an unused, non-expired token and mark it as used.

Returns the pre-update document, or None if no matching token exists.
"""
now = datetime.now(timezone.utc)
try:
doc = await self._col.find_one_and_update(
{
"token_hash": token_hash,
"token_type": token_type,
"used_at": None,
"expires_at": {"$gt": now},
},
{"$set": {"used_at": now}},
)
return VerificationTokenDoc.from_mongo(doc)
except PyMongoError as exc:
log.error(
"token_repo_consume_by_hash_failed",
token_type=token_type,
error=str(exc),
error_type=type(exc).__name__,
Expand Down Expand Up @@ -84,6 +108,50 @@ async def mark_as_used(self, token_id: ObjectId) -> bool:
)
raise

async def find_latest_by_user(
self, user_id: ObjectId, token_type: str
) -> VerificationTokenDoc | None:
"""Find the most recent non-used token for a user and type."""
try:
doc = await self._col.find_one(
{
"user_id": user_id,
"token_type": token_type,
"used_at": None,
},
sort=[("created_at", -1)],
)
return VerificationTokenDoc.from_mongo(doc)
except PyMongoError as exc:
log.error(
"token_repo_find_latest_by_user_failed",
user_id=str(user_id),
token_type=token_type,
error=str(exc),
error_type=type(exc).__name__,
)
raise

async def increment_attempts(self, token_id: ObjectId) -> bool:
"""Atomically increment the ``attempts`` counter on a token.

Returns True if a document was modified.
"""
try:
result = await self._col.update_one(
{"_id": token_id},
{"$inc": {"attempts": 1}},
)
return result.modified_count > 0
except PyMongoError as exc:
log.error(
"token_repo_increment_attempts_failed",
token_id=str(token_id),
error=str(exc),
error_type=type(exc).__name__,
)
raise

async def delete_by_user(
self, user_id: ObjectId, token_type: str | None = None
) -> int:
Expand Down
Loading
Loading