From 9e100f6ac2fe4673119e68fbd855332510c20643 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Tue, 10 Mar 2026 23:09:33 +0100 Subject: [PATCH 1/5] feat: add mandatory JWT + API key authentication (#256) Implement enterprise-grade authentication replacing the insecure X-Human-Role header with real credential verification: - JWT authentication with Argon2id password hashing - API key authentication via SHA-256 hash lookup - First-run admin setup flow (POST /auth/setup) - Login, password change, and /me endpoints - Auth middleware (AbstractAuthenticationMiddleware) with exclude paths - Auto-generated JWT secret persisted to settings table - User and API key repositories (protocol + SQLite implementation) - Schema migration v4 (users, api_keys, settings tables) - Guards now read from authenticated connection.user - Approval controller reads decided_by from auth user role - Comprehensive test coverage (config, service, middleware, controller, repos) Closes #256 --- docker/.env.example | 8 + pyproject.toml | 2 + src/ai_company/api/app.py | 40 ++- src/ai_company/api/auth/__init__.py | 14 + src/ai_company/api/auth/config.py | 79 ++++ src/ai_company/api/auth/controller.py | 336 ++++++++++++++++++ src/ai_company/api/auth/middleware.py | 211 +++++++++++ src/ai_company/api/auth/models.py | 88 +++++ src/ai_company/api/auth/secret.py | 56 +++ src/ai_company/api/auth/service.py | 150 ++++++++ src/ai_company/api/config.py | 7 +- src/ai_company/api/controllers/__init__.py | 3 + src/ai_company/api/controllers/approvals.py | 14 +- src/ai_company/api/errors.py | 9 + src/ai_company/api/exception_handlers.py | 14 + src/ai_company/api/guards.py | 30 +- src/ai_company/api/state.py | 16 + src/ai_company/observability/events/api.py | 5 + .../observability/events/persistence.py | 20 ++ src/ai_company/persistence/protocol.py | 40 +++ src/ai_company/persistence/repositories.py | 155 ++++++++ src/ai_company/persistence/sqlite/backend.py | 83 ++++- .../persistence/sqlite/migrations.py | 44 ++- .../persistence/sqlite/user_repo.py | 314 ++++++++++++++++ tests/unit/api/auth/__init__.py | 0 tests/unit/api/auth/test_config.py | 51 +++ tests/unit/api/auth/test_controller.py | 193 ++++++++++ tests/unit/api/auth/test_middleware.py | 251 +++++++++++++ tests/unit/api/auth/test_service.py | 145 ++++++++ tests/unit/api/conftest.py | 191 +++++++++- tests/unit/api/controllers/test_agents.py | 10 +- tests/unit/api/controllers/test_analytics.py | 7 +- tests/unit/api/controllers/test_approvals.py | 23 +- tests/unit/api/controllers/test_autonomy.py | 12 +- tests/unit/api/controllers/test_budget.py | 7 +- tests/unit/api/controllers/test_company.py | 8 +- tests/unit/api/controllers/test_tasks.py | 22 +- tests/unit/api/test_app.py | 14 +- tests/unit/api/test_guards.py | 120 +++++-- .../unit/persistence/sqlite/test_user_repo.py | 199 +++++++++++ tests/unit/persistence/test_migrations_v2.py | 4 +- tests/unit/persistence/test_protocol.py | 61 ++++ uv.lock | 47 +++ 43 files changed, 2996 insertions(+), 107 deletions(-) create mode 100644 src/ai_company/api/auth/__init__.py create mode 100644 src/ai_company/api/auth/config.py create mode 100644 src/ai_company/api/auth/controller.py create mode 100644 src/ai_company/api/auth/middleware.py create mode 100644 src/ai_company/api/auth/models.py create mode 100644 src/ai_company/api/auth/secret.py create mode 100644 src/ai_company/api/auth/service.py create mode 100644 src/ai_company/persistence/sqlite/user_repo.py create mode 100644 tests/unit/api/auth/__init__.py create mode 100644 tests/unit/api/auth/test_config.py create mode 100644 tests/unit/api/auth/test_controller.py create mode 100644 tests/unit/api/auth/test_middleware.py create mode 100644 tests/unit/api/auth/test_service.py create mode 100644 tests/unit/persistence/sqlite/test_user_repo.py diff --git a/docker/.env.example b/docker/.env.example index eae54725a1..7f336261e8 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -9,6 +9,14 @@ # API key for the LLM provider (required for agent execution) LLM_API_KEY= +# --- Authentication ---------------------------------------------------------- +# JWT signing secret (optional — auto-generated and persisted on first run). +# Set explicitly only for multi-instance deployments sharing a common secret. +# Must be >= 32 characters if set. +# Generate with: python -c "import secrets; print(secrets.token_urlsafe(48))" +# AI_COMPANY_JWT_SECRET= +# First-run: POST /api/v1/auth/setup to create admin account + # --- Application ------------------------------------------------------------- # Log level: debug, info, warning, error, critical AI_COMPANY_LOG_LEVEL=info diff --git a/pyproject.toml b/pyproject.toml index aa14594df4..3d8e36f09e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,12 +15,14 @@ classifiers = [ dependencies = [ "aiodocker==0.26.0", "aiosqlite==0.22.1", + "argon2-cffi==25.1.0", "jinja2==3.1.6", "jsonschema==4.26.0", "litellm==1.82.1", "litestar[standard,structlog,pydantic,brotli,prometheus]==2.21.1", "mcp==1.26.0", "pydantic==2.12.5", + "pyjwt[crypto]==2.11.0", "pyyaml==6.0.3", "structlog==25.5.0", ] diff --git a/src/ai_company/api/app.py b/src/ai_company/api/app.py index 793b5832fb..8cffe96771 100644 --- a/src/ai_company/api/app.py +++ b/src/ai_company/api/app.py @@ -19,6 +19,9 @@ from ai_company import __version__ from ai_company.api.approval_store import ApprovalStore +from ai_company.api.auth.middleware import create_auth_middleware_class +from ai_company.api.auth.secret import resolve_jwt_secret +from ai_company.api.auth.service import AuthService from ai_company.api.bus_bridge import MessageBusBridge from ai_company.api.channels import CHANNEL_APPROVALS, create_channels_plugin from ai_company.api.controllers import ALL_CONTROLLERS @@ -98,6 +101,7 @@ def _build_lifecycle( persistence: PersistenceBackend | None, message_bus: MessageBus | None, bridge: MessageBusBridge | None, + app_state: AppState, ) -> tuple[ Sequence[Callable[[], Awaitable[None]]], Sequence[Callable[[], Awaitable[None]]], @@ -110,7 +114,7 @@ def _build_lifecycle( async def on_startup() -> None: logger.info(API_APP_STARTUP, version=__version__) - await _safe_startup(persistence, message_bus, bridge) + await _safe_startup(persistence, message_bus, bridge, app_state) async def on_shutdown() -> None: logger.info(API_APP_SHUTDOWN, version=__version__) @@ -151,8 +155,9 @@ async def _safe_startup( persistence: PersistenceBackend | None, message_bus: MessageBus | None, bridge: MessageBusBridge | None, + app_state: AppState, ) -> None: - """Connect persistence, start message bus and bridge. + """Connect persistence, resolve JWT secret, start message bus and bridge. Executes in order; on failure, cleans up already-started components in reverse order before re-raising. @@ -171,6 +176,23 @@ async def _safe_startup( ) raise started_persistence = True + + # Resolve JWT secret after persistence is up + if app_state._auth_service is None: # noqa: SLF001 + try: + secret = await resolve_jwt_secret(persistence) + auth_config = app_state.config.api.auth.with_secret( + secret, + ) + app_state._auth_service = AuthService(auth_config) # noqa: SLF001 + except Exception: + logger.error( + API_APP_STARTUP, + error="Failed to resolve JWT secret", + exc_info=True, + ) + raise + if message_bus is not None: try: await message_bus.start() @@ -237,13 +259,14 @@ async def _safe_shutdown( ) -def create_app( +def create_app( # noqa: PLR0913 *, config: RootConfig | None = None, persistence: PersistenceBackend | None = None, message_bus: MessageBus | None = None, cost_tracker: CostTracker | None = None, approval_store: ApprovalStore | None = None, + auth_service: AuthService | None = None, ) -> Litestar: """Create and configure the Litestar application. @@ -256,6 +279,7 @@ def create_app( message_bus: Internal message bus. cost_tracker: Cost tracking service. approval_store: Approval queue store. + auth_service: Pre-built auth service (for testing). Returns: Configured Litestar application. @@ -283,6 +307,7 @@ def create_app( message_bus=message_bus, cost_tracker=cost_tracker, approval_store=effective_approval_store, + auth_service=auth_service, startup_time=time.monotonic(), ) @@ -299,6 +324,7 @@ def create_app( persistence, message_bus, bridge, + app_state, ) return Litestar( @@ -369,4 +395,10 @@ def _build_middleware(api_config: ApiConfig) -> list[Middleware]: rate_limit=(rl.time_unit, rl.max_requests), # type: ignore[arg-type] exclude=list(rl.exclude_paths), ) - return [CSPMiddleware, RequestLoggingMiddleware, rate_limit.middleware] + auth_middleware = create_auth_middleware_class(api_config.auth) + return [ + auth_middleware, + CSPMiddleware, + RequestLoggingMiddleware, + rate_limit.middleware, + ] diff --git a/src/ai_company/api/auth/__init__.py b/src/ai_company/api/auth/__init__.py new file mode 100644 index 0000000000..3bad8b316e --- /dev/null +++ b/src/ai_company/api/auth/__init__.py @@ -0,0 +1,14 @@ +"""Authentication and authorization for the API layer.""" + +from ai_company.api.auth.config import AuthConfig +from ai_company.api.auth.models import ApiKey, AuthenticatedUser, AuthMethod, User +from ai_company.api.auth.service import AuthService + +__all__ = [ + "ApiKey", + "AuthConfig", + "AuthMethod", + "AuthService", + "AuthenticatedUser", + "User", +] diff --git a/src/ai_company/api/auth/config.py b/src/ai_company/api/auth/config.py new file mode 100644 index 0000000000..72af06f111 --- /dev/null +++ b/src/ai_company/api/auth/config.py @@ -0,0 +1,79 @@ +"""Authentication configuration.""" + +from pydantic import BaseModel, ConfigDict, Field + +_MIN_SECRET_LENGTH = 32 + + +class AuthConfig(BaseModel): + """JWT and authentication configuration. + + The ``jwt_secret`` is resolved at application startup via a + priority chain: + + 1. ``AI_COMPANY_JWT_SECRET`` environment variable (for multi-instance + deployments sharing a common secret). + 2. Stored secret in the persistence ``settings`` table (auto-generated + on first run). + 3. Auto-generated and persisted on first startup. + + At construction time the secret may be empty — it is populated + before the first request is served. + + Attributes: + jwt_secret: HMAC signing key (resolved at startup, repr-hidden). + jwt_algorithm: JWT signing algorithm. + jwt_expiry_minutes: Token lifetime in minutes. + exclude_paths: URL paths excluded from auth middleware. + """ + + model_config = ConfigDict(frozen=True) + + jwt_secret: str = Field( + default="", + repr=False, + description="JWT signing secret (resolved at startup)", + ) + jwt_algorithm: str = Field( + default="HS256", + description="JWT signing algorithm", + ) + jwt_expiry_minutes: int = Field( + default=1440, + ge=1, + le=43200, + description="Token lifetime in minutes (default 24h)", + ) + exclude_paths: tuple[str, ...] = Field( + default=( + "^/api/v1/health$", + "^/docs", + "^/api$", + "^/api/v1/auth/setup$", + "^/api/v1/auth/login$", + ), + description=( + "Regex patterns for paths excluded from authentication. " + "Anchor with ^ and $ to avoid substring matches." + ), + ) + + def with_secret(self, secret: str) -> AuthConfig: + """Return a copy with the JWT secret set. + + Args: + secret: Resolved JWT signing secret. + + Returns: + New ``AuthConfig`` with the secret populated. + + Raises: + ValueError: If the secret is too short. + """ + if len(secret) < _MIN_SECRET_LENGTH: + msg = ( + f"jwt_secret must be at least {_MIN_SECRET_LENGTH} " + f"characters (got {len(secret)})" + ) + raise ValueError(msg) + return self.model_copy(update={"jwt_secret": secret}) diff --git a/src/ai_company/api/auth/controller.py b/src/ai_company/api/auth/controller.py new file mode 100644 index 0000000000..8f96838f59 --- /dev/null +++ b/src/ai_company/api/auth/controller.py @@ -0,0 +1,336 @@ +"""Authentication controller — setup, login, password change, me.""" + +import uuid +from datetime import UTC, datetime + +from litestar import Controller, Response, get, post +from litestar.connection import ASGIConnection # noqa: TC002 +from litestar.exceptions import PermissionDeniedException +from pydantic import BaseModel, ConfigDict, Field + +from ai_company.api.auth.models import AuthenticatedUser, User +from ai_company.api.auth.service import AuthService # noqa: TC001 +from ai_company.api.dto import ApiResponse +from ai_company.api.errors import ApiValidationError, ConflictError, UnauthorizedError +from ai_company.api.guards import HumanRole +from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.observability import get_logger +from ai_company.observability.events.api import ( + API_AUTH_FAILED, + API_AUTH_PASSWORD_CHANGED, + API_AUTH_SETUP_COMPLETE, + API_AUTH_TOKEN_ISSUED, +) + +logger = get_logger(__name__) + +_MIN_PASSWORD_LENGTH = 12 + + +# ── Request DTOs ────────────────────────────────────────────── + + +class SetupRequest(BaseModel): + """First-run admin account creation payload. + + Attributes: + username: Admin login username. + password: Admin password (min 12 chars). + """ + + model_config = ConfigDict(frozen=True) + + username: NotBlankStr = Field(max_length=128) + password: NotBlankStr = Field(min_length=_MIN_PASSWORD_LENGTH, max_length=128) + + +class LoginRequest(BaseModel): + """Login credentials payload. + + Attributes: + username: Login username. + password: Login password. + """ + + model_config = ConfigDict(frozen=True) + + username: NotBlankStr = Field(max_length=128) + password: NotBlankStr = Field(max_length=128) + + +class ChangePasswordRequest(BaseModel): + """Password change payload. + + Attributes: + current_password: Current password for verification. + new_password: New password (min 12 chars). + """ + + model_config = ConfigDict(frozen=True) + + current_password: NotBlankStr = Field(max_length=128) + new_password: NotBlankStr = Field(min_length=_MIN_PASSWORD_LENGTH, max_length=128) + + +# ── Response DTOs ───────────────────────────────────────────── + + +class TokenResponse(BaseModel): + """JWT token response. + + Attributes: + token: Encoded JWT string. + expires_in: Token lifetime in seconds. + must_change_password: Whether password change is required. + """ + + model_config = ConfigDict(frozen=True) + + token: str + expires_in: int + must_change_password: bool + + +class UserInfoResponse(BaseModel): + """Current user information. + + Attributes: + id: User ID. + username: Login username. + role: Access control role. + must_change_password: Whether password change is required. + """ + + model_config = ConfigDict(frozen=True) + + id: str + username: str + role: str + must_change_password: bool + + +# ── Guards ──────────────────────────────────────────────────── + + +def require_password_changed( + connection: ASGIConnection, # type: ignore[type-arg] + _: object, +) -> None: + """Guard that blocks users who must change their password. + + Applied to all routes except ``/auth/change-password`` and + ``/auth/me``. + + Args: + connection: The incoming connection. + _: Route handler (unused). + + Raises: + PermissionDeniedException: If password change is required. + """ + user = connection.scope.get("user") + if ( + user is not None + and isinstance(user, AuthenticatedUser) + and user.must_change_password + ): + raise PermissionDeniedException(detail="Password change required") + + +def _validate_password(password: str) -> None: + """Raise if the password is too short.""" + if len(password) < _MIN_PASSWORD_LENGTH: + msg = f"Password must be at least {_MIN_PASSWORD_LENGTH} characters" + raise ApiValidationError(msg) + + +# ── Controller ──────────────────────────────────────────────── + + +class AuthController(Controller): + """Authentication endpoints: setup, login, password change, me.""" + + path = "/auth" + tags = ("auth",) + + @post( + "/setup", + status_code=201, + summary="First-run admin setup", + ) + async def setup( + self, + data: SetupRequest, + request: ASGIConnection, # type: ignore[type-arg] + ) -> Response[ApiResponse[TokenResponse]]: + """Create the first admin account (CEO). + + Only available when no users exist. Returns 409 after + the first account is created. + """ + _validate_password(data.password) + + app_state = request.app.state["app_state"] + auth_service: AuthService = app_state.auth_service + persistence = app_state.persistence + + user_count = await persistence.users.count() + if user_count > 0: + msg = "Setup already completed" + raise ConflictError(msg) + + now = datetime.now(UTC) + user = User( + id=str(uuid.uuid4()), + username=data.username, + password_hash=auth_service.hash_password(data.password), + role=HumanRole.CEO, + must_change_password=True, + created_at=now, + updated_at=now, + ) + await persistence.users.save(user) + + token, expires_in = auth_service.create_token(user) + + logger.info( + API_AUTH_SETUP_COMPLETE, + user_id=user.id, + username=user.username, + ) + + return Response( + content=ApiResponse( + data=TokenResponse( + token=token, + expires_in=expires_in, + must_change_password=True, + ), + ), + status_code=201, + ) + + @post( + "/login", + status_code=200, + summary="Authenticate with credentials", + ) + async def login( + self, + data: LoginRequest, + request: ASGIConnection, # type: ignore[type-arg] + ) -> Response[ApiResponse[TokenResponse]]: + """Validate credentials and return a JWT.""" + app_state = request.app.state["app_state"] + auth_service: AuthService = app_state.auth_service + persistence = app_state.persistence + + user = await persistence.users.get_by_username(data.username) + if user is None or not auth_service.verify_password( + data.password, user.password_hash + ): + logger.warning( + API_AUTH_FAILED, + reason="invalid_credentials", + username=data.username, + ) + msg = "Invalid credentials" + raise UnauthorizedError(msg) + + token, expires_in = auth_service.create_token(user) + + logger.info( + API_AUTH_TOKEN_ISSUED, + user_id=user.id, + username=user.username, + ) + + return Response( + content=ApiResponse( + data=TokenResponse( + token=token, + expires_in=expires_in, + must_change_password=user.must_change_password, + ), + ), + ) + + @post( + "/change-password", + status_code=200, + summary="Change current user password", + ) + async def change_password( + self, + data: ChangePasswordRequest, + request: ASGIConnection, # type: ignore[type-arg] + ) -> Response[ApiResponse[UserInfoResponse]]: + """Validate current password and set new one.""" + _validate_password(data.new_password) + auth_user: AuthenticatedUser = request.scope["user"] + app_state = request.app.state["app_state"] + auth_service: AuthService = app_state.auth_service + persistence = app_state.persistence + + user = await persistence.users.get(auth_user.user_id) + if user is None: + msg = "User not found" + raise UnauthorizedError(msg) + + if not auth_service.verify_password(data.current_password, user.password_hash): + logger.warning( + API_AUTH_FAILED, + reason="invalid_current_password", + user_id=user.id, + ) + msg = "Invalid current password" + raise UnauthorizedError(msg) + + now = datetime.now(UTC) + updated_user = user.model_copy( + update={ + "password_hash": auth_service.hash_password(data.new_password), + "must_change_password": False, + "updated_at": now, + } + ) + await persistence.users.save(updated_user) + + logger.info( + API_AUTH_PASSWORD_CHANGED, + user_id=user.id, + username=user.username, + ) + + return Response( + content=ApiResponse( + data=UserInfoResponse( + id=updated_user.id, + username=updated_user.username, + role=updated_user.role.value, + must_change_password=False, + ), + ), + ) + + @get( + "/me", + summary="Get current user info", + ) + async def me( + self, + request: ASGIConnection, # type: ignore[type-arg] + ) -> Response[ApiResponse[UserInfoResponse]]: + """Return information about the authenticated user.""" + auth_user: AuthenticatedUser = request.scope["user"] + + return Response( + content=ApiResponse( + data=UserInfoResponse( + id=auth_user.user_id, + username=auth_user.username, + role=auth_user.role.value, + must_change_password=auth_user.must_change_password, + ), + ), + ) diff --git a/src/ai_company/api/auth/middleware.py b/src/ai_company/api/auth/middleware.py new file mode 100644 index 0000000000..620de6ab06 --- /dev/null +++ b/src/ai_company/api/auth/middleware.py @@ -0,0 +1,211 @@ +"""JWT + API key authentication middleware.""" + +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any + +import jwt +from litestar.exceptions import NotAuthorizedException +from litestar.middleware import ( + AbstractAuthenticationMiddleware, + AuthenticationResult, +) + +from ai_company.api.auth.models import AuthenticatedUser, AuthMethod +from ai_company.api.auth.service import AuthService +from ai_company.api.guards import HumanRole +from ai_company.observability import get_logger +from ai_company.observability.events.api import ( + API_AUTH_FAILED, + API_AUTH_SUCCESS, +) + +if TYPE_CHECKING: + from litestar.connection import ASGIConnection + + from ai_company.api.auth.config import AuthConfig + +logger = get_logger(__name__) + +_BEARER_PARTS = 2 + + +class ApiAuthMiddleware(AbstractAuthenticationMiddleware): + """Authenticate requests via JWT or API key. + + Reads ``Authorization: Bearer `` from the request. + Tokens containing ``.`` are tried as JWTs first; if that fails + (or the token has no dots), it is tried as an API key via + SHA-256 hash lookup. + + Requires ``auth_service``, persistence backend on + ``app.state["app_state"]``. + """ + + async def authenticate_request( + self, + connection: ASGIConnection[Any, Any, Any, Any], + ) -> AuthenticationResult: + """Validate the Authorization header. + + Args: + connection: Incoming ASGI connection. + + Returns: + AuthenticationResult with AuthenticatedUser. + + Raises: + NotAuthorizedException: If authentication fails. + """ + auth_header = connection.headers.get("authorization") + if not auth_header: + logger.warning( + API_AUTH_FAILED, + reason="missing_header", + path=str(connection.url.path), + ) + raise NotAuthorizedException(detail="Missing Authorization header") + + token = _extract_bearer_token(auth_header) + if token is None: + logger.warning( + API_AUTH_FAILED, + reason="invalid_scheme", + path=str(connection.url.path), + ) + raise NotAuthorizedException(detail="Invalid authorization scheme") + + app_state = connection.app.state["app_state"] + auth_service: AuthService = app_state.auth_service + + # Try JWT first (tokens with dots are likely JWTs) + if "." in token: + user = await _try_jwt_auth(token, auth_service, app_state, connection) + if user is not None: + return AuthenticationResult(user=user, auth=token) + + # Fall back to API key + user = await _try_api_key_auth(token, app_state, connection) + if user is not None: + return AuthenticationResult(user=user, auth=token) + + logger.warning( + API_AUTH_FAILED, + reason="invalid_credentials", + path=str(connection.url.path), + ) + raise NotAuthorizedException(detail="Invalid credentials") + + +def _extract_bearer_token(header: str) -> str | None: + """Extract token from ``Bearer `` header value.""" + parts = header.split(None, 1) + if len(parts) != _BEARER_PARTS or parts[0].lower() != "bearer": + return None + return parts[1] + + +async def _try_jwt_auth( + token: str, + auth_service: AuthService, + app_state: Any, + connection: ASGIConnection[Any, Any, Any, Any], +) -> AuthenticatedUser | None: + """Attempt JWT authentication.""" + try: + claims = auth_service.decode_token(token) + except jwt.InvalidTokenError: + return None + + user_id = claims.get("sub") + if not user_id: + return None + + # Verify the user still exists + persistence = app_state.persistence + db_user = await persistence.users.get(user_id) + if db_user is None: + return None + + authenticated = AuthenticatedUser( + user_id=db_user.id, + username=db_user.username, + role=db_user.role, + auth_method=AuthMethod.JWT, + must_change_password=db_user.must_change_password, + ) + logger.info( + API_AUTH_SUCCESS, + user_id=db_user.id, + username=db_user.username, + auth_method="jwt", + path=str(connection.url.path), + ) + return authenticated + + +async def _try_api_key_auth( + token: str, + app_state: Any, + connection: ASGIConnection[Any, Any, Any, Any], +) -> AuthenticatedUser | None: + """Attempt API key authentication.""" + key_hash = AuthService.hash_api_key(token) + persistence = app_state.persistence + api_key = await persistence.api_keys.get_by_hash(key_hash) + if api_key is None: + return None + + # Check revocation and expiry + if api_key.revoked: + return None + if api_key.expires_at is not None and api_key.expires_at < datetime.now(UTC): + return None + + # Look up the owning user + db_user = await persistence.users.get(api_key.user_id) + if db_user is None: + return None + + authenticated = AuthenticatedUser( + user_id=db_user.id, + username=db_user.username, + role=HumanRole(api_key.role), + auth_method=AuthMethod.API_KEY, + must_change_password=db_user.must_change_password, + ) + logger.info( + API_AUTH_SUCCESS, + user_id=db_user.id, + username=db_user.username, + auth_method="api_key", + key_name=api_key.name, + path=str(connection.url.path), + ) + return authenticated + + +def create_auth_middleware_class( + auth_config: AuthConfig, +) -> type[ApiAuthMiddleware]: + """Create a middleware class with excluded paths baked in. + + Litestar's ``AbstractAuthenticationMiddleware.__init__`` takes + ``exclude`` as a parameter (default ``None``). We create a + subclass whose ``__init__`` forwards the configured exclude + list to ``super().__init__``. + + Args: + auth_config: Auth configuration with exclude_paths. + + Returns: + Middleware class ready for use in the Litestar middleware stack. + """ + exclude_paths = list(auth_config.exclude_paths) or None + + class ConfiguredAuthMiddleware(ApiAuthMiddleware): + """Auth middleware with pre-configured exclude paths.""" + + def __init__(self, app: Any) -> None: + super().__init__(app, exclude=exclude_paths) + + return ConfiguredAuthMiddleware diff --git a/src/ai_company/api/auth/models.py b/src/ai_company/api/auth/models.py new file mode 100644 index 0000000000..049f9b4cec --- /dev/null +++ b/src/ai_company/api/auth/models.py @@ -0,0 +1,88 @@ +"""Authentication domain models.""" + +from datetime import datetime # noqa: TC003 +from enum import StrEnum + +from pydantic import BaseModel, ConfigDict, Field + +from ai_company.api.guards import HumanRole # noqa: TC001 +from ai_company.core.types import NotBlankStr # noqa: TC001 + + +class AuthMethod(StrEnum): + """Authentication method used for a request.""" + + JWT = "jwt" + API_KEY = "api_key" + + +class User(BaseModel): + """Persisted user account. + + Attributes: + id: Unique user identifier (UUID). + username: Login username. + password_hash: Argon2id hash (excluded from repr). + role: Access control role. + must_change_password: Whether the user must change password. + created_at: Account creation timestamp. + updated_at: Last modification timestamp. + """ + + model_config = ConfigDict(frozen=True) + + id: NotBlankStr + username: NotBlankStr + password_hash: str = Field(repr=False) + role: HumanRole + must_change_password: bool = True + created_at: datetime + updated_at: datetime + + +class ApiKey(BaseModel): + """Persisted API key (hash-only storage). + + Attributes: + id: Unique key identifier (UUID). + key_hash: SHA-256 hex digest of the raw key. + name: Human-readable label. + role: Access control role. + user_id: Owner user ID. + created_at: Key creation timestamp. + expires_at: Optional expiry timestamp. + revoked: Whether the key has been revoked. + """ + + model_config = ConfigDict(frozen=True) + + id: NotBlankStr + key_hash: str = Field(repr=False) + name: NotBlankStr + role: HumanRole + user_id: NotBlankStr + created_at: datetime + expires_at: datetime | None = None + revoked: bool = False + + +class AuthenticatedUser(BaseModel): + """Lightweight identity attached to ``connection.user``. + + Populated by the auth middleware after successful authentication. + + Attributes: + user_id: User's unique identifier. + username: User's login name. + role: Access control role. + auth_method: How the user authenticated. + must_change_password: Whether forced password change is pending. + """ + + model_config = ConfigDict(frozen=True) + + user_id: NotBlankStr + username: NotBlankStr + role: HumanRole + auth_method: AuthMethod + must_change_password: bool = False diff --git a/src/ai_company/api/auth/secret.py b/src/ai_company/api/auth/secret.py new file mode 100644 index 0000000000..7f5b57d540 --- /dev/null +++ b/src/ai_company/api/auth/secret.py @@ -0,0 +1,56 @@ +"""JWT secret resolution — env var → persistence → auto-generate.""" + +import os +import secrets + +from ai_company.observability import get_logger +from ai_company.observability.events.api import API_APP_STARTUP +from ai_company.persistence.protocol import PersistenceBackend # noqa: TC001 + +logger = get_logger(__name__) + +_SETTING_KEY = "jwt_secret" +_SECRET_LENGTH = 48 # 64 URL-safe base64 chars + + +async def resolve_jwt_secret( + persistence: PersistenceBackend, +) -> str: + """Resolve the JWT signing secret using a priority chain. + + 1. ``AI_COMPANY_JWT_SECRET`` env var (for multi-instance deploys). + 2. Stored secret in persistence ``settings`` table. + 3. Auto-generate, persist, and return. + + Args: + persistence: Connected persistence backend. + + Returns: + JWT signing secret (>= 32 characters). + """ + # 1. Env var override (highest priority) + env_secret = os.environ.get("AI_COMPANY_JWT_SECRET", "").strip() + if env_secret: + logger.info( + API_APP_STARTUP, + note="JWT secret loaded from AI_COMPANY_JWT_SECRET env var", + ) + return env_secret + + # 2. Check persistence + stored = await persistence.get_setting(_SETTING_KEY) + if stored: + logger.info( + API_APP_STARTUP, + note="JWT secret loaded from persistence", + ) + return stored + + # 3. Auto-generate and persist + generated = secrets.token_urlsafe(_SECRET_LENGTH) + await persistence.set_setting(_SETTING_KEY, generated) + logger.info( + API_APP_STARTUP, + note="JWT secret auto-generated and saved to persistence", + ) + return generated diff --git a/src/ai_company/api/auth/service.py b/src/ai_company/api/auth/service.py new file mode 100644 index 0000000000..f6539fa36b --- /dev/null +++ b/src/ai_company/api/auth/service.py @@ -0,0 +1,150 @@ +"""Authentication service — password hashing, JWT ops, API key hashing.""" + +import hashlib +import hmac +import secrets +from datetime import UTC, datetime, timedelta +from typing import TYPE_CHECKING, Any + +import argon2 +import jwt + +from ai_company.api.auth.models import User # noqa: TC001 +from ai_company.observability import get_logger +from ai_company.observability.events.api import API_AUTH_TOKEN_ISSUED + +if TYPE_CHECKING: + from ai_company.api.auth.config import AuthConfig + +logger = get_logger(__name__) + +_hasher = argon2.PasswordHasher( + time_cost=3, + memory_cost=65536, + parallelism=4, + hash_len=32, + salt_len=16, +) + + +class AuthService: + """Stateless authentication operations. + + Args: + config: Authentication configuration (carries JWT secret). + """ + + def __init__(self, config: AuthConfig) -> None: + self._config = config + + def hash_password(self, password: str) -> str: + """Hash a password with Argon2id. + + Args: + password: Plaintext password. + + Returns: + Argon2id hash string. + """ + return _hasher.hash(password) + + def verify_password(self, password: str, password_hash: str) -> bool: + """Verify a password against an Argon2id hash. + + Args: + password: Plaintext password to check. + password_hash: Stored Argon2id hash. + + Returns: + ``True`` if the password matches. + """ + try: + return _hasher.verify(password_hash, password) + except argon2.exceptions.VerifyMismatchError: + return False + except argon2.exceptions.VerificationError: + return False + + def create_token(self, user: User) -> tuple[str, int]: + """Create a JWT for the given user. + + Args: + user: Authenticated user. + + Returns: + Tuple of (encoded JWT string, expiry seconds). + """ + now = datetime.now(UTC) + expiry_seconds = self._config.jwt_expiry_minutes * 60 + payload: dict[str, Any] = { + "sub": user.id, + "username": user.username, + "role": user.role.value, + "must_change_password": user.must_change_password, + "iat": now, + "exp": now + timedelta(seconds=expiry_seconds), + } + token = jwt.encode( + payload, + self._config.jwt_secret, + algorithm=self._config.jwt_algorithm, + ) + logger.info( + API_AUTH_TOKEN_ISSUED, + user_id=user.id, + username=user.username, + ) + return token, expiry_seconds + + def decode_token(self, token: str) -> dict[str, Any]: + """Decode and validate a JWT. + + Args: + token: Encoded JWT string. + + Returns: + Decoded claims dictionary. + + Raises: + jwt.InvalidTokenError: If the token is invalid or expired. + """ + return jwt.decode( + token, + self._config.jwt_secret, + algorithms=[self._config.jwt_algorithm], + ) + + @staticmethod + def hash_api_key(raw_key: str) -> str: + """Compute SHA-256 hex digest of a raw API key. + + Args: + raw_key: The plaintext API key. + + Returns: + Lowercase hex digest. + """ + return hashlib.sha256(raw_key.encode()).hexdigest() + + @staticmethod + def verify_api_key(raw_key: str, stored_hash: str) -> bool: + """Constant-time comparison of API key hash. + + Args: + raw_key: Plaintext API key from request. + stored_hash: SHA-256 hex digest from storage. + + Returns: + ``True`` if the key matches. + """ + computed = hashlib.sha256(raw_key.encode()).hexdigest() + return hmac.compare_digest(computed, stored_hash) + + @staticmethod + def generate_api_key() -> str: + """Generate a cryptographically secure API key. + + Returns: + URL-safe base64 string (43 chars). + """ + return secrets.token_urlsafe(32) diff --git a/src/ai_company/api/config.py b/src/ai_company/api/config.py index d3210999df..9063016220 100644 --- a/src/ai_company/api/config.py +++ b/src/ai_company/api/config.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator +from ai_company.api.auth.config import AuthConfig from ai_company.core.types import NotBlankStr # noqa: TC001 @@ -34,7 +35,7 @@ class CorsConfig(BaseModel): description="HTTP methods permitted in cross-origin requests", ) allow_headers: tuple[str, ...] = Field( - default=("Content-Type", "Authorization", "X-Human-Role"), + default=("Content-Type", "Authorization"), description="Headers permitted in cross-origin requests", ) allow_credentials: bool = Field( @@ -167,6 +168,10 @@ class ApiConfig(BaseModel): default_factory=ServerConfig, description="Uvicorn server configuration", ) + auth: AuthConfig = Field( + default_factory=AuthConfig, + description="Authentication configuration", + ) api_prefix: NotBlankStr = Field( default="/api/v1", description="URL prefix for all API routes", diff --git a/src/ai_company/api/controllers/__init__.py b/src/ai_company/api/controllers/__init__.py index 8343349809..575f6ee8b4 100644 --- a/src/ai_company/api/controllers/__init__.py +++ b/src/ai_company/api/controllers/__init__.py @@ -2,6 +2,7 @@ from litestar import Controller +from ai_company.api.auth.controller import AuthController from ai_company.api.controllers.agents import AgentController from ai_company.api.controllers.analytics import AnalyticsController from ai_company.api.controllers.approvals import ApprovalsController @@ -33,6 +34,7 @@ ProviderController, ApprovalsController, AutonomyController, + AuthController, ) __all__ = [ @@ -41,6 +43,7 @@ "AnalyticsController", "ApprovalsController", "ArtifactController", + "AuthController", "AutonomyController", "BudgetController", "CompanyController", diff --git a/src/ai_company/api/controllers/approvals.py b/src/ai_company/api/controllers/approvals.py index 2056a6d0f0..76fc613154 100644 --- a/src/ai_company/api/controllers/approvals.py +++ b/src/ai_company/api/controllers/approvals.py @@ -236,8 +236,8 @@ async def approve( ) -> ApiResponse[ApprovalItem]: """Approve a pending approval item. - The ``decided_by`` field is populated from the - ``X-Human-Role`` header. + The ``decided_by`` field is populated from the authenticated + user's role. Args: state: Application state. @@ -272,7 +272,8 @@ async def approve( ) raise ConflictError(msg) - role = request.headers.get("x-human-role", "unknown") + auth_user = request.scope.get("user") + role = auth_user.role.value if auth_user is not None else "unknown" now = datetime.now(UTC) updated = item.model_copy( update={ @@ -319,8 +320,8 @@ async def reject( ) -> ApiResponse[ApprovalItem]: """Reject a pending approval item. - The ``decided_by`` field is populated from the - ``X-Human-Role`` header. + The ``decided_by`` field is populated from the authenticated + user's role. Args: state: Application state. @@ -355,7 +356,8 @@ async def reject( ) raise ConflictError(msg) - role = request.headers.get("x-human-role", "unknown") + auth_user = request.scope.get("user") + role = auth_user.role.value if auth_user is not None else "unknown" now = datetime.now(UTC) updated = item.model_copy( update={ diff --git a/src/ai_company/api/errors.py b/src/ai_company/api/errors.py index f903adff73..2e9c86dc42 100644 --- a/src/ai_company/api/errors.py +++ b/src/ai_company/api/errors.py @@ -56,6 +56,15 @@ def __init__(self, message: str | None = None) -> None: super().__init__(message, status_code=403) +class UnauthorizedError(ApiError): + """Raised when authentication is required or invalid (401).""" + + default_message: str = "Authentication required" + + def __init__(self, message: str | None = None) -> None: + super().__init__(message, status_code=401) + + class ServiceUnavailableError(ApiError): """Raised when a required service is not configured (503).""" diff --git a/src/ai_company/api/exception_handlers.py b/src/ai_company/api/exception_handlers.py index 5c73144c37..a41be9ef18 100644 --- a/src/ai_company/api/exception_handlers.py +++ b/src/ai_company/api/exception_handlers.py @@ -9,6 +9,7 @@ from litestar import Request, Response from litestar.exceptions import ( + NotAuthorizedException, PermissionDeniedException, ValidationException, ) @@ -149,10 +150,23 @@ def handle_validation_error( ) +def handle_not_authorized( + request: Request[Any, Any, Any], + exc: NotAuthorizedException, +) -> Response[ApiResponse[None]]: + """Map ``NotAuthorizedException`` to 401.""" + _log_error(request, exc, status=401) + return Response( + content=ApiResponse[None](error="Authentication required"), + status_code=401, + ) + + EXCEPTION_HANDLERS: dict[type[Exception], object] = { RecordNotFoundError: handle_record_not_found, DuplicateRecordError: handle_duplicate_record, PersistenceError: handle_persistence_error, + NotAuthorizedException: handle_not_authorized, PermissionDeniedException: handle_permission_denied, ValidationException: handle_validation_error, ApiError: handle_api_error, diff --git a/src/ai_company/api/guards.py b/src/ai_company/api/guards.py index 553e4e883d..9a46dc2a44 100644 --- a/src/ai_company/api/guards.py +++ b/src/ai_company/api/guards.py @@ -1,16 +1,7 @@ """Route guards for access control. -.. warning:: **Security Stub (M6)** - - These guards check the ``X-Human-Role`` header, which is - **self-asserted by the caller** — there is no signature - verification, session token, or JWT. This is intentional for - the M6 milestone scope. **Real authentication and authorization - (pre-shared API key, JWT, or OAuth) will be implemented in M7 - (issue scope: security & HR).** - - Until M7, the API should only be exposed on trusted networks or - behind a reverse proxy that enforces authentication. +Guards read the authenticated user identity from ``connection.user`` +(populated by the auth middleware) and check role-based permissions. """ from enum import StrEnum @@ -45,11 +36,14 @@ class HumanRole(StrEnum): _READ_ROLES: frozenset[HumanRole] = _WRITE_ROLES | frozenset({HumanRole.OBSERVER}) -def _get_role(connection: ASGIConnection) -> str | None: # type: ignore[type-arg] - """Extract the human role from the request header.""" - value = connection.headers.get("x-human-role") - if value is not None: - return value.strip().lower() +def _get_role(connection: ASGIConnection) -> HumanRole | None: # type: ignore[type-arg] + """Extract the human role from the authenticated user.""" + user = connection.scope.get("user") + if user is not None and hasattr(user, "role"): + try: + return HumanRole(user.role) + except ValueError: + return None return None @@ -59,7 +53,7 @@ def require_write_access( ) -> None: """Guard that allows only write-capable roles. - Checks the ``X-Human-Role`` header for ``ceo``, ``manager``, + Checks ``connection.user.role`` for ``ceo``, ``manager``, ``board_member``, or ``pair_programmer``. Args: @@ -86,7 +80,7 @@ def require_read_access( ) -> None: """Guard that allows all recognised roles. - Checks the ``X-Human-Role`` header for any valid role + Checks ``connection.user.role`` for any valid role including ``observer``. Args: diff --git a/src/ai_company/api/state.py b/src/ai_company/api/state.py index a7583548ad..586381a1d2 100644 --- a/src/ai_company/api/state.py +++ b/src/ai_company/api/state.py @@ -6,6 +6,7 @@ """ from ai_company.api.approval_store import ApprovalStore # noqa: TC001 +from ai_company.api.auth.service import AuthService # noqa: TC001 from ai_company.api.errors import ServiceUnavailableError from ai_company.budget.tracker import CostTracker # noqa: TC001 from ai_company.communication.bus_protocol import MessageBus # noqa: TC001 @@ -33,6 +34,7 @@ class AppState: """ __slots__ = ( + "_auth_service", "_cost_tracker", "_message_bus", "_persistence", @@ -49,6 +51,7 @@ def __init__( # noqa: PLR0913 persistence: PersistenceBackend | None = None, message_bus: MessageBus | None = None, cost_tracker: CostTracker | None = None, + auth_service: AuthService | None = None, startup_time: float = 0.0, ) -> None: self.config = config @@ -56,6 +59,7 @@ def __init__( # noqa: PLR0913 self._persistence = persistence self._message_bus = message_bus self._cost_tracker = cost_tracker + self._auth_service = auth_service self.startup_time = startup_time @property @@ -93,3 +97,15 @@ def cost_tracker(self) -> CostTracker: msg = "Cost tracker not configured" raise ServiceUnavailableError(msg) return self._cost_tracker + + @property + def auth_service(self) -> AuthService: + """Return auth service or raise 503.""" + if self._auth_service is None: + logger.warning( + API_SERVICE_UNAVAILABLE, + service="auth_service", + ) + msg = "Auth service not configured" + raise ServiceUnavailableError(msg) + return self._auth_service diff --git a/src/ai_company/observability/events/api.py b/src/ai_company/observability/events/api.py index 354955e5e6..8890cd7538 100644 --- a/src/ai_company/observability/events/api.py +++ b/src/ai_company/observability/events/api.py @@ -30,3 +30,8 @@ API_WS_TRANSPORT_ERROR: Final[str] = "api.ws.transport_error" API_WS_SEND_FAILED: Final[str] = "api.ws.send_failed" API_SERVICE_UNAVAILABLE: Final[str] = "api.service.unavailable" +API_AUTH_SUCCESS: Final[str] = "api.auth.success" +API_AUTH_FAILED: Final[str] = "api.auth.failed" +API_AUTH_TOKEN_ISSUED: Final[str] = "api.auth.token_issued" # noqa: S105 +API_AUTH_SETUP_COMPLETE: Final[str] = "api.auth.setup_complete" +API_AUTH_PASSWORD_CHANGED: Final[str] = "api.auth.password_changed" # noqa: S105 diff --git a/src/ai_company/observability/events/persistence.py b/src/ai_company/observability/events/persistence.py index 9c95516bd3..48d271e89d 100644 --- a/src/ai_company/observability/events/persistence.py +++ b/src/ai_company/observability/events/persistence.py @@ -123,3 +123,23 @@ PERSISTENCE_AUDIT_ENTRY_DESERIALIZE_FAILED: Final[str] = ( "persistence.audit_entry.deserialize_failed" ) + +PERSISTENCE_USER_SAVED: Final[str] = "persistence.user.saved" +PERSISTENCE_USER_SAVE_FAILED: Final[str] = "persistence.user.save_failed" +PERSISTENCE_USER_FETCHED: Final[str] = "persistence.user.fetched" +PERSISTENCE_USER_FETCH_FAILED: Final[str] = "persistence.user.fetch_failed" +PERSISTENCE_USER_LISTED: Final[str] = "persistence.user.listed" +PERSISTENCE_USER_LIST_FAILED: Final[str] = "persistence.user.list_failed" +PERSISTENCE_USER_COUNTED: Final[str] = "persistence.user.counted" +PERSISTENCE_USER_COUNT_FAILED: Final[str] = "persistence.user.count_failed" +PERSISTENCE_USER_DELETED: Final[str] = "persistence.user.deleted" +PERSISTENCE_USER_DELETE_FAILED: Final[str] = "persistence.user.delete_failed" + +PERSISTENCE_API_KEY_SAVED: Final[str] = "persistence.api_key.saved" +PERSISTENCE_API_KEY_SAVE_FAILED: Final[str] = "persistence.api_key.save_failed" +PERSISTENCE_API_KEY_FETCHED: Final[str] = "persistence.api_key.fetched" +PERSISTENCE_API_KEY_FETCH_FAILED: Final[str] = "persistence.api_key.fetch_failed" +PERSISTENCE_API_KEY_LISTED: Final[str] = "persistence.api_key.listed" +PERSISTENCE_API_KEY_LIST_FAILED: Final[str] = "persistence.api_key.list_failed" +PERSISTENCE_API_KEY_DELETED: Final[str] = "persistence.api_key.deleted" +PERSISTENCE_API_KEY_DELETE_FAILED: Final[str] = "persistence.api_key.delete_failed" diff --git a/src/ai_company/persistence/protocol.py b/src/ai_company/persistence/protocol.py index 26871358ed..bd4d01336f 100644 --- a/src/ai_company/persistence/protocol.py +++ b/src/ai_company/persistence/protocol.py @@ -13,11 +13,13 @@ TaskMetricRepository, # noqa: TC001 ) from ai_company.persistence.repositories import ( + ApiKeyRepository, # noqa: TC001 AuditRepository, # noqa: TC001 CostRecordRepository, # noqa: TC001 MessageRepository, # noqa: TC001 ParkedContextRepository, # noqa: TC001 TaskRepository, # noqa: TC001 + UserRepository, # noqa: TC001 ) @@ -123,3 +125,41 @@ def parked_contexts(self) -> ParkedContextRepository: def audit_entries(self) -> AuditRepository: """Repository for AuditEntry persistence.""" ... + + @property + def users(self) -> UserRepository: + """Repository for User persistence.""" + ... + + @property + def api_keys(self) -> ApiKeyRepository: + """Repository for ApiKey persistence.""" + ... + + async def get_setting(self, key: str) -> str | None: + """Retrieve a setting value by key. + + Args: + key: Setting key. + + Returns: + The setting value, or ``None`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def set_setting(self, key: str, value: str) -> None: + """Store a setting value. + + Upserts — creates or updates the key. + + Args: + key: Setting key. + value: Setting value. + + Raises: + PersistenceError: If the operation fails. + """ + ... diff --git a/src/ai_company/persistence/repositories.py b/src/ai_company/persistence/repositories.py index 73fd914fd1..0ec74cf34b 100644 --- a/src/ai_company/persistence/repositories.py +++ b/src/ai_company/persistence/repositories.py @@ -8,6 +8,7 @@ from pydantic import AwareDatetime # noqa: TC002 +from ai_company.api.auth.models import ApiKey, User # noqa: TC001 from ai_company.budget.cost_record import CostRecord # noqa: TC001 from ai_company.communication.message import Message # noqa: TC001 from ai_company.core.enums import ApprovalRiskLevel, TaskStatus # noqa: TC001 @@ -22,6 +23,7 @@ from ai_company.security.timeout.parked_context import ParkedContext # noqa: TC001 __all__ = [ + "ApiKeyRepository", "AuditRepository", "CollaborationMetricRepository", "CostRecordRepository", @@ -30,6 +32,7 @@ "ParkedContextRepository", "TaskMetricRepository", "TaskRepository", + "UserRepository", ] @@ -318,3 +321,155 @@ async def query( # noqa: PLR0913 *until* is earlier than *since*. """ ... + + +@runtime_checkable +class UserRepository(Protocol): + """CRUD interface for User persistence.""" + + async def save(self, user: User) -> None: + """Persist a user (insert or update). + + Args: + user: The user to persist. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get(self, user_id: str) -> User | None: + """Retrieve a user by ID. + + Args: + user_id: The user identifier. + + Returns: + The user, or ``None`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get_by_username(self, username: str) -> User | None: + """Retrieve a user by username. + + Args: + username: The login username. + + Returns: + The user, or ``None`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def list_users(self) -> tuple[User, ...]: + """List all users. + + Returns: + All users as a tuple. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def count(self) -> int: + """Count the number of users. + + Returns: + Total user count. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def delete(self, user_id: str) -> bool: + """Delete a user by ID. + + Args: + user_id: The user identifier. + + Returns: + ``True`` if deleted, ``False`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + +@runtime_checkable +class ApiKeyRepository(Protocol): + """CRUD interface for API key persistence.""" + + async def save(self, key: ApiKey) -> None: + """Persist an API key. + + Args: + key: The API key to persist. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get(self, key_id: str) -> ApiKey | None: + """Retrieve an API key by ID. + + Args: + key_id: The key identifier. + + Returns: + The API key, or ``None`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get_by_hash(self, key_hash: str) -> ApiKey | None: + """Retrieve an API key by its hash. + + Args: + key_hash: SHA-256 hex digest. + + Returns: + The API key, or ``None`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def list_by_user(self, user_id: str) -> tuple[ApiKey, ...]: + """List API keys belonging to a user. + + Args: + user_id: The owner user ID. + + Returns: + API keys for the user. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def delete(self, key_id: str) -> bool: + """Delete an API key by ID. + + Args: + key_id: The key identifier. + + Returns: + ``True`` if deleted, ``False`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... diff --git a/src/ai_company/persistence/sqlite/backend.py b/src/ai_company/persistence/sqlite/backend.py index 282ed4e7f6..b1a5b1fa08 100644 --- a/src/ai_company/persistence/sqlite/backend.py +++ b/src/ai_company/persistence/sqlite/backend.py @@ -20,7 +20,10 @@ PERSISTENCE_BACKEND_NOT_CONNECTED, PERSISTENCE_BACKEND_WAL_MODE_FAILED, ) -from ai_company.persistence.errors import PersistenceConnectionError +from ai_company.persistence.errors import ( + PersistenceConnectionError, + QueryError, +) from ai_company.persistence.sqlite.audit_repository import ( SQLiteAuditRepository, ) @@ -38,6 +41,10 @@ SQLiteMessageRepository, SQLiteTaskRepository, ) +from ai_company.persistence.sqlite.user_repo import ( + SQLiteApiKeyRepository, + SQLiteUserRepository, +) if TYPE_CHECKING: from ai_company.persistence.config import SQLiteConfig @@ -68,6 +75,8 @@ def __init__(self, config: SQLiteConfig) -> None: self._collaboration_metrics: SQLiteCollaborationMetricRepository | None = None self._parked_contexts: SQLiteParkedContextRepository | None = None self._audit_entries: SQLiteAuditRepository | None = None + self._users: SQLiteUserRepository | None = None + self._api_keys: SQLiteApiKeyRepository | None = None def _clear_state(self) -> None: """Reset connection and repository references to ``None``.""" @@ -80,6 +89,8 @@ def _clear_state(self) -> None: self._collaboration_metrics = None self._parked_contexts = None self._audit_entries = None + self._users = None + self._api_keys = None async def connect(self) -> None: """Open the SQLite database and configure WAL mode.""" @@ -139,6 +150,8 @@ def _create_repositories(self) -> None: self._collaboration_metrics = SQLiteCollaborationMetricRepository(self._db) self._parked_contexts = SQLiteParkedContextRepository(self._db) self._audit_entries = SQLiteAuditRepository(self._db) + self._users = SQLiteUserRepository(self._db) + self._api_keys = SQLiteApiKeyRepository(self._db) async def _cleanup_failed_connect(self, exc: sqlite3.Error | OSError) -> None: """Log failure, close partial connection, and raise. @@ -320,3 +333,71 @@ def audit_entries(self) -> SQLiteAuditRepository: PersistenceConnectionError: If not connected. """ return self._require_connected(self._audit_entries, "audit_entries") + + @property + def users(self) -> SQLiteUserRepository: + """Repository for User persistence. + + Raises: + PersistenceConnectionError: If not connected. + """ + return self._require_connected(self._users, "users") + + @property + def api_keys(self) -> SQLiteApiKeyRepository: + """Repository for ApiKey persistence. + + Raises: + PersistenceConnectionError: If not connected. + """ + return self._require_connected(self._api_keys, "api_keys") + + async def get_setting(self, key: str) -> str | None: + """Retrieve a setting value by key. + + Raises: + PersistenceConnectionError: If not connected. + """ + if self._db is None: + msg = "Not connected — call connect() before accessing settings" + logger.warning(PERSISTENCE_BACKEND_NOT_CONNECTED, error=msg) + raise PersistenceConnectionError(msg) + try: + cursor = await self._db.execute( + "SELECT value FROM settings WHERE key = ?", (key,) + ) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to get setting {key!r}" + logger.exception( + PERSISTENCE_BACKEND_NOT_CONNECTED, + error=str(exc), + ) + raise QueryError(msg) from exc + return str(row[0]) if row else None + + async def set_setting(self, key: str, value: str) -> None: + """Store a setting value (upsert). + + Raises: + PersistenceConnectionError: If not connected. + """ + if self._db is None: + msg = "Not connected — call connect() before accessing settings" + logger.warning(PERSISTENCE_BACKEND_NOT_CONNECTED, error=msg) + raise PersistenceConnectionError(msg) + try: + await self._db.execute( + """\ +INSERT INTO settings (key, value) VALUES (?, ?) +ON CONFLICT(key) DO UPDATE SET value=excluded.value""", + (key, value), + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to set setting {key!r}" + logger.exception( + PERSISTENCE_BACKEND_NOT_CONNECTED, + error=str(exc), + ) + raise QueryError(msg) from exc diff --git a/src/ai_company/persistence/sqlite/migrations.py b/src/ai_company/persistence/sqlite/migrations.py index cc306d8086..de331a43f7 100644 --- a/src/ai_company/persistence/sqlite/migrations.py +++ b/src/ai_company/persistence/sqlite/migrations.py @@ -23,7 +23,7 @@ logger = get_logger(__name__) # Current schema version — bump when adding new migrations. -SCHEMA_VERSION = 4 +SCHEMA_VERSION = 5 _V1_STATEMENTS: Sequence[str] = ( # ── Tasks ───────────────────────────────────────────── @@ -188,6 +188,41 @@ "CREATE INDEX IF NOT EXISTS idx_ae_risk_level ON audit_entries(risk_level)", ) +_V5_STATEMENTS: Sequence[str] = ( + # ── Settings (key-value store) ───────────────────────── + """\ +CREATE TABLE IF NOT EXISTS settings ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL +)""", + # ── Users ────────────────────────────────────────────── + """\ +CREATE TABLE IF NOT EXISTS users ( + id TEXT PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + password_hash TEXT NOT NULL, + role TEXT NOT NULL, + must_change_password INTEGER NOT NULL DEFAULT 1, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL +)""", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username)", + # ── API keys ─────────────────────────────────────────── + """\ +CREATE TABLE IF NOT EXISTS api_keys ( + id TEXT PRIMARY KEY, + key_hash TEXT NOT NULL UNIQUE, + name TEXT NOT NULL, + role TEXT NOT NULL, + user_id TEXT NOT NULL REFERENCES users(id), + created_at TEXT NOT NULL, + expires_at TEXT, + revoked INTEGER NOT NULL DEFAULT 0 +)""", + "CREATE INDEX IF NOT EXISTS idx_api_keys_user_id ON api_keys(user_id)", + "CREATE UNIQUE INDEX IF NOT EXISTS idx_api_keys_hash ON api_keys(key_hash)", +) + _MigrateFn = Callable[[aiosqlite.Connection], Coroutine[Any, Any, None]] @@ -244,6 +279,12 @@ async def _apply_v4(db: aiosqlite.Connection) -> None: await db.execute(stmt) +async def _apply_v5(db: aiosqlite.Connection) -> None: + """Apply schema v5: settings, users, api_keys.""" + for stmt in _V5_STATEMENTS: + await db.execute(stmt) + + # Ordered list of (target_version, migration_function) pairs. Each migration # is applied when the current schema version is below its target version. _MIGRATIONS: list[tuple[int, _MigrateFn]] = [ @@ -251,6 +292,7 @@ async def _apply_v4(db: aiosqlite.Connection) -> None: (2, _apply_v2), (3, _apply_v3), (4, _apply_v4), + (5, _apply_v5), ] diff --git a/src/ai_company/persistence/sqlite/user_repo.py b/src/ai_company/persistence/sqlite/user_repo.py new file mode 100644 index 0000000000..cd8ae7e73e --- /dev/null +++ b/src/ai_company/persistence/sqlite/user_repo.py @@ -0,0 +1,314 @@ +"""SQLite repository implementations for User and ApiKey.""" + +import sqlite3 +from datetime import datetime + +import aiosqlite +from pydantic import ValidationError + +from ai_company.api.auth.models import ApiKey, User +from ai_company.api.guards import HumanRole +from ai_company.observability import get_logger +from ai_company.observability.events.persistence import ( + PERSISTENCE_API_KEY_DELETE_FAILED, + PERSISTENCE_API_KEY_DELETED, + PERSISTENCE_API_KEY_FETCH_FAILED, + PERSISTENCE_API_KEY_FETCHED, + PERSISTENCE_API_KEY_LIST_FAILED, + PERSISTENCE_API_KEY_LISTED, + PERSISTENCE_API_KEY_SAVE_FAILED, + PERSISTENCE_API_KEY_SAVED, + PERSISTENCE_USER_COUNT_FAILED, + PERSISTENCE_USER_COUNTED, + PERSISTENCE_USER_DELETE_FAILED, + PERSISTENCE_USER_DELETED, + PERSISTENCE_USER_FETCH_FAILED, + PERSISTENCE_USER_FETCHED, + PERSISTENCE_USER_LIST_FAILED, + PERSISTENCE_USER_LISTED, + PERSISTENCE_USER_SAVE_FAILED, + PERSISTENCE_USER_SAVED, +) +from ai_company.persistence.errors import QueryError + +logger = get_logger(__name__) + + +def _row_to_user(row: aiosqlite.Row) -> User: + """Reconstruct a User from a database row.""" + data = dict(row) + data["must_change_password"] = bool(data["must_change_password"]) + data["role"] = HumanRole(data["role"]) + data["created_at"] = datetime.fromisoformat(data["created_at"]) + data["updated_at"] = datetime.fromisoformat(data["updated_at"]) + return User.model_validate(data) + + +def _row_to_api_key(row: aiosqlite.Row) -> ApiKey: + """Reconstruct an ApiKey from a database row.""" + data = dict(row) + data["revoked"] = bool(data["revoked"]) + data["role"] = HumanRole(data["role"]) + data["created_at"] = datetime.fromisoformat(data["created_at"]) + if data["expires_at"] is not None: + data["expires_at"] = datetime.fromisoformat(data["expires_at"]) + return ApiKey.model_validate(data) + + +class SQLiteUserRepository: + """SQLite implementation of the UserRepository protocol. + + Args: + db: An open aiosqlite connection. + """ + + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + async def save(self, user: User) -> None: + """Persist a user (upsert semantics).""" + try: + await self._db.execute( + """\ +INSERT INTO users (id, username, password_hash, role, + must_change_password, created_at, updated_at) +VALUES (?, ?, ?, ?, ?, ?, ?) +ON CONFLICT(id) DO UPDATE SET + username=excluded.username, + password_hash=excluded.password_hash, + role=excluded.role, + must_change_password=excluded.must_change_password, + updated_at=excluded.updated_at""", + ( + user.id, + user.username, + user.password_hash, + user.role.value, + int(user.must_change_password), + user.created_at.isoformat(), + user.updated_at.isoformat(), + ), + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to save user {user.id!r}" + logger.exception( + PERSISTENCE_USER_SAVE_FAILED, + user_id=user.id, + error=str(exc), + ) + raise QueryError(msg) from exc + logger.debug(PERSISTENCE_USER_SAVED, user_id=user.id) + + async def get(self, user_id: str) -> User | None: + """Retrieve a user by ID.""" + try: + cursor = await self._db.execute( + "SELECT * FROM users WHERE id = ?", (user_id,) + ) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to fetch user {user_id!r}" + logger.exception( + PERSISTENCE_USER_FETCH_FAILED, + user_id=user_id, + error=str(exc), + ) + raise QueryError(msg) from exc + if row is None: + logger.debug(PERSISTENCE_USER_FETCHED, user_id=user_id, found=False) + return None + logger.debug(PERSISTENCE_USER_FETCHED, user_id=user_id, found=True) + return _row_to_user(row) + + async def get_by_username(self, username: str) -> User | None: + """Retrieve a user by username.""" + try: + cursor = await self._db.execute( + "SELECT * FROM users WHERE username = ?", (username,) + ) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to fetch user by username {username!r}" + logger.exception( + PERSISTENCE_USER_FETCH_FAILED, + username=username, + error=str(exc), + ) + raise QueryError(msg) from exc + if row is None: + return None + return _row_to_user(row) + + async def list_users(self) -> tuple[User, ...]: + """List all users.""" + try: + cursor = await self._db.execute("SELECT * FROM users ORDER BY created_at") + rows = await cursor.fetchall() + except (sqlite3.Error, aiosqlite.Error, ValidationError) as exc: + msg = "Failed to list users" + logger.exception(PERSISTENCE_USER_LIST_FAILED, error=str(exc)) + raise QueryError(msg) from exc + users = tuple(_row_to_user(row) for row in rows) + logger.debug(PERSISTENCE_USER_LISTED, count=len(users)) + return users + + async def count(self) -> int: + """Count the number of users.""" + try: + cursor = await self._db.execute("SELECT COUNT(*) FROM users") + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = "Failed to count users" + logger.exception(PERSISTENCE_USER_COUNT_FAILED, error=str(exc)) + raise QueryError(msg) from exc + result = int(row[0]) if row else 0 + logger.debug(PERSISTENCE_USER_COUNTED, count=result) + return result + + async def delete(self, user_id: str) -> bool: + """Delete a user by ID.""" + try: + cursor = await self._db.execute( + "DELETE FROM users WHERE id = ?", (user_id,) + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to delete user {user_id!r}" + logger.exception( + PERSISTENCE_USER_DELETE_FAILED, + user_id=user_id, + error=str(exc), + ) + raise QueryError(msg) from exc + deleted = cursor.rowcount > 0 + logger.debug(PERSISTENCE_USER_DELETED, user_id=user_id, deleted=deleted) + return deleted + + +class SQLiteApiKeyRepository: + """SQLite implementation of the ApiKeyRepository protocol. + + Args: + db: An open aiosqlite connection. + """ + + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + async def save(self, key: ApiKey) -> None: + """Persist an API key (upsert semantics).""" + try: + await self._db.execute( + """\ +INSERT INTO api_keys (id, key_hash, name, role, user_id, + created_at, expires_at, revoked) +VALUES (?, ?, ?, ?, ?, ?, ?, ?) +ON CONFLICT(id) DO UPDATE SET + key_hash=excluded.key_hash, + name=excluded.name, + role=excluded.role, + user_id=excluded.user_id, + expires_at=excluded.expires_at, + revoked=excluded.revoked""", + ( + key.id, + key.key_hash, + key.name, + key.role.value, + key.user_id, + key.created_at.isoformat(), + key.expires_at.isoformat() if key.expires_at else None, + int(key.revoked), + ), + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to save API key {key.id!r}" + logger.exception( + PERSISTENCE_API_KEY_SAVE_FAILED, + key_id=key.id, + error=str(exc), + ) + raise QueryError(msg) from exc + logger.debug(PERSISTENCE_API_KEY_SAVED, key_id=key.id) + + async def get(self, key_id: str) -> ApiKey | None: + """Retrieve an API key by ID.""" + try: + cursor = await self._db.execute( + "SELECT * FROM api_keys WHERE id = ?", (key_id,) + ) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to fetch API key {key_id!r}" + logger.exception( + PERSISTENCE_API_KEY_FETCH_FAILED, + key_id=key_id, + error=str(exc), + ) + raise QueryError(msg) from exc + if row is None: + logger.debug(PERSISTENCE_API_KEY_FETCHED, key_id=key_id, found=False) + return None + logger.debug(PERSISTENCE_API_KEY_FETCHED, key_id=key_id, found=True) + return _row_to_api_key(row) + + async def get_by_hash(self, key_hash: str) -> ApiKey | None: + """Retrieve an API key by its hash.""" + try: + cursor = await self._db.execute( + "SELECT * FROM api_keys WHERE key_hash = ?", + (key_hash,), + ) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = "Failed to fetch API key by hash" + logger.exception(PERSISTENCE_API_KEY_FETCH_FAILED, error=str(exc)) + raise QueryError(msg) from exc + if row is None: + return None + return _row_to_api_key(row) + + async def list_by_user(self, user_id: str) -> tuple[ApiKey, ...]: + """List API keys belonging to a user.""" + try: + cursor = await self._db.execute( + "SELECT * FROM api_keys WHERE user_id = ? ORDER BY created_at", + (user_id,), + ) + rows = await cursor.fetchall() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to list API keys for user {user_id!r}" + logger.exception( + PERSISTENCE_API_KEY_LIST_FAILED, + user_id=user_id, + error=str(exc), + ) + raise QueryError(msg) from exc + keys = tuple(_row_to_api_key(row) for row in rows) + logger.debug( + PERSISTENCE_API_KEY_LISTED, + user_id=user_id, + count=len(keys), + ) + return keys + + async def delete(self, key_id: str) -> bool: + """Delete an API key by ID.""" + try: + cursor = await self._db.execute( + "DELETE FROM api_keys WHERE id = ?", (key_id,) + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to delete API key {key_id!r}" + logger.exception( + PERSISTENCE_API_KEY_DELETE_FAILED, + key_id=key_id, + error=str(exc), + ) + raise QueryError(msg) from exc + deleted = cursor.rowcount > 0 + logger.debug(PERSISTENCE_API_KEY_DELETED, key_id=key_id, deleted=deleted) + return deleted diff --git a/tests/unit/api/auth/__init__.py b/tests/unit/api/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/api/auth/test_config.py b/tests/unit/api/auth/test_config.py new file mode 100644 index 0000000000..66eedf538c --- /dev/null +++ b/tests/unit/api/auth/test_config.py @@ -0,0 +1,51 @@ +"""Tests for AuthConfig.""" + +import pytest + +from ai_company.api.auth.config import AuthConfig + + +@pytest.mark.unit +class TestAuthConfig: + def test_default_values(self) -> None: + config = AuthConfig() + assert config.jwt_secret == "" + assert config.jwt_algorithm == "HS256" + assert config.jwt_expiry_minutes == 1440 + assert "^/api/v1/health$" in config.exclude_paths + assert "^/api/v1/auth/setup$" in config.exclude_paths + assert "^/api/v1/auth/login$" in config.exclude_paths + + def test_with_secret_sets_secret(self) -> None: + config = AuthConfig() + updated = config.with_secret( + "a-very-long-secret-that-is-at-least-32-characters" + ) + assert updated.jwt_secret == "a-very-long-secret-that-is-at-least-32-characters" + + def test_with_secret_too_short_raises(self) -> None: + config = AuthConfig() + with pytest.raises(ValueError, match="at least 32"): + config.with_secret("short") + + def test_frozen(self) -> None: + config = AuthConfig() + with pytest.raises(Exception): # noqa: B017, PT011 + config.jwt_secret = "new" # type: ignore[misc] + + def test_original_unchanged_after_with_secret(self) -> None: + config = AuthConfig() + config.with_secret("a-very-long-secret-that-is-at-least-32-characters") + assert config.jwt_secret == "" + + def test_custom_expiry(self) -> None: + config = AuthConfig(jwt_expiry_minutes=60) + assert config.jwt_expiry_minutes == 60 + + def test_expiry_min_bound(self) -> None: + with pytest.raises(Exception): # noqa: B017, PT011 + AuthConfig(jwt_expiry_minutes=0) + + def test_expiry_max_bound(self) -> None: + with pytest.raises(Exception): # noqa: B017, PT011 + AuthConfig(jwt_expiry_minutes=50000) diff --git a/tests/unit/api/auth/test_controller.py b/tests/unit/api/auth/test_controller.py new file mode 100644 index 0000000000..e918299f81 --- /dev/null +++ b/tests/unit/api/auth/test_controller.py @@ -0,0 +1,193 @@ +"""Tests for AuthController endpoints.""" + +from typing import Any + +import pytest +from litestar.testing import TestClient # noqa: TC002 + +from tests.unit.api.conftest import make_auth_headers + + +@pytest.fixture +def bare_client(test_client: TestClient[Any]) -> TestClient[Any]: + """Test client with no default Authorization header.""" + test_client.headers.pop("authorization", None) + return test_client + + +@pytest.mark.unit +class TestSetup: + def test_setup_creates_admin(self, bare_client: TestClient[Any]) -> None: + app_state = bare_client.app.state["app_state"] + app_state.persistence._users._users.clear() + + response = bare_client.post( + "/api/v1/auth/setup", + json={ + "username": "newadmin", + "password": "super-secure-password-12", + }, + ) + assert response.status_code == 201 + data = response.json()["data"] + assert "token" in data + assert data["must_change_password"] is True + assert data["expires_in"] > 0 + + def test_setup_409_when_users_exist(self, bare_client: TestClient[Any]) -> None: + # Re-seed a user so the check fails + import asyncio + import uuid + from datetime import UTC, datetime + + from ai_company.api.auth.models import User + from ai_company.api.auth.service import AuthService # noqa: TC001 + from ai_company.api.guards import HumanRole + + app_state = bare_client.app.state["app_state"] + svc: AuthService = app_state.auth_service + now = datetime.now(UTC) + user = User( + id=str(uuid.uuid4()), + username="existing", + password_hash=svc.hash_password("test-password-12chars"), + role=HumanRole.CEO, + must_change_password=False, + created_at=now, + updated_at=now, + ) + loop = asyncio.get_event_loop() + loop.run_until_complete(app_state.persistence.users.save(user)) + + response = bare_client.post( + "/api/v1/auth/setup", + json={ + "username": "admin2", + "password": "super-secure-password-12", + }, + ) + assert response.status_code == 409 + + def test_setup_short_password_rejected(self, bare_client: TestClient[Any]) -> None: + app_state = bare_client.app.state["app_state"] + app_state.persistence._users._users.clear() + + response = bare_client.post( + "/api/v1/auth/setup", + json={"username": "admin", "password": "short"}, + ) + assert response.status_code == 422 + + +@pytest.mark.unit +class TestLogin: + def test_login_valid_credentials(self, bare_client: TestClient[Any]) -> None: + app_state = bare_client.app.state["app_state"] + app_state.persistence._users._users.clear() + + bare_client.post( + "/api/v1/auth/setup", + json={ + "username": "loginuser", + "password": "super-secure-password-12", + }, + ) + + response = bare_client.post( + "/api/v1/auth/login", + json={ + "username": "loginuser", + "password": "super-secure-password-12", + }, + ) + assert response.status_code == 200 + data = response.json()["data"] + assert "token" in data + assert data["expires_in"] > 0 + + def test_login_wrong_password(self, bare_client: TestClient[Any]) -> None: + response = bare_client.post( + "/api/v1/auth/login", + json={ + "username": "test-ceo", + "password": "wrong-password-12345", + }, + ) + assert response.status_code == 401 + + def test_login_nonexistent_user(self, bare_client: TestClient[Any]) -> None: + response = bare_client.post( + "/api/v1/auth/login", + json={ + "username": "nonexistent", + "password": "any-password-12345", + }, + ) + assert response.status_code == 401 + + +@pytest.mark.unit +class TestChangePassword: + def test_change_password_success(self, bare_client: TestClient[Any]) -> None: + app_state = bare_client.app.state["app_state"] + app_state.persistence._users._users.clear() + + setup_resp = bare_client.post( + "/api/v1/auth/setup", + json={ + "username": "changepw", + "password": "old-password-12chars", + }, + ) + token = setup_resp.json()["data"]["token"] + + response = bare_client.post( + "/api/v1/auth/change-password", + json={ + "current_password": "old-password-12chars", + "new_password": "new-password-12chars", + }, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 200 + data = response.json()["data"] + assert data["must_change_password"] is False + + def test_change_password_wrong_current(self, test_client: TestClient[Any]) -> None: + response = test_client.post( + "/api/v1/auth/change-password", + json={ + "current_password": "wrong-current-pw-12", + "new_password": "new-password-12chars", + }, + headers=make_auth_headers("ceo"), + ) + assert response.status_code == 401 + + def test_change_password_requires_auth(self, bare_client: TestClient[Any]) -> None: + response = bare_client.post( + "/api/v1/auth/change-password", + json={ + "current_password": "old-password-12chars", + "new_password": "new-password-12chars", + }, + ) + assert response.status_code == 401 + + +@pytest.mark.unit +class TestMe: + def test_me_returns_user_info(self, test_client: TestClient[Any]) -> None: + response = test_client.get( + "/api/v1/auth/me", + headers=make_auth_headers("ceo"), + ) + assert response.status_code == 200 + data = response.json()["data"] + assert data["username"] == "test-ceo" + assert data["role"] == "ceo" + assert data["must_change_password"] is False + + def test_me_requires_auth(self, bare_client: TestClient[Any]) -> None: + response = bare_client.get("/api/v1/auth/me") + assert response.status_code == 401 diff --git a/tests/unit/api/auth/test_middleware.py b/tests/unit/api/auth/test_middleware.py new file mode 100644 index 0000000000..203ae80974 --- /dev/null +++ b/tests/unit/api/auth/test_middleware.py @@ -0,0 +1,251 @@ +"""Tests for ApiAuthMiddleware.""" + +from datetime import UTC, datetime + +import pytest +from litestar import Litestar, get +from litestar.testing import TestClient + +from ai_company.api.auth.config import AuthConfig +from ai_company.api.auth.middleware import create_auth_middleware_class +from ai_company.api.auth.models import ApiKey, User +from ai_company.api.auth.service import AuthService +from ai_company.api.guards import HumanRole +from tests.unit.api.conftest import FakePersistenceBackend + +_SECRET = "test-secret-that-is-at-least-32-characters-long" + + +def _make_auth_service() -> AuthService: + return AuthService(AuthConfig(jwt_secret=_SECRET)) + + +def _make_user(svc: AuthService) -> User: + now = datetime.now(UTC) + return User( + id="mw-user-001", + username="mw-admin", + password_hash=svc.hash_password("test-password-12chars"), + role=HumanRole.CEO, + must_change_password=False, + created_at=now, + updated_at=now, + ) + + +def _build_app( + *, + auth_service: AuthService, + persistence: FakePersistenceBackend, + exclude_paths: tuple[str, ...] = (), +) -> Litestar: + """Build a minimal Litestar app with auth middleware.""" + auth_config = AuthConfig( + jwt_secret=_SECRET, + exclude_paths=exclude_paths, + ) + + @get("/protected") + async def protected_route() -> dict[str, str]: + return {"status": "ok"} + + @get("/public") + async def public_route() -> dict[str, str]: + return {"status": "public"} + + middleware_cls = create_auth_middleware_class(auth_config) + + class _FakeState: + def __init__(self) -> None: + self.auth_service = auth_service + self.persistence = persistence + + app = Litestar( + route_handlers=[protected_route, public_route], + middleware=[middleware_cls], + ) + app.state["app_state"] = _FakeState() + return app + + +@pytest.mark.unit +class TestAuthMiddlewareJWT: + async def test_valid_jwt_authenticates(self) -> None: + svc = _make_auth_service() + user = _make_user(svc) + persistence = FakePersistenceBackend() + await persistence.connect() + await persistence.users.save(user) + + app = _build_app(auth_service=svc, persistence=persistence) + token, _ = svc.create_token(user) + + with TestClient(app) as client: + resp = client.get( + "/protected", + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200 + + async def test_missing_header_returns_401(self) -> None: + svc = _make_auth_service() + persistence = FakePersistenceBackend() + await persistence.connect() + app = _build_app(auth_service=svc, persistence=persistence) + + with TestClient(app) as client: + resp = client.get("/protected") + assert resp.status_code == 401 + + async def test_invalid_scheme_returns_401(self) -> None: + svc = _make_auth_service() + persistence = FakePersistenceBackend() + await persistence.connect() + app = _build_app(auth_service=svc, persistence=persistence) + + with TestClient(app) as client: + resp = client.get( + "/protected", + headers={"Authorization": "Basic dXNlcjpwYXNz"}, + ) + assert resp.status_code == 401 + + async def test_invalid_jwt_returns_401(self) -> None: + svc = _make_auth_service() + persistence = FakePersistenceBackend() + await persistence.connect() + app = _build_app(auth_service=svc, persistence=persistence) + + with TestClient(app) as client: + resp = client.get( + "/protected", + headers={"Authorization": "Bearer bad.jwt.token"}, + ) + assert resp.status_code == 401 + + async def test_jwt_for_deleted_user_returns_401(self) -> None: + svc = _make_auth_service() + user = _make_user(svc) + persistence = FakePersistenceBackend() + await persistence.connect() + # Don't save user — simulate deleted user + token, _ = svc.create_token(user) + app = _build_app(auth_service=svc, persistence=persistence) + + with TestClient(app) as client: + resp = client.get( + "/protected", + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 401 + + +@pytest.mark.unit +class TestAuthMiddlewareApiKey: + async def test_valid_api_key_authenticates(self) -> None: + svc = _make_auth_service() + user = _make_user(svc) + persistence = FakePersistenceBackend() + await persistence.connect() + await persistence.users.save(user) + + raw_key = AuthService.generate_api_key() + key_hash = AuthService.hash_api_key(raw_key) + now = datetime.now(UTC) + api_key = ApiKey( + id="key-001", + key_hash=key_hash, + name="test-key", + role=HumanRole.CEO, + user_id=user.id, + created_at=now, + ) + await persistence.api_keys.save(api_key) + + app = _build_app(auth_service=svc, persistence=persistence) + + with TestClient(app) as client: + resp = client.get( + "/protected", + headers={"Authorization": f"Bearer {raw_key}"}, + ) + assert resp.status_code == 200 + + async def test_revoked_api_key_returns_401(self) -> None: + svc = _make_auth_service() + user = _make_user(svc) + persistence = FakePersistenceBackend() + await persistence.connect() + await persistence.users.save(user) + + raw_key = AuthService.generate_api_key() + key_hash = AuthService.hash_api_key(raw_key) + now = datetime.now(UTC) + api_key = ApiKey( + id="key-002", + key_hash=key_hash, + name="revoked-key", + role=HumanRole.CEO, + user_id=user.id, + created_at=now, + revoked=True, + ) + await persistence.api_keys.save(api_key) + + app = _build_app(auth_service=svc, persistence=persistence) + + with TestClient(app) as client: + resp = client.get( + "/protected", + headers={"Authorization": f"Bearer {raw_key}"}, + ) + assert resp.status_code == 401 + + async def test_expired_api_key_returns_401(self) -> None: + from datetime import timedelta + + svc = _make_auth_service() + user = _make_user(svc) + persistence = FakePersistenceBackend() + await persistence.connect() + await persistence.users.save(user) + + raw_key = AuthService.generate_api_key() + key_hash = AuthService.hash_api_key(raw_key) + now = datetime.now(UTC) + api_key = ApiKey( + id="key-003", + key_hash=key_hash, + name="expired-key", + role=HumanRole.CEO, + user_id=user.id, + created_at=now - timedelta(days=2), + expires_at=now - timedelta(days=1), + ) + await persistence.api_keys.save(api_key) + + app = _build_app(auth_service=svc, persistence=persistence) + + with TestClient(app) as client: + resp = client.get( + "/protected", + headers={"Authorization": f"Bearer {raw_key}"}, + ) + assert resp.status_code == 401 + + +@pytest.mark.unit +class TestAuthMiddlewareExcludePaths: + async def test_excluded_path_skips_auth(self) -> None: + svc = _make_auth_service() + persistence = FakePersistenceBackend() + await persistence.connect() + app = _build_app( + auth_service=svc, + persistence=persistence, + exclude_paths=("/public",), + ) + + with TestClient(app) as client: + resp = client.get("/public") + assert resp.status_code == 200 diff --git a/tests/unit/api/auth/test_service.py b/tests/unit/api/auth/test_service.py new file mode 100644 index 0000000000..e5992fc032 --- /dev/null +++ b/tests/unit/api/auth/test_service.py @@ -0,0 +1,145 @@ +"""Tests for AuthService.""" + +import pytest + +from ai_company.api.auth.config import AuthConfig +from ai_company.api.auth.models import User +from ai_company.api.auth.service import AuthService +from ai_company.api.guards import HumanRole + +_SECRET = "test-secret-that-is-at-least-32-characters-long" + + +def _make_service() -> AuthService: + return AuthService(AuthConfig(jwt_secret=_SECRET)) + + +def _make_user( + *, + role: HumanRole = HumanRole.CEO, + must_change_password: bool = False, +) -> User: + from datetime import UTC, datetime + + now = datetime.now(UTC) + svc = _make_service() + return User( + id="user-001", + username="admin", + password_hash=svc.hash_password("test-password-12chars"), + role=role, + must_change_password=must_change_password, + created_at=now, + updated_at=now, + ) + + +@pytest.mark.unit +class TestPasswordHashing: + def test_hash_and_verify(self) -> None: + svc = _make_service() + hashed = svc.hash_password("my-secret-password") + assert svc.verify_password("my-secret-password", hashed) + + def test_wrong_password_fails(self) -> None: + svc = _make_service() + hashed = svc.hash_password("correct-password") + assert not svc.verify_password("wrong-password", hashed) + + def test_hash_is_not_plaintext(self) -> None: + svc = _make_service() + hashed = svc.hash_password("my-secret-password") + assert hashed != "my-secret-password" + assert "$argon2" in hashed + + def test_different_hashes_for_same_password(self) -> None: + svc = _make_service() + h1 = svc.hash_password("same-password") + h2 = svc.hash_password("same-password") + # Different salts produce different hashes + assert h1 != h2 + + +@pytest.mark.unit +class TestJWT: + def test_create_and_decode(self) -> None: + svc = _make_service() + user = _make_user() + token, expires_in = svc.create_token(user) + assert isinstance(token, str) + assert expires_in == 1440 * 60 + + claims = svc.decode_token(token) + assert claims["sub"] == "user-001" + assert claims["username"] == "admin" + assert claims["role"] == "ceo" + + def test_expired_token_raises(self) -> None: + import jwt + + config = AuthConfig(jwt_secret=_SECRET, jwt_expiry_minutes=1) + svc = AuthService(config) + user = _make_user() + _token, _ = svc.create_token(user) + + # Manually create an expired token + from datetime import UTC, datetime, timedelta + + expired_payload = { + "sub": user.id, + "username": user.username, + "role": user.role.value, + "must_change_password": False, + "iat": datetime.now(UTC) - timedelta(hours=2), + "exp": datetime.now(UTC) - timedelta(hours=1), + } + expired_token = jwt.encode(expired_payload, _SECRET, algorithm="HS256") + with pytest.raises(jwt.ExpiredSignatureError): + svc.decode_token(expired_token) + + def test_invalid_signature_raises(self) -> None: + import jwt + + svc = _make_service() + user = _make_user() + token, _ = svc.create_token(user) + + # Decode with wrong secret + wrong_svc = AuthService( + AuthConfig(jwt_secret="wrong-secret-that-is-at-least-32-chars!!") + ) + with pytest.raises(jwt.InvalidSignatureError): + wrong_svc.decode_token(token) + + def test_must_change_password_in_claims(self) -> None: + svc = _make_service() + user = _make_user(must_change_password=True) + token, _ = svc.create_token(user) + claims = svc.decode_token(token) + assert claims["must_change_password"] is True + + +@pytest.mark.unit +class TestApiKeyHashing: + def test_hash_deterministic(self) -> None: + h1 = AuthService.hash_api_key("my-key") + h2 = AuthService.hash_api_key("my-key") + assert h1 == h2 + + def test_verify_correct_key(self) -> None: + key = AuthService.generate_api_key() + h = AuthService.hash_api_key(key) + assert AuthService.verify_api_key(key, h) + + def test_verify_wrong_key(self) -> None: + h = AuthService.hash_api_key("real-key") + assert not AuthService.verify_api_key("wrong-key", h) + + def test_generate_key_unique(self) -> None: + k1 = AuthService.generate_api_key() + k2 = AuthService.generate_api_key() + assert k1 != k2 + + def test_generate_key_length(self) -> None: + key = AuthService.generate_api_key() + assert len(key) > 30 diff --git a/tests/unit/api/conftest.py b/tests/unit/api/conftest.py index b43fc29181..722ebdf935 100644 --- a/tests/unit/api/conftest.py +++ b/tests/unit/api/conftest.py @@ -1,6 +1,7 @@ """Shared fixtures for API unit tests.""" import asyncio +import uuid from datetime import UTC, datetime, timedelta from typing import Any @@ -9,6 +10,10 @@ from ai_company.api.app import create_app from ai_company.api.approval_store import ApprovalStore +from ai_company.api.auth.config import AuthConfig +from ai_company.api.auth.models import ApiKey, User +from ai_company.api.auth.service import AuthService +from ai_company.api.guards import HumanRole from ai_company.budget.cost_record import CostRecord # noqa: TC001 from ai_company.budget.tracker import CostTracker from ai_company.communication.channel import Channel # noqa: TC001 @@ -25,6 +30,12 @@ from ai_company.security.models import AuditEntry, AuditVerdictStr # noqa: TC001 from ai_company.security.timeout.parked_context import ParkedContext # noqa: TC001 +# ── Test auth constants ─────────────────────────────────────── + +_TEST_JWT_SECRET = "test-secret-that-is-at-least-32-characters-long" +_TEST_USER_ID = "test-user-001" +_TEST_USERNAME = "testadmin" + # ── Fake Repositories ──────────────────────────────────────────── @@ -262,6 +273,59 @@ async def query( # noqa: PLR0913 return tuple(results[:limit]) +class FakeUserRepository: + """In-memory user repository for tests.""" + + def __init__(self) -> None: + self._users: dict[str, User] = {} + + async def save(self, user: User) -> None: + self._users[user.id] = user + + async def get(self, user_id: str) -> User | None: + return self._users.get(user_id) + + async def get_by_username(self, username: str) -> User | None: + for user in self._users.values(): + if user.username == username: + return user + return None + + async def list_users(self) -> tuple[User, ...]: + return tuple(self._users.values()) + + async def count(self) -> int: + return len(self._users) + + async def delete(self, user_id: str) -> bool: + return self._users.pop(user_id, None) is not None + + +class FakeApiKeyRepository: + """In-memory API key repository for tests.""" + + def __init__(self) -> None: + self._keys: dict[str, ApiKey] = {} + + async def save(self, key: ApiKey) -> None: + self._keys[key.id] = key + + async def get(self, key_id: str) -> ApiKey | None: + return self._keys.get(key_id) + + async def get_by_hash(self, key_hash: str) -> ApiKey | None: + for key in self._keys.values(): + if key.key_hash == key_hash: + return key + return None + + async def list_by_user(self, user_id: str) -> tuple[ApiKey, ...]: + return tuple(k for k in self._keys.values() if k.user_id == user_id) + + async def delete(self, key_id: str) -> bool: + return self._keys.pop(key_id, None) is not None + + class FakePersistenceBackend: """In-memory persistence backend for tests.""" @@ -274,6 +338,9 @@ def __init__(self) -> None: self._collaboration_metrics = FakeCollaborationMetricRepository() self._parked_contexts = FakeParkedContextRepository() self._audit_entries = FakeAuditRepository() + self._users = FakeUserRepository() + self._api_keys = FakeApiKeyRepository() + self._settings: dict[str, str] = {} self._connected = False async def connect(self) -> None: @@ -328,6 +395,20 @@ def parked_contexts(self) -> FakeParkedContextRepository: def audit_entries(self) -> FakeAuditRepository: return self._audit_entries + @property + def users(self) -> FakeUserRepository: + return self._users + + @property + def api_keys(self) -> FakeApiKeyRepository: + return self._api_keys + + async def get_setting(self, key: str) -> str | None: + return self._settings.get(key) + + async def set_setting(self, key: str, value: str) -> None: + self._settings[key] = value + # ── Fake Message Bus ──────────────────────────────────────────── @@ -396,9 +477,80 @@ async def get_channel_history( return () +# ── Auth helpers ──────────────────────────────────────────────── + + +def _make_test_auth_config() -> AuthConfig: + """Create an AuthConfig with a test JWT secret.""" + return AuthConfig(jwt_secret=_TEST_JWT_SECRET) + + +def _make_test_auth_service() -> AuthService: + """Create an AuthService backed by test config.""" + return AuthService(_make_test_auth_config()) + + +def _make_test_user( + *, + role: HumanRole = HumanRole.CEO, + must_change_password: bool = False, + user_id: str = _TEST_USER_ID, + username: str = _TEST_USERNAME, +) -> User: + """Create a test User with given role.""" + now = datetime.now(UTC) + auth_service = _make_test_auth_service() + return User( + id=user_id, + username=username, + password_hash=auth_service.hash_password("test-password-12chars"), + role=role, + must_change_password=must_change_password, + created_at=now, + updated_at=now, + ) + + +def make_auth_headers( + role: str = "ceo", + *, + must_change_password: bool = False, +) -> dict[str, str]: + """Build an Authorization header with a JWT for the given role. + + Uses deterministic user IDs matching ``_seed_test_users`` so + middleware user lookups succeed. + """ + auth_service = _make_test_auth_service() + # Must match the ID pattern in _seed_test_users + user_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"test-{role}")) + now = datetime.now(UTC) + user = User( + id=user_id, + username=f"test-{role}", + password_hash=auth_service.hash_password("test-password-12chars"), + role=HumanRole(role), + must_change_password=must_change_password, + created_at=now, + updated_at=now, + ) + token, _ = auth_service.create_token(user) + return {"Authorization": f"Bearer {token}"} + + # ── Fixtures ──────────────────────────────────────────────────── +@pytest.fixture +def auth_config() -> AuthConfig: + return _make_test_auth_config() + + +@pytest.fixture +def auth_service() -> AuthService: + return _make_test_auth_service() + + @pytest.fixture async def fake_persistence() -> FakePersistenceBackend: backend = FakePersistenceBackend() @@ -429,25 +581,60 @@ def root_config() -> RootConfig: @pytest.fixture -def test_client( +def test_client( # noqa: PLR0913 fake_persistence: FakePersistenceBackend, fake_message_bus: FakeMessageBus, cost_tracker: CostTracker, approval_store: ApprovalStore, root_config: RootConfig, + auth_service: AuthService, ) -> TestClient[Any]: + # Pre-seed users for each role so JWT sub claims resolve + _seed_test_users(fake_persistence, auth_service) + app = create_app( config=root_config, persistence=fake_persistence, message_bus=fake_message_bus, cost_tracker=cost_tracker, approval_store=approval_store, + auth_service=auth_service, ) client = TestClient(app) - client.headers["X-Human-Role"] = "observer" + # Default: CEO token (most tests need write access) + client.headers.update(make_auth_headers("ceo")) return client +def _seed_test_users( + backend: FakePersistenceBackend, + auth_service: AuthService, +) -> None: + """Pre-seed a user for each role so JWT validation succeeds. + + The middleware looks up the user by ``sub`` claim, so we + need matching users in the fake persistence for every role + that tests might use. + """ + import asyncio + + now = datetime.now(UTC) + for role in HumanRole: + user_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"test-{role.value}")) + user = User( + id=user_id, + username=f"test-{role.value}", + password_hash=auth_service.hash_password("test-password-12chars"), + role=role, + must_change_password=False, + created_at=now, + updated_at=now, + ) + # Save synchronously via event loop + loop = asyncio.get_event_loop() + loop.run_until_complete(backend.users.save(user)) + + def make_task( # noqa: PLR0913 *, task_id: str = "task-001", diff --git a/tests/unit/api/controllers/test_agents.py b/tests/unit/api/controllers/test_agents.py index 2e35e4d800..c6632b2282 100644 --- a/tests/unit/api/controllers/test_agents.py +++ b/tests/unit/api/controllers/test_agents.py @@ -6,9 +6,10 @@ from litestar.testing import TestClient from ai_company.config.schema import AgentConfig, RootConfig -from tests.unit.api.conftest import ( # noqa: TC001 +from tests.unit.api.conftest import ( FakeMessageBus, FakePersistenceBackend, + make_auth_headers, ) @@ -27,7 +28,9 @@ def test_list_agents_with_data( fake_message_bus: FakeMessageBus, ) -> None: from ai_company.api.app import create_app + from ai_company.api.auth.service import AuthService # noqa: TC001 from ai_company.budget.tracker import CostTracker + from tests.unit.api.conftest import _make_test_auth_service, _seed_test_users config = RootConfig( company_name="test", @@ -39,14 +42,17 @@ def test_list_agents_with_data( ), ), ) + auth_service: AuthService = _make_test_auth_service() + _seed_test_users(fake_persistence, auth_service) app = create_app( config=config, persistence=fake_persistence, message_bus=fake_message_bus, cost_tracker=CostTracker(), + auth_service=auth_service, ) with TestClient(app) as client: - client.headers["X-Human-Role"] = "observer" + client.headers.update(make_auth_headers("observer")) resp = client.get("/api/v1/agents") body = resp.json() assert body["pagination"]["total"] == 1 diff --git a/tests/unit/api/controllers/test_analytics.py b/tests/unit/api/controllers/test_analytics.py index a1462817fb..49fc03f29d 100644 --- a/tests/unit/api/controllers/test_analytics.py +++ b/tests/unit/api/controllers/test_analytics.py @@ -6,8 +6,9 @@ from litestar.testing import TestClient # noqa: TC002 from ai_company.core.enums import TaskStatus +from tests.unit.api.conftest import make_auth_headers -_HEADERS = {"X-Human-Role": "ceo"} +_HEADERS = make_auth_headers("ceo") @pytest.mark.unit @@ -27,6 +28,6 @@ def test_overview_empty(self, test_client: TestClient[Any]) -> None: def test_overview_requires_read_access(self, test_client: TestClient[Any]) -> None: resp = test_client.get( "/api/v1/analytics/overview", - headers={"X-Human-Role": "invalid"}, + headers={"Authorization": "Bearer invalid-token"}, ) - assert resp.status_code == 403 + assert resp.status_code == 401 diff --git a/tests/unit/api/controllers/test_approvals.py b/tests/unit/api/controllers/test_approvals.py index 025e80ee5b..6defff389a 100644 --- a/tests/unit/api/controllers/test_approvals.py +++ b/tests/unit/api/controllers/test_approvals.py @@ -9,11 +9,11 @@ from ai_company.api.approval_store import ApprovalStore # noqa: TC001 from ai_company.core.approval import ApprovalItem from ai_company.core.enums import ApprovalRiskLevel, ApprovalStatus -from tests.unit.api.conftest import make_approval +from tests.unit.api.conftest import make_approval, make_auth_headers _BASE = "/api/v1/approvals" -_WRITE_HEADERS = {"X-Human-Role": "ceo"} -_READ_HEADERS = {"X-Human-Role": "observer"} +_WRITE_HEADERS = make_auth_headers("ceo") +_READ_HEADERS = make_auth_headers("observer") def _create_payload( @@ -134,8 +134,8 @@ async def test_list_pagination( assert body["pagination"]["offset"] == 2 def test_list_blocks_no_role(self, test_client: TestClient[Any]) -> None: - resp = test_client.get(_BASE, headers={"X-Human-Role": "invalid"}) - assert resp.status_code == 403 + resp = test_client.get(_BASE, headers={"Authorization": "Bearer invalid-token"}) + assert resp.status_code == 401 @pytest.mark.unit @@ -164,16 +164,16 @@ def test_get_allows_observer(self, test_client: TestClient[Any]) -> None: # Observer should have read access (even if 404) resp = test_client.get( f"{_BASE}/nonexistent", - headers={"X-Human-Role": "observer"}, + headers=make_auth_headers("observer"), ) assert resp.status_code == 404 # 404 = authorized but not found def test_get_blocks_no_role(self, test_client: TestClient[Any]) -> None: resp = test_client.get( f"{_BASE}/whatever", - headers={"X-Human-Role": "invalid"}, + headers={"Authorization": "Bearer invalid-token"}, ) - assert resp.status_code == 403 + assert resp.status_code == 401 @pytest.mark.unit @@ -234,12 +234,13 @@ def test_create_blocks_observer(self, test_client: TestClient[Any]) -> None: ) assert resp.status_code == 403 - def test_create_blocks_no_role(self, test_client: TestClient[Any]) -> None: + def test_create_blocks_no_auth(self, test_client: TestClient[Any]) -> None: resp = test_client.post( _BASE, json=_create_payload(), + headers={"Authorization": "Bearer invalid-token"}, ) - assert resp.status_code == 403 + assert resp.status_code == 401 @pytest.mark.unit @@ -270,7 +271,7 @@ async def test_approve_records_decided_by_from_header( resp = test_client.post( f"{_BASE}/approval-001/approve", json={}, - headers={"X-Human-Role": "manager"}, + headers=make_auth_headers("manager"), ) assert resp.status_code == 200 assert resp.json()["data"]["decided_by"] == "manager" diff --git a/tests/unit/api/controllers/test_autonomy.py b/tests/unit/api/controllers/test_autonomy.py index c460503db9..c84cde81d7 100644 --- a/tests/unit/api/controllers/test_autonomy.py +++ b/tests/unit/api/controllers/test_autonomy.py @@ -5,9 +5,11 @@ import pytest from litestar.testing import TestClient # noqa: TC002 +from tests.unit.api.conftest import make_auth_headers + _BASE = "/api/v1/agents" -_WRITE_HEADERS = {"X-Human-Role": "ceo"} -_READ_HEADERS = {"X-Human-Role": "observer"} +_WRITE_HEADERS = make_auth_headers("ceo") +_READ_HEADERS = make_auth_headers("observer") def _url(agent_id: str = "agent-001") -> str: @@ -29,8 +31,10 @@ def test_get_autonomy(self, test_client: TestClient[Any]) -> None: def test_get_autonomy_requires_read_access( self, test_client: TestClient[Any] ) -> None: - resp = test_client.get(_url(), headers={"X-Human-Role": "invalid"}) - assert resp.status_code == 403 + resp = test_client.get( + _url(), headers={"Authorization": "Bearer invalid-token"} + ) + assert resp.status_code == 401 @pytest.mark.unit diff --git a/tests/unit/api/controllers/test_budget.py b/tests/unit/api/controllers/test_budget.py index 16289e27d6..28c8e47a68 100644 --- a/tests/unit/api/controllers/test_budget.py +++ b/tests/unit/api/controllers/test_budget.py @@ -8,8 +8,9 @@ from ai_company.budget.cost_record import CostRecord from ai_company.budget.tracker import CostTracker # noqa: TC001 +from tests.unit.api.conftest import make_auth_headers -_HEADERS = {"X-Human-Role": "ceo"} +_HEADERS = make_auth_headers("ceo") @pytest.mark.unit @@ -74,6 +75,6 @@ async def test_agent_spending( def test_budget_requires_read_access(self, test_client: TestClient[Any]) -> None: resp = test_client.get( "/api/v1/budget/config", - headers={"X-Human-Role": "invalid"}, + headers={"Authorization": "Bearer invalid-token"}, ) - assert resp.status_code == 403 + assert resp.status_code == 401 diff --git a/tests/unit/api/controllers/test_company.py b/tests/unit/api/controllers/test_company.py index b447874a58..02a554f3fa 100644 --- a/tests/unit/api/controllers/test_company.py +++ b/tests/unit/api/controllers/test_company.py @@ -5,7 +5,9 @@ import pytest from litestar.testing import TestClient # noqa: TC002 -_HEADERS = {"X-Human-Role": "ceo"} +from tests.unit.api.conftest import make_auth_headers + +_HEADERS = make_auth_headers("ceo") @pytest.mark.unit @@ -27,6 +29,6 @@ def test_list_departments(self, test_client: TestClient[Any]) -> None: def test_company_requires_read_access(self, test_client: TestClient[Any]) -> None: resp = test_client.get( "/api/v1/company", - headers={"X-Human-Role": "invalid"}, + headers={"Authorization": "Bearer invalid-token"}, ) - assert resp.status_code == 403 + assert resp.status_code == 401 diff --git a/tests/unit/api/controllers/test_tasks.py b/tests/unit/api/controllers/test_tasks.py index 95a6895ffd..04755227d9 100644 --- a/tests/unit/api/controllers/test_tasks.py +++ b/tests/unit/api/controllers/test_tasks.py @@ -5,7 +5,7 @@ import pytest from litestar.testing import TestClient # noqa: TC002 -from tests.unit.api.conftest import FakePersistenceBackend, make_task +from tests.unit.api.conftest import FakePersistenceBackend, make_auth_headers, make_task @pytest.mark.unit @@ -77,7 +77,7 @@ def test_create_task(self, test_client: TestClient[Any]) -> None: "project": "proj-1", "created_by": "alice", }, - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert resp.status_code == 201 body = resp.json() @@ -93,14 +93,14 @@ def test_delete_task( fake_persistence.tasks._tasks[task.id] = task resp = test_client.delete( "/api/v1/tasks/task-001", - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert resp.status_code == 200 def test_delete_task_not_found(self, test_client: TestClient[Any]) -> None: resp = test_client.delete( "/api/v1/tasks/nonexistent", - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert resp.status_code == 404 @@ -117,7 +117,7 @@ def test_update_task( resp = test_client.patch( "/api/v1/tasks/task-001", json={"title": "Updated title"}, - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert resp.status_code == 200 assert resp.json()["data"]["title"] == "Updated title" @@ -126,7 +126,7 @@ def test_update_not_found(self, test_client: TestClient[Any]) -> None: resp = test_client.patch( "/api/v1/tasks/nonexistent", json={"title": "Nope"}, - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert resp.status_code == 404 @@ -134,7 +134,7 @@ def test_update_requires_write_role(self, test_client: TestClient[Any]) -> None: resp = test_client.patch( "/api/v1/tasks/task-001", json={"title": "Nope"}, - headers={"X-Human-Role": "observer"}, + headers=make_auth_headers("observer"), ) assert resp.status_code == 403 @@ -154,7 +154,7 @@ def test_transition_task( "target_status": "assigned", "assigned_to": "bob", }, - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert resp.status_code == 201 assert resp.json()["data"]["status"] == "assigned" @@ -169,7 +169,7 @@ def test_transition_invalid( resp = test_client.post( "/api/v1/tasks/task-001/transition", json={"target_status": "completed"}, - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert resp.status_code == 422 @@ -177,7 +177,7 @@ def test_transition_not_found(self, test_client: TestClient[Any]) -> None: resp = test_client.post( "/api/v1/tasks/nonexistent/transition", json={"target_status": "assigned"}, - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert resp.status_code == 404 @@ -185,6 +185,6 @@ def test_transition_requires_write_role(self, test_client: TestClient[Any]) -> N resp = test_client.post( "/api/v1/tasks/task-001/transition", json={"target_status": "assigned"}, - headers={"X-Human-Role": "observer"}, + headers=make_auth_headers("observer"), ) assert resp.status_code == 403 diff --git a/tests/unit/api/test_app.py b/tests/unit/api/test_app.py index 7d3b7c1688..20117107cb 100644 --- a/tests/unit/api/test_app.py +++ b/tests/unit/api/test_app.py @@ -35,9 +35,14 @@ def test_openapi_schema_accessible(self, test_client: TestClient[Any]) -> None: @pytest.mark.unit class TestAppLifecycle: - async def test_startup_partial_failure_cleanup(self) -> None: + async def test_startup_partial_failure_cleanup( + self, + root_config: Any, + ) -> None: """Persistence ok, bus fails → persistence cleaned up.""" from ai_company.api.app import _safe_startup + from ai_company.api.approval_store import ApprovalStore + from ai_company.api.state import AppState from tests.unit.api.conftest import ( FakeMessageBus, FakePersistenceBackend, @@ -45,6 +50,11 @@ async def test_startup_partial_failure_cleanup(self) -> None: persistence = FakePersistenceBackend() bus = FakeMessageBus() + app_state = AppState( + config=root_config, + approval_store=ApprovalStore(), + persistence=persistence, + ) async def failing_start() -> None: msg = "bus boom" @@ -53,7 +63,7 @@ async def failing_start() -> None: bus.start = failing_start # type: ignore[method-assign] with pytest.raises(RuntimeError, match="bus boom"): - await _safe_startup(persistence, bus, None) + await _safe_startup(persistence, bus, None, app_state) # Persistence should have been disconnected during cleanup assert not persistence.is_connected diff --git a/tests/unit/api/test_guards.py b/tests/unit/api/test_guards.py index aa9e5b2c9b..4a1a543ae7 100644 --- a/tests/unit/api/test_guards.py +++ b/tests/unit/api/test_guards.py @@ -1,14 +1,28 @@ -"""Tests for route guards.""" +"""Tests for route guards with JWT-based authentication.""" from typing import Any import pytest from litestar.testing import TestClient # noqa: TC002 +from tests.unit.api.conftest import make_auth_headers + +# To test "no auth" we need a fresh client without default headers. +# The test_client fixture sets CEO headers. Passing headers={} to +# a request merges with session defaults — it does NOT clear them. +# Instead we create a bare_client fixture. + + +@pytest.fixture +def bare_client(test_client: TestClient[Any]) -> TestClient[Any]: + """Test client with no default Authorization header.""" + test_client.headers.pop("authorization", None) + return test_client + @pytest.mark.unit -class TestGuards: - def test_write_guard_allows_ceo(self, test_client: TestClient[Any]) -> None: +class TestWriteGuard: + def test_allows_ceo(self, test_client: TestClient[Any]) -> None: response = test_client.post( "/api/v1/tasks", json={ @@ -18,11 +32,11 @@ def test_write_guard_allows_ceo(self, test_client: TestClient[Any]) -> None: "project": "proj", "created_by": "alice", }, - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert response.status_code == 201 - def test_write_guard_blocks_observer(self, test_client: TestClient[Any]) -> None: + def test_allows_manager(self, test_client: TestClient[Any]) -> None: response = test_client.post( "/api/v1/tasks", json={ @@ -32,13 +46,11 @@ def test_write_guard_blocks_observer(self, test_client: TestClient[Any]) -> None "project": "proj", "created_by": "alice", }, - headers={"X-Human-Role": "observer"}, + headers=make_auth_headers("manager"), ) - assert response.status_code == 403 + assert response.status_code == 201 - def test_write_guard_blocks_missing_role( - self, test_client: TestClient[Any] - ) -> None: + def test_allows_board_member(self, test_client: TestClient[Any]) -> None: response = test_client.post( "/api/v1/tasks", json={ @@ -48,38 +60,52 @@ def test_write_guard_blocks_missing_role( "project": "proj", "created_by": "alice", }, + headers=make_auth_headers("board_member"), ) - assert response.status_code == 403 - - def test_case_insensitive_role(self, test_client: TestClient[Any]) -> None: - response = test_client.get( - "/api/v1/tasks", - headers={"X-Human-Role": "CEO"}, - ) - assert response.status_code == 200 + assert response.status_code == 201 - def test_whitespace_padded_role(self, test_client: TestClient[Any]) -> None: - response = test_client.get( + def test_allows_pair_programmer(self, test_client: TestClient[Any]) -> None: + response = test_client.post( "/api/v1/tasks", - headers={"X-Human-Role": " ceo "}, + json={ + "title": "Test", + "description": "Test desc", + "type": "development", + "project": "proj", + "created_by": "alice", + }, + headers=make_auth_headers("pair_programmer"), ) - assert response.status_code == 200 + assert response.status_code == 201 - def test_read_guard_allows_observer(self, test_client: TestClient[Any]) -> None: - response = test_client.get( + def test_blocks_observer(self, test_client: TestClient[Any]) -> None: + response = test_client.post( "/api/v1/tasks", - headers={"X-Human-Role": "observer"}, + json={ + "title": "Test", + "description": "Test desc", + "type": "development", + "project": "proj", + "created_by": "alice", + }, + headers=make_auth_headers("observer"), ) - assert response.status_code == 200 + assert response.status_code == 403 - def test_read_guard_blocks_missing_role(self, test_client: TestClient[Any]) -> None: - response = test_client.get( + def test_missing_auth_returns_401(self, bare_client: TestClient[Any]) -> None: + response = bare_client.post( "/api/v1/tasks", - headers={"X-Human-Role": "invalid"}, + json={ + "title": "Test", + "description": "Test desc", + "type": "development", + "project": "proj", + "created_by": "alice", + }, ) - assert response.status_code == 403 + assert response.status_code == 401 - def test_write_guard_allows_manager(self, test_client: TestClient[Any]) -> None: + def test_invalid_token_returns_401(self, test_client: TestClient[Any]) -> None: response = test_client.post( "/api/v1/tasks", json={ @@ -89,6 +115,34 @@ def test_write_guard_allows_manager(self, test_client: TestClient[Any]) -> None: "project": "proj", "created_by": "alice", }, - headers={"X-Human-Role": "manager"}, + headers={"Authorization": "Bearer invalid-token"}, ) - assert response.status_code == 201 + assert response.status_code == 401 + + +@pytest.mark.unit +class TestReadGuard: + def test_allows_observer(self, test_client: TestClient[Any]) -> None: + response = test_client.get( + "/api/v1/tasks", + headers=make_auth_headers("observer"), + ) + assert response.status_code == 200 + + def test_allows_ceo(self, test_client: TestClient[Any]) -> None: + response = test_client.get( + "/api/v1/tasks", + headers=make_auth_headers("ceo"), + ) + assert response.status_code == 200 + + def test_missing_auth_returns_401(self, bare_client: TestClient[Any]) -> None: + response = bare_client.get("/api/v1/tasks") + assert response.status_code == 401 + + def test_invalid_token_returns_401(self, test_client: TestClient[Any]) -> None: + response = test_client.get( + "/api/v1/tasks", + headers={"Authorization": "Bearer invalid-token"}, + ) + assert response.status_code == 401 diff --git a/tests/unit/persistence/sqlite/test_user_repo.py b/tests/unit/persistence/sqlite/test_user_repo.py new file mode 100644 index 0000000000..dcbbd3b9dc --- /dev/null +++ b/tests/unit/persistence/sqlite/test_user_repo.py @@ -0,0 +1,199 @@ +"""Tests for SQLiteUserRepository and SQLiteApiKeyRepository.""" + +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +import aiosqlite +import pytest + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + +from ai_company.api.auth.models import ApiKey, User +from ai_company.api.guards import HumanRole +from ai_company.persistence.sqlite.migrations import run_migrations +from ai_company.persistence.sqlite.user_repo import ( + SQLiteApiKeyRepository, + SQLiteUserRepository, +) + + +@pytest.fixture +async def db() -> AsyncGenerator[aiosqlite.Connection]: + """Create an in-memory SQLite DB with schema applied.""" + conn = await aiosqlite.connect(":memory:") + conn.row_factory = aiosqlite.Row + await run_migrations(conn) + yield conn + await conn.close() + + +@pytest.fixture +def user_repo(db: aiosqlite.Connection) -> SQLiteUserRepository: + return SQLiteUserRepository(db) + + +@pytest.fixture +def api_key_repo(db: aiosqlite.Connection) -> SQLiteApiKeyRepository: + return SQLiteApiKeyRepository(db) + + +def _make_user( + *, + user_id: str = "user-001", + username: str = "admin", + role: HumanRole = HumanRole.CEO, +) -> User: + now = datetime.now(UTC) + return User( + id=user_id, + username=username, + password_hash="$argon2id$fake-hash", + role=role, + must_change_password=False, + created_at=now, + updated_at=now, + ) + + +@pytest.mark.unit +class TestSQLiteUserRepository: + async def test_save_and_get(self, user_repo: SQLiteUserRepository) -> None: + user = _make_user() + await user_repo.save(user) + fetched = await user_repo.get("user-001") + assert fetched is not None + assert fetched.id == "user-001" + assert fetched.username == "admin" + + async def test_get_nonexistent(self, user_repo: SQLiteUserRepository) -> None: + result = await user_repo.get("nonexistent") + assert result is None + + async def test_get_by_username(self, user_repo: SQLiteUserRepository) -> None: + user = _make_user() + await user_repo.save(user) + fetched = await user_repo.get_by_username("admin") + assert fetched is not None + assert fetched.id == "user-001" + + async def test_get_by_username_not_found( + self, user_repo: SQLiteUserRepository + ) -> None: + result = await user_repo.get_by_username("nope") + assert result is None + + async def test_list_users(self, user_repo: SQLiteUserRepository) -> None: + await user_repo.save(_make_user(user_id="u1", username="alice")) + await user_repo.save(_make_user(user_id="u2", username="bob")) + users = await user_repo.list_users() + assert len(users) == 2 + + async def test_count(self, user_repo: SQLiteUserRepository) -> None: + assert await user_repo.count() == 0 + await user_repo.save(_make_user()) + assert await user_repo.count() == 1 + + async def test_delete(self, user_repo: SQLiteUserRepository) -> None: + await user_repo.save(_make_user()) + deleted = await user_repo.delete("user-001") + assert deleted is True + assert await user_repo.get("user-001") is None + + async def test_delete_nonexistent(self, user_repo: SQLiteUserRepository) -> None: + deleted = await user_repo.delete("nope") + assert deleted is False + + async def test_upsert(self, user_repo: SQLiteUserRepository) -> None: + user = _make_user() + await user_repo.save(user) + updated = user.model_copy( + update={"username": "new-admin", "updated_at": datetime.now(UTC)} + ) + await user_repo.save(updated) + fetched = await user_repo.get("user-001") + assert fetched is not None + assert fetched.username == "new-admin" + assert await user_repo.count() == 1 + + +@pytest.mark.unit +class TestSQLiteApiKeyRepository: + async def test_save_and_get( + self, + api_key_repo: SQLiteApiKeyRepository, + user_repo: SQLiteUserRepository, + ) -> None: + await user_repo.save(_make_user()) + now = datetime.now(UTC) + key = ApiKey( + id="key-001", + key_hash="abc123hash", + name="test-key", + role=HumanRole.CEO, + user_id="user-001", + created_at=now, + ) + await api_key_repo.save(key) + fetched = await api_key_repo.get("key-001") + assert fetched is not None + assert fetched.name == "test-key" + + async def test_get_by_hash( + self, + api_key_repo: SQLiteApiKeyRepository, + user_repo: SQLiteUserRepository, + ) -> None: + await user_repo.save(_make_user()) + now = datetime.now(UTC) + key = ApiKey( + id="key-002", + key_hash="unique-hash", + name="hash-key", + role=HumanRole.CEO, + user_id="user-001", + created_at=now, + ) + await api_key_repo.save(key) + fetched = await api_key_repo.get_by_hash("unique-hash") + assert fetched is not None + assert fetched.id == "key-002" + + async def test_list_by_user( + self, + api_key_repo: SQLiteApiKeyRepository, + user_repo: SQLiteUserRepository, + ) -> None: + await user_repo.save(_make_user()) + now = datetime.now(UTC) + for i in range(3): + key = ApiKey( + id=f"key-{i}", + key_hash=f"hash-{i}", + name=f"key-{i}", + role=HumanRole.CEO, + user_id="user-001", + created_at=now, + ) + await api_key_repo.save(key) + keys = await api_key_repo.list_by_user("user-001") + assert len(keys) == 3 + + async def test_delete( + self, + api_key_repo: SQLiteApiKeyRepository, + user_repo: SQLiteUserRepository, + ) -> None: + await user_repo.save(_make_user()) + now = datetime.now(UTC) + key = ApiKey( + id="key-del", + key_hash="del-hash", + name="del-key", + role=HumanRole.CEO, + user_id="user-001", + created_at=now, + ) + await api_key_repo.save(key) + assert await api_key_repo.delete("key-del") is True + assert await api_key_repo.get("key-del") is None diff --git a/tests/unit/persistence/test_migrations_v2.py b/tests/unit/persistence/test_migrations_v2.py index 776a3dbe47..1c33500761 100644 --- a/tests/unit/persistence/test_migrations_v2.py +++ b/tests/unit/persistence/test_migrations_v2.py @@ -28,8 +28,8 @@ async def memory_db() -> AsyncGenerator[aiosqlite.Connection]: @pytest.mark.unit class TestSchemaMigrations: - async def test_schema_version_is_four(self) -> None: - assert SCHEMA_VERSION == 4 + async def test_schema_version_is_five(self) -> None: + assert SCHEMA_VERSION == 5 async def test_fresh_db_creates_all_v2_tables( self, memory_db: aiosqlite.Connection diff --git a/tests/unit/persistence/test_protocol.py b/tests/unit/persistence/test_protocol.py index 7b559bd05e..06d932ea59 100644 --- a/tests/unit/persistence/test_protocol.py +++ b/tests/unit/persistence/test_protocol.py @@ -12,16 +12,19 @@ ) from ai_company.persistence.protocol import PersistenceBackend from ai_company.persistence.repositories import ( + ApiKeyRepository, AuditRepository, CostRecordRepository, MessageRepository, ParkedContextRepository, TaskRepository, + UserRepository, ) if TYPE_CHECKING: from pydantic import AwareDatetime + from ai_company.api.auth.models import ApiKey, User from ai_company.budget.cost_record import CostRecord from ai_company.communication.message import Message from ai_company.core.enums import ApprovalRiskLevel, TaskStatus @@ -163,6 +166,44 @@ async def query( # noqa: PLR0913 return () +class _FakeUserRepository: + async def save(self, user: User) -> None: + pass + + async def get(self, user_id: str) -> User | None: + return None + + async def get_by_username(self, username: str) -> User | None: + return None + + async def list_users(self) -> tuple[User, ...]: + return () + + async def count(self) -> int: + return 0 + + async def delete(self, user_id: str) -> bool: + return False + + +class _FakeApiKeyRepository: + async def save(self, key: ApiKey) -> None: + pass + + async def get(self, key_id: str) -> ApiKey | None: + return None + + async def get_by_hash(self, key_hash: str) -> ApiKey | None: + return None + + async def list_by_user(self, user_id: str) -> tuple[ApiKey, ...]: + return () + + async def delete(self, key_id: str) -> bool: + return False + + + class _FakeBackend: async def connect(self) -> None: pass @@ -216,6 +257,20 @@ def collaboration_metrics(self) -> _FakeCollaborationMetricRepository: def audit_entries(self) -> _FakeAuditRepository: return _FakeAuditRepository() + @property + def users(self) -> _FakeUserRepository: + return _FakeUserRepository() + + @property + def api_keys(self) -> _FakeApiKeyRepository: + return _FakeApiKeyRepository() + + async def get_setting(self, key: str) -> str | None: + return None + + async def set_setting(self, key: str, value: str) -> None: + pass + @pytest.mark.unit class TestProtocolCompliance: @@ -255,3 +310,9 @@ def test_fake_parked_context_repo_is_parked_context_repository( def test_fake_audit_repo_is_audit_repository(self) -> None: assert isinstance(_FakeAuditRepository(), AuditRepository) + + def test_fake_user_repo_is_user_repository(self) -> None: + assert isinstance(_FakeUserRepository(), UserRepository) + + def test_fake_api_key_repo_is_api_key_repository(self) -> None: + assert isinstance(_FakeApiKeyRepository(), ApiKeyRepository) diff --git a/uv.lock b/uv.lock index d12431845f..0ef0fd877a 100644 --- a/uv.lock +++ b/uv.lock @@ -8,12 +8,14 @@ source = { editable = "." } dependencies = [ { name = "aiodocker" }, { name = "aiosqlite" }, + { name = "argon2-cffi" }, { name = "jinja2" }, { name = "jsonschema" }, { name = "litellm" }, { name = "litestar", extra = ["brotli", "prometheus", "pydantic", "standard", "structlog"] }, { name = "mcp" }, { name = "pydantic" }, + { name = "pyjwt", extra = ["crypto"] }, { name = "pyyaml" }, { name = "structlog" }, ] @@ -50,12 +52,14 @@ test = [ requires-dist = [ { name = "aiodocker", specifier = "==0.26.0" }, { name = "aiosqlite", specifier = "==0.22.1" }, + { name = "argon2-cffi", specifier = "==25.1.0" }, { name = "jinja2", specifier = "==3.1.6" }, { name = "jsonschema", specifier = "==4.26.0" }, { name = "litellm", specifier = "==1.82.1" }, { name = "litestar", extras = ["brotli", "prometheus", "pydantic", "standard", "structlog"], specifier = "==2.21.1" }, { name = "mcp", specifier = "==1.26.0" }, { name = "pydantic", specifier = "==2.12.5" }, + { name = "pyjwt", extras = ["crypto"], specifier = "==2.11.0" }, { name = "pyyaml", specifier = "==6.0.3" }, { name = "structlog", specifier = "==25.5.0" }, ] @@ -220,6 +224,49 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/74/f5/9373290775639cb67a2fce7f629a1c240dce9f12fe927bc32b2736e16dfc/argcomplete-3.6.3-py3-none-any.whl", hash = "sha256:f5007b3a600ccac5d25bbce33089211dfd49eab4a7718da3f10e3082525a92ce", size = 43846, upload-time = "2025-10-20T03:33:33.021Z" }, ] +[[package]] +name = "argon2-cffi" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "argon2-cffi-bindings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/89/ce5af8a7d472a67cc819d5d998aa8c82c5d860608c4db9f46f1162d7dab9/argon2_cffi-25.1.0.tar.gz", hash = "sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1", size = 45706, upload-time = "2025-06-03T06:55:32.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/d3/a8b22fa575b297cd6e3e3b0155c7e25db170edf1c74783d6a31a2490b8d9/argon2_cffi-25.1.0-py3-none-any.whl", hash = "sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741", size = 14657, upload-time = "2025-06-03T06:55:30.804Z" }, +] + +[[package]] +name = "argon2-cffi-bindings" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5c/2d/db8af0df73c1cf454f71b2bbe5e356b8c1f8041c979f505b3d3186e520a9/argon2_cffi_bindings-25.1.0.tar.gz", hash = "sha256:b957f3e6ea4d55d820e40ff76f450952807013d361a65d7f28acc0acbf29229d", size = 1783441, upload-time = "2025-07-30T10:02:05.147Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/97/3c0a35f46e52108d4707c44b95cfe2afcafc50800b5450c197454569b776/argon2_cffi_bindings-25.1.0-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:3d3f05610594151994ca9ccb3c771115bdb4daef161976a266f0dd8aa9996b8f", size = 54393, upload-time = "2025-07-30T10:01:40.97Z" }, + { url = "https://files.pythonhosted.org/packages/9d/f4/98bbd6ee89febd4f212696f13c03ca302b8552e7dbf9c8efa11ea4a388c3/argon2_cffi_bindings-25.1.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:8b8efee945193e667a396cbc7b4fb7d357297d6234d30a489905d96caabde56b", size = 29328, upload-time = "2025-07-30T10:01:41.916Z" }, + { url = "https://files.pythonhosted.org/packages/43/24/90a01c0ef12ac91a6be05969f29944643bc1e5e461155ae6559befa8f00b/argon2_cffi_bindings-25.1.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:3c6702abc36bf3ccba3f802b799505def420a1b7039862014a65db3205967f5a", size = 31269, upload-time = "2025-07-30T10:01:42.716Z" }, + { url = "https://files.pythonhosted.org/packages/d4/d3/942aa10782b2697eee7af5e12eeff5ebb325ccfb86dd8abda54174e377e4/argon2_cffi_bindings-25.1.0-cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a1c70058c6ab1e352304ac7e3b52554daadacd8d453c1752e547c76e9c99ac44", size = 86558, upload-time = "2025-07-30T10:01:43.943Z" }, + { url = "https://files.pythonhosted.org/packages/0d/82/b484f702fec5536e71836fc2dbc8c5267b3f6e78d2d539b4eaa6f0db8bf8/argon2_cffi_bindings-25.1.0-cp314-cp314t-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e2fd3bfbff3c5d74fef31a722f729bf93500910db650c925c2d6ef879a7e51cb", size = 92364, upload-time = "2025-07-30T10:01:44.887Z" }, + { url = "https://files.pythonhosted.org/packages/c9/c1/a606ff83b3f1735f3759ad0f2cd9e038a0ad11a3de3b6c673aa41c24bb7b/argon2_cffi_bindings-25.1.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c4f9665de60b1b0e99bcd6be4f17d90339698ce954cfd8d9cf4f91c995165a92", size = 85637, upload-time = "2025-07-30T10:01:46.225Z" }, + { url = "https://files.pythonhosted.org/packages/44/b4/678503f12aceb0262f84fa201f6027ed77d71c5019ae03b399b97caa2f19/argon2_cffi_bindings-25.1.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ba92837e4a9aa6a508c8d2d7883ed5a8f6c308c89a4790e1e447a220deb79a85", size = 91934, upload-time = "2025-07-30T10:01:47.203Z" }, + { url = "https://files.pythonhosted.org/packages/f0/c7/f36bd08ef9bd9f0a9cff9428406651f5937ce27b6c5b07b92d41f91ae541/argon2_cffi_bindings-25.1.0-cp314-cp314t-win32.whl", hash = "sha256:84a461d4d84ae1295871329b346a97f68eade8c53b6ed9a7ca2d7467f3c8ff6f", size = 28158, upload-time = "2025-07-30T10:01:48.341Z" }, + { url = "https://files.pythonhosted.org/packages/b3/80/0106a7448abb24a2c467bf7d527fe5413b7fdfa4ad6d6a96a43a62ef3988/argon2_cffi_bindings-25.1.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b55aec3565b65f56455eebc9b9f34130440404f27fe21c3b375bf1ea4d8fbae6", size = 32597, upload-time = "2025-07-30T10:01:49.112Z" }, + { url = "https://files.pythonhosted.org/packages/05/b8/d663c9caea07e9180b2cb662772865230715cbd573ba3b5e81793d580316/argon2_cffi_bindings-25.1.0-cp314-cp314t-win_arm64.whl", hash = "sha256:87c33a52407e4c41f3b70a9c2d3f6056d88b10dad7695be708c5021673f55623", size = 28231, upload-time = "2025-07-30T10:01:49.92Z" }, + { url = "https://files.pythonhosted.org/packages/1d/57/96b8b9f93166147826da5f90376e784a10582dd39a393c99bb62cfcf52f0/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:aecba1723ae35330a008418a91ea6cfcedf6d31e5fbaa056a166462ff066d500", size = 54121, upload-time = "2025-07-30T10:01:50.815Z" }, + { url = "https://files.pythonhosted.org/packages/0a/08/a9bebdb2e0e602dde230bdde8021b29f71f7841bd54801bcfd514acb5dcf/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:2630b6240b495dfab90aebe159ff784d08ea999aa4b0d17efa734055a07d2f44", size = 29177, upload-time = "2025-07-30T10:01:51.681Z" }, + { url = "https://files.pythonhosted.org/packages/b6/02/d297943bcacf05e4f2a94ab6f462831dc20158614e5d067c35d4e63b9acb/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:7aef0c91e2c0fbca6fc68e7555aa60ef7008a739cbe045541e438373bc54d2b0", size = 31090, upload-time = "2025-07-30T10:01:53.184Z" }, + { url = "https://files.pythonhosted.org/packages/c1/93/44365f3d75053e53893ec6d733e4a5e3147502663554b4d864587c7828a7/argon2_cffi_bindings-25.1.0-cp39-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6", size = 81246, upload-time = "2025-07-30T10:01:54.145Z" }, + { url = "https://files.pythonhosted.org/packages/09/52/94108adfdd6e2ddf58be64f959a0b9c7d4ef2fa71086c38356d22dc501ea/argon2_cffi_bindings-25.1.0-cp39-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d3e924cfc503018a714f94a49a149fdc0b644eaead5d1f089330399134fa028a", size = 87126, upload-time = "2025-07-30T10:01:55.074Z" }, + { url = "https://files.pythonhosted.org/packages/72/70/7a2993a12b0ffa2a9271259b79cc616e2389ed1a4d93842fac5a1f923ffd/argon2_cffi_bindings-25.1.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:c87b72589133f0346a1cb8d5ecca4b933e3c9b64656c9d175270a000e73b288d", size = 80343, upload-time = "2025-07-30T10:01:56.007Z" }, + { url = "https://files.pythonhosted.org/packages/78/9a/4e5157d893ffc712b74dbd868c7f62365618266982b64accab26bab01edc/argon2_cffi_bindings-25.1.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99", size = 86777, upload-time = "2025-07-30T10:01:56.943Z" }, + { url = "https://files.pythonhosted.org/packages/74/cd/15777dfde1c29d96de7f18edf4cc94c385646852e7c7b0320aa91ccca583/argon2_cffi_bindings-25.1.0-cp39-abi3-win32.whl", hash = "sha256:473bcb5f82924b1becbb637b63303ec8d10e84c8d241119419897a26116515d2", size = 27180, upload-time = "2025-07-30T10:01:57.759Z" }, + { url = "https://files.pythonhosted.org/packages/e2/c6/a759ece8f1829d1f162261226fbfd2c6832b3ff7657384045286d2afa384/argon2_cffi_bindings-25.1.0-cp39-abi3-win_amd64.whl", hash = "sha256:a98cd7d17e9f7ce244c0803cad3c23a7d379c301ba618a5fa76a67d116618b98", size = 31715, upload-time = "2025-07-30T10:01:58.56Z" }, + { url = "https://files.pythonhosted.org/packages/42/b9/f8d6fa329ab25128b7e98fd83a3cb34d9db5b059a9847eddb840a0af45dd/argon2_cffi_bindings-25.1.0-cp39-abi3-win_arm64.whl", hash = "sha256:b0fdbcf513833809c882823f98dc2f931cf659d9a1429616ac3adebb49f5db94", size = 27149, upload-time = "2025-07-30T10:01:59.329Z" }, +] + [[package]] name = "attrs" version = "25.4.0" From 335bbeb175059f99a9b215cce3ebbc71f6310407 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Wed, 11 Mar 2026 00:31:08 +0100 Subject: [PATCH 2/5] refactor: resolve merge conflicts and apply pre-PR review fixes Resolve conflicts between audit repository (main) and auth feature (this branch) by keeping both: v4 migration for audit_entries, v5 migration for settings/users/api_keys. Bump SCHEMA_VERSION to 5. Apply 25 findings from 10 review agents: add password length model validators, pass through exception detail messages, add test coverage for auth guards/secret resolution/state management, update docs. --- CLAUDE.md | 2 +- DESIGN_SPEC.md | 24 ++++-- README.md | 9 ++- src/ai_company/api/app.py | 52 ++++++------- src/ai_company/api/auth/config.py | 45 +++++++---- src/ai_company/api/auth/controller.py | 70 ++++++++++------- src/ai_company/api/auth/middleware.py | 70 ++++++++++++++--- src/ai_company/api/auth/secret.py | 8 ++ src/ai_company/api/auth/service.py | 36 ++++----- src/ai_company/api/exception_handlers.py | 19 +++-- src/ai_company/api/guards.py | 6 ++ src/ai_company/api/state.py | 68 +++++++++-------- .../observability/events/persistence.py | 3 + src/ai_company/persistence/repositories.py | 14 ++-- src/ai_company/persistence/sqlite/backend.py | 8 +- .../persistence/sqlite/user_repo.py | 51 +++++++++++-- tests/unit/api/auth/test_controller.py | 60 ++++++++++++++- tests/unit/api/auth/test_middleware.py | 34 +++++++++ tests/unit/api/auth/test_secret.py | 75 +++++++++++++++++++ tests/unit/api/auth/test_service.py | 12 +-- tests/unit/api/test_exception_handlers.py | 41 ++++++++-- tests/unit/api/test_state.py | 34 +++++++++ tests/unit/config/conftest.py | 2 + .../persistence/sqlite/test_migrations.py | 42 +++++++++++ tests/unit/persistence/test_migrations_v2.py | 2 +- tests/unit/persistence/test_protocol.py | 1 - 26 files changed, 604 insertions(+), 184 deletions(-) create mode 100644 tests/unit/api/auth/test_secret.py diff --git a/CLAUDE.md b/CLAUDE.md index 8fe06a5cb8..134b81491b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -65,7 +65,7 @@ curl http://localhost:3000/api/v1/health # backend (via web proxy) ```text src/ai_company/ - api/ # Litestar REST + WebSocket API (controllers, guards, channels) + api/ # Litestar REST + WebSocket API (controllers, guards, channels, JWT + API key auth) budget/ # Cost tracking, budget enforcement (pre-flight/in-flight checks, auto-downgrade), billing periods, cost tiers, quota/subscription tracking, CFO cost optimization (anomaly detection, efficiency analysis, downgrade recommendations, approval decisions), spending reports, budget errors (BudgetExhaustedError, DailyLimitExceededError, QuotaExhaustedError) cli/ # CLI interface (future — thin API wrapper if needed) communication/ # Message bus, dispatcher, messenger, channels, delegation, loop prevention, conflict resolution, meeting protocol diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index dfd1ce9a03..a28583aafa 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -2562,6 +2562,7 @@ The REST/WebSocket API is the **primary interface** for all consumers. The Web U ```text /api/v1/ ├── /health # Health check, readiness + ├── /auth # Authentication: setup, login, password change, me ├── /company # CRUD company config ├── /agents # List, hire, fire, modify agents ├── /departments # Department management @@ -2758,6 +2759,7 @@ Circular inheritance is detected via chain tracking and raises `TemplateInherita | **Docker API** | aiodocker | Async-native Docker API client for `DockerSandbox` backend | | **Tool Integration** | MCP SDK (`mcp`) | Industry standard for LLM-to-tool integration | | **Agent Comms** | A2A Protocol compatible | Future-proof inter-agent communication | +| **Authentication** | PyJWT + argon2-cffi | JWT (HMAC HS256/384/512) for session tokens, Argon2id for password hashing, SHA-256 for API key storage | | **Config Format** | YAML + Pydantic validation | Human-readable config with strict validation | | **CLI** | TBD (future, if needed) | Thin wrapper around the REST API for terminal use. May not be needed — interactive Scalar docs at `/docs/api` and `curl`/`httpie` may suffice | @@ -2980,7 +2982,7 @@ ai-company/ │ ├── persistence/ # Operational data persistence (§7.6) │ │ ├── __init__.py # Package exports │ │ ├── protocol.py # PersistenceBackend protocol (M5) -│ │ ├── repositories.py # Repository protocols: TaskRepository, CostRecordRepository, MessageRepository, ParkedContextRepository, AuditRepository +│ │ ├── repositories.py # Repository protocols: TaskRepository, CostRecordRepository, MessageRepository, ParkedContextRepository, AuditRepository, UserRepository, ApiKeyRepository │ │ ├── config.py # PersistenceConfig model (M5) │ │ ├── errors.py # Persistence error hierarchy (M5) │ │ ├── factory.py # create_backend() factory (M5) @@ -2991,7 +2993,8 @@ ai-company/ │ │ ├── hr_repositories.py # SQLite HR repositories (LifecycleEvent, TaskMetricRecord, CollaborationMetricRecord) │ │ ├── parked_context_repo.py # SQLiteParkedContextRepository (park/resume serialized agent state) │ │ ├── audit_repository.py # SQLiteAuditRepository (append-only audit entry persistence) -│ │ └── migrations.py # Schema migrations (user_version pragma) +│ │ ├── user_repo.py # SQLiteUserRepository + SQLiteApiKeyRepository +│ │ └── migrations.py # Schema migrations (user_version pragma, v1–v5) │ ├── observability/ # Structured logging & correlation │ │ ├── __init__.py # get_logger() entry point │ │ ├── _logger.py # Logger configuration @@ -3183,18 +3186,25 @@ ai-company/ │ ├── api/ # REST + WebSocket API (M6) │ │ ├── app.py # Litestar application factory, lifecycle hooks │ │ ├── approval_store.py # In-memory approval queue storage +│ │ ├── auth/ # JWT + API key authentication subsystem +│ │ │ ├── config.py # AuthConfig (frozen Pydantic, HMAC algorithm, exclude paths) +│ │ │ ├── controller.py # AuthController (setup, login, change-password, me) +│ │ │ ├── middleware.py # ApiAuthMiddleware (JWT-first, API key fallback) +│ │ │ ├── models.py # User, ApiKey, AuthenticatedUser, AuthMethod +│ │ │ ├── secret.py # JWT secret resolution (env var → persistence → auto-generate) +│ │ │ └── service.py # AuthService (Argon2id password hashing, JWT ops, API key hashing) │ │ ├── bus_bridge.py # Message-bus → WebSocket bridge │ │ ├── channels.py # WebSocket channel definitions │ │ ├── config.py # API configuration models (ServerConfig, CorsConfig) -│ │ ├── controllers/ # 14 class-based controllers + 1 WebSocket handler (15 route modules) +│ │ ├── controllers/ # 15 class-based controllers + 1 WebSocket handler (16 route modules) │ │ ├── dto.py # Request/response DTOs and envelopes -│ │ ├── errors.py # API error hierarchy (ApiError, NotFoundError, etc.) +│ │ ├── errors.py # API error hierarchy (ApiError, NotFoundError, UnauthorizedError, etc.) │ │ ├── exception_handlers.py # Litestar exception handler registration -│ │ ├── guards.py # Route guards — read/write access (stub auth, M7 real auth) -│ │ ├── middleware.py # Request logging middleware +│ │ ├── guards.py # Route guards — role-based read/write access control (HumanRole enum) +│ │ ├── middleware.py # Request logging, CSP middleware │ │ ├── pagination.py # Cursor-free offset/limit pagination │ │ ├── server.py # Uvicorn server runner -│ │ ├── state.py # Typed AppState container with service access +│ │ ├── state.py # Typed AppState container with service access (deferred auth init) │ │ └── ws_models.py # WebSocket event models (WsEvent, WsEventType) │ ├── cli/ # CLI interface (future, if needed) │ │ ├── __init__.py diff --git a/README.md b/README.md index db2750f090..e0a8074963 100644 --- a/README.md +++ b/README.md @@ -24,10 +24,11 @@ AI Company lets you spin up a virtual organization staffed entirely by AI agents - **Memory Interface (M5)** - Pluggable `MemoryBackend` protocol with capability discovery, shared knowledge protocol, domain models, config, factory, and context injection retrieval pipeline (ranking, token-budget formatting, non-inferable filtering). Shared organizational memory via `OrgMemoryBackend` protocol with hybrid prompt+retrieval backend. Memory consolidation/archival with pluggable strategies and retention enforcement - **Coordination Error Taxonomy (M5)** - Post-execution classification pipeline detecting logical contradictions, numerical drift, context omissions, and coordination failures - **Budget Enforcement (M5)** - `BudgetEnforcer` service with pre-flight checks, in-flight budget checking, auto-downgrade, configurable cost tiers, and quota/subscription tracking; `CostOptimizer` CFO service with anomaly detection, efficiency analysis, downgrade recommendations, and approval decisions; `ReportGenerator` for multi-dimensional spending reports -- **Litestar REST API (M6)** - 13 controllers + WebSocket handler covering company, agents, tasks, budget, approvals, analytics, messages, meetings, projects, departments, artifacts, providers, health, and WebSocket real-time feed +- **Litestar REST API (M6)** - 15 controllers + WebSocket handler covering company, agents, tasks, budget, approvals, analytics, messages, meetings, projects, departments, artifacts, providers, health, auth, and WebSocket real-time feed - **Human Approval Queue (M6)** - Approval submission, approve/reject with reason, list/filter by status, WebSocket notifications for approval events - **WebSocket Real-Time Feed (M6)** - Channel-based subscriptions (tasks, agents, budget, messages, system, approvals), per-channel payload filters, message-bus bridge -- **Route Guards (M6)** - Role-based read/write access control (stub auth for M6; real JWT/OAuth planned for M7) +- **Route Guards (M6)** - Role-based read/write access control with 5 human roles (CEO, Manager, Board Member, Pair Programmer, Observer) +- **JWT + API Key Authentication (M7)** - Mandatory auth middleware (JWT-first with API key fallback), Argon2id password hashing, first-run admin setup, password change flow, SHA-256 API key hashing, regex-based path exclusions - **HR Engine (M7)** - Hiring pipeline (request → generate candidate → approval → instantiate), onboarding checklists, offboarding pipeline (reassign → archive → notify → terminate), agent registry - **Performance Tracking (M7)** - Task metrics, CI-based quality scoring, behavioral collaboration scoring, Theil-Sen robust trend detection, multi-window rolling metric aggregation - **Progressive Trust (M7)** - 4 strategies (disabled/weighted/per-category/milestone) behind pluggable `TrustStrategy` protocol, trust level tracking, action permission evaluation @@ -38,12 +39,12 @@ AI Company lets you spin up a virtual organization staffed entirely by AI agents - **Memory Backend Adapter (M5)** - Memory protocols, retrieval pipeline, org memory, and consolidation are complete; initial Mem0 adapter backend ([ADR-001](docs/decisions/ADR-001-memory-layer.md)) pending; research backends (GraphRAG, Temporal KG) planned - **CLI Surface** - `cli/` package is placeholder-only -- **Security/Approval System (M7)** - Real authentication (JWT/OAuth) and approval workflow gates are planned +- **Security/Approval System (M7)** - SecOps agent with rule engine (soft-allow/hard-deny, fail-closed), audit log, output scanner, risk classifier, and ToolInvoker integration are implemented; progressive trust (4 strategies), promotion/demotion, autonomy levels (5 tiers with presets, resolver, change strategies) and approval timeout policies (wait-forever, auto-deny, tiered, escalation-chain with task park/resume) are implemented; JWT + API key authentication is implemented; approval workflow gates remain planned - **Advanced Product Surface** - web dashboard, external integrations ## Status -**M7: Security & Approval** partially complete — Docker sandbox, MCP bridge, code runner, SecOps agent, HR engine + performance tracking, progressive trust, promotion/demotion done; authentication/approval workflow gates remain. See [DESIGN_SPEC.md](DESIGN_SPEC.md) for the full high-level specification. +**M7: Security & Approval** partially complete — Docker sandbox, MCP bridge, code runner, SecOps agent, HR engine + performance tracking, progressive trust, promotion/demotion, JWT + API key authentication done; approval workflow gates remain. See [DESIGN_SPEC.md](DESIGN_SPEC.md) for the full high-level specification. ## Tech Stack diff --git a/src/ai_company/api/app.py b/src/ai_company/api/app.py index 8cffe96771..6fc3853001 100644 --- a/src/ai_company/api/app.py +++ b/src/ai_company/api/app.py @@ -135,19 +135,17 @@ async def _cleanup_on_failure( try: await message_bus.stop() except Exception: - logger.error( + logger.exception( API_APP_STARTUP, error="Cleanup: failed to stop message bus", - exc_info=True, ) if started_persistence and persistence is not None: try: await persistence.disconnect() except Exception: - logger.error( + logger.exception( API_APP_STARTUP, error="Cleanup: failed to disconnect persistence", - exc_info=True, ) @@ -169,38 +167,36 @@ async def _safe_startup( try: await persistence.connect() except Exception: - logger.error( + logger.exception( API_APP_STARTUP, error="Failed to connect persistence", - exc_info=True, ) raise started_persistence = True # Resolve JWT secret after persistence is up - if app_state._auth_service is None: # noqa: SLF001 - try: - secret = await resolve_jwt_secret(persistence) - auth_config = app_state.config.api.auth.with_secret( - secret, - ) - app_state._auth_service = AuthService(auth_config) # noqa: SLF001 - except Exception: - logger.error( - API_APP_STARTUP, - error="Failed to resolve JWT secret", - exc_info=True, - ) - raise + try: + secret = await resolve_jwt_secret(persistence) + auth_config = app_state.config.api.auth.with_secret( + secret, + ) + app_state.set_auth_service(AuthService(auth_config)) + except RuntimeError: + pass # Already configured (e.g. test-injected) + except Exception: + logger.exception( + API_APP_STARTUP, + error="Failed to resolve JWT secret", + ) + raise if message_bus is not None: try: await message_bus.start() except Exception: - logger.error( + logger.exception( API_APP_STARTUP, error="Failed to start message bus", - exc_info=True, ) raise started_bus = True @@ -208,10 +204,9 @@ async def _safe_startup( try: await bridge.start() except Exception: - logger.error( + logger.exception( API_APP_STARTUP, error="Failed to start message bus bridge", - exc_info=True, ) raise except Exception: @@ -234,28 +229,25 @@ async def _safe_shutdown( try: await bridge.stop() except Exception: - logger.error( + logger.exception( API_APP_SHUTDOWN, error="Failed to stop message bus bridge", - exc_info=True, ) if message_bus is not None: try: await message_bus.stop() except Exception: - logger.error( + logger.exception( API_APP_SHUTDOWN, error="Failed to stop message bus", - exc_info=True, ) if persistence is not None: try: await persistence.disconnect() except Exception: - logger.error( + logger.exception( API_APP_SHUTDOWN, error="Failed to disconnect persistence", - exc_info=True, ) diff --git a/src/ai_company/api/auth/config.py b/src/ai_company/api/auth/config.py index 72af06f111..ddce43c8ef 100644 --- a/src/ai_company/api/auth/config.py +++ b/src/ai_company/api/auth/config.py @@ -1,10 +1,29 @@ """Authentication configuration.""" -from pydantic import BaseModel, ConfigDict, Field +from typing import Literal, Self + +from pydantic import BaseModel, ConfigDict, Field, model_validator _MIN_SECRET_LENGTH = 32 +def _require_valid_secret(secret: str) -> None: + """Raise ``ValueError`` if *secret* is non-empty but too short. + + Args: + secret: JWT signing secret to validate. + + Raises: + ValueError: If *secret* is shorter than ``_MIN_SECRET_LENGTH``. + """ + if secret and len(secret) < _MIN_SECRET_LENGTH: + msg = ( + f"jwt_secret must be at least {_MIN_SECRET_LENGTH} " + f"characters (got {len(secret)})" + ) + raise ValueError(msg) + + class AuthConfig(BaseModel): """JWT and authentication configuration. @@ -13,16 +32,15 @@ class AuthConfig(BaseModel): 1. ``AI_COMPANY_JWT_SECRET`` environment variable (for multi-instance deployments sharing a common secret). - 2. Stored secret in the persistence ``settings`` table (auto-generated - on first run). - 3. Auto-generated and persisted on first startup. + 2. Previously persisted secret in the ``settings`` table. + 3. Auto-generate a new secret and persist it for future runs. At construction time the secret may be empty — it is populated before the first request is served. Attributes: jwt_secret: HMAC signing key (resolved at startup, repr-hidden). - jwt_algorithm: JWT signing algorithm. + jwt_algorithm: JWT signing algorithm (HMAC family only). jwt_expiry_minutes: Token lifetime in minutes. exclude_paths: URL paths excluded from auth middleware. """ @@ -34,9 +52,9 @@ class AuthConfig(BaseModel): repr=False, description="JWT signing secret (resolved at startup)", ) - jwt_algorithm: str = Field( + jwt_algorithm: Literal["HS256", "HS384", "HS512"] = Field( default="HS256", - description="JWT signing algorithm", + description="JWT signing algorithm (HMAC family)", ) jwt_expiry_minutes: int = Field( default=1440, @@ -58,6 +76,12 @@ class AuthConfig(BaseModel): ), ) + @model_validator(mode="after") + def _validate_secret_length(self) -> Self: + """Reject non-empty secrets shorter than the minimum.""" + _require_valid_secret(self.jwt_secret) + return self + def with_secret(self, secret: str) -> AuthConfig: """Return a copy with the JWT secret set. @@ -70,10 +94,5 @@ def with_secret(self, secret: str) -> AuthConfig: Raises: ValueError: If the secret is too short. """ - if len(secret) < _MIN_SECRET_LENGTH: - msg = ( - f"jwt_secret must be at least {_MIN_SECRET_LENGTH} " - f"characters (got {len(secret)})" - ) - raise ValueError(msg) + _require_valid_secret(secret) return self.model_copy(update={"jwt_secret": secret}) diff --git a/src/ai_company/api/auth/controller.py b/src/ai_company/api/auth/controller.py index 8f96838f59..eec8a8ca94 100644 --- a/src/ai_company/api/auth/controller.py +++ b/src/ai_company/api/auth/controller.py @@ -2,16 +2,17 @@ import uuid from datetime import UTC, datetime +from typing import Self from litestar import Controller, Response, get, post from litestar.connection import ASGIConnection # noqa: TC002 from litestar.exceptions import PermissionDeniedException -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator from ai_company.api.auth.models import AuthenticatedUser, User from ai_company.api.auth.service import AuthService # noqa: TC001 from ai_company.api.dto import ApiResponse -from ai_company.api.errors import ApiValidationError, ConflictError, UnauthorizedError +from ai_company.api.errors import ConflictError, UnauthorizedError from ai_company.api.guards import HumanRole from ai_company.core.types import NotBlankStr # noqa: TC001 from ai_company.observability import get_logger @@ -27,6 +28,24 @@ _MIN_PASSWORD_LENGTH = 12 +def _check_password_length(password: str) -> str: + """Validate that a password meets the minimum length requirement. + + Args: + password: Password to validate. + + Returns: + The password unchanged. + + Raises: + ValueError: If the password is too short. + """ + if len(password) < _MIN_PASSWORD_LENGTH: + msg = f"Password must be at least {_MIN_PASSWORD_LENGTH} characters" + raise ValueError(msg) + return password + + # ── Request DTOs ────────────────────────────────────────────── @@ -41,7 +60,13 @@ class SetupRequest(BaseModel): model_config = ConfigDict(frozen=True) username: NotBlankStr = Field(max_length=128) - password: NotBlankStr = Field(min_length=_MIN_PASSWORD_LENGTH, max_length=128) + password: NotBlankStr = Field(max_length=128) + + @model_validator(mode="after") + def _validate_password_length(self) -> Self: + """Reject passwords shorter than the minimum.""" + _check_password_length(self.password) + return self class LoginRequest(BaseModel): @@ -69,7 +94,13 @@ class ChangePasswordRequest(BaseModel): model_config = ConfigDict(frozen=True) current_password: NotBlankStr = Field(max_length=128) - new_password: NotBlankStr = Field(min_length=_MIN_PASSWORD_LENGTH, max_length=128) + new_password: NotBlankStr = Field(max_length=128) + + @model_validator(mode="after") + def _validate_password_length(self) -> Self: + """Reject new passwords shorter than the minimum.""" + _check_password_length(self.new_password) + return self # ── Response DTOs ───────────────────────────────────────────── @@ -103,9 +134,9 @@ class UserInfoResponse(BaseModel): model_config = ConfigDict(frozen=True) - id: str - username: str - role: str + id: NotBlankStr + username: NotBlankStr + role: HumanRole must_change_password: bool @@ -129,21 +160,12 @@ def require_password_changed( PermissionDeniedException: If password change is required. """ user = connection.scope.get("user") - if ( - user is not None - and isinstance(user, AuthenticatedUser) - and user.must_change_password - ): + if not isinstance(user, AuthenticatedUser): + return + if user.must_change_password: raise PermissionDeniedException(detail="Password change required") -def _validate_password(password: str) -> None: - """Raise if the password is too short.""" - if len(password) < _MIN_PASSWORD_LENGTH: - msg = f"Password must be at least {_MIN_PASSWORD_LENGTH} characters" - raise ApiValidationError(msg) - - # ── Controller ──────────────────────────────────────────────── @@ -168,8 +190,6 @@ async def setup( Only available when no users exist. Returns 409 after the first account is created. """ - _validate_password(data.password) - app_state = request.app.state["app_state"] auth_service: AuthService = app_state.auth_service persistence = app_state.persistence @@ -204,7 +224,7 @@ async def setup( data=TokenResponse( token=token, expires_in=expires_in, - must_change_password=True, + must_change_password=user.must_change_password, ), ), status_code=201, @@ -232,7 +252,6 @@ async def login( logger.warning( API_AUTH_FAILED, reason="invalid_credentials", - username=data.username, ) msg = "Invalid credentials" raise UnauthorizedError(msg) @@ -266,7 +285,6 @@ async def change_password( request: ASGIConnection, # type: ignore[type-arg] ) -> Response[ApiResponse[UserInfoResponse]]: """Validate current password and set new one.""" - _validate_password(data.new_password) auth_user: AuthenticatedUser = request.scope["user"] app_state = request.app.state["app_state"] auth_service: AuthService = app_state.auth_service @@ -307,7 +325,7 @@ async def change_password( data=UserInfoResponse( id=updated_user.id, username=updated_user.username, - role=updated_user.role.value, + role=updated_user.role, must_change_password=False, ), ), @@ -329,7 +347,7 @@ async def me( data=UserInfoResponse( id=auth_user.user_id, username=auth_user.username, - role=auth_user.role.value, + role=auth_user.role, must_change_password=auth_user.must_change_password, ), ), diff --git a/src/ai_company/api/auth/middleware.py b/src/ai_company/api/auth/middleware.py index 620de6ab06..9e089d93b5 100644 --- a/src/ai_company/api/auth/middleware.py +++ b/src/ai_company/api/auth/middleware.py @@ -12,7 +12,6 @@ from ai_company.api.auth.models import AuthenticatedUser, AuthMethod from ai_company.api.auth.service import AuthService -from ai_company.api.guards import HumanRole from ai_company.observability import get_logger from ai_company.observability.events.api import ( API_AUTH_FAILED, @@ -23,6 +22,7 @@ from litestar.connection import ASGIConnection from ai_company.api.auth.config import AuthConfig + from ai_company.api.state import AppState logger = get_logger(__name__) @@ -107,23 +107,46 @@ def _extract_bearer_token(header: str) -> str | None: async def _try_jwt_auth( token: str, auth_service: AuthService, - app_state: Any, + app_state: AppState, connection: ASGIConnection[Any, Any, Any, Any], ) -> AuthenticatedUser | None: - """Attempt JWT authentication.""" + """Attempt JWT authentication. + + Returns: + Authenticated user on success, or ``None`` if the token is + invalid, the ``sub`` claim is missing, or the user no longer + exists in the database. + """ try: claims = auth_service.decode_token(token) - except jwt.InvalidTokenError: + except jwt.InvalidTokenError as exc: + logger.warning( + API_AUTH_FAILED, + reason="jwt_invalid", + error_type=type(exc).__qualname__, + error=str(exc), + path=str(connection.url.path), + ) return None user_id = claims.get("sub") if not user_id: + logger.warning( + API_AUTH_FAILED, + reason="jwt_missing_sub", + path=str(connection.url.path), + ) return None - # Verify the user still exists persistence = app_state.persistence db_user = await persistence.users.get(user_id) if db_user is None: + logger.warning( + API_AUTH_FAILED, + reason="jwt_user_not_found", + user_id=user_id, + path=str(connection.url.path), + ) return None authenticated = AuthenticatedUser( @@ -145,31 +168,54 @@ async def _try_jwt_auth( async def _try_api_key_auth( token: str, - app_state: Any, + app_state: AppState, connection: ASGIConnection[Any, Any, Any, Any], ) -> AuthenticatedUser | None: - """Attempt API key authentication.""" + """Attempt API key authentication. + + Returns: + Authenticated user on success, or ``None`` if the key hash + is not found, the key is revoked or expired, or the owning + user no longer exists. + """ key_hash = AuthService.hash_api_key(token) persistence = app_state.persistence api_key = await persistence.api_keys.get_by_hash(key_hash) if api_key is None: return None - # Check revocation and expiry if api_key.revoked: + logger.warning( + API_AUTH_FAILED, + reason="api_key_revoked", + key_name=api_key.name, + path=str(connection.url.path), + ) return None if api_key.expires_at is not None and api_key.expires_at < datetime.now(UTC): + logger.warning( + API_AUTH_FAILED, + reason="api_key_expired", + key_name=api_key.name, + path=str(connection.url.path), + ) return None - # Look up the owning user db_user = await persistence.users.get(api_key.user_id) if db_user is None: + logger.error( + API_AUTH_FAILED, + reason="api_key_orphaned", + key_name=api_key.name, + user_id=api_key.user_id, + path=str(connection.url.path), + ) return None authenticated = AuthenticatedUser( user_id=db_user.id, username=db_user.username, - role=HumanRole(api_key.role), + role=api_key.role, auth_method=AuthMethod.API_KEY, must_change_password=db_user.must_change_password, ) @@ -200,7 +246,9 @@ def create_auth_middleware_class( Returns: Middleware class ready for use in the Litestar middleware stack. """ - exclude_paths = list(auth_config.exclude_paths) or None + exclude_paths = ( + list(auth_config.exclude_paths) if auth_config.exclude_paths else None + ) class ConfiguredAuthMiddleware(ApiAuthMiddleware): """Auth middleware with pre-configured exclude paths.""" diff --git a/src/ai_company/api/auth/secret.py b/src/ai_company/api/auth/secret.py index 7f5b57d540..34c825a3fa 100644 --- a/src/ai_company/api/auth/secret.py +++ b/src/ai_company/api/auth/secret.py @@ -11,6 +11,7 @@ _SETTING_KEY = "jwt_secret" _SECRET_LENGTH = 48 # 64 URL-safe base64 chars +_MIN_SECRET_LENGTH = 32 async def resolve_jwt_secret( @@ -31,6 +32,13 @@ async def resolve_jwt_secret( # 1. Env var override (highest priority) env_secret = os.environ.get("AI_COMPANY_JWT_SECRET", "").strip() if env_secret: + if len(env_secret) < _MIN_SECRET_LENGTH: + msg = ( + f"AI_COMPANY_JWT_SECRET must be at least " + f"{_MIN_SECRET_LENGTH} characters (got {len(env_secret)})" + ) + logger.error(API_APP_STARTUP, error=msg) + raise ValueError(msg) logger.info( API_APP_STARTUP, note="JWT secret loaded from AI_COMPANY_JWT_SECRET env var", diff --git a/src/ai_company/api/auth/service.py b/src/ai_company/api/auth/service.py index f6539fa36b..21172d1af4 100644 --- a/src/ai_company/api/auth/service.py +++ b/src/ai_company/api/auth/service.py @@ -1,7 +1,6 @@ """Authentication service — password hashing, JWT ops, API key hashing.""" import hashlib -import hmac import secrets from datetime import UTC, datetime, timedelta from typing import TYPE_CHECKING, Any @@ -11,7 +10,7 @@ from ai_company.api.auth.models import User # noqa: TC001 from ai_company.observability import get_logger -from ai_company.observability.events.api import API_AUTH_TOKEN_ISSUED +from ai_company.observability.events.api import API_AUTH_FAILED if TYPE_CHECKING: from ai_company.api.auth.config import AuthConfig @@ -28,7 +27,7 @@ class AuthService: - """Stateless authentication operations. + """Immutable authentication operations. Args: config: Authentication configuration (carries JWT secret). @@ -63,6 +62,18 @@ def verify_password(self, password: str, password_hash: str) -> bool: except argon2.exceptions.VerifyMismatchError: return False except argon2.exceptions.VerificationError: + logger.warning( + API_AUTH_FAILED, + reason="hash_verification_error", + exc_info=True, + ) + return False + except argon2.exceptions.InvalidHashError: + logger.warning( + API_AUTH_FAILED, + reason="invalid_hash_format", + exc_info=True, + ) return False def create_token(self, user: User) -> tuple[str, int]: @@ -89,11 +100,6 @@ def create_token(self, user: User) -> tuple[str, int]: self._config.jwt_secret, algorithm=self._config.jwt_algorithm, ) - logger.info( - API_AUTH_TOKEN_ISSUED, - user_id=user.id, - username=user.username, - ) return token, expiry_seconds def decode_token(self, token: str) -> dict[str, Any]: @@ -126,20 +132,6 @@ def hash_api_key(raw_key: str) -> str: """ return hashlib.sha256(raw_key.encode()).hexdigest() - @staticmethod - def verify_api_key(raw_key: str, stored_hash: str) -> bool: - """Constant-time comparison of API key hash. - - Args: - raw_key: Plaintext API key from request. - stored_hash: SHA-256 hex digest from storage. - - Returns: - ``True`` if the key matches. - """ - computed = hashlib.sha256(raw_key.encode()).hexdigest() - return hmac.compare_digest(computed, stored_hash) - @staticmethod def generate_api_key() -> str: """Generate a cryptographically secure API key. diff --git a/src/ai_company/api/exception_handlers.py b/src/ai_company/api/exception_handlers.py index a41be9ef18..8b222e2d45 100644 --- a/src/ai_company/api/exception_handlers.py +++ b/src/ai_company/api/exception_handlers.py @@ -100,12 +100,15 @@ def handle_api_error( ) -> Response[ApiResponse[None]]: """Map ``ApiError`` subclasses to their declared status code.""" _log_error(request, exc, status=exc.status_code) - # Return the class-level default message, not the - # caller-interpolated string (which may contain internal IDs). - exc_cls = type(exc) - default_msg = getattr(exc_cls, "default_message", "Internal server error") + # For 5xx errors return the generic class-level default to avoid + # leaking internals. For 4xx client errors return the actual + # exception message — it was set by the controller and is user-safe. + if exc.status_code >= _SERVER_ERROR_THRESHOLD: + msg = type(exc).default_message + else: + msg = str(exc) or type(exc).default_message return Response( - content=ApiResponse[None](error=default_msg), + content=ApiResponse[None](error=msg), status_code=exc.status_code, ) @@ -130,8 +133,9 @@ def handle_permission_denied( ) -> Response[ApiResponse[None]]: """Map ``PermissionDeniedException`` to 403.""" _log_error(request, exc, status=403) + detail = exc.detail or "Forbidden" return Response( - content=ApiResponse[None](error="Forbidden"), + content=ApiResponse[None](error=detail), status_code=403, ) @@ -156,8 +160,9 @@ def handle_not_authorized( ) -> Response[ApiResponse[None]]: """Map ``NotAuthorizedException`` to 401.""" _log_error(request, exc, status=401) + detail = exc.detail or "Authentication required" return Response( - content=ApiResponse[None](error="Authentication required"), + content=ApiResponse[None](error=detail), status_code=401, ) diff --git a/src/ai_company/api/guards.py b/src/ai_company/api/guards.py index 9a46dc2a44..e47f6683ff 100644 --- a/src/ai_company/api/guards.py +++ b/src/ai_company/api/guards.py @@ -43,6 +43,12 @@ def _get_role(connection: ASGIConnection) -> HumanRole | None: # type: ignore[t try: return HumanRole(user.role) except ValueError: + logger.warning( + API_GUARD_DENIED, + guard="_get_role", + invalid_role=str(user.role), + path=str(connection.url.path), + ) return None return None diff --git a/src/ai_company/api/state.py b/src/ai_company/api/state.py index 586381a1d2..95db5ab407 100644 --- a/src/ai_company/api/state.py +++ b/src/ai_company/api/state.py @@ -62,50 +62,54 @@ def __init__( # noqa: PLR0913 self._auth_service = auth_service self.startup_time = startup_time + def _require_service[T](self, service: T | None, name: str) -> T: + """Return *service* or raise 503 if not configured. + + Args: + service: Service instance (``None`` when not configured). + name: Service name for logging and error message. + + Raises: + ServiceUnavailableError: If *service* is ``None``. + """ + if service is None: + logger.warning(API_SERVICE_UNAVAILABLE, service=name) + msg = f"{name.replace('_', ' ').title()} not configured" + raise ServiceUnavailableError(msg) + return service + @property def persistence(self) -> PersistenceBackend: """Return persistence backend or raise 503.""" - if self._persistence is None: - logger.warning( - API_SERVICE_UNAVAILABLE, - service="persistence", - ) - msg = "Persistence backend not configured" - raise ServiceUnavailableError(msg) - return self._persistence + return self._require_service(self._persistence, "persistence") @property def message_bus(self) -> MessageBus: """Return message bus or raise 503.""" - if self._message_bus is None: - logger.warning( - API_SERVICE_UNAVAILABLE, - service="message_bus", - ) - msg = "Message bus not configured" - raise ServiceUnavailableError(msg) - return self._message_bus + return self._require_service(self._message_bus, "message_bus") @property def cost_tracker(self) -> CostTracker: """Return cost tracker or raise 503.""" - if self._cost_tracker is None: - logger.warning( - API_SERVICE_UNAVAILABLE, - service="cost_tracker", - ) - msg = "Cost tracker not configured" - raise ServiceUnavailableError(msg) - return self._cost_tracker + return self._require_service(self._cost_tracker, "cost_tracker") @property def auth_service(self) -> AuthService: """Return auth service or raise 503.""" - if self._auth_service is None: - logger.warning( - API_SERVICE_UNAVAILABLE, - service="auth_service", - ) - msg = "Auth service not configured" - raise ServiceUnavailableError(msg) - return self._auth_service + return self._require_service(self._auth_service, "auth_service") + + def set_auth_service(self, service: AuthService) -> None: + """Set the auth service (deferred initialisation). + + Called once during startup after the JWT secret is resolved. + + Args: + service: Fully configured auth service. + + Raises: + RuntimeError: If the auth service was already configured. + """ + if self._auth_service is not None: + msg = "Auth service already configured" + raise RuntimeError(msg) + self._auth_service = service diff --git a/src/ai_company/observability/events/persistence.py b/src/ai_company/observability/events/persistence.py index 48d271e89d..e0d7eb1a7a 100644 --- a/src/ai_company/observability/events/persistence.py +++ b/src/ai_company/observability/events/persistence.py @@ -143,3 +143,6 @@ PERSISTENCE_API_KEY_LIST_FAILED: Final[str] = "persistence.api_key.list_failed" PERSISTENCE_API_KEY_DELETED: Final[str] = "persistence.api_key.deleted" PERSISTENCE_API_KEY_DELETE_FAILED: Final[str] = "persistence.api_key.delete_failed" + +PERSISTENCE_SETTING_FETCH_FAILED: Final[str] = "persistence.setting.fetch_failed" +PERSISTENCE_SETTING_SAVE_FAILED: Final[str] = "persistence.setting.save_failed" diff --git a/src/ai_company/persistence/repositories.py b/src/ai_company/persistence/repositories.py index 0ec74cf34b..b03e7a8982 100644 --- a/src/ai_company/persistence/repositories.py +++ b/src/ai_company/persistence/repositories.py @@ -338,7 +338,7 @@ async def save(self, user: User) -> None: """ ... - async def get(self, user_id: str) -> User | None: + async def get(self, user_id: NotBlankStr) -> User | None: """Retrieve a user by ID. Args: @@ -352,7 +352,7 @@ async def get(self, user_id: str) -> User | None: """ ... - async def get_by_username(self, username: str) -> User | None: + async def get_by_username(self, username: NotBlankStr) -> User | None: """Retrieve a user by username. Args: @@ -388,7 +388,7 @@ async def count(self) -> int: """ ... - async def delete(self, user_id: str) -> bool: + async def delete(self, user_id: NotBlankStr) -> bool: """Delete a user by ID. Args: @@ -418,7 +418,7 @@ async def save(self, key: ApiKey) -> None: """ ... - async def get(self, key_id: str) -> ApiKey | None: + async def get(self, key_id: NotBlankStr) -> ApiKey | None: """Retrieve an API key by ID. Args: @@ -432,7 +432,7 @@ async def get(self, key_id: str) -> ApiKey | None: """ ... - async def get_by_hash(self, key_hash: str) -> ApiKey | None: + async def get_by_hash(self, key_hash: NotBlankStr) -> ApiKey | None: """Retrieve an API key by its hash. Args: @@ -446,7 +446,7 @@ async def get_by_hash(self, key_hash: str) -> ApiKey | None: """ ... - async def list_by_user(self, user_id: str) -> tuple[ApiKey, ...]: + async def list_by_user(self, user_id: NotBlankStr) -> tuple[ApiKey, ...]: """List API keys belonging to a user. Args: @@ -460,7 +460,7 @@ async def list_by_user(self, user_id: str) -> tuple[ApiKey, ...]: """ ... - async def delete(self, key_id: str) -> bool: + async def delete(self, key_id: NotBlankStr) -> bool: """Delete an API key by ID. Args: diff --git a/src/ai_company/persistence/sqlite/backend.py b/src/ai_company/persistence/sqlite/backend.py index b1a5b1fa08..61816c470b 100644 --- a/src/ai_company/persistence/sqlite/backend.py +++ b/src/ai_company/persistence/sqlite/backend.py @@ -19,6 +19,8 @@ PERSISTENCE_BACKEND_HEALTH_CHECK, PERSISTENCE_BACKEND_NOT_CONNECTED, PERSISTENCE_BACKEND_WAL_MODE_FAILED, + PERSISTENCE_SETTING_FETCH_FAILED, + PERSISTENCE_SETTING_SAVE_FAILED, ) from ai_company.persistence.errors import ( PersistenceConnectionError, @@ -370,7 +372,8 @@ async def get_setting(self, key: str) -> str | None: except (sqlite3.Error, aiosqlite.Error) as exc: msg = f"Failed to get setting {key!r}" logger.exception( - PERSISTENCE_BACKEND_NOT_CONNECTED, + PERSISTENCE_SETTING_FETCH_FAILED, + key=key, error=str(exc), ) raise QueryError(msg) from exc @@ -397,7 +400,8 @@ async def set_setting(self, key: str, value: str) -> None: except (sqlite3.Error, aiosqlite.Error) as exc: msg = f"Failed to set setting {key!r}" logger.exception( - PERSISTENCE_BACKEND_NOT_CONNECTED, + PERSISTENCE_SETTING_SAVE_FAILED, + key=key, error=str(exc), ) raise QueryError(msg) from exc diff --git a/src/ai_company/persistence/sqlite/user_repo.py b/src/ai_company/persistence/sqlite/user_repo.py index cd8ae7e73e..b686f6464e 100644 --- a/src/ai_company/persistence/sqlite/user_repo.py +++ b/src/ai_company/persistence/sqlite/user_repo.py @@ -118,8 +118,18 @@ async def get(self, user_id: str) -> User | None: if row is None: logger.debug(PERSISTENCE_USER_FETCHED, user_id=user_id, found=False) return None + try: + user = _row_to_user(row) + except (ValueError, ValidationError) as exc: + msg = f"Failed to deserialize user {user_id!r}" + logger.exception( + PERSISTENCE_USER_FETCH_FAILED, + user_id=user_id, + error=str(exc), + ) + raise QueryError(msg) from exc logger.debug(PERSISTENCE_USER_FETCHED, user_id=user_id, found=True) - return _row_to_user(row) + return user async def get_by_username(self, username: str) -> User | None: """Retrieve a user by username.""" @@ -138,18 +148,32 @@ async def get_by_username(self, username: str) -> User | None: raise QueryError(msg) from exc if row is None: return None - return _row_to_user(row) + try: + return _row_to_user(row) + except (ValueError, ValidationError) as exc: + msg = f"Failed to deserialize user {username!r}" + logger.exception( + PERSISTENCE_USER_FETCH_FAILED, + username=username, + error=str(exc), + ) + raise QueryError(msg) from exc async def list_users(self) -> tuple[User, ...]: """List all users.""" try: cursor = await self._db.execute("SELECT * FROM users ORDER BY created_at") rows = await cursor.fetchall() - except (sqlite3.Error, aiosqlite.Error, ValidationError) as exc: + except (sqlite3.Error, aiosqlite.Error) as exc: msg = "Failed to list users" logger.exception(PERSISTENCE_USER_LIST_FAILED, error=str(exc)) raise QueryError(msg) from exc - users = tuple(_row_to_user(row) for row in rows) + try: + users = tuple(_row_to_user(row) for row in rows) + except (ValueError, ValidationError) as exc: + msg = "Failed to deserialize users" + logger.exception(PERSISTENCE_USER_LIST_FAILED, error=str(exc)) + raise QueryError(msg) from exc logger.debug(PERSISTENCE_USER_LISTED, count=len(users)) return users @@ -251,8 +275,18 @@ async def get(self, key_id: str) -> ApiKey | None: if row is None: logger.debug(PERSISTENCE_API_KEY_FETCHED, key_id=key_id, found=False) return None + try: + key = _row_to_api_key(row) + except (ValueError, ValidationError) as exc: + msg = f"Failed to deserialize API key {key_id!r}" + logger.exception( + PERSISTENCE_API_KEY_FETCH_FAILED, + key_id=key_id, + error=str(exc), + ) + raise QueryError(msg) from exc logger.debug(PERSISTENCE_API_KEY_FETCHED, key_id=key_id, found=True) - return _row_to_api_key(row) + return key async def get_by_hash(self, key_hash: str) -> ApiKey | None: """Retrieve an API key by its hash.""" @@ -268,7 +302,12 @@ async def get_by_hash(self, key_hash: str) -> ApiKey | None: raise QueryError(msg) from exc if row is None: return None - return _row_to_api_key(row) + try: + return _row_to_api_key(row) + except (ValueError, ValidationError) as exc: + msg = "Failed to deserialize API key by hash" + logger.exception(PERSISTENCE_API_KEY_FETCH_FAILED, error=str(exc)) + raise QueryError(msg) from exc async def list_by_user(self, user_id: str) -> tuple[ApiKey, ...]: """List API keys belonging to a user.""" diff --git a/tests/unit/api/auth/test_controller.py b/tests/unit/api/auth/test_controller.py index e918299f81..365660fcc3 100644 --- a/tests/unit/api/auth/test_controller.py +++ b/tests/unit/api/auth/test_controller.py @@ -5,6 +5,7 @@ import pytest from litestar.testing import TestClient # noqa: TC002 +from ai_company.api.guards import HumanRole from tests.unit.api.conftest import make_auth_headers @@ -76,7 +77,7 @@ def test_setup_short_password_rejected(self, bare_client: TestClient[Any]) -> No "/api/v1/auth/setup", json={"username": "admin", "password": "short"}, ) - assert response.status_code == 422 + assert response.status_code == 400 @pytest.mark.unit @@ -191,3 +192,60 @@ def test_me_returns_user_info(self, test_client: TestClient[Any]) -> None: def test_me_requires_auth(self, bare_client: TestClient[Any]) -> None: response = bare_client.get("/api/v1/auth/me") assert response.status_code == 401 + + +@pytest.mark.unit +class TestRequirePasswordChanged: + def test_blocks_user_with_must_change_password(self) -> None: + """Guard raises PermissionDeniedException for flagged users.""" + from unittest.mock import MagicMock + + from litestar.exceptions import PermissionDeniedException + + from ai_company.api.auth.controller import require_password_changed + from ai_company.api.auth.models import AuthenticatedUser, AuthMethod + + user = AuthenticatedUser( + user_id="u1", + username="admin", + role=HumanRole.CEO, + auth_method=AuthMethod.JWT, + must_change_password=True, + ) + connection = MagicMock() + connection.scope = {"user": user} + + with pytest.raises(PermissionDeniedException): + require_password_changed(connection, None) + + def test_allows_user_without_flag(self) -> None: + """Guard passes when must_change_password is False.""" + from unittest.mock import MagicMock + + from ai_company.api.auth.controller import require_password_changed + from ai_company.api.auth.models import AuthenticatedUser, AuthMethod + + user = AuthenticatedUser( + user_id="u1", + username="admin", + role=HumanRole.CEO, + auth_method=AuthMethod.JWT, + must_change_password=False, + ) + connection = MagicMock() + connection.scope = {"user": user} + + # Should not raise + require_password_changed(connection, None) + + def test_allows_when_no_user_in_scope(self) -> None: + """Guard passes when no user is in scope (pre-auth).""" + from unittest.mock import MagicMock + + from ai_company.api.auth.controller import require_password_changed + + connection = MagicMock() + connection.scope = {} + + # Should not raise + require_password_changed(connection, None) diff --git a/tests/unit/api/auth/test_middleware.py b/tests/unit/api/auth/test_middleware.py index 203ae80974..9361d97abb 100644 --- a/tests/unit/api/auth/test_middleware.py +++ b/tests/unit/api/auth/test_middleware.py @@ -234,6 +234,40 @@ async def test_expired_api_key_returns_401(self) -> None: assert resp.status_code == 401 +@pytest.mark.unit +class TestAuthMiddlewareApiKeyEdgeCases: + async def test_api_key_with_deleted_owner_returns_401(self) -> None: + svc = _make_auth_service() + user = _make_user(svc) + persistence = FakePersistenceBackend() + await persistence.connect() + # Save the user, create a key, then delete the user + await persistence.users.save(user) + + raw_key = AuthService.generate_api_key() + key_hash = AuthService.hash_api_key(raw_key) + now = datetime.now(UTC) + api_key = ApiKey( + id="key-orphan", + key_hash=key_hash, + name="orphaned-key", + role=HumanRole.CEO, + user_id=user.id, + created_at=now, + ) + await persistence.api_keys.save(api_key) + await persistence.users.delete(user.id) + + app = _build_app(auth_service=svc, persistence=persistence) + + with TestClient(app) as client: + resp = client.get( + "/protected", + headers={"Authorization": f"Bearer {raw_key}"}, + ) + assert resp.status_code == 401 + + @pytest.mark.unit class TestAuthMiddlewareExcludePaths: async def test_excluded_path_skips_auth(self) -> None: diff --git a/tests/unit/api/auth/test_secret.py b/tests/unit/api/auth/test_secret.py new file mode 100644 index 0000000000..716a9993fd --- /dev/null +++ b/tests/unit/api/auth/test_secret.py @@ -0,0 +1,75 @@ +"""Tests for JWT secret resolution chain.""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from ai_company.api.auth.secret import resolve_jwt_secret + + +def _make_persistence(stored_secret: str | None = None) -> AsyncMock: + """Build a fake persistence backend for secret resolution tests.""" + persistence = AsyncMock() + persistence.get_setting = AsyncMock(return_value=stored_secret) + persistence.set_setting = AsyncMock() + return persistence + + +@pytest.mark.unit +class TestResolveJwtSecret: + async def test_env_var_takes_priority(self) -> None: + secret = "env-secret-that-is-at-least-32-characters!!" + persistence = _make_persistence(stored_secret="stored-secret-32-chars-long!!!!") + with patch.dict("os.environ", {"AI_COMPANY_JWT_SECRET": secret}): + result = await resolve_jwt_secret(persistence) + + assert result == secret + persistence.get_setting.assert_not_called() + + async def test_stored_secret_used_when_no_env_var(self) -> None: + stored = "stored-secret-that-is-at-least-32-chars!!" + persistence = _make_persistence(stored_secret=stored) + with patch.dict("os.environ", {}, clear=True): + result = await resolve_jwt_secret(persistence) + + assert result == stored + + async def test_generates_and_persists_when_nothing_stored(self) -> None: + persistence = _make_persistence(stored_secret=None) + with patch.dict("os.environ", {}, clear=True): + result = await resolve_jwt_secret(persistence) + + assert len(result) >= 32 + persistence.set_setting.assert_awaited_once_with("jwt_secret", result) + + async def test_env_var_too_short_raises(self) -> None: + persistence = _make_persistence() + with ( + patch.dict("os.environ", {"AI_COMPANY_JWT_SECRET": "short"}), + pytest.raises(ValueError, match="at least 32 characters"), + ): + await resolve_jwt_secret(persistence) + + async def test_env_var_whitespace_stripped(self) -> None: + secret = " env-secret-that-is-at-least-32-characters!! " + persistence = _make_persistence() + with patch.dict("os.environ", {"AI_COMPANY_JWT_SECRET": secret}): + result = await resolve_jwt_secret(persistence) + + assert result == secret.strip() + + async def test_empty_env_var_falls_through(self) -> None: + stored = "stored-secret-that-is-at-least-32-chars!!" + persistence = _make_persistence(stored_secret=stored) + with patch.dict("os.environ", {"AI_COMPANY_JWT_SECRET": ""}): + result = await resolve_jwt_secret(persistence) + + assert result == stored + + async def test_whitespace_only_env_var_falls_through(self) -> None: + stored = "stored-secret-that-is-at-least-32-chars!!" + persistence = _make_persistence(stored_secret=stored) + with patch.dict("os.environ", {"AI_COMPANY_JWT_SECRET": " "}): + result = await resolve_jwt_secret(persistence) + + assert result == stored diff --git a/tests/unit/api/auth/test_service.py b/tests/unit/api/auth/test_service.py index e5992fc032..d0e9615568 100644 --- a/tests/unit/api/auth/test_service.py +++ b/tests/unit/api/auth/test_service.py @@ -126,14 +126,10 @@ def test_hash_deterministic(self) -> None: h2 = AuthService.hash_api_key("my-key") assert h1 == h2 - def test_verify_correct_key(self) -> None: - key = AuthService.generate_api_key() - h = AuthService.hash_api_key(key) - assert AuthService.verify_api_key(key, h) - - def test_verify_wrong_key(self) -> None: - h = AuthService.hash_api_key("real-key") - assert not AuthService.verify_api_key("wrong-key", h) + def test_different_keys_different_hashes(self) -> None: + h1 = AuthService.hash_api_key("key-one") + h2 = AuthService.hash_api_key("key-two") + assert h1 != h2 def test_generate_key_unique(self) -> None: k1 = AuthService.generate_api_key() diff --git a/tests/unit/api/test_exception_handlers.py b/tests/unit/api/test_exception_handlers.py index e89a630a79..9a446afcf2 100644 --- a/tests/unit/api/test_exception_handlers.py +++ b/tests/unit/api/test_exception_handlers.py @@ -6,7 +6,13 @@ from litestar import Litestar, get from litestar.testing import TestClient -from ai_company.api.errors import ConflictError, ForbiddenError, NotFoundError +from ai_company.api.errors import ( + ApiValidationError, + ConflictError, + ForbiddenError, + NotFoundError, + UnauthorizedError, +) from ai_company.api.exception_handlers import EXCEPTION_HANDLERS from ai_company.persistence.errors import ( DuplicateRecordError, @@ -73,8 +79,8 @@ async def handler() -> None: resp = client.get("/test") assert resp.status_code == 404 body = resp.json() - # handle_api_error returns the default class message. - assert body["error"] == "Resource not found" + # 4xx errors return the actual exception message + assert body["error"] == "nope" def test_api_conflict_error_maps_to_409(self) -> None: @get("/test") @@ -86,7 +92,7 @@ async def handler() -> None: resp = client.get("/test") assert resp.status_code == 409 body = resp.json() - assert body["error"] == "Resource conflict" + assert body["error"] == "conflict" def test_api_forbidden_error_maps_to_403(self) -> None: @get("/test") @@ -98,7 +104,7 @@ async def handler() -> None: resp = client.get("/test") assert resp.status_code == 403 body = resp.json() - assert body["error"] == "Forbidden" + assert body["error"] == "denied" def test_value_error_falls_through_to_catch_all(self) -> None: @get("/test") @@ -122,3 +128,28 @@ async def handler() -> None: body = resp.json() assert body["success"] is False assert body["error"] == "Internal server error" + + def test_unauthorized_error_maps_to_401(self) -> None: + @get("/test") + async def handler() -> None: + msg = "Invalid credentials" + raise UnauthorizedError(msg) + + with TestClient(_make_app(handler)) as client: + resp = client.get("/test") + assert resp.status_code == 401 + body = resp.json() + # 4xx returns the actual message, not the generic default + assert body["error"] == "Invalid credentials" + + def test_validation_error_maps_to_422(self) -> None: + @get("/test") + async def handler() -> None: + msg = "Bad field" + raise ApiValidationError(msg) + + with TestClient(_make_app(handler)) as client: + resp = client.get("/test") + assert resp.status_code == 422 + body = resp.json() + assert body["error"] == "Bad field" diff --git a/tests/unit/api/test_state.py b/tests/unit/api/test_state.py index c48494ec20..792907cead 100644 --- a/tests/unit/api/test_state.py +++ b/tests/unit/api/test_state.py @@ -56,3 +56,37 @@ def test_cost_tracker_returns_when_set(self) -> None: tracker = CostTracker() state = _make_state(cost_tracker=tracker) assert state.cost_tracker is tracker + + def test_auth_service_raises_when_none(self) -> None: + state = _make_state(auth_service=None) + with pytest.raises(ServiceUnavailableError): + _ = state.auth_service + + def test_auth_service_returns_when_set(self) -> None: + from ai_company.api.auth.config import AuthConfig + from ai_company.api.auth.service import AuthService + + secret = "test-secret-that-is-at-least-32-characters-long" + svc = AuthService(AuthConfig(jwt_secret=secret)) + state = _make_state(auth_service=svc) + assert state.auth_service is svc + + def test_set_auth_service_succeeds_once(self) -> None: + from ai_company.api.auth.config import AuthConfig + from ai_company.api.auth.service import AuthService + + secret = "test-secret-that-is-at-least-32-characters-long" + svc = AuthService(AuthConfig(jwt_secret=secret)) + state = _make_state() + state.set_auth_service(svc) + assert state.auth_service is svc + + def test_set_auth_service_twice_raises(self) -> None: + from ai_company.api.auth.config import AuthConfig + from ai_company.api.auth.service import AuthService + + secret = "test-secret-that-is-at-least-32-characters-long" + svc = AuthService(AuthConfig(jwt_secret=secret)) + state = _make_state(auth_service=svc) + with pytest.raises(RuntimeError, match="already configured"): + state.set_auth_service(svc) diff --git a/tests/unit/config/conftest.py b/tests/unit/config/conftest.py index f05e50b05d..7300e5af2d 100644 --- a/tests/unit/config/conftest.py +++ b/tests/unit/config/conftest.py @@ -5,6 +5,7 @@ import pytest from polyfactory.factories.pydantic_factory import ModelFactory +from ai_company.api.config import ApiConfig from ai_company.budget.config import BudgetConfig from ai_company.budget.coordination_config import CoordinationMetricsConfig from ai_company.budget.cost_tiers import CostTiersConfig @@ -78,6 +79,7 @@ class RootConfigFactory(ModelFactory[RootConfig]): custom_roles = () providers: dict[str, ProviderConfig] = {} # noqa: RUF012 config = CompanyConfig() + api = ApiConfig() budget = BudgetConfig() communication = CommunicationConfig() routing = RoutingConfig() diff --git a/tests/unit/persistence/sqlite/test_migrations.py b/tests/unit/persistence/sqlite/test_migrations.py index f2d979ed87..4a7f4821fa 100644 --- a/tests/unit/persistence/sqlite/test_migrations.py +++ b/tests/unit/persistence/sqlite/test_migrations.py @@ -144,6 +144,48 @@ async def test_v4_creates_audit_entry_indexes( } assert expected.issubset(indexes) + async def test_v5_creates_users_table( + self, memory_db: aiosqlite.Connection + ) -> None: + await run_migrations(memory_db) + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='users'" + ) + row = await cursor.fetchone() + assert row is not None + + async def test_v5_creates_api_keys_table( + self, memory_db: aiosqlite.Connection + ) -> None: + await run_migrations(memory_db) + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='api_keys'" + ) + row = await cursor.fetchone() + assert row is not None + + async def test_v5_creates_settings_table( + self, memory_db: aiosqlite.Connection + ) -> None: + await run_migrations(memory_db) + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='settings'" + ) + row = await cursor.fetchone() + assert row is not None + + async def test_v5_creates_user_indexes( + self, memory_db: aiosqlite.Connection + ) -> None: + await run_migrations(memory_db) + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='index' " + "AND name LIKE 'idx_%' AND name LIKE '%user%' ORDER BY name" + ) + indexes = {row[0] for row in await cursor.fetchall()} + assert "idx_users_username" in indexes + assert "idx_api_keys_user_id" in indexes + async def test_migration_failure_raises_migration_error( self, memory_db: aiosqlite.Connection ) -> None: diff --git a/tests/unit/persistence/test_migrations_v2.py b/tests/unit/persistence/test_migrations_v2.py index 1c33500761..9e80beaf52 100644 --- a/tests/unit/persistence/test_migrations_v2.py +++ b/tests/unit/persistence/test_migrations_v2.py @@ -58,7 +58,7 @@ async def test_v1_to_v2_migration(self, memory_db: aiosqlite.Connection) -> None assert await get_user_version(memory_db) == 1 await run_migrations(memory_db) - assert await get_user_version(memory_db) == 4 + assert await get_user_version(memory_db) == 5 cursor = await memory_db.execute( "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" diff --git a/tests/unit/persistence/test_protocol.py b/tests/unit/persistence/test_protocol.py index 06d932ea59..7b6db34524 100644 --- a/tests/unit/persistence/test_protocol.py +++ b/tests/unit/persistence/test_protocol.py @@ -203,7 +203,6 @@ async def delete(self, key_id: str) -> bool: return False - class _FakeBackend: async def connect(self) -> None: pass From e044eca2b1123725d7777ae0aee1eff6b2f57d3c Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Wed, 11 Mar 2026 01:53:08 +0100 Subject: [PATCH 3/5] fix: address 42 PR review items from local agents and external reviewers - AuthConfig: make exclude_paths None-default, derived from api_prefix - AuthService: add empty-secret guards, pwd_sig claim, async hashing - Middleware: JWT-only fallthrough for dot-tokens, pwd_sig validation - Controller: TOCTOU race guard on setup, path-based guard exemptions - Models: switch datetime to AwareDatetime, TokenResponse validation - Secret: import MIN_SECRET_LENGTH, validate stored secret length - State: add has_auth_service property, error logging - Protocol: add users/api_keys to docstring, NotBlankStr for settings - SQLite repos: add list_by_user error handling, NotBlankStr params - Events: add PERSISTENCE_SETTING_FETCHED/SAVED constants - DESIGN_SPEC: update implementation snapshot (JWT auth done) - Tests: add corrupted hash, missing sub, bearer edge cases, parametrize guards, short password change test --- DESIGN_SPEC.md | 2 +- src/ai_company/api/app.py | 92 ++++++++++++++----- src/ai_company/api/auth/config.py | 32 ++++--- src/ai_company/api/auth/controller.py | 66 +++++++++---- src/ai_company/api/auth/middleware.py | 29 +++++- src/ai_company/api/auth/models.py | 15 ++- src/ai_company/api/auth/secret.py | 27 ++++-- src/ai_company/api/auth/service.py | 57 +++++++++++- src/ai_company/api/config.py | 6 +- src/ai_company/api/controllers/approvals.py | 8 +- src/ai_company/api/state.py | 13 ++- .../observability/events/persistence.py | 2 + src/ai_company/persistence/protocol.py | 6 +- .../persistence/sqlite/user_repo.py | 26 ++++-- tests/unit/api/auth/test_config.py | 9 +- tests/unit/api/auth/test_controller.py | 75 +++++++++++++++ tests/unit/api/auth/test_middleware.py | 28 ++++++ tests/unit/api/auth/test_service.py | 35 +++++++ tests/unit/api/conftest.py | 39 +++++++- 19 files changed, 462 insertions(+), 105 deletions(-) diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index a28583aafa..60ec1f775b 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -81,7 +81,7 @@ The MVP validates the core hypothesis: **a single agent can complete a real task > **Implementation snapshot (2026-03-10):** > - **Done:** M0–M6 (tooling, config/core, providers, single-agent engine, multi-agent orchestration, API/CLI surface) + Docker sandbox (#50), MCP bridge (#53), code runner + HR engine (hiring/firing/onboarding/offboarding/registry) + performance tracking (task metrics, quality scoring, collaboration scoring, trend detection, rolling windows). Memory layer backend selected ([ADR-001](docs/decisions/ADR-001-memory-layer.md)). Persistence backend (§7.6) completed (including audit entry persistence via AuditRepository + SQLite backend). Memory retrieval pipeline (#41: ranking, token-budget formatting, context injection, non-inferable filtering) complete. Budget enforcement complete (BudgetEnforcer + configurable cost tiers + quota/subscription tracking). CFO cost optimization complete (CostOptimizer: anomaly detection, efficiency analysis, downgrade recommendations, routing optimization, approval decisions; ReportGenerator: multi-dimensional spending reports). Shared org memory (#125: HybridPromptRetrievalBackend, OrgFactStore, access control, factory) complete. Memory consolidation/archival (#48: ConsolidationService, SimpleConsolidationStrategy, RetentionEnforcer, ArchivalStore protocol) complete. SecOps agent (rule engine, audit log, output scanner, output scan response policies (redact/withhold/log-only/autonomy-tiered), risk classifier, ToolInvoker integration), progressive trust (4 strategies: disabled/weighted/per-category/milestone behind TrustStrategy protocol), promotion/demotion (criteria evaluation, approval strategies, model mapping). Autonomy levels (#42: AutonomyLevel enum, presets, 3-level resolver, rule-based auto-downgrade/human-only promotion change strategy) + approval timeout policies (#126: 4 timeout policies, park/resume service, risk tier classifier, timeout checker) complete. -> - **Remaining:** JWT/OAuth auth, approval workflow gates. +> - **Remaining:** Approval workflow gates. ### 1.5 Configuration Philosophy diff --git a/src/ai_company/api/app.py b/src/ai_company/api/app.py index 6fc3853001..c48c84dcbf 100644 --- a/src/ai_company/api/app.py +++ b/src/ai_company/api/app.py @@ -19,6 +19,7 @@ from ai_company import __version__ from ai_company.api.approval_store import ApprovalStore +from ai_company.api.auth.controller import require_password_changed from ai_company.api.auth.middleware import create_auth_middleware_class from ai_company.api.auth.secret import resolve_jwt_secret from ai_company.api.auth.service import AuthService @@ -149,6 +150,55 @@ async def _cleanup_on_failure( ) +async def _init_persistence( + persistence: PersistenceBackend, + app_state: AppState, +) -> None: + """Connect persistence, run migrations, and resolve JWT secret. + + Args: + persistence: Persistence backend to initialise. + app_state: Application state for auth service injection. + """ + try: + await persistence.connect() + except Exception: + logger.exception( + API_APP_STARTUP, + error="Failed to connect persistence", + ) + raise + + try: + await persistence.migrate() + except Exception: + logger.exception( + API_APP_STARTUP, + error="Failed to run persistence migrations", + ) + raise + + # Resolve JWT secret after persistence is up + if app_state.has_auth_service: + logger.info( + API_APP_STARTUP, + note="Auth service already configured, skipping JWT secret resolution", + ) + else: + try: + secret = await resolve_jwt_secret(persistence) + auth_config = app_state.config.api.auth.with_secret( + secret, + ) + app_state.set_auth_service(AuthService(auth_config)) + except Exception: + logger.exception( + API_APP_STARTUP, + error="Failed to resolve JWT secret", + ) + raise + + async def _safe_startup( persistence: PersistenceBackend | None, message_bus: MessageBus | None, @@ -164,32 +214,9 @@ async def _safe_startup( started_persistence = False try: if persistence is not None: - try: - await persistence.connect() - except Exception: - logger.exception( - API_APP_STARTUP, - error="Failed to connect persistence", - ) - raise + await _init_persistence(persistence, app_state) started_persistence = True - # Resolve JWT secret after persistence is up - try: - secret = await resolve_jwt_secret(persistence) - auth_config = app_state.config.api.auth.with_secret( - secret, - ) - app_state.set_auth_service(AuthService(auth_config)) - except RuntimeError: - pass # Already configured (e.g. test-injected) - except Exception: - logger.exception( - API_APP_STARTUP, - error="Failed to resolve JWT secret", - ) - raise - if message_bus is not None: try: await message_bus.start() @@ -310,6 +337,7 @@ def create_app( # noqa: PLR0913 api_router = Router( path=api_config.api_prefix, route_handlers=[*ALL_CONTROLLERS, ws_handler], + guards=[require_password_changed], ) startup, shutdown = _build_lifecycle( @@ -387,7 +415,21 @@ def _build_middleware(api_config: ApiConfig) -> list[Middleware]: rate_limit=(rl.time_unit, rl.max_requests), # type: ignore[arg-type] exclude=list(rl.exclude_paths), ) - auth_middleware = create_auth_middleware_class(api_config.auth) + auth = api_config.auth + if auth.exclude_paths is None: + prefix = api_config.api_prefix + auth = auth.model_copy( + update={ + "exclude_paths": ( + f"^{prefix}/health$", + "^/docs", + "^/api$", + f"^{prefix}/auth/setup$", + f"^{prefix}/auth/login$", + ), + }, + ) + auth_middleware = create_auth_middleware_class(auth) return [ auth_middleware, CSPMiddleware, diff --git a/src/ai_company/api/auth/config.py b/src/ai_company/api/auth/config.py index ddce43c8ef..88f9d58661 100644 --- a/src/ai_company/api/auth/config.py +++ b/src/ai_company/api/auth/config.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator -_MIN_SECRET_LENGTH = 32 +MIN_SECRET_LENGTH = 32 def _require_valid_secret(secret: str) -> None: @@ -14,11 +14,12 @@ def _require_valid_secret(secret: str) -> None: secret: JWT signing secret to validate. Raises: - ValueError: If *secret* is shorter than ``_MIN_SECRET_LENGTH``. + ValueError: If *secret* is non-empty and shorter than + ``MIN_SECRET_LENGTH``. """ - if secret and len(secret) < _MIN_SECRET_LENGTH: + if secret and len(secret) < MIN_SECRET_LENGTH: msg = ( - f"jwt_secret must be at least {_MIN_SECRET_LENGTH} " + f"jwt_secret must be at least {MIN_SECRET_LENGTH} " f"characters (got {len(secret)})" ) raise ValueError(msg) @@ -42,6 +43,7 @@ class AuthConfig(BaseModel): jwt_secret: HMAC signing key (resolved at startup, repr-hidden). jwt_algorithm: JWT signing algorithm (HMAC family only). jwt_expiry_minutes: Token lifetime in minutes. + min_password_length: Minimum password length for setup/change. exclude_paths: URL paths excluded from auth middleware. """ @@ -62,17 +64,21 @@ class AuthConfig(BaseModel): le=43200, description="Token lifetime in minutes (default 24h)", ) - exclude_paths: tuple[str, ...] = Field( - default=( - "^/api/v1/health$", - "^/docs", - "^/api$", - "^/api/v1/auth/setup$", - "^/api/v1/auth/login$", - ), + min_password_length: int = Field( + default=12, + ge=8, + le=128, + description="Minimum password length for setup and password change", + ) + exclude_paths: tuple[str, ...] | None = Field( + default=None, description=( "Regex patterns for paths excluded from authentication. " - "Anchor with ^ and $ to avoid substring matches." + "When None (default), paths are auto-derived from the " + "API prefix (health, auth/setup, auth/login, docs, " + "scalar UI). " + "Use ^ to anchor at the start of the path and add $ when " + "an exact match (rather than a prefix match) is required." ), ) diff --git a/src/ai_company/api/auth/controller.py b/src/ai_company/api/auth/controller.py index eec8a8ca94..623b330466 100644 --- a/src/ai_company/api/auth/controller.py +++ b/src/ai_company/api/auth/controller.py @@ -2,9 +2,9 @@ import uuid from datetime import UTC, datetime -from typing import Self +from typing import Any, Self -from litestar import Controller, Response, get, post +from litestar import Controller, Request, Response, get, post from litestar.connection import ASGIConnection # noqa: TC002 from litestar.exceptions import PermissionDeniedException from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -117,8 +117,8 @@ class TokenResponse(BaseModel): model_config = ConfigDict(frozen=True) - token: str - expires_in: int + token: NotBlankStr + expires_in: int = Field(gt=0) must_change_password: bool @@ -140,6 +140,8 @@ class UserInfoResponse(BaseModel): must_change_password: bool +_PWD_CHANGE_EXEMPT_SUFFIXES = ("/auth/change-password", "/auth/me") + # ── Guards ──────────────────────────────────────────────────── @@ -149,19 +151,33 @@ def require_password_changed( ) -> None: """Guard that blocks users who must change their password. - Applied to all routes except ``/auth/change-password`` and - ``/auth/me``. + Paths ending with ``/auth/change-password`` or ``/auth/me`` + are exempt so the user can actually change the password or + inspect their own profile. Args: connection: The incoming connection. _: Route handler (unused). Raises: - PermissionDeniedException: If password change is required. + PermissionDeniedException: If password change is required + or the user object is present but not an + ``AuthenticatedUser``. """ + path = str(connection.url.path) + if any(path.endswith(s) for s in _PWD_CHANGE_EXEMPT_SUFFIXES): + return user = connection.scope.get("user") - if not isinstance(user, AuthenticatedUser): + if user is None: return + if not isinstance(user, AuthenticatedUser): + logger.warning( + API_AUTH_FAILED, + reason="unexpected_user_type", + user_type=type(user).__qualname__, + path=path, + ) + raise PermissionDeniedException(detail="Invalid user session") if user.must_change_password: raise PermissionDeniedException(detail="Password change required") @@ -183,7 +199,7 @@ class AuthController(Controller): async def setup( self, data: SetupRequest, - request: ASGIConnection, # type: ignore[type-arg] + request: Request[Any, Any, Any], ) -> Response[ApiResponse[TokenResponse]]: """Create the first admin account (CEO). @@ -196,14 +212,16 @@ async def setup( user_count = await persistence.users.count() if user_count > 0: + logger.warning(API_AUTH_FAILED, reason="setup_already_completed") msg = "Setup already completed" raise ConflictError(msg) now = datetime.now(UTC) + password_hash = await auth_service.hash_password_async(data.password) user = User( id=str(uuid.uuid4()), username=data.username, - password_hash=auth_service.hash_password(data.password), + password_hash=password_hash, role=HumanRole.CEO, must_change_password=True, created_at=now, @@ -211,6 +229,14 @@ async def setup( ) await persistence.users.save(user) + # Race guard: undo if another setup completed concurrently + post_count = await persistence.users.count() + if post_count > 1: + await persistence.users.delete(user.id) + logger.warning(API_AUTH_FAILED, reason="setup_race_detected") + msg = "Setup already completed" + raise ConflictError(msg) + token, expires_in = auth_service.create_token(user) logger.info( @@ -238,7 +264,7 @@ async def setup( async def login( self, data: LoginRequest, - request: ASGIConnection, # type: ignore[type-arg] + request: Request[Any, Any, Any], ) -> Response[ApiResponse[TokenResponse]]: """Validate credentials and return a JWT.""" app_state = request.app.state["app_state"] @@ -246,7 +272,7 @@ async def login( persistence = app_state.persistence user = await persistence.users.get_by_username(data.username) - if user is None or not auth_service.verify_password( + if user is None or not await auth_service.verify_password_async( data.password, user.password_hash ): logger.warning( @@ -282,7 +308,7 @@ async def login( async def change_password( self, data: ChangePasswordRequest, - request: ASGIConnection, # type: ignore[type-arg] + request: Request[Any, Any, Any], ) -> Response[ApiResponse[UserInfoResponse]]: """Validate current password and set new one.""" auth_user: AuthenticatedUser = request.scope["user"] @@ -292,10 +318,17 @@ async def change_password( user = await persistence.users.get(auth_user.user_id) if user is None: + logger.warning( + API_AUTH_FAILED, + reason="user_not_found_for_password_change", + user_id=auth_user.user_id, + ) msg = "User not found" raise UnauthorizedError(msg) - if not auth_service.verify_password(data.current_password, user.password_hash): + if not await auth_service.verify_password_async( + data.current_password, user.password_hash + ): logger.warning( API_AUTH_FAILED, reason="invalid_current_password", @@ -305,9 +338,10 @@ async def change_password( raise UnauthorizedError(msg) now = datetime.now(UTC) + new_hash = await auth_service.hash_password_async(data.new_password) updated_user = user.model_copy( update={ - "password_hash": auth_service.hash_password(data.new_password), + "password_hash": new_hash, "must_change_password": False, "updated_at": now, } @@ -337,7 +371,7 @@ async def change_password( ) async def me( self, - request: ASGIConnection, # type: ignore[type-arg] + request: Request[Any, Any, Any], ) -> Response[ApiResponse[UserInfoResponse]]: """Return information about the authenticated user.""" auth_user: AuthenticatedUser = request.scope["user"] diff --git a/src/ai_company/api/auth/middleware.py b/src/ai_company/api/auth/middleware.py index 9e089d93b5..a3e20bc83e 100644 --- a/src/ai_company/api/auth/middleware.py +++ b/src/ai_company/api/auth/middleware.py @@ -1,5 +1,6 @@ """JWT + API key authentication middleware.""" +import hashlib from datetime import UTC, datetime from typing import TYPE_CHECKING, Any @@ -33,9 +34,9 @@ class ApiAuthMiddleware(AbstractAuthenticationMiddleware): """Authenticate requests via JWT or API key. Reads ``Authorization: Bearer `` from the request. - Tokens containing ``.`` are tried as JWTs first; if that fails - (or the token has no dots), it is tried as an API key via - SHA-256 hash lookup. + Tokens containing ``.`` are treated exclusively as JWTs. + Tokens without dots are tried as API keys via SHA-256 hash + lookup. Requires ``auth_service``, persistence backend on ``app.state["app_state"]``. @@ -77,13 +78,14 @@ async def authenticate_request( app_state = connection.app.state["app_state"] auth_service: AuthService = app_state.auth_service - # Try JWT first (tokens with dots are likely JWTs) + # Try JWT for tokens with dots; API key otherwise if "." in token: user = await _try_jwt_auth(token, auth_service, app_state, connection) if user is not None: return AuthenticationResult(user=user, auth=token) + raise NotAuthorizedException(detail="Invalid JWT token") - # Fall back to API key + # API key (no dots in token) user = await _try_api_key_auth(token, app_state, connection) if user is not None: return AuthenticationResult(user=user, auth=token) @@ -149,6 +151,18 @@ async def _try_jwt_auth( ) return None + expected_sig = hashlib.sha256( + db_user.password_hash.encode(), + ).hexdigest()[:16] + if claims.get("pwd_sig") != expected_sig: + logger.warning( + API_AUTH_FAILED, + reason="password_changed_since_token_issued", + user_id=user_id, + path=str(connection.url.path), + ) + return None + authenticated = AuthenticatedUser( user_id=db_user.id, username=db_user.username, @@ -182,6 +196,11 @@ async def _try_api_key_auth( persistence = app_state.persistence api_key = await persistence.api_keys.get_by_hash(key_hash) if api_key is None: + logger.debug( + API_AUTH_FAILED, + reason="api_key_not_found", + path=str(connection.url.path), + ) return None if api_key.revoked: diff --git a/src/ai_company/api/auth/models.py b/src/ai_company/api/auth/models.py index 049f9b4cec..6c1c9f349a 100644 --- a/src/ai_company/api/auth/models.py +++ b/src/ai_company/api/auth/models.py @@ -1,9 +1,8 @@ """Authentication domain models.""" -from datetime import datetime # noqa: TC003 from enum import StrEnum -from pydantic import BaseModel, ConfigDict, Field +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field from ai_company.api.guards import HumanRole # noqa: TC001 from ai_company.core.types import NotBlankStr # noqa: TC001 @@ -36,8 +35,8 @@ class User(BaseModel): password_hash: str = Field(repr=False) role: HumanRole must_change_password: bool = True - created_at: datetime - updated_at: datetime + created_at: AwareDatetime + updated_at: AwareDatetime class ApiKey(BaseModel): @@ -49,8 +48,8 @@ class ApiKey(BaseModel): name: Human-readable label. role: Access control role. user_id: Owner user ID. - created_at: Key creation timestamp. - expires_at: Optional expiry timestamp. + created_at: Key creation timestamp (timezone-aware). + expires_at: Optional expiry timestamp (timezone-aware). revoked: Whether the key has been revoked. """ @@ -61,8 +60,8 @@ class ApiKey(BaseModel): name: NotBlankStr role: HumanRole user_id: NotBlankStr - created_at: datetime - expires_at: datetime | None = None + created_at: AwareDatetime + expires_at: AwareDatetime | None = None revoked: bool = False diff --git a/src/ai_company/api/auth/secret.py b/src/ai_company/api/auth/secret.py index 34c825a3fa..5e38a0f9c9 100644 --- a/src/ai_company/api/auth/secret.py +++ b/src/ai_company/api/auth/secret.py @@ -3,6 +3,7 @@ import os import secrets +from ai_company.api.auth.config import MIN_SECRET_LENGTH from ai_company.observability import get_logger from ai_company.observability.events.api import API_APP_STARTUP from ai_company.persistence.protocol import PersistenceBackend # noqa: TC001 @@ -11,7 +12,6 @@ _SETTING_KEY = "jwt_secret" _SECRET_LENGTH = 48 # 64 URL-safe base64 chars -_MIN_SECRET_LENGTH = 32 async def resolve_jwt_secret( @@ -32,10 +32,10 @@ async def resolve_jwt_secret( # 1. Env var override (highest priority) env_secret = os.environ.get("AI_COMPANY_JWT_SECRET", "").strip() if env_secret: - if len(env_secret) < _MIN_SECRET_LENGTH: + if len(env_secret) < MIN_SECRET_LENGTH: msg = ( f"AI_COMPANY_JWT_SECRET must be at least " - f"{_MIN_SECRET_LENGTH} characters (got {len(env_secret)})" + f"{MIN_SECRET_LENGTH} characters (got {len(env_secret)})" ) logger.error(API_APP_STARTUP, error=msg) raise ValueError(msg) @@ -48,11 +48,22 @@ async def resolve_jwt_secret( # 2. Check persistence stored = await persistence.get_setting(_SETTING_KEY) if stored: - logger.info( - API_APP_STARTUP, - note="JWT secret loaded from persistence", - ) - return stored + stored = stored.strip() + if len(stored) < MIN_SECRET_LENGTH: + logger.warning( + API_APP_STARTUP, + note=( + "Stored JWT secret too short " + f"({len(stored)} < {MIN_SECRET_LENGTH}), " + "auto-generating replacement" + ), + ) + else: + logger.info( + API_APP_STARTUP, + note="JWT secret loaded from persistence", + ) + return stored # 3. Auto-generate and persist generated = secrets.token_urlsafe(_SECRET_LENGTH) diff --git a/src/ai_company/api/auth/service.py b/src/ai_company/api/auth/service.py index 21172d1af4..9e57ceef1b 100644 --- a/src/ai_company/api/auth/service.py +++ b/src/ai_company/api/auth/service.py @@ -1,5 +1,6 @@ """Authentication service — password hashing, JWT ops, API key hashing.""" +import asyncio import hashlib import secrets from datetime import UTC, datetime, timedelta @@ -69,13 +70,50 @@ def verify_password(self, password: str, password_hash: str) -> bool: ) return False except argon2.exceptions.InvalidHashError: - logger.warning( + logger.error( API_AUTH_FAILED, - reason="invalid_hash_format", + reason="invalid_hash_data_corruption", exc_info=True, ) return False + async def hash_password_async(self, password: str) -> str: + """Hash a password with Argon2id in a thread executor. + + Offloads the CPU-intensive hashing to avoid blocking the + event loop. + + Args: + password: Plaintext password. + + Returns: + Argon2id hash string. + """ + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, self.hash_password, password) + + async def verify_password_async( + self, + password: str, + password_hash: str, + ) -> bool: + """Verify a password against an Argon2id hash in a thread executor. + + Offloads the CPU-intensive verification to avoid blocking the + event loop. + + Args: + password: Plaintext password to check. + password_hash: Stored Argon2id hash. + + Returns: + ``True`` if the password matches. + """ + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, self.verify_password, password, password_hash + ) + def create_token(self, user: User) -> tuple[str, int]: """Create a JWT for the given user. @@ -84,14 +122,24 @@ def create_token(self, user: User) -> tuple[str, int]: Returns: Tuple of (encoded JWT string, expiry seconds). + + Raises: + RuntimeError: If the JWT secret is empty. """ + if not self._config.jwt_secret: + msg = "JWT secret not configured" + raise RuntimeError(msg) now = datetime.now(UTC) expiry_seconds = self._config.jwt_expiry_minutes * 60 + pwd_sig = hashlib.sha256( + user.password_hash.encode(), + ).hexdigest()[:16] payload: dict[str, Any] = { "sub": user.id, "username": user.username, "role": user.role.value, "must_change_password": user.must_change_password, + "pwd_sig": pwd_sig, "iat": now, "exp": now + timedelta(seconds=expiry_seconds), } @@ -112,12 +160,17 @@ def decode_token(self, token: str) -> dict[str, Any]: Decoded claims dictionary. Raises: + RuntimeError: If the JWT secret is empty. jwt.InvalidTokenError: If the token is invalid or expired. """ + if not self._config.jwt_secret: + msg = "JWT secret not configured" + raise RuntimeError(msg) return jwt.decode( token, self._config.jwt_secret, algorithms=[self._config.jwt_algorithm], + options={"require": ["exp", "iat", "sub"]}, ) @staticmethod diff --git a/src/ai_company/api/config.py b/src/ai_company/api/config.py index 9063016220..e72fd060f2 100644 --- a/src/ai_company/api/config.py +++ b/src/ai_company/api/config.py @@ -1,7 +1,8 @@ """API configuration models. -Frozen Pydantic models for CORS, rate limiting, server, and the -top-level ``ApiConfig`` that aggregates them all. +Frozen Pydantic models for CORS, rate limiting, server, +authentication, and the top-level ``ApiConfig`` that aggregates +them all. """ from enum import StrEnum @@ -151,6 +152,7 @@ class ApiConfig(BaseModel): cors: CORS configuration. rate_limit: Rate limiting configuration. server: Uvicorn server configuration. + auth: Authentication configuration. api_prefix: URL prefix for all API routes. """ diff --git a/src/ai_company/api/controllers/approvals.py b/src/ai_company/api/controllers/approvals.py index 76fc613154..251b906a05 100644 --- a/src/ai_company/api/controllers/approvals.py +++ b/src/ai_company/api/controllers/approvals.py @@ -272,8 +272,8 @@ async def approve( ) raise ConflictError(msg) - auth_user = request.scope.get("user") - role = auth_user.role.value if auth_user is not None else "unknown" + auth_user = request.scope["user"] + role = auth_user.role.value now = datetime.now(UTC) updated = item.model_copy( update={ @@ -356,8 +356,8 @@ async def reject( ) raise ConflictError(msg) - auth_user = request.scope.get("user") - role = auth_user.role.value if auth_user is not None else "unknown" + auth_user = request.scope["user"] + role = auth_user.role.value now = datetime.now(UTC) updated = item.model_copy( update={ diff --git a/src/ai_company/api/state.py b/src/ai_company/api/state.py index 95db5ab407..2ffd80dab2 100644 --- a/src/ai_company/api/state.py +++ b/src/ai_company/api/state.py @@ -12,7 +12,7 @@ from ai_company.communication.bus_protocol import MessageBus # noqa: TC001 from ai_company.config.schema import RootConfig # noqa: TC001 from ai_company.observability import get_logger -from ai_company.observability.events.api import API_SERVICE_UNAVAILABLE +from ai_company.observability.events.api import API_APP_STARTUP, API_SERVICE_UNAVAILABLE from ai_company.persistence.protocol import PersistenceBackend # noqa: TC001 logger = get_logger(__name__) @@ -21,8 +21,9 @@ class AppState: """Typed application state container. - Service fields (``persistence``, ``message_bus``, ``cost_tracker``) - accept ``None`` at construction time for dev/test mode. Property + Service fields (``persistence``, ``message_bus``, ``cost_tracker``, + ``auth_service``) accept ``None`` at construction time for dev/test + mode. Property accessors raise ``ServiceUnavailableError`` (HTTP 503) when the service is not configured, producing a clear error instead of an opaque ``AttributeError``. @@ -98,6 +99,11 @@ def auth_service(self) -> AuthService: """Return auth service or raise 503.""" return self._require_service(self._auth_service, "auth_service") + @property + def has_auth_service(self) -> bool: + """Check whether the auth service is already configured.""" + return self._auth_service is not None + def set_auth_service(self, service: AuthService) -> None: """Set the auth service (deferred initialisation). @@ -111,5 +117,6 @@ def set_auth_service(self, service: AuthService) -> None: """ if self._auth_service is not None: msg = "Auth service already configured" + logger.error(API_APP_STARTUP, error=msg) raise RuntimeError(msg) self._auth_service = service diff --git a/src/ai_company/observability/events/persistence.py b/src/ai_company/observability/events/persistence.py index e0d7eb1a7a..b2c048ac2a 100644 --- a/src/ai_company/observability/events/persistence.py +++ b/src/ai_company/observability/events/persistence.py @@ -144,5 +144,7 @@ PERSISTENCE_API_KEY_DELETED: Final[str] = "persistence.api_key.deleted" PERSISTENCE_API_KEY_DELETE_FAILED: Final[str] = "persistence.api_key.delete_failed" +PERSISTENCE_SETTING_FETCHED: Final[str] = "persistence.setting.fetched" PERSISTENCE_SETTING_FETCH_FAILED: Final[str] = "persistence.setting.fetch_failed" +PERSISTENCE_SETTING_SAVED: Final[str] = "persistence.setting.saved" PERSISTENCE_SETTING_SAVE_FAILED: Final[str] = "persistence.setting.save_failed" diff --git a/src/ai_company/persistence/protocol.py b/src/ai_company/persistence/protocol.py index bd4d01336f..3cfcd03074 100644 --- a/src/ai_company/persistence/protocol.py +++ b/src/ai_company/persistence/protocol.py @@ -42,6 +42,8 @@ class PersistenceBackend(Protocol): collaboration_metrics: Repository for CollaborationMetricRecord persistence. parked_contexts: Repository for ParkedContext persistence. audit_entries: Repository for AuditEntry persistence. + users: Repository for User persistence. + api_keys: Repository for ApiKey persistence. """ async def connect(self) -> None: @@ -136,7 +138,7 @@ def api_keys(self) -> ApiKeyRepository: """Repository for ApiKey persistence.""" ... - async def get_setting(self, key: str) -> str | None: + async def get_setting(self, key: NotBlankStr) -> str | None: """Retrieve a setting value by key. Args: @@ -150,7 +152,7 @@ async def get_setting(self, key: str) -> str | None: """ ... - async def set_setting(self, key: str, value: str) -> None: + async def set_setting(self, key: NotBlankStr, value: str) -> None: """Store a setting value. Upserts — creates or updates the key. diff --git a/src/ai_company/persistence/sqlite/user_repo.py b/src/ai_company/persistence/sqlite/user_repo.py index b686f6464e..6317c63e35 100644 --- a/src/ai_company/persistence/sqlite/user_repo.py +++ b/src/ai_company/persistence/sqlite/user_repo.py @@ -8,6 +8,7 @@ from ai_company.api.auth.models import ApiKey, User from ai_company.api.guards import HumanRole +from ai_company.core.types import NotBlankStr # noqa: TC001 from ai_company.observability import get_logger from ai_company.observability.events.persistence import ( PERSISTENCE_API_KEY_DELETE_FAILED, @@ -100,7 +101,7 @@ async def save(self, user: User) -> None: raise QueryError(msg) from exc logger.debug(PERSISTENCE_USER_SAVED, user_id=user.id) - async def get(self, user_id: str) -> User | None: + async def get(self, user_id: NotBlankStr) -> User | None: """Retrieve a user by ID.""" try: cursor = await self._db.execute( @@ -131,7 +132,7 @@ async def get(self, user_id: str) -> User | None: logger.debug(PERSISTENCE_USER_FETCHED, user_id=user_id, found=True) return user - async def get_by_username(self, username: str) -> User | None: + async def get_by_username(self, username: NotBlankStr) -> User | None: """Retrieve a user by username.""" try: cursor = await self._db.execute( @@ -190,7 +191,7 @@ async def count(self) -> int: logger.debug(PERSISTENCE_USER_COUNTED, count=result) return result - async def delete(self, user_id: str) -> bool: + async def delete(self, user_id: NotBlankStr) -> bool: """Delete a user by ID.""" try: cursor = await self._db.execute( @@ -257,7 +258,7 @@ async def save(self, key: ApiKey) -> None: raise QueryError(msg) from exc logger.debug(PERSISTENCE_API_KEY_SAVED, key_id=key.id) - async def get(self, key_id: str) -> ApiKey | None: + async def get(self, key_id: NotBlankStr) -> ApiKey | None: """Retrieve an API key by ID.""" try: cursor = await self._db.execute( @@ -288,7 +289,7 @@ async def get(self, key_id: str) -> ApiKey | None: logger.debug(PERSISTENCE_API_KEY_FETCHED, key_id=key_id, found=True) return key - async def get_by_hash(self, key_hash: str) -> ApiKey | None: + async def get_by_hash(self, key_hash: NotBlankStr) -> ApiKey | None: """Retrieve an API key by its hash.""" try: cursor = await self._db.execute( @@ -309,7 +310,7 @@ async def get_by_hash(self, key_hash: str) -> ApiKey | None: logger.exception(PERSISTENCE_API_KEY_FETCH_FAILED, error=str(exc)) raise QueryError(msg) from exc - async def list_by_user(self, user_id: str) -> tuple[ApiKey, ...]: + async def list_by_user(self, user_id: NotBlankStr) -> tuple[ApiKey, ...]: """List API keys belonging to a user.""" try: cursor = await self._db.execute( @@ -325,7 +326,16 @@ async def list_by_user(self, user_id: str) -> tuple[ApiKey, ...]: error=str(exc), ) raise QueryError(msg) from exc - keys = tuple(_row_to_api_key(row) for row in rows) + try: + keys = tuple(_row_to_api_key(row) for row in rows) + except (ValueError, ValidationError) as exc: + msg = f"Failed to deserialize API keys for user {user_id!r}" + logger.exception( + PERSISTENCE_API_KEY_LIST_FAILED, + user_id=user_id, + error=str(exc), + ) + raise QueryError(msg) from exc logger.debug( PERSISTENCE_API_KEY_LISTED, user_id=user_id, @@ -333,7 +343,7 @@ async def list_by_user(self, user_id: str) -> tuple[ApiKey, ...]: ) return keys - async def delete(self, key_id: str) -> bool: + async def delete(self, key_id: NotBlankStr) -> bool: """Delete an API key by ID.""" try: cursor = await self._db.execute( diff --git a/tests/unit/api/auth/test_config.py b/tests/unit/api/auth/test_config.py index 66eedf538c..42025b465a 100644 --- a/tests/unit/api/auth/test_config.py +++ b/tests/unit/api/auth/test_config.py @@ -12,9 +12,12 @@ def test_default_values(self) -> None: assert config.jwt_secret == "" assert config.jwt_algorithm == "HS256" assert config.jwt_expiry_minutes == 1440 - assert "^/api/v1/health$" in config.exclude_paths - assert "^/api/v1/auth/setup$" in config.exclude_paths - assert "^/api/v1/auth/login$" in config.exclude_paths + assert config.exclude_paths is None + + def test_explicit_exclude_paths(self) -> None: + paths = ("^/health$", "^/docs") + config = AuthConfig(exclude_paths=paths) + assert config.exclude_paths == paths def test_with_secret_sets_secret(self) -> None: config = AuthConfig() diff --git a/tests/unit/api/auth/test_controller.py b/tests/unit/api/auth/test_controller.py index 365660fcc3..bd4a4a5037 100644 --- a/tests/unit/api/auth/test_controller.py +++ b/tests/unit/api/auth/test_controller.py @@ -175,6 +175,32 @@ def test_change_password_requires_auth(self, bare_client: TestClient[Any]) -> No ) assert response.status_code == 401 + def test_change_password_short_new_password( + self, + bare_client: TestClient[Any], + ) -> None: + app_state = bare_client.app.state["app_state"] + app_state.persistence._users._users.clear() + + setup_resp = bare_client.post( + "/api/v1/auth/setup", + json={ + "username": "shortpw", + "password": "old-password-12chars", + }, + ) + token = setup_resp.json()["data"]["token"] + + response = bare_client.post( + "/api/v1/auth/change-password", + json={ + "current_password": "old-password-12chars", + "new_password": "short", + }, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 400 + @pytest.mark.unit class TestMe: @@ -214,6 +240,7 @@ def test_blocks_user_with_must_change_password(self) -> None: ) connection = MagicMock() connection.scope = {"user": user} + connection.url.path = "/api/v1/health" with pytest.raises(PermissionDeniedException): require_password_changed(connection, None) @@ -234,6 +261,7 @@ def test_allows_user_without_flag(self) -> None: ) connection = MagicMock() connection.scope = {"user": user} + connection.url.path = "/api/v1/health" # Should not raise require_password_changed(connection, None) @@ -246,6 +274,53 @@ def test_allows_when_no_user_in_scope(self) -> None: connection = MagicMock() connection.scope = {} + connection.url.path = "/api/v1/health" # Should not raise require_password_changed(connection, None) + + @pytest.mark.parametrize( + "path", + [ + pytest.param("/api/v1/auth/change-password", id="change-password"), + pytest.param("/api/v1/auth/me", id="me"), + ], + ) + def test_exempts_paths_for_must_change_password_users( + self, + path: str, + ) -> None: + """Guard allows must_change_password users on exempt paths.""" + from unittest.mock import MagicMock + + from ai_company.api.auth.controller import require_password_changed + from ai_company.api.auth.models import AuthenticatedUser, AuthMethod + + user = AuthenticatedUser( + user_id="u1", + username="admin", + role=HumanRole.CEO, + auth_method=AuthMethod.JWT, + must_change_password=True, + ) + connection = MagicMock() + connection.scope = {"user": user} + connection.url.path = path + + # Should not raise — exempt path + require_password_changed(connection, None) + + def test_rejects_unknown_user_type(self) -> None: + """Guard raises PermissionDeniedException for non-AuthenticatedUser.""" + from unittest.mock import MagicMock + + from litestar.exceptions import PermissionDeniedException + + from ai_company.api.auth.controller import require_password_changed + + connection = MagicMock() + connection.scope = {"user": "not-an-auth-user"} + connection.url.path = "/api/v1/health" + + with pytest.raises(PermissionDeniedException): + require_password_changed(connection, None) diff --git a/tests/unit/api/auth/test_middleware.py b/tests/unit/api/auth/test_middleware.py index 9361d97abb..5a7c1ffc16 100644 --- a/tests/unit/api/auth/test_middleware.py +++ b/tests/unit/api/auth/test_middleware.py @@ -268,6 +268,34 @@ async def test_api_key_with_deleted_owner_returns_401(self) -> None: assert resp.status_code == 401 +@pytest.mark.unit +class TestExtractBearerToken: + @pytest.mark.parametrize( + ("header", "expected"), + [ + pytest.param("Bearer mytoken123", "mytoken123", id="valid"), + pytest.param("bearer mytoken123", "mytoken123", id="lowercase"), + pytest.param("BEARER mytoken123", "mytoken123", id="uppercase"), + pytest.param("", None, id="empty"), + pytest.param("Bearer", None, id="no-token"), + pytest.param("Basic dXNlcjpwYXNz", None, id="wrong-scheme"), + pytest.param( + "Bearer token with spaces", + "token with spaces", + id="token-with-spaces", + ), + ], + ) + def test_extract_bearer_token( + self, + header: str, + expected: str | None, + ) -> None: + from ai_company.api.auth.middleware import _extract_bearer_token + + assert _extract_bearer_token(header) == expected + + @pytest.mark.unit class TestAuthMiddlewareExcludePaths: async def test_excluded_path_skips_auth(self) -> None: diff --git a/tests/unit/api/auth/test_service.py b/tests/unit/api/auth/test_service.py index d0e9615568..64f7a8e72e 100644 --- a/tests/unit/api/auth/test_service.py +++ b/tests/unit/api/auth/test_service.py @@ -59,6 +59,14 @@ def test_different_hashes_for_same_password(self) -> None: # Different salts produce different hashes assert h1 != h2 + def test_verify_password_with_corrupted_hash(self) -> None: + svc = _make_service() + assert not svc.verify_password("my-password", "not-a-valid-argon2-hash") + + def test_verify_password_with_empty_hash(self) -> None: + svc = _make_service() + assert not svc.verify_password("my-password", "") + @pytest.mark.unit class TestJWT: @@ -118,6 +126,33 @@ def test_must_change_password_in_claims(self) -> None: claims = svc.decode_token(token) assert claims["must_change_password"] is True + def test_decode_token_missing_sub_claim(self) -> None: + from datetime import UTC, datetime, timedelta + + import jwt as pyjwt + + svc = _make_service() + payload = { + "username": "admin", + "role": "ceo", + "iat": datetime.now(UTC), + "exp": datetime.now(UTC) + timedelta(hours=1), + } + token = pyjwt.encode(payload, _SECRET, algorithm="HS256") + with pytest.raises(pyjwt.MissingRequiredClaimError): + svc.decode_token(token) + + def test_create_token_empty_secret_raises(self) -> None: + svc = AuthService(AuthConfig()) + user = _make_user() + with pytest.raises(RuntimeError, match="JWT secret not configured"): + svc.create_token(user) + + def test_decode_token_empty_secret_raises(self) -> None: + svc = AuthService(AuthConfig()) + with pytest.raises(RuntimeError, match="JWT secret not configured"): + svc.decode_token("any.token.here") + @pytest.mark.unit class TestApiKeyHashing: diff --git a/tests/unit/api/conftest.py b/tests/unit/api/conftest.py index 722ebdf935..4c97c05426 100644 --- a/tests/unit/api/conftest.py +++ b/tests/unit/api/conftest.py @@ -479,6 +479,10 @@ async def get_channel_history( # ── Auth helpers ──────────────────────────────────────────────── +# Cache password hashes by role so that make_auth_headers and +# _seed_test_users produce identical pwd_sig claims. +_TEST_PASSWORD_HASHES: dict[str, str] = {} + def _make_test_auth_config() -> AuthConfig: """Create an AuthConfig with a test JWT secret.""" @@ -490,6 +494,24 @@ def _make_test_auth_service() -> AuthService: return AuthService(_make_test_auth_config()) +def _get_test_password_hash( + role: str, + auth_service: AuthService, +) -> str: + """Return a cached password hash for the given role. + + On the first call for a role, hashes the test password and + caches the result so that ``make_auth_headers`` and + ``_seed_test_users`` produce tokens with matching ``pwd_sig`` + claims. + """ + if role not in _TEST_PASSWORD_HASHES: + _TEST_PASSWORD_HASHES[role] = auth_service.hash_password( + "test-password-12chars", + ) + return _TEST_PASSWORD_HASHES[role] + + def _make_test_user( *, role: HumanRole = HumanRole.CEO, @@ -503,7 +525,7 @@ def _make_test_user( return User( id=user_id, username=username, - password_hash=auth_service.hash_password("test-password-12chars"), + password_hash=_get_test_password_hash(role.value, auth_service), role=role, must_change_password=must_change_password, created_at=now, @@ -519,7 +541,9 @@ def make_auth_headers( """Build an Authorization header with a JWT for the given role. Uses deterministic user IDs matching ``_seed_test_users`` so - middleware user lookups succeed. + middleware user lookups succeed. The password hash is cached + per role to ensure the ``pwd_sig`` claim matches the seeded + user in persistence. """ auth_service = _make_test_auth_service() # Must match the ID pattern in _seed_test_users @@ -528,7 +552,7 @@ def make_auth_headers( user = User( id=user_id, username=f"test-{role}", - password_hash=auth_service.hash_password("test-password-12chars"), + password_hash=_get_test_password_hash(role, auth_service), role=HumanRole(role), must_change_password=must_change_password, created_at=now, @@ -614,7 +638,9 @@ def _seed_test_users( The middleware looks up the user by ``sub`` claim, so we need matching users in the fake persistence for every role - that tests might use. + that tests might use. Uses cached password hashes to ensure + ``pwd_sig`` claims match between seeded users and tokens + produced by ``make_auth_headers``. """ import asyncio @@ -624,7 +650,10 @@ def _seed_test_users( user = User( id=user_id, username=f"test-{role.value}", - password_hash=auth_service.hash_password("test-password-12chars"), + password_hash=_get_test_password_hash( + role.value, + auth_service, + ), role=role, must_change_password=False, created_at=now, From 0e3a157b717bf5fb7341e17699278672ce11b602 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Wed, 11 Mar 2026 03:02:56 +0100 Subject: [PATCH 4/5] fix: address CodeRabbit round 2 review items - Normalize timestamps to UTC before SQLite serialization - Use defensive scope.get("user") in change_password and me endpoints - Simplify test fixtures to sync with direct dict assignment --- src/ai_company/api/auth/controller.py | 10 ++++++++-- src/ai_company/persistence/sqlite/user_repo.py | 14 +++++++++----- tests/unit/api/auth/test_controller.py | 9 +++++---- tests/unit/api/conftest.py | 10 +++++----- 4 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/ai_company/api/auth/controller.py b/src/ai_company/api/auth/controller.py index 623b330466..ae980846a4 100644 --- a/src/ai_company/api/auth/controller.py +++ b/src/ai_company/api/auth/controller.py @@ -311,7 +311,10 @@ async def change_password( request: Request[Any, Any, Any], ) -> Response[ApiResponse[UserInfoResponse]]: """Validate current password and set new one.""" - auth_user: AuthenticatedUser = request.scope["user"] + auth_user = request.scope.get("user") + if not isinstance(auth_user, AuthenticatedUser): + msg = "Authentication required" + raise UnauthorizedError(msg) app_state = request.app.state["app_state"] auth_service: AuthService = app_state.auth_service persistence = app_state.persistence @@ -374,7 +377,10 @@ async def me( request: Request[Any, Any, Any], ) -> Response[ApiResponse[UserInfoResponse]]: """Return information about the authenticated user.""" - auth_user: AuthenticatedUser = request.scope["user"] + auth_user = request.scope.get("user") + if not isinstance(auth_user, AuthenticatedUser): + msg = "Authentication required" + raise UnauthorizedError(msg) return Response( content=ApiResponse( diff --git a/src/ai_company/persistence/sqlite/user_repo.py b/src/ai_company/persistence/sqlite/user_repo.py index 6317c63e35..bc14ec16cf 100644 --- a/src/ai_company/persistence/sqlite/user_repo.py +++ b/src/ai_company/persistence/sqlite/user_repo.py @@ -1,7 +1,7 @@ """SQLite repository implementations for User and ApiKey.""" import sqlite3 -from datetime import datetime +from datetime import UTC, datetime import aiosqlite from pydantic import ValidationError @@ -86,8 +86,8 @@ async def save(self, user: User) -> None: user.password_hash, user.role.value, int(user.must_change_password), - user.created_at.isoformat(), - user.updated_at.isoformat(), + user.created_at.astimezone(UTC).isoformat(), + user.updated_at.astimezone(UTC).isoformat(), ), ) await self._db.commit() @@ -242,8 +242,12 @@ async def save(self, key: ApiKey) -> None: key.name, key.role.value, key.user_id, - key.created_at.isoformat(), - key.expires_at.isoformat() if key.expires_at else None, + key.created_at.astimezone(UTC).isoformat(), + ( + key.expires_at.astimezone(UTC).isoformat() + if key.expires_at + else None + ), int(key.revoked), ), ) diff --git a/tests/unit/api/auth/test_controller.py b/tests/unit/api/auth/test_controller.py index bd4a4a5037..abfc63d0d3 100644 --- a/tests/unit/api/auth/test_controller.py +++ b/tests/unit/api/auth/test_controller.py @@ -35,9 +35,11 @@ def test_setup_creates_admin(self, bare_client: TestClient[Any]) -> None: assert data["must_change_password"] is True assert data["expires_in"] > 0 - def test_setup_409_when_users_exist(self, bare_client: TestClient[Any]) -> None: + def test_setup_409_when_users_exist( + self, + bare_client: TestClient[Any], + ) -> None: # Re-seed a user so the check fails - import asyncio import uuid from datetime import UTC, datetime @@ -57,8 +59,7 @@ def test_setup_409_when_users_exist(self, bare_client: TestClient[Any]) -> None: created_at=now, updated_at=now, ) - loop = asyncio.get_event_loop() - loop.run_until_complete(app_state.persistence.users.save(user)) + app_state.persistence._users._users[user.id] = user response = bare_client.post( "/api/v1/auth/setup", diff --git a/tests/unit/api/conftest.py b/tests/unit/api/conftest.py index 4c97c05426..32baeee0ce 100644 --- a/tests/unit/api/conftest.py +++ b/tests/unit/api/conftest.py @@ -641,9 +641,11 @@ def _seed_test_users( that tests might use. Uses cached password hashes to ensure ``pwd_sig`` claims match between seeded users and tokens produced by ``make_auth_headers``. - """ - import asyncio + Assigns directly to the fake repository's internal dict + (avoiding async) so this helper works in both sync fixtures + and sync test functions. + """ now = datetime.now(UTC) for role in HumanRole: user_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"test-{role.value}")) @@ -659,9 +661,7 @@ def _seed_test_users( created_at=now, updated_at=now, ) - # Save synchronously via event loop - loop = asyncio.get_event_loop() - loop.run_until_complete(backend.users.save(user)) + backend._users._users[user.id] = user def make_task( # noqa: PLR0913 From 573c52f5311bbcbffae2e31d825f4ce1597e77a4 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Wed, 11 Mar 2026 07:05:51 +0100 Subject: [PATCH 5/5] fix: address 13 remaining PR review comments - Mitigate timing side-channel in login (dummy argon2 verification) - Derive _MIN_PASSWORD_LENGTH from AuthConfig default - Remove redundant unique indexes (users.username, api_keys.key_hash) - Defensive scope.get("user") in approvals controller - Record username (not role) in decided_by field - Extract _resolve_decision helper in approvals (DRY) - Fix started_persistence cleanup on post-connect failure - Enable PRAGMA foreign_keys=ON in SQLite backend - Log user/api-key saves and deletes at INFO (state transitions) - Hoist imports in test_service.py - Parametrize v5 schema tests - Expand user_repo docstrings to Google style --- src/ai_company/api/app.py | 27 +-- src/ai_company/api/auth/controller.py | 27 ++- src/ai_company/api/controllers/approvals.py | 79 +++++--- src/ai_company/persistence/sqlite/backend.py | 3 + .../persistence/sqlite/migrations.py | 2 - .../persistence/sqlite/user_repo.py | 177 +++++++++++++++--- tests/unit/api/auth/test_service.py | 8 +- tests/unit/api/controllers/test_approvals.py | 6 +- .../persistence/sqlite/test_migrations.py | 35 ++-- 9 files changed, 262 insertions(+), 102 deletions(-) diff --git a/src/ai_company/api/app.py b/src/ai_company/api/app.py index c48c84dcbf..a3c896c39e 100644 --- a/src/ai_company/api/app.py +++ b/src/ai_company/api/app.py @@ -154,21 +154,14 @@ async def _init_persistence( persistence: PersistenceBackend, app_state: AppState, ) -> None: - """Connect persistence, run migrations, and resolve JWT secret. + """Run migrations and resolve JWT secret on an already-connected backend. + + Must only be called after ``persistence.connect()`` has succeeded. Args: - persistence: Persistence backend to initialise. + persistence: Connected persistence backend. app_state: Application state for auth service injection. """ - try: - await persistence.connect() - except Exception: - logger.exception( - API_APP_STARTUP, - error="Failed to connect persistence", - ) - raise - try: await persistence.migrate() except Exception: @@ -214,8 +207,18 @@ async def _safe_startup( started_persistence = False try: if persistence is not None: - await _init_persistence(persistence, app_state) + try: + await persistence.connect() + except Exception: + logger.exception( + API_APP_STARTUP, + error="Failed to connect persistence", + ) + raise + # Mark connected immediately so cleanup can disconnect + # if migrate() or JWT resolution fails below. started_persistence = True + await _init_persistence(persistence, app_state) if message_bus is not None: try: diff --git a/src/ai_company/api/auth/controller.py b/src/ai_company/api/auth/controller.py index ae980846a4..61b4b59df7 100644 --- a/src/ai_company/api/auth/controller.py +++ b/src/ai_company/api/auth/controller.py @@ -9,6 +9,7 @@ from litestar.exceptions import PermissionDeniedException from pydantic import BaseModel, ConfigDict, Field, model_validator +from ai_company.api.auth.config import AuthConfig from ai_company.api.auth.models import AuthenticatedUser, User from ai_company.api.auth.service import AuthService # noqa: TC001 from ai_company.api.dto import ApiResponse @@ -25,7 +26,17 @@ logger = get_logger(__name__) -_MIN_PASSWORD_LENGTH = 12 +# Derive from AuthConfig default to prevent silent divergence. +_MIN_PASSWORD_LENGTH: int = AuthConfig.model_fields["min_password_length"].default + +# Pre-computed Argon2id hash for constant-time rejection when the +# username doesn't exist — prevents timing-based username enumeration. +# The actual password is irrelevant; only the verification time matters. +_DUMMY_ARGON2_HASH = ( + "$argon2id$v=19$m=65536,t=3,p=4$" + "c2FsdHNhbHRzYWx0$" + "mB0bZKSNwOhSdxMQfsldT3qGmFyjVqbkntMkutMfdUs" +) def _check_password_length(password: str) -> str: @@ -272,9 +283,17 @@ async def login( persistence = app_state.persistence user = await persistence.users.get_by_username(data.username) - if user is None or not await auth_service.verify_password_async( - data.password, user.password_hash - ): + if user is not None: + password_valid = await auth_service.verify_password_async( + data.password, user.password_hash + ) + else: + # Constant-time rejection: run verification against a + # dummy hash to prevent timing-based username enumeration. + await auth_service.verify_password_async(data.password, _DUMMY_ARGON2_HASH) + password_valid = False + + if not password_valid: logger.warning( API_AUTH_FAILED, reason="invalid_credentials", diff --git a/src/ai_company/api/controllers/approvals.py b/src/ai_company/api/controllers/approvals.py index 251b906a05..a3b91099b9 100644 --- a/src/ai_company/api/controllers/approvals.py +++ b/src/ai_company/api/controllers/approvals.py @@ -8,6 +8,7 @@ from litestar.channels import ChannelsPlugin from litestar.datastructures import State # noqa: TC002 +from ai_company.api.auth.models import AuthenticatedUser from ai_company.api.channels import CHANNEL_APPROVALS from ai_company.api.dto import ( ApiResponse, @@ -16,7 +17,7 @@ PaginatedResponse, RejectRequest, ) -from ai_company.api.errors import ConflictError, NotFoundError +from ai_company.api.errors import ConflictError, NotFoundError, UnauthorizedError from ai_company.api.guards import require_read_access, require_write_access from ai_company.api.pagination import PaginationLimit, PaginationOffset, paginate from ai_company.api.state import AppState # noqa: TC001 @@ -102,6 +103,46 @@ def _publish_approval_event( ) +def _resolve_decision( + request: Request[Any, Any, Any], + item: ApprovalItem, + approval_id: str, +) -> AuthenticatedUser: + """Validate that an approval item is pending and extract the auth user. + + Performs the shared pre-checks for approve/reject operations: + look up the authenticated user, and verify the item is still + in PENDING status. + + Args: + request: The incoming HTTP request. + item: The approval item to act on. + approval_id: Approval identifier (for log messages). + + Returns: + The authenticated user making the decision. + + Raises: + UnauthorizedError: If the user is missing from the request scope. + ConflictError: If the approval is not in PENDING status. + """ + if item.status != ApprovalStatus.PENDING: + msg = f"Approval {approval_id!r} is {item.status.value}, not pending" + logger.warning( + API_APPROVAL_CONFLICT, + approval_id=approval_id, + current_status=item.status.value, + ) + raise ConflictError(msg) + + auth_user = request.scope.get("user") + if not isinstance(auth_user, AuthenticatedUser): + msg = "Authentication required" + raise UnauthorizedError(msg) + + return auth_user + + class ApprovalsController(Controller): """Human approval queue — list, create, approve, reject.""" @@ -237,7 +278,7 @@ async def approve( """Approve a pending approval item. The ``decided_by`` field is populated from the authenticated - user's role. + user's username. Args: state: Application state. @@ -263,23 +304,13 @@ async def approve( ) raise NotFoundError(msg) - if item.status != ApprovalStatus.PENDING: - msg = f"Approval {approval_id!r} is {item.status.value}, not pending" - logger.warning( - API_APPROVAL_CONFLICT, - approval_id=approval_id, - current_status=item.status.value, - ) - raise ConflictError(msg) - - auth_user = request.scope["user"] - role = auth_user.role.value + auth_user = _resolve_decision(request, item, approval_id) now = datetime.now(UTC) updated = item.model_copy( update={ "status": ApprovalStatus.APPROVED, "decided_at": now, - "decided_by": role, + "decided_by": auth_user.username, "decision_reason": data.comment, }, ) @@ -302,7 +333,7 @@ async def approve( logger.info( API_APPROVAL_APPROVED, approval_id=approval_id, - decided_by=role, + decided_by=auth_user.username, ) return ApiResponse(data=updated) @@ -321,7 +352,7 @@ async def reject( """Reject a pending approval item. The ``decided_by`` field is populated from the authenticated - user's role. + user's username. Args: state: Application state. @@ -347,23 +378,13 @@ async def reject( ) raise NotFoundError(msg) - if item.status != ApprovalStatus.PENDING: - msg = f"Approval {approval_id!r} is {item.status.value}, not pending" - logger.warning( - API_APPROVAL_CONFLICT, - approval_id=approval_id, - current_status=item.status.value, - ) - raise ConflictError(msg) - - auth_user = request.scope["user"] - role = auth_user.role.value + auth_user = _resolve_decision(request, item, approval_id) now = datetime.now(UTC) updated = item.model_copy( update={ "status": ApprovalStatus.REJECTED, "decided_at": now, - "decided_by": role, + "decided_by": auth_user.username, "decision_reason": data.reason, }, ) @@ -386,6 +407,6 @@ async def reject( logger.info( API_APPROVAL_REJECTED, approval_id=approval_id, - decided_by=role, + decided_by=auth_user.username, ) return ApiResponse(data=updated) diff --git a/src/ai_company/persistence/sqlite/backend.py b/src/ai_company/persistence/sqlite/backend.py index 61816c470b..9033eee3d8 100644 --- a/src/ai_company/persistence/sqlite/backend.py +++ b/src/ai_company/persistence/sqlite/backend.py @@ -109,6 +109,9 @@ async def connect(self) -> None: self._db = await aiosqlite.connect(self._config.path) self._db.row_factory = aiosqlite.Row + # Enable foreign key enforcement (off by default in SQLite). + await self._db.execute("PRAGMA foreign_keys = ON") + if self._config.wal_mode: await self._configure_wal() diff --git a/src/ai_company/persistence/sqlite/migrations.py b/src/ai_company/persistence/sqlite/migrations.py index de331a43f7..66cb1bfbd6 100644 --- a/src/ai_company/persistence/sqlite/migrations.py +++ b/src/ai_company/persistence/sqlite/migrations.py @@ -206,7 +206,6 @@ created_at TEXT NOT NULL, updated_at TEXT NOT NULL )""", - "CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username)", # ── API keys ─────────────────────────────────────────── """\ CREATE TABLE IF NOT EXISTS api_keys ( @@ -220,7 +219,6 @@ revoked INTEGER NOT NULL DEFAULT 0 )""", "CREATE INDEX IF NOT EXISTS idx_api_keys_user_id ON api_keys(user_id)", - "CREATE UNIQUE INDEX IF NOT EXISTS idx_api_keys_hash ON api_keys(key_hash)", ) _MigrateFn = Callable[[aiosqlite.Connection], Coroutine[Any, Any, None]] diff --git a/src/ai_company/persistence/sqlite/user_repo.py b/src/ai_company/persistence/sqlite/user_repo.py index bc14ec16cf..a04f71f3c4 100644 --- a/src/ai_company/persistence/sqlite/user_repo.py +++ b/src/ai_company/persistence/sqlite/user_repo.py @@ -1,4 +1,9 @@ -"""SQLite repository implementations for User and ApiKey.""" +"""SQLite repository implementations for User and ApiKey. + +Provides ``SQLiteUserRepository`` and ``SQLiteApiKeyRepository``, which +persist ``User`` and ``ApiKey`` domain models to SQLite via aiosqlite. +Both use upsert semantics for ``save`` operations. +""" import sqlite3 from datetime import UTC, datetime @@ -36,7 +41,17 @@ def _row_to_user(row: aiosqlite.Row) -> User: - """Reconstruct a User from a database row.""" + """Reconstruct a ``User`` from a database row. + + Converts SQLite-native types (integers, ISO strings) back into + the domain model's expected Python types. + + Args: + row: A single database row with user columns. + + Returns: + Validated ``User`` model instance. + """ data = dict(row) data["must_change_password"] = bool(data["must_change_password"]) data["role"] = HumanRole(data["role"]) @@ -46,7 +61,17 @@ def _row_to_user(row: aiosqlite.Row) -> User: def _row_to_api_key(row: aiosqlite.Row) -> ApiKey: - """Reconstruct an ApiKey from a database row.""" + """Reconstruct an ``ApiKey`` from a database row. + + Converts SQLite-native types (integers, ISO strings) back into + the domain model's expected Python types. + + Args: + row: A single database row with API key columns. + + Returns: + Validated ``ApiKey`` model instance. + """ data = dict(row) data["revoked"] = bool(data["revoked"]) data["role"] = HumanRole(data["role"]) @@ -57,17 +82,29 @@ def _row_to_api_key(row: aiosqlite.Row) -> ApiKey: class SQLiteUserRepository: - """SQLite implementation of the UserRepository protocol. + """SQLite-backed user repository. + + Provides CRUD operations for ``User`` models using a shared + ``aiosqlite.Connection``. All write operations commit + immediately. Args: - db: An open aiosqlite connection. + db: An open aiosqlite connection with ``row_factory`` + set to ``aiosqlite.Row``. """ def __init__(self, db: aiosqlite.Connection) -> None: self._db = db async def save(self, user: User) -> None: - """Persist a user (upsert semantics).""" + """Persist a user via upsert (insert or update on conflict). + + Args: + user: User model to persist. + + Raises: + QueryError: If the database operation fails. + """ try: await self._db.execute( """\ @@ -99,10 +136,20 @@ async def save(self, user: User) -> None: error=str(exc), ) raise QueryError(msg) from exc - logger.debug(PERSISTENCE_USER_SAVED, user_id=user.id) + logger.info(PERSISTENCE_USER_SAVED, user_id=user.id) async def get(self, user_id: NotBlankStr) -> User | None: - """Retrieve a user by ID.""" + """Retrieve a user by primary key. + + Args: + user_id: Unique user identifier. + + Returns: + The matching ``User``, or ``None`` if not found. + + Raises: + QueryError: If the database query or deserialization fails. + """ try: cursor = await self._db.execute( "SELECT * FROM users WHERE id = ?", (user_id,) @@ -133,7 +180,17 @@ async def get(self, user_id: NotBlankStr) -> User | None: return user async def get_by_username(self, username: NotBlankStr) -> User | None: - """Retrieve a user by username.""" + """Retrieve a user by their unique username. + + Args: + username: Login username to look up. + + Returns: + The matching ``User``, or ``None`` if not found. + + Raises: + QueryError: If the database query or deserialization fails. + """ try: cursor = await self._db.execute( "SELECT * FROM users WHERE username = ?", (username,) @@ -161,7 +218,14 @@ async def get_by_username(self, username: NotBlankStr) -> User | None: raise QueryError(msg) from exc async def list_users(self) -> tuple[User, ...]: - """List all users.""" + """List all users ordered by creation date. + + Returns: + Tuple of all ``User`` records, oldest first. + + Raises: + QueryError: If the database query or deserialization fails. + """ try: cursor = await self._db.execute("SELECT * FROM users ORDER BY created_at") rows = await cursor.fetchall() @@ -179,7 +243,14 @@ async def list_users(self) -> tuple[User, ...]: return users async def count(self) -> int: - """Count the number of users.""" + """Return the total number of persisted users. + + Returns: + Non-negative integer count. + + Raises: + QueryError: If the database query fails. + """ try: cursor = await self._db.execute("SELECT COUNT(*) FROM users") row = await cursor.fetchone() @@ -192,7 +263,17 @@ async def count(self) -> int: return result async def delete(self, user_id: NotBlankStr) -> bool: - """Delete a user by ID.""" + """Delete a user by primary key. + + Args: + user_id: Unique user identifier. + + Returns: + ``True`` if a row was deleted, ``False`` if not found. + + Raises: + QueryError: If the database operation fails. + """ try: cursor = await self._db.execute( "DELETE FROM users WHERE id = ?", (user_id,) @@ -207,22 +288,34 @@ async def delete(self, user_id: NotBlankStr) -> bool: ) raise QueryError(msg) from exc deleted = cursor.rowcount > 0 - logger.debug(PERSISTENCE_USER_DELETED, user_id=user_id, deleted=deleted) + logger.info(PERSISTENCE_USER_DELETED, user_id=user_id, deleted=deleted) return deleted class SQLiteApiKeyRepository: - """SQLite implementation of the ApiKeyRepository protocol. + """SQLite-backed API key repository. + + Provides CRUD operations for ``ApiKey`` models using a shared + ``aiosqlite.Connection``. All write operations commit + immediately. Args: - db: An open aiosqlite connection. + db: An open aiosqlite connection with ``row_factory`` + set to ``aiosqlite.Row``. """ def __init__(self, db: aiosqlite.Connection) -> None: self._db = db async def save(self, key: ApiKey) -> None: - """Persist an API key (upsert semantics).""" + """Persist an API key via upsert (insert or update on conflict). + + Args: + key: API key model to persist. + + Raises: + QueryError: If the database operation fails. + """ try: await self._db.execute( """\ @@ -260,10 +353,20 @@ async def save(self, key: ApiKey) -> None: error=str(exc), ) raise QueryError(msg) from exc - logger.debug(PERSISTENCE_API_KEY_SAVED, key_id=key.id) + logger.info(PERSISTENCE_API_KEY_SAVED, key_id=key.id) async def get(self, key_id: NotBlankStr) -> ApiKey | None: - """Retrieve an API key by ID.""" + """Retrieve an API key by primary key. + + Args: + key_id: Unique key identifier. + + Returns: + The matching ``ApiKey``, or ``None`` if not found. + + Raises: + QueryError: If the database query or deserialization fails. + """ try: cursor = await self._db.execute( "SELECT * FROM api_keys WHERE id = ?", (key_id,) @@ -294,7 +397,17 @@ async def get(self, key_id: NotBlankStr) -> ApiKey | None: return key async def get_by_hash(self, key_hash: NotBlankStr) -> ApiKey | None: - """Retrieve an API key by its hash.""" + """Retrieve an API key by its SHA-256 hash. + + Args: + key_hash: Hex-encoded SHA-256 digest of the raw key. + + Returns: + The matching ``ApiKey``, or ``None`` if not found. + + Raises: + QueryError: If the database query or deserialization fails. + """ try: cursor = await self._db.execute( "SELECT * FROM api_keys WHERE key_hash = ?", @@ -315,7 +428,17 @@ async def get_by_hash(self, key_hash: NotBlankStr) -> ApiKey | None: raise QueryError(msg) from exc async def list_by_user(self, user_id: NotBlankStr) -> tuple[ApiKey, ...]: - """List API keys belonging to a user.""" + """List all API keys belonging to a user, ordered by creation date. + + Args: + user_id: Owner user identifier. + + Returns: + Tuple of ``ApiKey`` records, oldest first. + + Raises: + QueryError: If the database query or deserialization fails. + """ try: cursor = await self._db.execute( "SELECT * FROM api_keys WHERE user_id = ? ORDER BY created_at", @@ -348,7 +471,17 @@ async def list_by_user(self, user_id: NotBlankStr) -> tuple[ApiKey, ...]: return keys async def delete(self, key_id: NotBlankStr) -> bool: - """Delete an API key by ID.""" + """Delete an API key by primary key. + + Args: + key_id: Unique key identifier. + + Returns: + ``True`` if a row was deleted, ``False`` if not found. + + Raises: + QueryError: If the database operation fails. + """ try: cursor = await self._db.execute( "DELETE FROM api_keys WHERE id = ?", (key_id,) @@ -363,5 +496,5 @@ async def delete(self, key_id: NotBlankStr) -> bool: ) raise QueryError(msg) from exc deleted = cursor.rowcount > 0 - logger.debug(PERSISTENCE_API_KEY_DELETED, key_id=key_id, deleted=deleted) + logger.info(PERSISTENCE_API_KEY_DELETED, key_id=key_id, deleted=deleted) return deleted diff --git a/tests/unit/api/auth/test_service.py b/tests/unit/api/auth/test_service.py index 64f7a8e72e..3f2dda24e8 100644 --- a/tests/unit/api/auth/test_service.py +++ b/tests/unit/api/auth/test_service.py @@ -1,5 +1,7 @@ """Tests for AuthService.""" +from datetime import UTC, datetime, timedelta + import pytest from ai_company.api.auth.config import AuthConfig @@ -19,8 +21,6 @@ def _make_user( role: HumanRole = HumanRole.CEO, must_change_password: bool = False, ) -> User: - from datetime import UTC, datetime - now = datetime.now(UTC) svc = _make_service() return User( @@ -91,8 +91,6 @@ def test_expired_token_raises(self) -> None: _token, _ = svc.create_token(user) # Manually create an expired token - from datetime import UTC, datetime, timedelta - expired_payload = { "sub": user.id, "username": user.username, @@ -127,8 +125,6 @@ def test_must_change_password_in_claims(self) -> None: assert claims["must_change_password"] is True def test_decode_token_missing_sub_claim(self) -> None: - from datetime import UTC, datetime, timedelta - import jwt as pyjwt svc = _make_service() diff --git a/tests/unit/api/controllers/test_approvals.py b/tests/unit/api/controllers/test_approvals.py index 6defff389a..29a986353e 100644 --- a/tests/unit/api/controllers/test_approvals.py +++ b/tests/unit/api/controllers/test_approvals.py @@ -259,7 +259,7 @@ async def test_approve_pending( assert resp.status_code == 200 body = resp.json() assert body["data"]["status"] == "approved" - assert body["data"]["decided_by"] == "ceo" + assert body["data"]["decided_by"] == "test-ceo" assert body["data"]["decision_reason"] == "Looks good" async def test_approve_records_decided_by_from_header( @@ -274,7 +274,7 @@ async def test_approve_records_decided_by_from_header( headers=make_auth_headers("manager"), ) assert resp.status_code == 200 - assert resp.json()["data"]["decided_by"] == "manager" + assert resp.json()["data"]["decided_by"] == "test-manager" def test_approve_not_found(self, test_client: TestClient[Any]) -> None: resp = test_client.post( @@ -360,7 +360,7 @@ async def test_reject_pending( assert resp.status_code == 200 body = resp.json() assert body["data"]["status"] == "rejected" - assert body["data"]["decided_by"] == "ceo" + assert body["data"]["decided_by"] == "test-ceo" assert body["data"]["decision_reason"] == "Too risky" async def test_reject_requires_reason( diff --git a/tests/unit/persistence/sqlite/test_migrations.py b/tests/unit/persistence/sqlite/test_migrations.py index 4a7f4821fa..154a0ba8a4 100644 --- a/tests/unit/persistence/sqlite/test_migrations.py +++ b/tests/unit/persistence/sqlite/test_migrations.py @@ -144,32 +144,20 @@ async def test_v4_creates_audit_entry_indexes( } assert expected.issubset(indexes) - async def test_v5_creates_users_table( - self, memory_db: aiosqlite.Connection - ) -> None: - await run_migrations(memory_db) - cursor = await memory_db.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='users'" - ) - row = await cursor.fetchone() - assert row is not None - - async def test_v5_creates_api_keys_table( - self, memory_db: aiosqlite.Connection - ) -> None: - await run_migrations(memory_db) - cursor = await memory_db.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='api_keys'" - ) - row = await cursor.fetchone() - assert row is not None - - async def test_v5_creates_settings_table( - self, memory_db: aiosqlite.Connection + @pytest.mark.parametrize( + "table_name", + ["users", "api_keys", "settings"], + ) + async def test_v5_creates_table( + self, + memory_db: aiosqlite.Connection, + table_name: str, ) -> None: + """V5 migration creates the expected tables.""" await run_migrations(memory_db) cursor = await memory_db.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='settings'" + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (table_name,), ) row = await cursor.fetchone() assert row is not None @@ -183,7 +171,6 @@ async def test_v5_creates_user_indexes( "AND name LIKE 'idx_%' AND name LIKE '%user%' ORDER BY name" ) indexes = {row[0] for row in await cursor.fetchall()} - assert "idx_users_username" in indexes assert "idx_api_keys_user_id" in indexes async def test_migration_failure_raises_migration_error(