diff --git a/.env.example b/.env.example
index c1bc247..e92488c 100644
--- a/.env.example
+++ b/.env.example
@@ -16,6 +16,13 @@ SECRET_KEY="" # To generate: python -c "import os; print(os.urandom(32).hex())"
HOST_URI="127.0.0.1:8000"
ENV="development" # change to "production" in production
+# CORS — allowed origins for private routes (auth, oauth, dashboard)
+# Public API routes (/api/v1/*) always allow all origins.
+# Set to your frontend domain(s) in production. Leave empty to block all cross-origin
+# requests to private routes (safe default).
+# Example: CORS_PRIVATE_ORIGINS='["https://spoo.me","chrome-extension://your-ext-id"]'
+CORS_PRIVATE_ORIGINS='[]'
+
# Logging Configuration
LOG_LEVEL=DEBUG # DEBUG, INFO, WARNING, ERROR, CRITICAL
LOG_FORMAT=console # json (prod) or console (dev)
diff --git a/app.py b/app.py
index 1f43631..29aac91 100644
--- a/app.py
+++ b/app.py
@@ -33,7 +33,11 @@
configure_openapi,
)
from middleware.rate_limiter import limiter
-from middleware.security import MaxContentLengthMiddleware, configure_cors
+from middleware.security import (
+ MaxContentLengthMiddleware,
+ SecurityHeadersMiddleware,
+ configure_cors,
+)
from repositories.indexes import ensure_indexes
from routes.api_v1 import router as api_v1_router
from routes.auth_routes import router as auth_router
@@ -120,6 +124,45 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
await ensure_indexes(app.state.db)
+ # Warn if session secret is missing when auth is enabled
+ if settings.jwt and not settings.secret_key:
+ log.warning(
+ "secret_key_empty",
+ detail="SECRET_KEY is empty — session cookies are unsigned. "
+ "Set a strong SECRET_KEY when auth/OAuth is enabled.",
+ )
+
+ # Warn if CORS private origins not configured in production
+ if settings.is_production and not settings.cors_private_origins:
+ log.warning(
+ "cors_private_origins_empty",
+ detail="CORS_PRIVATE_ORIGINS is empty — auth/oauth/dashboard routes "
+ "will reject all cross-origin requests. Set to your frontend domain(s).",
+ )
+
+ # Warn if JWT config is weak (auth is optional, so don't crash)
+ if settings.jwt and bool(settings.jwt.jwt_private_key) != bool(
+ settings.jwt.jwt_public_key
+ ):
+ log.warning(
+ "jwt_rsa_half_configured",
+ detail="Only one of JWT_PRIVATE_KEY / JWT_PUBLIC_KEY is set — "
+ "both are required for RS256. Falling back to HS256.",
+ )
+
+ if settings.jwt and not settings.jwt.use_rs256:
+ if not settings.jwt.jwt_secret:
+ log.warning(
+ "jwt_config_insecure",
+ detail="RS256 keys not set and JWT_SECRET is empty — tokens can be forged. "
+ "Set JWT_PRIVATE_KEY + JWT_PUBLIC_KEY or a strong JWT_SECRET.",
+ )
+ elif len(settings.jwt.jwt_secret) < 32:
+ log.warning(
+ "jwt_secret_weak",
+ detail="JWT_SECRET is shorter than 32 characters — consider using RS256 keys or a longer secret.",
+ )
+
yield
# ── Shutdown ─────────────────────────────────────────────────────────
@@ -164,13 +207,16 @@ async def docs(request: Request):
# ── Middleware (registered in reverse execution order) ────────────────
# 1. Session — outermost, needed by Authlib OAuth for state storage
app.add_middleware(SessionMiddleware, secret_key=settings.secret_key)
- # 2. CORS — must wrap everything
+ # 2. Security headers — must be outer so HSTS/CSP/nosniff apply to
+ # all responses including CORS preflights (204) and body-limit (413)
+ app.add_middleware(SecurityHeadersMiddleware, hsts_enabled=settings.is_production)
+ # 3. CORS
configure_cors(app, settings)
- # 3. Body size limit
+ # 4. Body size limit
app.add_middleware(
MaxContentLengthMiddleware, max_content_length=settings.max_content_length
)
- # 4. Request logging — innermost, logs all requests with request_id
+ # 5. Request logging — innermost, logs all requests with request_id
app.add_middleware(RequestLoggingMiddleware)
# ── Error handlers + rate limiter ────────────────────────────────────
diff --git a/config.py b/config.py
index e5e896b..1ccd589 100644
--- a/config.py
+++ b/config.py
@@ -108,8 +108,14 @@ class AppSettings(BaseSettings):
app_url: str = "https://spoo.me"
app_name: str = "spoo.me"
- # CORS — default matches current behaviour: all origins, credentials allowed
- cors_origins: list[str] = ["*"]
+ # CORS — public API routes allow all origins (no credentials).
+ # Private routes (auth, oauth, dashboard) require explicit origin allowlist.
+ cors_origins: list[str] = ["*"] # deprecated — kept for backward compat
+ cors_private_origins: list[str] = []
+
+ # Device auth flow - allowed redirect URIs for third-party clients
+ # (e.g., mobile apps, desktop apps). Browser extensions don't need this.
+ device_auth_redirect_uris: list[str] = []
# Request body size limit (bytes); 1 MB default
max_content_length: int = 1_048_576
diff --git a/infrastructure/oauth_clients.py b/infrastructure/oauth_clients.py
index b159ebf..ba4136e 100644
--- a/infrastructure/oauth_clients.py
+++ b/infrastructure/oauth_clients.py
@@ -149,7 +149,10 @@ def init_oauth(settings: Any) -> tuple[OAuth | None, dict[str, Any]]:
def generate_oauth_state(
- provider: str, action: OAuthAction = OAuthAction.LOGIN, user_id: str | None = None
+ provider: str,
+ action: OAuthAction = OAuthAction.LOGIN,
+ user_id: str | None = None,
+ next_url: str | None = None,
) -> str:
"""Generate a URL-safe state string for CSRF protection."""
action_str = action.value if isinstance(action, OAuthAction) else action
@@ -161,6 +164,8 @@ def generate_oauth_state(
]
if user_id:
parts.append(f"user_id={user_id}")
+ if next_url:
+ parts.append(f"next={next_url}")
return "&".join(parts)
diff --git a/middleware/rate_limiter.py b/middleware/rate_limiter.py
index 3757379..95ddf1a 100644
--- a/middleware/rate_limiter.py
+++ b/middleware/rate_limiter.py
@@ -46,6 +46,10 @@ class Limits:
PASSWORD_RESET_REQUEST = "3 per hour"
PASSWORD_RESET_CONFIRM = "5 per hour"
+ # Device auth flow (extensions, apps, CLIs)
+ DEVICE_AUTH = "10 per minute"
+ DEVICE_TOKEN = "10 per minute"
+
# OAuth
OAUTH_INIT = "10 per minute"
OAUTH_CALLBACK = "20 per minute"
diff --git a/middleware/security.py b/middleware/security.py
index cc1ab83..4868589 100644
--- a/middleware/security.py
+++ b/middleware/security.py
@@ -1,30 +1,131 @@
"""
-Security middleware — CORS configuration and request body size limit.
+Security middleware — CORS configuration, security headers, and request body size limit.
"""
from __future__ import annotations
from fastapi import FastAPI
-from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
-from starlette.responses import Response
+from starlette.responses import PlainTextResponse, Response
from config import AppSettings
+# ── Path-based CORS classification ──────────────────────────────────────────
+
+_PRIVATE_PREFIXES = ("/auth", "/oauth", "/dashboard")
+_PUBLIC_PREFIXES = ("/api/v1", "/auth/device", "/stats", "/export", "/metric")
+
+_ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
+_ALLOWED_HEADERS = "Authorization, Content-Type, Accept, X-Request-ID"
+
+
+def _classify_path(path: str) -> str:
+ """Return 'public', 'private', or 'none' based on the request path."""
+ for prefix in _PUBLIC_PREFIXES:
+ if path == prefix or path.startswith(prefix + "/"):
+ return "public"
+ for prefix in _PRIVATE_PREFIXES:
+ if path == prefix or path.startswith(prefix + "/"):
+ return "private"
+ # Legacy root shortener endpoint (POST /) is public API
+ if path == "/":
+ return "public"
+ return "none"
+
+
+class SplitCORSMiddleware(BaseHTTPMiddleware):
+ """Apply different CORS policies based on the request path.
+
+ - Public routes (``/api/v1/*``, legacy API): ``Access-Control-Allow-Origin: *``,
+ no credentials.
+ - Private routes (``/auth/*``, ``/oauth/*``, ``/dashboard/*``): origin checked
+ against *private_origins* allowlist, credentials allowed.
+ - All other routes: no CORS headers.
+ """
+
+ def __init__(self, app, *, private_origins: list[str]) -> None:
+ super().__init__(app)
+ self._private_origins: set[str] = {o.rstrip("/") for o in private_origins}
+
+ async def dispatch(
+ self, request: Request, call_next: RequestResponseEndpoint
+ ) -> Response:
+ origin = request.headers.get("origin")
+ path = request.url.path
+ classification = _classify_path(path)
+
+ # Handle preflight — only intercept routes that need CORS
+ if request.method == "OPTIONS" and origin and classification != "none":
+ return self._preflight(origin, classification)
+
+ response = await call_next(request)
+
+ if origin and classification != "none":
+ self._set_cors_headers(response, origin, classification)
+
+ return response
+
+ def _preflight(self, origin: str, classification: str) -> Response:
+ """Return a 204 preflight response with appropriate CORS headers."""
+ response = PlainTextResponse("", status_code=204)
+ if classification != "none":
+ self._set_cors_headers(response, origin, classification)
+ response.headers["Access-Control-Allow-Methods"] = _ALLOWED_METHODS
+ response.headers["Access-Control-Allow-Headers"] = _ALLOWED_HEADERS
+ response.headers["Access-Control-Max-Age"] = "86400"
+ return response
+
+ def _set_cors_headers(
+ self, response: Response, origin: str, classification: str
+ ) -> None:
+ """Set CORS headers based on route classification."""
+ if classification == "public":
+ response.headers["Access-Control-Allow-Origin"] = "*"
+ elif classification == "private" and origin in self._private_origins:
+ response.headers["Access-Control-Allow-Origin"] = origin
+ response.headers["Access-Control-Allow-Credentials"] = "true"
+ response.headers["Vary"] = "Origin"
+
def configure_cors(app: FastAPI, settings: AppSettings) -> None:
- """Add CORS middleware from settings. Defaults to allow-all with credentials."""
+ """Add split CORS middleware — different policies for public vs private routes."""
app.add_middleware(
- CORSMiddleware,
- allow_origins=settings.cors_origins,
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
+ SplitCORSMiddleware,
+ private_origins=settings.cors_private_origins,
)
+class SecurityHeadersMiddleware(BaseHTTPMiddleware):
+ """Set standard security headers on all responses."""
+
+ def __init__(self, app, *, hsts_enabled: bool = True) -> None:
+ super().__init__(app)
+ self.hsts_enabled = hsts_enabled
+
+ async def dispatch(
+ self, request: Request, call_next: RequestResponseEndpoint
+ ) -> Response:
+ response = await call_next(request)
+ if self.hsts_enabled:
+ response.headers["Strict-Transport-Security"] = (
+ "max-age=31536000; includeSubDomains"
+ )
+ response.headers["X-Content-Type-Options"] = "nosniff"
+ response.headers["X-Frame-Options"] = "DENY"
+ response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
+ response.headers["Permissions-Policy"] = (
+ "camera=(), microphone=(), geolocation=()"
+ )
+ response.headers["Content-Security-Policy-Report-Only"] = (
+ "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; "
+ "img-src 'self' data: https:; font-src 'self'; connect-src 'self'; "
+ "frame-ancestors 'none'"
+ )
+ return response
+
+
class MaxContentLengthMiddleware(BaseHTTPMiddleware):
"""Reject requests whose Content-Length exceeds the configured limit."""
diff --git a/repositories/indexes.py b/repositories/indexes.py
index dac0f6a..9fd39dd 100644
--- a/repositories/indexes.py
+++ b/repositories/indexes.py
@@ -82,5 +82,9 @@ async def ensure_indexes(db: AsyncDatabase) -> None:
await tokens_col.create_index([("token_hash", 1)])
await tokens_col.create_index([("token_type", 1)])
await tokens_col.create_index([("expires_at", 1)], expireAfterSeconds=0)
+ await tokens_col.create_index(
+ [("user_id", 1), ("token_type", 1), ("used_at", 1), ("created_at", -1)],
+ name="ix_latest_unused_by_user",
+ )
log.info("mongodb_indexes_ensured")
diff --git a/repositories/token_repository.py b/repositories/token_repository.py
index f7c9c96..a8f6145 100644
--- a/repositories/token_repository.py
+++ b/repositories/token_repository.py
@@ -37,14 +37,10 @@ async def create(self, token_data: dict) -> ObjectId:
)
raise
- async def find_by_hash(
+ async def find_by_hash_and_type(
self, token_hash: str, token_type: str
) -> VerificationTokenDoc | None:
- """
- Find a non-used token by its SHA-256 hash and type.
-
- Only returns tokens where ``used_at`` is None (not yet consumed).
- """
+ """Find a non-used token by its SHA-256 hash and type."""
try:
doc = await self._col.find_one(
{
@@ -56,7 +52,35 @@ async def find_by_hash(
return VerificationTokenDoc.from_mongo(doc)
except PyMongoError as exc:
log.error(
- "token_repo_find_by_hash_failed",
+ "token_repo_find_by_hash_and_type_failed",
+ token_type=token_type,
+ error=str(exc),
+ error_type=type(exc).__name__,
+ )
+ raise
+
+ async def consume_by_hash(
+ self, token_hash: str, token_type: str
+ ) -> VerificationTokenDoc | None:
+ """Atomically find an unused, non-expired token and mark it as used.
+
+ Returns the pre-update document, or None if no matching token exists.
+ """
+ now = datetime.now(timezone.utc)
+ try:
+ doc = await self._col.find_one_and_update(
+ {
+ "token_hash": token_hash,
+ "token_type": token_type,
+ "used_at": None,
+ "expires_at": {"$gt": now},
+ },
+ {"$set": {"used_at": now}},
+ )
+ return VerificationTokenDoc.from_mongo(doc)
+ except PyMongoError as exc:
+ log.error(
+ "token_repo_consume_by_hash_failed",
token_type=token_type,
error=str(exc),
error_type=type(exc).__name__,
@@ -84,6 +108,50 @@ async def mark_as_used(self, token_id: ObjectId) -> bool:
)
raise
+ async def find_latest_by_user(
+ self, user_id: ObjectId, token_type: str
+ ) -> VerificationTokenDoc | None:
+ """Find the most recent non-used token for a user and type."""
+ try:
+ doc = await self._col.find_one(
+ {
+ "user_id": user_id,
+ "token_type": token_type,
+ "used_at": None,
+ },
+ sort=[("created_at", -1)],
+ )
+ return VerificationTokenDoc.from_mongo(doc)
+ except PyMongoError as exc:
+ log.error(
+ "token_repo_find_latest_by_user_failed",
+ user_id=str(user_id),
+ token_type=token_type,
+ error=str(exc),
+ error_type=type(exc).__name__,
+ )
+ raise
+
+ async def increment_attempts(self, token_id: ObjectId) -> bool:
+ """Atomically increment the ``attempts`` counter on a token.
+
+ Returns True if a document was modified.
+ """
+ try:
+ result = await self._col.update_one(
+ {"_id": token_id},
+ {"$inc": {"attempts": 1}},
+ )
+ return result.modified_count > 0
+ except PyMongoError as exc:
+ log.error(
+ "token_repo_increment_attempts_failed",
+ token_id=str(token_id),
+ error=str(exc),
+ error_type=type(exc).__name__,
+ )
+ raise
+
async def delete_by_user(
self, user_id: ObjectId, token_type: str | None = None
) -> int:
diff --git a/routes/auth_routes.py b/routes/auth_routes.py
index 77370e3..14bcef6 100644
--- a/routes/auth_routes.py
+++ b/routes/auth_routes.py
@@ -25,12 +25,13 @@
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
-from dependencies import AuthUser, get_auth_service
+from dependencies import AuthUser, OptionalUser, get_auth_service
from errors import AuthenticationError
from middleware.openapi import AUTH_RESPONSES, ERROR_RESPONSES, PUBLIC_SECURITY
from middleware.rate_limiter import Limits, limiter
from routes.cookie_helpers import clear_auth_cookies, set_auth_cookies
from schemas.dto.requests.auth import (
+ DeviceTokenRequest,
LoginRequest,
RegisterRequest,
RequestPasswordResetRequest,
@@ -39,6 +40,7 @@
VerifyEmailRequest,
)
from schemas.dto.responses.auth import (
+ DeviceTokenResponse,
LoginResponse,
LogoutResponse,
MeResponse,
@@ -61,6 +63,13 @@
templates = Jinja2Templates(directory=_TEMPLATE_DIR)
+def _validate_redirect_uri(uri: str, allowed: list[str]) -> str | None:
+ """Return the URI if it's in the allowlist, None otherwise."""
+ if not uri or not allowed:
+ return None
+ return uri if uri in allowed else None
+
+
# ── Redirect shortcuts ────────────────────────────────────────────────────────
@@ -444,3 +453,99 @@ async def reset_password(
body.email.strip().lower(), body.code.strip(), body.password
)
return MessageResponse(success=True, message="password reset successfully")
+
+
+# ── Device auth flow (extensions, apps, CLIs) ────────────────────────────────
+
+
+@router.get("/auth/device/login", include_in_schema=False)
+@limiter.limit(Limits.DEVICE_AUTH)
+async def device_login(
+ request: Request,
+ user: OptionalUser,
+ auth_service: AuthService = Depends(get_auth_service),
+ redirect_uri: str = "",
+ state: str = "",
+) -> RedirectResponse:
+ """Initiate the device auth flow.
+
+ If the user already has a valid session, generates an auth code and
+ redirects to the callback page (or a registered redirect_uri).
+ Otherwise, redirects to the login page.
+ Used by browser extensions, mobile apps, and other third-party clients.
+
+ The ``state`` parameter is passed through for CSRF protection — the
+ client generates it, the server carries it, the client verifies it.
+ """
+ if user:
+ profile = await auth_service.get_user_profile(str(user.user_id))
+ code = await auth_service.create_device_auth_code(profile.id, profile.email)
+ allowed = request.app.state.settings.device_auth_redirect_uris
+ validated_uri = _validate_redirect_uri(redirect_uri, allowed)
+ if validated_uri:
+ separator = "&" if "?" in validated_uri else "?"
+ return RedirectResponse(
+ f"{validated_uri}{separator}code={code}&state={state}",
+ status_code=302,
+ )
+ return RedirectResponse(
+ f"/auth/device/callback?code={code}&state={state}", status_code=302
+ )
+
+ # Preserve params through the login flow
+ params = "state=" + state if state else ""
+ if redirect_uri:
+ params += ("&" if params else "") + f"redirect_uri={redirect_uri}"
+ next_url = "/auth/device/login" + (f"?{params}" if params else "")
+ return RedirectResponse(f"/?next={next_url}", status_code=302)
+
+
+@router.get("/auth/device/callback", include_in_schema=False)
+@limiter.limit(Limits.DEVICE_AUTH)
+async def device_callback(
+ request: Request,
+ code: str = "",
+ state: str = "",
+) -> Response:
+ """Render the device auth callback page.
+
+ The client reads the auth code and state from data attributes on the page.
+ For browser extensions, the content script handles this automatically.
+ """
+ if not code:
+ return RedirectResponse("/", status_code=302)
+ return templates.TemplateResponse(
+ request, "device_callback.html", {"code": code, "state": state}
+ )
+
+
+@router.post(
+ "/auth/device/token",
+ responses=ERROR_RESPONSES,
+ openapi_extra=PUBLIC_SECURITY,
+ operation_id="exchangeDeviceCode",
+ summary="Exchange Device Auth Code",
+)
+@limiter.limit(Limits.DEVICE_TOKEN)
+async def device_token(
+ request: Request,
+ body: DeviceTokenRequest,
+ auth_service: AuthService = Depends(get_auth_service),
+) -> DeviceTokenResponse:
+ """Exchange a one-time device auth code for JWT tokens.
+
+ The code is obtained from the callback page after the user authenticates
+ on spoo.me. Returns access and refresh tokens for the client.
+
+ **Authentication**: Not required (public endpoint)
+
+ **Rate Limits**: 10/min
+ """
+ user, access_token, refresh_token = await auth_service.exchange_device_code(
+ body.code.strip()
+ )
+ return DeviceTokenResponse(
+ access_token=access_token,
+ refresh_token=refresh_token,
+ user=UserProfileResponse.from_user(user),
+ )
diff --git a/routes/legacy/url_shortener.py b/routes/legacy/url_shortener.py
index 66cbf4e..93054af 100644
--- a/routes/legacy/url_shortener.py
+++ b/routes/legacy/url_shortener.py
@@ -37,6 +37,7 @@
validate_alias,
validate_blocked_url,
validate_emoji_alias,
+ validate_safe_redirect,
validate_url,
validate_url_password,
)
@@ -69,7 +70,8 @@
async def index(request: Request, user: OptionalUser) -> Response:
"""Render the index page. Redirect to dashboard if already logged in."""
if user is not None:
- return RedirectResponse("/dashboard", status_code=302)
+ next_url = validate_safe_redirect(request.query_params.get("next", ""))
+ return RedirectResponse(next_url, status_code=302)
return templates.TemplateResponse(
request, "index.html", {"host_url": str(request.base_url)}
)
diff --git a/routes/oauth_routes.py b/routes/oauth_routes.py
index 1f1ac05..eee7671 100644
--- a/routes/oauth_routes.py
+++ b/routes/oauth_routes.py
@@ -34,6 +34,7 @@
from services.oauth_service import OAuthService
from shared.ip_utils import get_client_ip
from shared.logging import get_logger
+from shared.validators import validate_safe_redirect
log = get_logger(__name__)
@@ -132,7 +133,8 @@ async def oauth_login(
raise NotFoundError(f"'{provider}' OAuth not configured")
log.info("oauth_flow_initiated", provider=provider)
- state = generate_oauth_state(provider, OAuthAction.LOGIN)
+ next_url = request.query_params.get("next")
+ state = generate_oauth_state(provider, OAuthAction.LOGIN, next_url=next_url)
redirect_uri = get_oauth_redirect_url(provider, request.app.state.settings.oauth)
return await client.authorize_redirect(request, redirect_uri, state=state)
@@ -207,9 +209,10 @@ async def oauth_callback(
provider, provider_info, action, state_data, client_ip
)
- # ── Redirect to dashboard with cookies ───────────────────────────────────
+ # ── Redirect with cookies ──────────────────────────────────────────────
+ next_url = validate_safe_redirect(state_data.get("next", ""))
jwt_cfg = request.app.state.settings.jwt
- resp = RedirectResponse(_DASHBOARD_URL, status_code=302)
+ resp = RedirectResponse(next_url, status_code=302)
set_auth_cookies(resp, access_token, refresh_token, jwt_cfg)
return resp
diff --git a/schemas/dto/requests/auth.py b/schemas/dto/requests/auth.py
index 159e45f..cfe27f8 100644
--- a/schemas/dto/requests/auth.py
+++ b/schemas/dto/requests/auth.py
@@ -8,6 +8,7 @@
SendVerificationRequest — POST /auth/send-verification (no body)
RequestPasswordResetRequest — POST /auth/request-password-reset
ResetPasswordRequest — POST /auth/reset-password
+DeviceTokenRequest — POST /auth/device/token
"""
from __future__ import annotations
@@ -113,3 +114,15 @@ class ResetPasswordRequest(BaseModel):
description="New password (min 8 chars, must contain letter + number + special char)",
examples=["NewSecurePass456!"],
)
+
+
+class DeviceTokenRequest(BaseModel):
+ """Request body for POST /auth/device/token."""
+
+ model_config = ConfigDict(populate_by_name=True)
+
+ code: str = Field(
+ min_length=1,
+ max_length=128,
+ description="One-time auth code from the device callback page",
+ )
diff --git a/schemas/dto/responses/auth.py b/schemas/dto/responses/auth.py
index 2ea7da3..2aea472 100644
--- a/schemas/dto/responses/auth.py
+++ b/schemas/dto/responses/auth.py
@@ -191,6 +191,16 @@ class SendVerificationResponse(BaseModel):
)
+class DeviceTokenResponse(BaseModel):
+ """Response body for POST /auth/device/token (200)."""
+
+ model_config = ConfigDict(populate_by_name=True)
+
+ access_token: str = Field(description="JWT access token")
+ refresh_token: str = Field(description="JWT refresh token")
+ user: UserProfileResponse = Field(description="User profile")
+
+
class OAuthProviderDetail(BaseModel):
"""Detailed OAuth provider entry for the providers list endpoint."""
diff --git a/schemas/models/token.py b/schemas/models/token.py
index 212dcf3..dc9d322 100644
--- a/schemas/models/token.py
+++ b/schemas/models/token.py
@@ -24,11 +24,13 @@ class TokenType(str, Enum):
EMAIL_VERIFY = "email_verify"
PASSWORD_RESET = "password_reset"
+ DEVICE_AUTH = "extension_auth"
# Backward-compat aliases for existing imports
TOKEN_TYPE_EMAIL_VERIFY = TokenType.EMAIL_VERIFY
TOKEN_TYPE_PASSWORD_RESET = TokenType.PASSWORD_RESET
+TOKEN_TYPE_DEVICE_AUTH = TokenType.DEVICE_AUTH
class VerificationTokenDoc(MongoBaseModel):
diff --git a/services/auth_service.py b/services/auth_service.py
index cd85a91..f07a398 100644
--- a/services/auth_service.py
+++ b/services/auth_service.py
@@ -27,11 +27,15 @@
from infrastructure.email.protocol import EmailProvider
from repositories.token_repository import TokenRepository
from repositories.user_repository import UserRepository
-from schemas.models.token import TOKEN_TYPE_EMAIL_VERIFY, TOKEN_TYPE_PASSWORD_RESET
+from schemas.models.token import (
+ TOKEN_TYPE_DEVICE_AUTH,
+ TOKEN_TYPE_EMAIL_VERIFY,
+ TOKEN_TYPE_PASSWORD_RESET,
+)
from schemas.models.user import UserDoc, UserPlan, UserStatus
from services.token_factory import TokenFactory
from shared.crypto import hash_password, hash_token, verify_password
-from shared.generators import generate_otp_code
+from shared.generators import generate_otp_code, generate_secure_token
from shared.logging import get_logger
from shared.validators import validate_account_password
@@ -42,6 +46,7 @@
OTP_EXPIRY_SECONDS = 600 # 10 minutes
MAX_TOKENS_PER_HOUR = 3
MAX_VERIFICATION_ATTEMPTS = 5
+DEVICE_AUTH_EXPIRY_SECONDS = 300 # 5 minutes
class AuthService:
@@ -152,16 +157,11 @@ async def _verify_otp(
) -> None:
"""Verify an OTP code and mark it as used.
- NOTE: Attempts are checked but NOT incremented (preserves existing
- behavior from utils/verification_utils.py — the increment was never
- implemented in the original code).
-
Raises:
ValidationError: On any verification failure.
AppError: If marking the token as used fails.
"""
- otp_hash = hash_token(otp_code)
- token_doc = await self._token_repo.find_by_hash(otp_hash, token_type)
+ token_doc = await self._token_repo.find_latest_by_user(user_id, token_type)
if not token_doc:
log.warning(
@@ -172,15 +172,6 @@ async def _verify_otp(
)
raise ValidationError("Invalid or expired verification code")
- if str(token_doc.user_id) != str(user_id):
- log.warning(
- "otp_verification_failed",
- user_id=str(user_id),
- reason="user_mismatch",
- token_type=token_type,
- )
- raise ValidationError("Invalid verification code")
-
expires_at = token_doc.expires_at
if not expires_at.tzinfo:
expires_at = expires_at.replace(tzinfo=timezone.utc)
@@ -194,26 +185,29 @@ async def _verify_otp(
)
raise ValidationError("Verification code has expired")
- if token_doc.used_at is not None:
+ if token_doc.attempts >= MAX_VERIFICATION_ATTEMPTS:
log.warning(
"otp_verification_failed",
user_id=str(user_id),
- reason="already_used",
+ reason="max_attempts",
token_type=token_type,
)
- raise ValidationError("Verification code has already been used")
+ raise ValidationError(
+ "Too many failed attempts. Please request a new code."
+ )
- # Check max attempts (but do NOT increment — preserves original behavior)
- if token_doc.attempts >= MAX_VERIFICATION_ATTEMPTS:
+ # Compare hash — increment attempts on mismatch
+ otp_hash = hash_token(otp_code)
+ if token_doc.token_hash != otp_hash:
+ await self._token_repo.increment_attempts(token_doc.id)
log.warning(
"otp_verification_failed",
user_id=str(user_id),
- reason="max_attempts",
+ reason="wrong_code",
token_type=token_type,
+ attempts=token_doc.attempts + 1,
)
- raise ValidationError(
- "Too many failed attempts. Please request a new code."
- )
+ raise ValidationError("Invalid or expired verification code")
marked = await self._token_repo.mark_as_used(token_doc.id)
if not marked:
@@ -509,8 +503,6 @@ async def reset_password(
AppError: DB update failure.
"""
user = await self._user_repo.find_by_email(email)
- if not user:
- raise ValidationError("invalid email or code")
is_valid, missing, _ = validate_account_password(new_password)
if not is_valid:
@@ -519,7 +511,17 @@ async def reset_password(
details={"missing_requirements": missing},
)
- await self._verify_otp(user.id, otp_code, TOKEN_TYPE_PASSWORD_RESET)
+ if not user:
+ # Simulate OTP verification timing to prevent user enumeration
+ _dummy_hash = hash_token(otp_code)
+ raise ValidationError("invalid email or code")
+
+ try:
+ await self._verify_otp(user.id, otp_code, TOKEN_TYPE_PASSWORD_RESET)
+ except ValidationError:
+ # Re-raise with generic message to prevent user enumeration
+ # via distinct error messages (_verify_otp already logs the details)
+ raise ValidationError("invalid email or code") from None
new_hash = hash_password(new_password)
updated = await self._user_repo.update(
@@ -588,3 +590,53 @@ async def get_user_profile(self, user_id: str) -> UserDoc:
if not user:
raise NotFoundError("user not found")
return user
+
+ # ── Device auth flow ─────────────────────────────────────────────────────
+
+ async def create_device_auth_code(self, user_id: ObjectId, email: str) -> str:
+ """Generate a one-time auth code for the device auth flow.
+
+ Returns the raw token (caller redirects to callback with it).
+ """
+ await self._token_repo.delete_by_user(user_id, TOKEN_TYPE_DEVICE_AUTH)
+
+ raw_token = generate_secure_token(48)
+ now = datetime.now(timezone.utc)
+ await self._token_repo.create(
+ {
+ "user_id": user_id,
+ "email": email,
+ "token_hash": hash_token(raw_token),
+ "token_type": TOKEN_TYPE_DEVICE_AUTH,
+ "expires_at": now + timedelta(seconds=DEVICE_AUTH_EXPIRY_SECONDS),
+ "created_at": now,
+ "used_at": None,
+ "attempts": 0,
+ }
+ )
+ log.info("device_auth_code_created", user_id=str(user_id))
+ return raw_token
+
+ async def exchange_device_code(self, code: str) -> tuple[UserDoc, str, str]:
+ """Exchange a one-time device auth code for JWT tokens.
+
+ Returns:
+ (user_doc, access_token, refresh_token)
+
+ Raises:
+ AuthenticationError: Code invalid, expired, or already used.
+ """
+ token_hash = hash_token(code)
+ token_doc = await self._token_repo.consume_by_hash(
+ token_hash, TOKEN_TYPE_DEVICE_AUTH
+ )
+ if not token_doc:
+ raise AuthenticationError("invalid or expired device auth code")
+
+ user = await self._user_repo.find_by_id(token_doc.user_id)
+ if not user or user.status != UserStatus.ACTIVE:
+ raise AuthenticationError("user not found or inactive")
+
+ log.info("device_auth_success", user_id=str(user.id))
+ access_token, refresh_token = self._tokens.issue_tokens(user, "ext")
+ return user, access_token, refresh_token
diff --git a/shared/validators.py b/shared/validators.py
index 15c818a..2cc46ac 100644
--- a/shared/validators.py
+++ b/shared/validators.py
@@ -10,7 +10,7 @@
import re
from collections.abc import Sequence
-from urllib.parse import unquote
+from urllib.parse import unquote, urlparse
import emoji
import regex
@@ -176,3 +176,17 @@ def validate_blocked_url(url: str, patterns: Sequence[str]) -> bool:
except TimeoutError:
pass # Treat timed-out patterns as non-matching (fail open)
return True
+
+
+def validate_safe_redirect(url: str, fallback: str = "/dashboard") -> str:
+ """Return *url* if it's a safe relative path, otherwise *fallback*.
+
+ Only allows paths starting with ``/`` that don't redirect to an external
+ host. Blocks ``//evil.com``, ``/\\evil.com``, and scheme-prefixed URLs.
+ """
+ if not url or not url.startswith("/"):
+ return fallback
+ parsed = urlparse(url)
+ if parsed.scheme or parsed.netloc:
+ return fallback
+ return url
diff --git a/static/js/auth.js b/static/js/auth.js
index 49e0d23..e6f8d7d 100644
--- a/static/js/auth.js
+++ b/static/js/auth.js
@@ -282,7 +282,6 @@ async function submitAuth() {
const data = await res.json().catch(() => ({}));
if (!res.ok) {
- // Handle password validation errors from backend
if (data.missing_requirements && data.missing_requirements.length > 0) {
showPasswordRequirements(data.missing_requirements);
} else {
@@ -292,7 +291,10 @@ async function submitAuth() {
}
closeAuthModal();
- window.location.href = '/dashboard';
+ const nextUrl = new URLSearchParams(window.location.search).get('next');
+ // Only allow relative paths to prevent open redirect
+ const safeUrl = (nextUrl && nextUrl.startsWith('/') && !nextUrl.startsWith('//')) ? nextUrl : '/dashboard';
+ window.location.href = safeUrl;
} catch (e) {
showAuthError('Something went wrong');
}
diff --git a/templates/base.html b/templates/base.html
index 94fc539..43e28e0 100644
--- a/templates/base.html
+++ b/templates/base.html
@@ -46,7 +46,7 @@
-
+
+