diff --git a/.env.example b/.env.example index c1bc247..e92488c 100644 --- a/.env.example +++ b/.env.example @@ -16,6 +16,13 @@ SECRET_KEY="" # To generate: python -c "import os; print(os.urandom(32).hex())" HOST_URI="127.0.0.1:8000" ENV="development" # change to "production" in production +# CORS — allowed origins for private routes (auth, oauth, dashboard) +# Public API routes (/api/v1/*) always allow all origins. +# Set to your frontend domain(s) in production. Leave empty to block all cross-origin +# requests to private routes (safe default). +# Example: CORS_PRIVATE_ORIGINS='["https://spoo.me","chrome-extension://your-ext-id"]' +CORS_PRIVATE_ORIGINS='[]' + # Logging Configuration LOG_LEVEL=DEBUG # DEBUG, INFO, WARNING, ERROR, CRITICAL LOG_FORMAT=console # json (prod) or console (dev) diff --git a/app.py b/app.py index 1f43631..29aac91 100644 --- a/app.py +++ b/app.py @@ -33,7 +33,11 @@ configure_openapi, ) from middleware.rate_limiter import limiter -from middleware.security import MaxContentLengthMiddleware, configure_cors +from middleware.security import ( + MaxContentLengthMiddleware, + SecurityHeadersMiddleware, + configure_cors, +) from repositories.indexes import ensure_indexes from routes.api_v1 import router as api_v1_router from routes.auth_routes import router as auth_router @@ -120,6 +124,45 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: await ensure_indexes(app.state.db) + # Warn if session secret is missing when auth is enabled + if settings.jwt and not settings.secret_key: + log.warning( + "secret_key_empty", + detail="SECRET_KEY is empty — session cookies are unsigned. " + "Set a strong SECRET_KEY when auth/OAuth is enabled.", + ) + + # Warn if CORS private origins not configured in production + if settings.is_production and not settings.cors_private_origins: + log.warning( + "cors_private_origins_empty", + detail="CORS_PRIVATE_ORIGINS is empty — auth/oauth/dashboard routes " + "will reject all cross-origin requests. Set to your frontend domain(s).", + ) + + # Warn if JWT config is weak (auth is optional, so don't crash) + if settings.jwt and bool(settings.jwt.jwt_private_key) != bool( + settings.jwt.jwt_public_key + ): + log.warning( + "jwt_rsa_half_configured", + detail="Only one of JWT_PRIVATE_KEY / JWT_PUBLIC_KEY is set — " + "both are required for RS256. Falling back to HS256.", + ) + + if settings.jwt and not settings.jwt.use_rs256: + if not settings.jwt.jwt_secret: + log.warning( + "jwt_config_insecure", + detail="RS256 keys not set and JWT_SECRET is empty — tokens can be forged. " + "Set JWT_PRIVATE_KEY + JWT_PUBLIC_KEY or a strong JWT_SECRET.", + ) + elif len(settings.jwt.jwt_secret) < 32: + log.warning( + "jwt_secret_weak", + detail="JWT_SECRET is shorter than 32 characters — consider using RS256 keys or a longer secret.", + ) + yield # ── Shutdown ───────────────────────────────────────────────────────── @@ -164,13 +207,16 @@ async def docs(request: Request): # ── Middleware (registered in reverse execution order) ──────────────── # 1. Session — outermost, needed by Authlib OAuth for state storage app.add_middleware(SessionMiddleware, secret_key=settings.secret_key) - # 2. CORS — must wrap everything + # 2. Security headers — must be outer so HSTS/CSP/nosniff apply to + # all responses including CORS preflights (204) and body-limit (413) + app.add_middleware(SecurityHeadersMiddleware, hsts_enabled=settings.is_production) + # 3. CORS configure_cors(app, settings) - # 3. Body size limit + # 4. Body size limit app.add_middleware( MaxContentLengthMiddleware, max_content_length=settings.max_content_length ) - # 4. Request logging — innermost, logs all requests with request_id + # 5. Request logging — innermost, logs all requests with request_id app.add_middleware(RequestLoggingMiddleware) # ── Error handlers + rate limiter ──────────────────────────────────── diff --git a/config.py b/config.py index e5e896b..1ccd589 100644 --- a/config.py +++ b/config.py @@ -108,8 +108,14 @@ class AppSettings(BaseSettings): app_url: str = "https://spoo.me" app_name: str = "spoo.me" - # CORS — default matches current behaviour: all origins, credentials allowed - cors_origins: list[str] = ["*"] + # CORS — public API routes allow all origins (no credentials). + # Private routes (auth, oauth, dashboard) require explicit origin allowlist. + cors_origins: list[str] = ["*"] # deprecated — kept for backward compat + cors_private_origins: list[str] = [] + + # Device auth flow - allowed redirect URIs for third-party clients + # (e.g., mobile apps, desktop apps). Browser extensions don't need this. + device_auth_redirect_uris: list[str] = [] # Request body size limit (bytes); 1 MB default max_content_length: int = 1_048_576 diff --git a/infrastructure/oauth_clients.py b/infrastructure/oauth_clients.py index b159ebf..ba4136e 100644 --- a/infrastructure/oauth_clients.py +++ b/infrastructure/oauth_clients.py @@ -149,7 +149,10 @@ def init_oauth(settings: Any) -> tuple[OAuth | None, dict[str, Any]]: def generate_oauth_state( - provider: str, action: OAuthAction = OAuthAction.LOGIN, user_id: str | None = None + provider: str, + action: OAuthAction = OAuthAction.LOGIN, + user_id: str | None = None, + next_url: str | None = None, ) -> str: """Generate a URL-safe state string for CSRF protection.""" action_str = action.value if isinstance(action, OAuthAction) else action @@ -161,6 +164,8 @@ def generate_oauth_state( ] if user_id: parts.append(f"user_id={user_id}") + if next_url: + parts.append(f"next={next_url}") return "&".join(parts) diff --git a/middleware/rate_limiter.py b/middleware/rate_limiter.py index 3757379..95ddf1a 100644 --- a/middleware/rate_limiter.py +++ b/middleware/rate_limiter.py @@ -46,6 +46,10 @@ class Limits: PASSWORD_RESET_REQUEST = "3 per hour" PASSWORD_RESET_CONFIRM = "5 per hour" + # Device auth flow (extensions, apps, CLIs) + DEVICE_AUTH = "10 per minute" + DEVICE_TOKEN = "10 per minute" + # OAuth OAUTH_INIT = "10 per minute" OAUTH_CALLBACK = "20 per minute" diff --git a/middleware/security.py b/middleware/security.py index cc1ab83..4868589 100644 --- a/middleware/security.py +++ b/middleware/security.py @@ -1,30 +1,131 @@ """ -Security middleware — CORS configuration and request body size limit. +Security middleware — CORS configuration, security headers, and request body size limit. """ from __future__ import annotations from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request -from starlette.responses import Response +from starlette.responses import PlainTextResponse, Response from config import AppSettings +# ── Path-based CORS classification ────────────────────────────────────────── + +_PRIVATE_PREFIXES = ("/auth", "/oauth", "/dashboard") +_PUBLIC_PREFIXES = ("/api/v1", "/auth/device", "/stats", "/export", "/metric") + +_ALLOWED_METHODS = "GET, POST, PUT, PATCH, DELETE, OPTIONS" +_ALLOWED_HEADERS = "Authorization, Content-Type, Accept, X-Request-ID" + + +def _classify_path(path: str) -> str: + """Return 'public', 'private', or 'none' based on the request path.""" + for prefix in _PUBLIC_PREFIXES: + if path == prefix or path.startswith(prefix + "/"): + return "public" + for prefix in _PRIVATE_PREFIXES: + if path == prefix or path.startswith(prefix + "/"): + return "private" + # Legacy root shortener endpoint (POST /) is public API + if path == "/": + return "public" + return "none" + + +class SplitCORSMiddleware(BaseHTTPMiddleware): + """Apply different CORS policies based on the request path. + + - Public routes (``/api/v1/*``, legacy API): ``Access-Control-Allow-Origin: *``, + no credentials. + - Private routes (``/auth/*``, ``/oauth/*``, ``/dashboard/*``): origin checked + against *private_origins* allowlist, credentials allowed. + - All other routes: no CORS headers. + """ + + def __init__(self, app, *, private_origins: list[str]) -> None: + super().__init__(app) + self._private_origins: set[str] = {o.rstrip("/") for o in private_origins} + + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + origin = request.headers.get("origin") + path = request.url.path + classification = _classify_path(path) + + # Handle preflight — only intercept routes that need CORS + if request.method == "OPTIONS" and origin and classification != "none": + return self._preflight(origin, classification) + + response = await call_next(request) + + if origin and classification != "none": + self._set_cors_headers(response, origin, classification) + + return response + + def _preflight(self, origin: str, classification: str) -> Response: + """Return a 204 preflight response with appropriate CORS headers.""" + response = PlainTextResponse("", status_code=204) + if classification != "none": + self._set_cors_headers(response, origin, classification) + response.headers["Access-Control-Allow-Methods"] = _ALLOWED_METHODS + response.headers["Access-Control-Allow-Headers"] = _ALLOWED_HEADERS + response.headers["Access-Control-Max-Age"] = "86400" + return response + + def _set_cors_headers( + self, response: Response, origin: str, classification: str + ) -> None: + """Set CORS headers based on route classification.""" + if classification == "public": + response.headers["Access-Control-Allow-Origin"] = "*" + elif classification == "private" and origin in self._private_origins: + response.headers["Access-Control-Allow-Origin"] = origin + response.headers["Access-Control-Allow-Credentials"] = "true" + response.headers["Vary"] = "Origin" + def configure_cors(app: FastAPI, settings: AppSettings) -> None: - """Add CORS middleware from settings. Defaults to allow-all with credentials.""" + """Add split CORS middleware — different policies for public vs private routes.""" app.add_middleware( - CORSMiddleware, - allow_origins=settings.cors_origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + SplitCORSMiddleware, + private_origins=settings.cors_private_origins, ) +class SecurityHeadersMiddleware(BaseHTTPMiddleware): + """Set standard security headers on all responses.""" + + def __init__(self, app, *, hsts_enabled: bool = True) -> None: + super().__init__(app) + self.hsts_enabled = hsts_enabled + + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + response = await call_next(request) + if self.hsts_enabled: + response.headers["Strict-Transport-Security"] = ( + "max-age=31536000; includeSubDomains" + ) + response.headers["X-Content-Type-Options"] = "nosniff" + response.headers["X-Frame-Options"] = "DENY" + response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" + response.headers["Permissions-Policy"] = ( + "camera=(), microphone=(), geolocation=()" + ) + response.headers["Content-Security-Policy-Report-Only"] = ( + "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; " + "img-src 'self' data: https:; font-src 'self'; connect-src 'self'; " + "frame-ancestors 'none'" + ) + return response + + class MaxContentLengthMiddleware(BaseHTTPMiddleware): """Reject requests whose Content-Length exceeds the configured limit.""" diff --git a/repositories/indexes.py b/repositories/indexes.py index dac0f6a..9fd39dd 100644 --- a/repositories/indexes.py +++ b/repositories/indexes.py @@ -82,5 +82,9 @@ async def ensure_indexes(db: AsyncDatabase) -> None: await tokens_col.create_index([("token_hash", 1)]) await tokens_col.create_index([("token_type", 1)]) await tokens_col.create_index([("expires_at", 1)], expireAfterSeconds=0) + await tokens_col.create_index( + [("user_id", 1), ("token_type", 1), ("used_at", 1), ("created_at", -1)], + name="ix_latest_unused_by_user", + ) log.info("mongodb_indexes_ensured") diff --git a/repositories/token_repository.py b/repositories/token_repository.py index f7c9c96..a8f6145 100644 --- a/repositories/token_repository.py +++ b/repositories/token_repository.py @@ -37,14 +37,10 @@ async def create(self, token_data: dict) -> ObjectId: ) raise - async def find_by_hash( + async def find_by_hash_and_type( self, token_hash: str, token_type: str ) -> VerificationTokenDoc | None: - """ - Find a non-used token by its SHA-256 hash and type. - - Only returns tokens where ``used_at`` is None (not yet consumed). - """ + """Find a non-used token by its SHA-256 hash and type.""" try: doc = await self._col.find_one( { @@ -56,7 +52,35 @@ async def find_by_hash( return VerificationTokenDoc.from_mongo(doc) except PyMongoError as exc: log.error( - "token_repo_find_by_hash_failed", + "token_repo_find_by_hash_and_type_failed", + token_type=token_type, + error=str(exc), + error_type=type(exc).__name__, + ) + raise + + async def consume_by_hash( + self, token_hash: str, token_type: str + ) -> VerificationTokenDoc | None: + """Atomically find an unused, non-expired token and mark it as used. + + Returns the pre-update document, or None if no matching token exists. + """ + now = datetime.now(timezone.utc) + try: + doc = await self._col.find_one_and_update( + { + "token_hash": token_hash, + "token_type": token_type, + "used_at": None, + "expires_at": {"$gt": now}, + }, + {"$set": {"used_at": now}}, + ) + return VerificationTokenDoc.from_mongo(doc) + except PyMongoError as exc: + log.error( + "token_repo_consume_by_hash_failed", token_type=token_type, error=str(exc), error_type=type(exc).__name__, @@ -84,6 +108,50 @@ async def mark_as_used(self, token_id: ObjectId) -> bool: ) raise + async def find_latest_by_user( + self, user_id: ObjectId, token_type: str + ) -> VerificationTokenDoc | None: + """Find the most recent non-used token for a user and type.""" + try: + doc = await self._col.find_one( + { + "user_id": user_id, + "token_type": token_type, + "used_at": None, + }, + sort=[("created_at", -1)], + ) + return VerificationTokenDoc.from_mongo(doc) + except PyMongoError as exc: + log.error( + "token_repo_find_latest_by_user_failed", + user_id=str(user_id), + token_type=token_type, + error=str(exc), + error_type=type(exc).__name__, + ) + raise + + async def increment_attempts(self, token_id: ObjectId) -> bool: + """Atomically increment the ``attempts`` counter on a token. + + Returns True if a document was modified. + """ + try: + result = await self._col.update_one( + {"_id": token_id}, + {"$inc": {"attempts": 1}}, + ) + return result.modified_count > 0 + except PyMongoError as exc: + log.error( + "token_repo_increment_attempts_failed", + token_id=str(token_id), + error=str(exc), + error_type=type(exc).__name__, + ) + raise + async def delete_by_user( self, user_id: ObjectId, token_type: str | None = None ) -> int: diff --git a/routes/auth_routes.py b/routes/auth_routes.py index 77370e3..14bcef6 100644 --- a/routes/auth_routes.py +++ b/routes/auth_routes.py @@ -25,12 +25,13 @@ from fastapi.responses import JSONResponse, RedirectResponse from fastapi.templating import Jinja2Templates -from dependencies import AuthUser, get_auth_service +from dependencies import AuthUser, OptionalUser, get_auth_service from errors import AuthenticationError from middleware.openapi import AUTH_RESPONSES, ERROR_RESPONSES, PUBLIC_SECURITY from middleware.rate_limiter import Limits, limiter from routes.cookie_helpers import clear_auth_cookies, set_auth_cookies from schemas.dto.requests.auth import ( + DeviceTokenRequest, LoginRequest, RegisterRequest, RequestPasswordResetRequest, @@ -39,6 +40,7 @@ VerifyEmailRequest, ) from schemas.dto.responses.auth import ( + DeviceTokenResponse, LoginResponse, LogoutResponse, MeResponse, @@ -61,6 +63,13 @@ templates = Jinja2Templates(directory=_TEMPLATE_DIR) +def _validate_redirect_uri(uri: str, allowed: list[str]) -> str | None: + """Return the URI if it's in the allowlist, None otherwise.""" + if not uri or not allowed: + return None + return uri if uri in allowed else None + + # ── Redirect shortcuts ──────────────────────────────────────────────────────── @@ -444,3 +453,99 @@ async def reset_password( body.email.strip().lower(), body.code.strip(), body.password ) return MessageResponse(success=True, message="password reset successfully") + + +# ── Device auth flow (extensions, apps, CLIs) ──────────────────────────────── + + +@router.get("/auth/device/login", include_in_schema=False) +@limiter.limit(Limits.DEVICE_AUTH) +async def device_login( + request: Request, + user: OptionalUser, + auth_service: AuthService = Depends(get_auth_service), + redirect_uri: str = "", + state: str = "", +) -> RedirectResponse: + """Initiate the device auth flow. + + If the user already has a valid session, generates an auth code and + redirects to the callback page (or a registered redirect_uri). + Otherwise, redirects to the login page. + Used by browser extensions, mobile apps, and other third-party clients. + + The ``state`` parameter is passed through for CSRF protection — the + client generates it, the server carries it, the client verifies it. + """ + if user: + profile = await auth_service.get_user_profile(str(user.user_id)) + code = await auth_service.create_device_auth_code(profile.id, profile.email) + allowed = request.app.state.settings.device_auth_redirect_uris + validated_uri = _validate_redirect_uri(redirect_uri, allowed) + if validated_uri: + separator = "&" if "?" in validated_uri else "?" + return RedirectResponse( + f"{validated_uri}{separator}code={code}&state={state}", + status_code=302, + ) + return RedirectResponse( + f"/auth/device/callback?code={code}&state={state}", status_code=302 + ) + + # Preserve params through the login flow + params = "state=" + state if state else "" + if redirect_uri: + params += ("&" if params else "") + f"redirect_uri={redirect_uri}" + next_url = "/auth/device/login" + (f"?{params}" if params else "") + return RedirectResponse(f"/?next={next_url}", status_code=302) + + +@router.get("/auth/device/callback", include_in_schema=False) +@limiter.limit(Limits.DEVICE_AUTH) +async def device_callback( + request: Request, + code: str = "", + state: str = "", +) -> Response: + """Render the device auth callback page. + + The client reads the auth code and state from data attributes on the page. + For browser extensions, the content script handles this automatically. + """ + if not code: + return RedirectResponse("/", status_code=302) + return templates.TemplateResponse( + request, "device_callback.html", {"code": code, "state": state} + ) + + +@router.post( + "/auth/device/token", + responses=ERROR_RESPONSES, + openapi_extra=PUBLIC_SECURITY, + operation_id="exchangeDeviceCode", + summary="Exchange Device Auth Code", +) +@limiter.limit(Limits.DEVICE_TOKEN) +async def device_token( + request: Request, + body: DeviceTokenRequest, + auth_service: AuthService = Depends(get_auth_service), +) -> DeviceTokenResponse: + """Exchange a one-time device auth code for JWT tokens. + + The code is obtained from the callback page after the user authenticates + on spoo.me. Returns access and refresh tokens for the client. + + **Authentication**: Not required (public endpoint) + + **Rate Limits**: 10/min + """ + user, access_token, refresh_token = await auth_service.exchange_device_code( + body.code.strip() + ) + return DeviceTokenResponse( + access_token=access_token, + refresh_token=refresh_token, + user=UserProfileResponse.from_user(user), + ) diff --git a/routes/legacy/url_shortener.py b/routes/legacy/url_shortener.py index 66cbf4e..93054af 100644 --- a/routes/legacy/url_shortener.py +++ b/routes/legacy/url_shortener.py @@ -37,6 +37,7 @@ validate_alias, validate_blocked_url, validate_emoji_alias, + validate_safe_redirect, validate_url, validate_url_password, ) @@ -69,7 +70,8 @@ async def index(request: Request, user: OptionalUser) -> Response: """Render the index page. Redirect to dashboard if already logged in.""" if user is not None: - return RedirectResponse("/dashboard", status_code=302) + next_url = validate_safe_redirect(request.query_params.get("next", "")) + return RedirectResponse(next_url, status_code=302) return templates.TemplateResponse( request, "index.html", {"host_url": str(request.base_url)} ) diff --git a/routes/oauth_routes.py b/routes/oauth_routes.py index 1f1ac05..eee7671 100644 --- a/routes/oauth_routes.py +++ b/routes/oauth_routes.py @@ -34,6 +34,7 @@ from services.oauth_service import OAuthService from shared.ip_utils import get_client_ip from shared.logging import get_logger +from shared.validators import validate_safe_redirect log = get_logger(__name__) @@ -132,7 +133,8 @@ async def oauth_login( raise NotFoundError(f"'{provider}' OAuth not configured") log.info("oauth_flow_initiated", provider=provider) - state = generate_oauth_state(provider, OAuthAction.LOGIN) + next_url = request.query_params.get("next") + state = generate_oauth_state(provider, OAuthAction.LOGIN, next_url=next_url) redirect_uri = get_oauth_redirect_url(provider, request.app.state.settings.oauth) return await client.authorize_redirect(request, redirect_uri, state=state) @@ -207,9 +209,10 @@ async def oauth_callback( provider, provider_info, action, state_data, client_ip ) - # ── Redirect to dashboard with cookies ─────────────────────────────────── + # ── Redirect with cookies ────────────────────────────────────────────── + next_url = validate_safe_redirect(state_data.get("next", "")) jwt_cfg = request.app.state.settings.jwt - resp = RedirectResponse(_DASHBOARD_URL, status_code=302) + resp = RedirectResponse(next_url, status_code=302) set_auth_cookies(resp, access_token, refresh_token, jwt_cfg) return resp diff --git a/schemas/dto/requests/auth.py b/schemas/dto/requests/auth.py index 159e45f..cfe27f8 100644 --- a/schemas/dto/requests/auth.py +++ b/schemas/dto/requests/auth.py @@ -8,6 +8,7 @@ SendVerificationRequest — POST /auth/send-verification (no body) RequestPasswordResetRequest — POST /auth/request-password-reset ResetPasswordRequest — POST /auth/reset-password +DeviceTokenRequest — POST /auth/device/token """ from __future__ import annotations @@ -113,3 +114,15 @@ class ResetPasswordRequest(BaseModel): description="New password (min 8 chars, must contain letter + number + special char)", examples=["NewSecurePass456!"], ) + + +class DeviceTokenRequest(BaseModel): + """Request body for POST /auth/device/token.""" + + model_config = ConfigDict(populate_by_name=True) + + code: str = Field( + min_length=1, + max_length=128, + description="One-time auth code from the device callback page", + ) diff --git a/schemas/dto/responses/auth.py b/schemas/dto/responses/auth.py index 2ea7da3..2aea472 100644 --- a/schemas/dto/responses/auth.py +++ b/schemas/dto/responses/auth.py @@ -191,6 +191,16 @@ class SendVerificationResponse(BaseModel): ) +class DeviceTokenResponse(BaseModel): + """Response body for POST /auth/device/token (200).""" + + model_config = ConfigDict(populate_by_name=True) + + access_token: str = Field(description="JWT access token") + refresh_token: str = Field(description="JWT refresh token") + user: UserProfileResponse = Field(description="User profile") + + class OAuthProviderDetail(BaseModel): """Detailed OAuth provider entry for the providers list endpoint.""" diff --git a/schemas/models/token.py b/schemas/models/token.py index 212dcf3..dc9d322 100644 --- a/schemas/models/token.py +++ b/schemas/models/token.py @@ -24,11 +24,13 @@ class TokenType(str, Enum): EMAIL_VERIFY = "email_verify" PASSWORD_RESET = "password_reset" + DEVICE_AUTH = "extension_auth" # Backward-compat aliases for existing imports TOKEN_TYPE_EMAIL_VERIFY = TokenType.EMAIL_VERIFY TOKEN_TYPE_PASSWORD_RESET = TokenType.PASSWORD_RESET +TOKEN_TYPE_DEVICE_AUTH = TokenType.DEVICE_AUTH class VerificationTokenDoc(MongoBaseModel): diff --git a/services/auth_service.py b/services/auth_service.py index cd85a91..f07a398 100644 --- a/services/auth_service.py +++ b/services/auth_service.py @@ -27,11 +27,15 @@ from infrastructure.email.protocol import EmailProvider from repositories.token_repository import TokenRepository from repositories.user_repository import UserRepository -from schemas.models.token import TOKEN_TYPE_EMAIL_VERIFY, TOKEN_TYPE_PASSWORD_RESET +from schemas.models.token import ( + TOKEN_TYPE_DEVICE_AUTH, + TOKEN_TYPE_EMAIL_VERIFY, + TOKEN_TYPE_PASSWORD_RESET, +) from schemas.models.user import UserDoc, UserPlan, UserStatus from services.token_factory import TokenFactory from shared.crypto import hash_password, hash_token, verify_password -from shared.generators import generate_otp_code +from shared.generators import generate_otp_code, generate_secure_token from shared.logging import get_logger from shared.validators import validate_account_password @@ -42,6 +46,7 @@ OTP_EXPIRY_SECONDS = 600 # 10 minutes MAX_TOKENS_PER_HOUR = 3 MAX_VERIFICATION_ATTEMPTS = 5 +DEVICE_AUTH_EXPIRY_SECONDS = 300 # 5 minutes class AuthService: @@ -152,16 +157,11 @@ async def _verify_otp( ) -> None: """Verify an OTP code and mark it as used. - NOTE: Attempts are checked but NOT incremented (preserves existing - behavior from utils/verification_utils.py — the increment was never - implemented in the original code). - Raises: ValidationError: On any verification failure. AppError: If marking the token as used fails. """ - otp_hash = hash_token(otp_code) - token_doc = await self._token_repo.find_by_hash(otp_hash, token_type) + token_doc = await self._token_repo.find_latest_by_user(user_id, token_type) if not token_doc: log.warning( @@ -172,15 +172,6 @@ async def _verify_otp( ) raise ValidationError("Invalid or expired verification code") - if str(token_doc.user_id) != str(user_id): - log.warning( - "otp_verification_failed", - user_id=str(user_id), - reason="user_mismatch", - token_type=token_type, - ) - raise ValidationError("Invalid verification code") - expires_at = token_doc.expires_at if not expires_at.tzinfo: expires_at = expires_at.replace(tzinfo=timezone.utc) @@ -194,26 +185,29 @@ async def _verify_otp( ) raise ValidationError("Verification code has expired") - if token_doc.used_at is not None: + if token_doc.attempts >= MAX_VERIFICATION_ATTEMPTS: log.warning( "otp_verification_failed", user_id=str(user_id), - reason="already_used", + reason="max_attempts", token_type=token_type, ) - raise ValidationError("Verification code has already been used") + raise ValidationError( + "Too many failed attempts. Please request a new code." + ) - # Check max attempts (but do NOT increment — preserves original behavior) - if token_doc.attempts >= MAX_VERIFICATION_ATTEMPTS: + # Compare hash — increment attempts on mismatch + otp_hash = hash_token(otp_code) + if token_doc.token_hash != otp_hash: + await self._token_repo.increment_attempts(token_doc.id) log.warning( "otp_verification_failed", user_id=str(user_id), - reason="max_attempts", + reason="wrong_code", token_type=token_type, + attempts=token_doc.attempts + 1, ) - raise ValidationError( - "Too many failed attempts. Please request a new code." - ) + raise ValidationError("Invalid or expired verification code") marked = await self._token_repo.mark_as_used(token_doc.id) if not marked: @@ -509,8 +503,6 @@ async def reset_password( AppError: DB update failure. """ user = await self._user_repo.find_by_email(email) - if not user: - raise ValidationError("invalid email or code") is_valid, missing, _ = validate_account_password(new_password) if not is_valid: @@ -519,7 +511,17 @@ async def reset_password( details={"missing_requirements": missing}, ) - await self._verify_otp(user.id, otp_code, TOKEN_TYPE_PASSWORD_RESET) + if not user: + # Simulate OTP verification timing to prevent user enumeration + _dummy_hash = hash_token(otp_code) + raise ValidationError("invalid email or code") + + try: + await self._verify_otp(user.id, otp_code, TOKEN_TYPE_PASSWORD_RESET) + except ValidationError: + # Re-raise with generic message to prevent user enumeration + # via distinct error messages (_verify_otp already logs the details) + raise ValidationError("invalid email or code") from None new_hash = hash_password(new_password) updated = await self._user_repo.update( @@ -588,3 +590,53 @@ async def get_user_profile(self, user_id: str) -> UserDoc: if not user: raise NotFoundError("user not found") return user + + # ── Device auth flow ───────────────────────────────────────────────────── + + async def create_device_auth_code(self, user_id: ObjectId, email: str) -> str: + """Generate a one-time auth code for the device auth flow. + + Returns the raw token (caller redirects to callback with it). + """ + await self._token_repo.delete_by_user(user_id, TOKEN_TYPE_DEVICE_AUTH) + + raw_token = generate_secure_token(48) + now = datetime.now(timezone.utc) + await self._token_repo.create( + { + "user_id": user_id, + "email": email, + "token_hash": hash_token(raw_token), + "token_type": TOKEN_TYPE_DEVICE_AUTH, + "expires_at": now + timedelta(seconds=DEVICE_AUTH_EXPIRY_SECONDS), + "created_at": now, + "used_at": None, + "attempts": 0, + } + ) + log.info("device_auth_code_created", user_id=str(user_id)) + return raw_token + + async def exchange_device_code(self, code: str) -> tuple[UserDoc, str, str]: + """Exchange a one-time device auth code for JWT tokens. + + Returns: + (user_doc, access_token, refresh_token) + + Raises: + AuthenticationError: Code invalid, expired, or already used. + """ + token_hash = hash_token(code) + token_doc = await self._token_repo.consume_by_hash( + token_hash, TOKEN_TYPE_DEVICE_AUTH + ) + if not token_doc: + raise AuthenticationError("invalid or expired device auth code") + + user = await self._user_repo.find_by_id(token_doc.user_id) + if not user or user.status != UserStatus.ACTIVE: + raise AuthenticationError("user not found or inactive") + + log.info("device_auth_success", user_id=str(user.id)) + access_token, refresh_token = self._tokens.issue_tokens(user, "ext") + return user, access_token, refresh_token diff --git a/shared/validators.py b/shared/validators.py index 15c818a..2cc46ac 100644 --- a/shared/validators.py +++ b/shared/validators.py @@ -10,7 +10,7 @@ import re from collections.abc import Sequence -from urllib.parse import unquote +from urllib.parse import unquote, urlparse import emoji import regex @@ -176,3 +176,17 @@ def validate_blocked_url(url: str, patterns: Sequence[str]) -> bool: except TimeoutError: pass # Treat timed-out patterns as non-matching (fail open) return True + + +def validate_safe_redirect(url: str, fallback: str = "/dashboard") -> str: + """Return *url* if it's a safe relative path, otherwise *fallback*. + + Only allows paths starting with ``/`` that don't redirect to an external + host. Blocks ``//evil.com``, ``/\\evil.com``, and scheme-prefixed URLs. + """ + if not url or not url.startswith("/"): + return fallback + parsed = urlparse(url) + if parsed.scheme or parsed.netloc: + return fallback + return url diff --git a/static/js/auth.js b/static/js/auth.js index 49e0d23..e6f8d7d 100644 --- a/static/js/auth.js +++ b/static/js/auth.js @@ -282,7 +282,6 @@ async function submitAuth() { const data = await res.json().catch(() => ({})); if (!res.ok) { - // Handle password validation errors from backend if (data.missing_requirements && data.missing_requirements.length > 0) { showPasswordRequirements(data.missing_requirements); } else { @@ -292,7 +291,10 @@ async function submitAuth() { } closeAuthModal(); - window.location.href = '/dashboard'; + const nextUrl = new URLSearchParams(window.location.search).get('next'); + // Only allow relative paths to prevent open redirect + const safeUrl = (nextUrl && nextUrl.startsWith('/') && !nextUrl.startsWith('//')) ? nextUrl : '/dashboard'; + window.location.href = safeUrl; } catch (e) { showAuthError('Something went wrong'); } diff --git a/templates/base.html b/templates/base.html index 94fc539..43e28e0 100644 --- a/templates/base.html +++ b/templates/base.html @@ -46,7 +46,7 @@ - + + + diff --git a/templates/partials/auth_modal.html b/templates/partials/auth_modal.html index 5b7c5e3..a3c7853 100644 --- a/templates/partials/auth_modal.html +++ b/templates/partials/auth_modal.html @@ -528,19 +528,24 @@ errorEl.style.display = 'none'; } + function _oauthUrl(provider) { + const next = new URLSearchParams(window.location.search).get('next'); + return next ? `/oauth/${provider}?next=${encodeURIComponent(next)}` : `/oauth/${provider}`; + } + function loginWithGoogle() { closeAuthModal(); - window.location.href = '/oauth/google'; + window.location.href = _oauthUrl('google'); } function loginWithGitHub() { closeAuthModal(); - window.location.href = '/oauth/github'; + window.location.href = _oauthUrl('github'); } function loginWithDiscord() { closeAuthModal(); - window.location.href = '/oauth/discord'; + window.location.href = _oauthUrl('discord'); } // Close on overlay/backdrop click @@ -554,5 +559,10 @@ document.addEventListener('keydown', function (ev) { if (ev.key === 'Escape') closeAuthModal(); }); + + // Auto-open modal when ?next= is in the URL (device auth flow) + if (new URLSearchParams(window.location.search).has('next')) { + openAuthModal('login'); + } \ No newline at end of file diff --git a/tests/integration/test_device_auth.py b/tests/integration/test_device_auth.py new file mode 100644 index 0000000..a2a3ad9 --- /dev/null +++ b/tests/integration/test_device_auth.py @@ -0,0 +1,152 @@ +"""Integration tests for the device auth flow endpoints.""" + +from __future__ import annotations + +import os +from contextlib import asynccontextmanager +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock + +from bson import ObjectId +from fastapi import FastAPI +from fastapi.testclient import TestClient +from slowapi import _rate_limit_exceeded_handler +from slowapi.errors import RateLimitExceeded + +os.environ.setdefault("MONGODB_URI", "mongodb://localhost:27017/") + +from config import AppSettings +from dependencies import get_auth_service, get_current_user +from errors import AuthenticationError +from middleware.error_handler import register_error_handlers +from middleware.rate_limiter import limiter +from routes.auth_routes import router as auth_router +from schemas.models.user import UserDoc + +_USER_OID = ObjectId() +_EMAIL = "test@example.com" + + +def _build_test_app(overrides: dict) -> FastAPI: + settings = AppSettings() + + @asynccontextmanager + async def lifespan(app: FastAPI): + app.state.settings = settings + app.state.db = MagicMock() + app.state.redis = None + app.state.email_provider = MagicMock() + app.state.http_client = MagicMock() + app.state.oauth_providers = {} + yield + + app = FastAPI(lifespan=lifespan) + app.state.limiter = limiter + app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + register_error_handlers(app) + app.include_router(auth_router) + app.dependency_overrides.update(overrides) + return app + + +def _make_user_doc() -> UserDoc: + return UserDoc.from_mongo( + { + "_id": _USER_OID, + "email": _EMAIL, + "email_verified": True, + "password_set": True, + "auth_providers": [], + "plan": "free", + "created_at": datetime.now(timezone.utc), + "updated_at": datetime.now(timezone.utc), + "status": "ACTIVE", + } + ) + + +# ── GET /auth/device/login ─────────────────────────────────────────────────── + + +def test_device_login_unauthenticated_redirects_to_index(): + mock_svc = AsyncMock() + app = _build_test_app( + {get_auth_service: lambda: mock_svc, get_current_user: lambda: None} + ) + with TestClient(app, raise_server_exceptions=False) as c: + resp = c.get("/auth/device/login?state=abc", follow_redirects=False) + assert resp.status_code == 302 + assert "/?next=" in resp.headers["location"] + assert "state=abc" in resp.headers["location"] + + +def test_device_login_authenticated_redirects_to_callback(): + from dependencies.auth import CurrentUser + + mock_svc = AsyncMock() + mock_svc.get_user_profile.return_value = _make_user_doc() + mock_svc.create_device_auth_code.return_value = "test-code-123" + + user = CurrentUser(user_id=_USER_OID, email_verified=True) + app = _build_test_app( + {get_auth_service: lambda: mock_svc, get_current_user: lambda: user} + ) + with TestClient(app, raise_server_exceptions=False) as c: + resp = c.get("/auth/device/login?state=xyz", follow_redirects=False) + assert resp.status_code == 302 + loc = resp.headers["location"] + assert "/auth/device/callback" in loc + assert "code=test-code-123" in loc + assert "state=xyz" in loc + + +# ── GET /auth/device/callback ──────────────────────────────────────────────── + + +def test_device_callback_no_code_redirects_home(): + mock_svc = AsyncMock() + app = _build_test_app({get_auth_service: lambda: mock_svc}) + with TestClient(app, raise_server_exceptions=False) as c: + resp = c.get("/auth/device/callback", follow_redirects=False) + assert resp.status_code == 302 + assert resp.headers["location"] == "/" + + +def test_device_callback_with_code_renders_page(): + mock_svc = AsyncMock() + app = _build_test_app({get_auth_service: lambda: mock_svc}) + with TestClient(app, raise_server_exceptions=False) as c: + resp = c.get("/auth/device/callback?code=abc&state=xyz") + assert resp.status_code == 200 + assert 'data-code="abc"' in resp.text + assert 'data-state="xyz"' in resp.text + + +# ── POST /auth/device/token ───────────────────────────────────────────────── + + +def test_device_token_valid_code(): + mock_svc = AsyncMock() + user = _make_user_doc() + mock_svc.exchange_device_code.return_value = (user, "access-tok", "refresh-tok") + + app = _build_test_app({get_auth_service: lambda: mock_svc}) + with TestClient(app, raise_server_exceptions=False) as c: + resp = c.post("/auth/device/token", json={"code": "valid-code"}) + assert resp.status_code == 200 + data = resp.json() + assert data["access_token"] == "access-tok" + assert data["refresh_token"] == "refresh-tok" + assert data["user"]["email"] == _EMAIL + + +def test_device_token_invalid_code(): + mock_svc = AsyncMock() + mock_svc.exchange_device_code.side_effect = AuthenticationError( + "invalid or expired device auth code" + ) + + app = _build_test_app({get_auth_service: lambda: mock_svc}) + with TestClient(app, raise_server_exceptions=False) as c: + resp = c.post("/auth/device/token", json={"code": "bad-code"}) + assert resp.status_code == 401 diff --git a/tests/integration/test_middleware.py b/tests/integration/test_middleware.py index 0adad24..965475f 100644 --- a/tests/integration/test_middleware.py +++ b/tests/integration/test_middleware.py @@ -17,7 +17,11 @@ from middleware.error_handler import register_error_handlers from middleware.logging import RequestLoggingMiddleware from middleware.rate_limiter import limiter -from middleware.security import MaxContentLengthMiddleware, configure_cors +from middleware.security import ( + MaxContentLengthMiddleware, + SecurityHeadersMiddleware, + configure_cors, +) _STATIC_DIR = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "static" @@ -27,6 +31,7 @@ def _build_test_app( include_logging=True, max_content_length=1_048_576, + hsts_enabled=True, ) -> FastAPI: settings = AppSettings() @@ -43,6 +48,7 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan, docs_url=None, redoc_url=None) app.state.limiter = limiter configure_cors(app, settings) + app.add_middleware(SecurityHeadersMiddleware, hsts_enabled=hsts_enabled) app.add_middleware( MaxContentLengthMiddleware, max_content_length=max_content_length ) @@ -152,26 +158,99 @@ def test_accept_json_header_overrides_to_json(): # ── Security Middleware Tests ──────────────────────────────────────────────── -def test_cors_headers_on_preflight(): +def test_security_headers_present(): + app = _build_test_app(include_logging=False) + app.include_router(_test_router) + with TestClient(app) as c: + resp = c.get("/test-ok") + assert resp.headers["x-content-type-options"] == "nosniff" + assert resp.headers["x-frame-options"] == "DENY" + assert resp.headers["referrer-policy"] == "strict-origin-when-cross-origin" + assert ( + resp.headers["permissions-policy"] == "camera=(), microphone=(), geolocation=()" + ) + assert "content-security-policy-report-only" in resp.headers + # HSTS enabled by default in test + assert ( + resp.headers["strict-transport-security"] + == "max-age=31536000; includeSubDomains" + ) + + +def test_security_headers_hsts_disabled(): + app = _build_test_app(include_logging=False, hsts_enabled=False) + app.include_router(_test_router) + with TestClient(app) as c: + resp = c.get("/test-ok") + assert "strict-transport-security" not in resp.headers + # Other headers still present + assert resp.headers["x-content-type-options"] == "nosniff" + + +def test_cors_public_route_allows_any_origin(): + app = _build_test_app(include_logging=False) + app.include_router(_test_router) + with TestClient(app) as c: + resp = c.get( + "/api/v1/test-404", + headers={"Origin": "https://evil.com"}, + ) + assert resp.headers.get("access-control-allow-origin") == "*" + assert "access-control-allow-credentials" not in resp.headers + + +def test_cors_public_route_preflight(): app = _build_test_app(include_logging=False) app.include_router(_test_router) with TestClient(app) as c: resp = c.options( - "/test-ok", + "/api/v1/test-404", headers={ "Origin": "https://example.com", "Access-Control-Request-Method": "GET", }, ) - assert "access-control-allow-origin" in resp.headers + assert resp.status_code == 204 + assert resp.headers.get("access-control-allow-origin") == "*" + assert "access-control-max-age" in resp.headers -def test_cors_allow_credentials(): +def test_cors_private_route_allowed_origin(): + app = _build_test_app(include_logging=False) + # Override private origins for this test + from middleware.security import SplitCORSMiddleware + + for mw in app.user_middleware: + if mw.cls is SplitCORSMiddleware: + mw.kwargs["private_origins"] = ["https://spoo.me"] + app.include_router(_test_router) + with TestClient(app, base_url="http://testserver") as c: + resp = c.get( + "/auth/test-401", + headers={"Origin": "https://spoo.me"}, + ) + assert resp.headers.get("access-control-allow-origin") == "https://spoo.me" + assert resp.headers.get("access-control-allow-credentials") == "true" + + +def test_cors_private_route_disallowed_origin(): + app = _build_test_app(include_logging=False) + app.include_router(_test_router) + with TestClient(app) as c: + resp = c.get( + "/auth/test-401", + headers={"Origin": "https://evil.com"}, + ) + # No CORS headers — browser will block the response + assert "access-control-allow-origin" not in resp.headers + + +def test_cors_unclassified_route_no_cors(): app = _build_test_app(include_logging=False) app.include_router(_test_router) with TestClient(app) as c: resp = c.get("/test-ok", headers={"Origin": "https://example.com"}) - assert resp.headers.get("access-control-allow-credentials") == "true" + assert "access-control-allow-origin" not in resp.headers def test_body_size_limit_rejects_large_payload(): diff --git a/tests/smoke/conftest.py b/tests/smoke/conftest.py index f426673..c46d4cf 100644 --- a/tests/smoke/conftest.py +++ b/tests/smoke/conftest.py @@ -18,7 +18,11 @@ from middleware.error_handler import register_error_handlers from middleware.logging import RequestLoggingMiddleware from middleware.rate_limiter import limiter -from middleware.security import MaxContentLengthMiddleware, configure_cors +from middleware.security import ( + MaxContentLengthMiddleware, + SecurityHeadersMiddleware, + configure_cors, +) from routes.api_v1 import router as api_v1_router from routes.auth_routes import router as auth_router from routes.dashboard_routes import router as dashboard_router @@ -56,6 +60,7 @@ async def lifespan(app: FastAPI): SessionMiddleware, secret_key=settings.secret_key or "test-secret" ) configure_cors(app, settings) + app.add_middleware(SecurityHeadersMiddleware, hsts_enabled=False) app.add_middleware( MaxContentLengthMiddleware, max_content_length=settings.max_content_length ) diff --git a/tests/smoke/test_config_defaults.py b/tests/smoke/test_config_defaults.py index 7699603..db8d3c3 100644 --- a/tests/smoke/test_config_defaults.py +++ b/tests/smoke/test_config_defaults.py @@ -41,8 +41,14 @@ def test_all_sub_configs_populated() -> None: def test_default_cors_origins() -> None: """Default CORS origins should be ["*"].""" - settings = AppSettings() - assert settings.cors_origins == ["*"] + field = AppSettings.model_fields["cors_origins"] + assert field.default == ["*"] + + +def test_default_cors_private_origins_empty() -> None: + """Default CORS private origins should be empty (no cross-origin auth by default).""" + field = AppSettings.model_fields["cors_private_origins"] + assert field.default == [] def test_default_max_content_length() -> None: diff --git a/tests/smoke/test_middleware_stack.py b/tests/smoke/test_middleware_stack.py index fb47a18..c9e31d7 100644 --- a/tests/smoke/test_middleware_stack.py +++ b/tests/smoke/test_middleware_stack.py @@ -16,29 +16,26 @@ def test_x_request_id_header_present(smoke_client: TestClient) -> None: assert resp.headers["x-request-id"].startswith("req_") -def test_cors_headers_on_options(smoke_client: TestClient) -> None: - """An OPTIONS request should receive CORS headers.""" +def test_cors_public_route_allows_any_origin(smoke_client: TestClient) -> None: + """Public API routes should allow any origin without credentials.""" resp = smoke_client.options( - "/health", + "/api/v1/shorten", headers={ - "Origin": "https://example.com", - "Access-Control-Request-Method": "GET", + "Origin": "https://arbitrary-domain.test", + "Access-Control-Request-Method": "POST", }, ) - assert "access-control-allow-origin" in resp.headers + assert resp.headers.get("access-control-allow-origin") == "*" + assert "access-control-allow-credentials" not in resp.headers -def test_cors_allows_all_origins(smoke_client: TestClient) -> None: - """Default CORS config should allow any origin.""" - resp = smoke_client.options( +def test_cors_unclassified_route_no_headers(smoke_client: TestClient) -> None: + """Routes outside public/private groups should not get CORS headers.""" + resp = smoke_client.get( "/health", - headers={ - "Origin": "https://arbitrary-domain.test", - "Access-Control-Request-Method": "GET", - }, + headers={"Origin": "https://example.com"}, ) - allow_origin = resp.headers.get("access-control-allow-origin", "") - assert allow_origin in ("*", "https://arbitrary-domain.test") + assert "access-control-allow-origin" not in resp.headers def test_max_content_length_rejects_large_body(smoke_client: TestClient) -> None: @@ -84,8 +81,8 @@ def test_middleware_ordering_correct(smoke_app) -> None: """Middleware should be stacked in the correct order. FastAPI registers middleware in reverse order (last added = outermost). - Registration order: Session, CORS, MaxContentLength, RequestLogging - Execution order (outermost first): Session -> CORS -> MaxContentLength -> RequestLogging + Registration order: Session, CORS, SecurityHeaders, MaxContentLength, RequestLogging + Execution order (outermost first): Session -> CORS -> SecurityHeaders -> MaxContentLength -> RequestLogging """ # Walk the middleware stack from the app diff --git a/tests/unit/repositories/test_token_repository.py b/tests/unit/repositories/test_token_repository.py index 0a9dc60..8527359 100644 --- a/tests/unit/repositories/test_token_repository.py +++ b/tests/unit/repositories/test_token_repository.py @@ -7,7 +7,7 @@ import pytest -from .conftest import TOKEN_OID, USER_OID, _token_doc, make_collection +from .conftest import TOKEN_OID, USER_OID, make_collection class TestTokenRepository: @@ -23,27 +23,6 @@ async def test_create_returns_id(self): oid = await self._repo(col).create({"token_hash": "abc"}) assert oid == TOKEN_OID - @pytest.mark.asyncio - async def test_find_by_hash_queries_unused_token(self): - col = make_collection() - col.find_one = AsyncMock(return_value=_token_doc()) - result = await self._repo(col).find_by_hash("cafebabe" * 8, "email_verify") - col.find_one.assert_awaited_once_with( - { - "token_hash": "cafebabe" * 8, - "token_type": "email_verify", - "used_at": None, - } - ) - assert result is not None - assert result.token_type == "email_verify" - - @pytest.mark.asyncio - async def test_find_by_hash_returns_none(self): - col = make_collection() - col.find_one = AsyncMock(return_value=None) - assert await self._repo(col).find_by_hash("nope", "email_verify") is None - @pytest.mark.asyncio async def test_mark_as_used_sets_used_at(self): col = make_collection() diff --git a/tests/unit/services/test_auth_service.py b/tests/unit/services/test_auth_service.py index 2a27f5c..2ddee5e 100644 --- a/tests/unit/services/test_auth_service.py +++ b/tests/unit/services/test_auth_service.py @@ -396,7 +396,7 @@ async def test_verify_email_success(self): user = make_user_doc(email_verified=False) token_doc = self._make_token_doc(USER_OID, "123456") svc._user_repo.find_by_id.return_value = user - svc._token_repo.find_by_hash.return_value = token_doc + svc._token_repo.find_latest_by_user.return_value = token_doc svc._token_repo.mark_as_used.return_value = True svc._user_repo.update.return_value = True svc._email.send_welcome_email.return_value = True @@ -429,7 +429,7 @@ async def test_verify_email_invalid_otp_raises(self): svc = make_auth_service() user = make_user_doc(email_verified=False) svc._user_repo.find_by_id.return_value = user - svc._token_repo.find_by_hash.return_value = None # token not found + svc._token_repo.find_latest_by_user.return_value = None # token not found with pytest.raises(ValidationError): await svc.verify_email(str(USER_OID), "000000") @@ -440,32 +440,34 @@ async def test_verify_email_expired_otp_raises(self): user = make_user_doc(email_verified=False) token_doc = self._make_token_doc(USER_OID, "123456", expired=True) svc._user_repo.find_by_id.return_value = user - svc._token_repo.find_by_hash.return_value = token_doc + svc._token_repo.find_latest_by_user.return_value = token_doc with pytest.raises(ValidationError, match="expired"): await svc.verify_email(str(USER_OID), "123456") @pytest.mark.asyncio - async def test_verify_email_used_otp_raises(self): + async def test_verify_email_max_attempts_raises(self): svc = make_auth_service() user = make_user_doc(email_verified=False) - token_doc = self._make_token_doc(USER_OID, "123456", used=True) + token_doc = self._make_token_doc(USER_OID, "123456", attempts=5) svc._user_repo.find_by_id.return_value = user - svc._token_repo.find_by_hash.return_value = token_doc + svc._token_repo.find_latest_by_user.return_value = token_doc - with pytest.raises(ValidationError, match="already been used"): + with pytest.raises(ValidationError, match="Too many failed attempts"): await svc.verify_email(str(USER_OID), "123456") @pytest.mark.asyncio - async def test_verify_email_max_attempts_raises(self): + async def test_verify_email_wrong_code_increments_attempts(self): svc = make_auth_service() user = make_user_doc(email_verified=False) - token_doc = self._make_token_doc(USER_OID, "123456", attempts=5) + token_doc = self._make_token_doc(USER_OID, "123456") svc._user_repo.find_by_id.return_value = user - svc._token_repo.find_by_hash.return_value = token_doc + svc._token_repo.find_latest_by_user.return_value = token_doc + svc._token_repo.increment_attempts.return_value = True - with pytest.raises(ValidationError, match="Too many failed attempts"): - await svc.verify_email(str(USER_OID), "123456") + with pytest.raises(ValidationError, match="Invalid or expired"): + await svc.verify_email(str(USER_OID), "999999") # wrong code + svc._token_repo.increment_attempts.assert_awaited_once() @pytest.mark.asyncio async def test_verify_email_welcome_email_failure_is_non_fatal(self): @@ -473,7 +475,7 @@ async def test_verify_email_welcome_email_failure_is_non_fatal(self): user = make_user_doc(email_verified=False) token_doc = self._make_token_doc(USER_OID, "123456") svc._user_repo.find_by_id.return_value = user - svc._token_repo.find_by_hash.return_value = token_doc + svc._token_repo.find_latest_by_user.return_value = token_doc svc._token_repo.mark_as_used.return_value = True svc._user_repo.update.return_value = True svc._email.send_welcome_email.side_effect = Exception("server down") @@ -668,7 +670,7 @@ async def test_reset_password_success(self): user = make_user_doc(password_set=True) token_doc = self._make_token_doc(USER_OID, "654321") svc._user_repo.find_by_email.return_value = user - svc._token_repo.find_by_hash.return_value = token_doc + svc._token_repo.find_latest_by_user.return_value = token_doc svc._token_repo.mark_as_used.return_value = True svc._user_repo.update.return_value = True @@ -699,11 +701,24 @@ async def test_reset_password_wrong_otp_raises(self): svc = make_auth_service() user = make_user_doc(password_set=True) svc._user_repo.find_by_email.return_value = user - svc._token_repo.find_by_hash.return_value = None # token not found + svc._token_repo.find_latest_by_user.return_value = None # token not found with pytest.raises(ValidationError): await svc.reset_password("test@example.com", "wrong-code", "NewValidPass1!") + @pytest.mark.asyncio + async def test_reset_password_wrong_code_increments_attempts(self): + svc = make_auth_service() + user = make_user_doc(password_set=True) + token_doc = self._make_token_doc(USER_OID, "654321") + svc._user_repo.find_by_email.return_value = user + svc._token_repo.find_latest_by_user.return_value = token_doc + svc._token_repo.increment_attempts.return_value = True + + with pytest.raises(ValidationError, match="invalid email or code"): + await svc.reset_password("test@example.com", "000000", "NewValidPass1!") + svc._token_repo.increment_attempts.assert_awaited_once() + @pytest.mark.asyncio async def test_request_password_reset_deletes_old_tokens_before_creating(self): """_create_password_reset_otp deletes old reset tokens before inserting new one.""" @@ -829,3 +844,67 @@ def test_profile_with_pfp(self): profile = UserProfileResponse.from_user(user) assert profile.pfp.url == "https://img.url" assert profile.pfp.source == "google" + + +# ── Extension auth flow tests ──────────────────────────────────────────────── + + +class TestExtensionAuth: + @pytest.mark.asyncio + async def test_create_device_auth_code(self): + svc = make_auth_service() + svc._token_repo.delete_by_user.return_value = 0 + svc._token_repo.create.return_value = ObjectId() + + code = await svc.create_device_auth_code(USER_OID, "test@example.com") + assert isinstance(code, str) + assert len(code) > 30 # secure token is long + svc._token_repo.create.assert_awaited_once() + + @pytest.mark.asyncio + async def test_exchange_device_code_success(self): + from datetime import timedelta + + from schemas.models.token import TOKEN_TYPE_DEVICE_AUTH, VerificationTokenDoc + from shared.crypto import hash_token + + svc = make_auth_service() + raw_code = "test-code-123" + now = datetime.now(timezone.utc) + token_doc = VerificationTokenDoc.from_mongo( + { + "_id": ObjectId(), + "user_id": USER_OID, + "email": "test@example.com", + "token_hash": hash_token(raw_code), + "token_type": TOKEN_TYPE_DEVICE_AUTH, + "expires_at": now + timedelta(minutes=5), + "created_at": now, + "used_at": None, + "attempts": 0, + } + ) + svc._token_repo.consume_by_hash.return_value = token_doc + svc._user_repo.find_by_id.return_value = make_user_doc(email_verified=True) + + _user, access, refresh = await svc.exchange_device_code(raw_code) + assert isinstance(access, str) + assert isinstance(refresh, str) + svc._token_repo.consume_by_hash.assert_awaited_once() + + @pytest.mark.asyncio + async def test_exchange_device_code_invalid(self): + svc = make_auth_service() + svc._token_repo.consume_by_hash.return_value = None + + with pytest.raises(AuthenticationError, match="invalid or expired"): + await svc.exchange_device_code("bad-code") + + @pytest.mark.asyncio + async def test_exchange_device_code_expired(self): + """Expired codes are filtered out by consume_by_hash (expires_at in query).""" + svc = make_auth_service() + svc._token_repo.consume_by_hash.return_value = None + + with pytest.raises(AuthenticationError, match="invalid or expired"): + await svc.exchange_device_code("expired-code") diff --git a/tests/unit/shared/test_safe_redirect.py b/tests/unit/shared/test_safe_redirect.py new file mode 100644 index 0000000..185e859 --- /dev/null +++ b/tests/unit/shared/test_safe_redirect.py @@ -0,0 +1,43 @@ +"""Unit tests for validate_safe_redirect.""" + +import pytest + +from shared.validators import validate_safe_redirect + + +@pytest.mark.parametrize( + "url, expected", + [ + ("/dashboard", "/dashboard"), + ("/auth/device/login?state=abc", "/auth/device/login?state=abc"), + ("/some/path", "/some/path"), + ("", "/dashboard"), + ("https://evil.com", "/dashboard"), + ("//evil.com", "/dashboard"), + ("http://evil.com", "/dashboard"), + ("javascript:alert(1)", "/dashboard"), + ("data:text/html,

hi

", "/dashboard"), + ( + "/\\evil.com", + "/\\evil.com", + ), # relative path with backslash — allowed (harmless) + ], + ids=[ + "valid_relative", + "valid_with_query", + "valid_nested", + "empty_falls_back", + "absolute_url_blocked", + "protocol_relative_blocked", + "http_blocked", + "javascript_blocked", + "data_uri_blocked", + "backslash_relative", + ], +) +def test_validate_safe_redirect(url, expected): + assert validate_safe_redirect(url) == expected + + +def test_custom_fallback(): + assert validate_safe_redirect("https://evil.com", fallback="/home") == "/home"