diff --git a/CLAUDE.md b/CLAUDE.md index 8fe06a5cb8..134b81491b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -65,7 +65,7 @@ curl http://localhost:3000/api/v1/health # backend (via web proxy) ```text src/ai_company/ - api/ # Litestar REST + WebSocket API (controllers, guards, channels) + api/ # Litestar REST + WebSocket API (controllers, guards, channels, JWT + API key auth) budget/ # Cost tracking, budget enforcement (pre-flight/in-flight checks, auto-downgrade), billing periods, cost tiers, quota/subscription tracking, CFO cost optimization (anomaly detection, efficiency analysis, downgrade recommendations, approval decisions), spending reports, budget errors (BudgetExhaustedError, DailyLimitExceededError, QuotaExhaustedError) cli/ # CLI interface (future — thin API wrapper if needed) communication/ # Message bus, dispatcher, messenger, channels, delegation, loop prevention, conflict resolution, meeting protocol diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index dfd1ce9a03..60ec1f775b 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -81,7 +81,7 @@ The MVP validates the core hypothesis: **a single agent can complete a real task > **Implementation snapshot (2026-03-10):** > - **Done:** M0–M6 (tooling, config/core, providers, single-agent engine, multi-agent orchestration, API/CLI surface) + Docker sandbox (#50), MCP bridge (#53), code runner + HR engine (hiring/firing/onboarding/offboarding/registry) + performance tracking (task metrics, quality scoring, collaboration scoring, trend detection, rolling windows). Memory layer backend selected ([ADR-001](docs/decisions/ADR-001-memory-layer.md)). Persistence backend (§7.6) completed (including audit entry persistence via AuditRepository + SQLite backend). Memory retrieval pipeline (#41: ranking, token-budget formatting, context injection, non-inferable filtering) complete. Budget enforcement complete (BudgetEnforcer + configurable cost tiers + quota/subscription tracking). CFO cost optimization complete (CostOptimizer: anomaly detection, efficiency analysis, downgrade recommendations, routing optimization, approval decisions; ReportGenerator: multi-dimensional spending reports). Shared org memory (#125: HybridPromptRetrievalBackend, OrgFactStore, access control, factory) complete. Memory consolidation/archival (#48: ConsolidationService, SimpleConsolidationStrategy, RetentionEnforcer, ArchivalStore protocol) complete. SecOps agent (rule engine, audit log, output scanner, output scan response policies (redact/withhold/log-only/autonomy-tiered), risk classifier, ToolInvoker integration), progressive trust (4 strategies: disabled/weighted/per-category/milestone behind TrustStrategy protocol), promotion/demotion (criteria evaluation, approval strategies, model mapping). Autonomy levels (#42: AutonomyLevel enum, presets, 3-level resolver, rule-based auto-downgrade/human-only promotion change strategy) + approval timeout policies (#126: 4 timeout policies, park/resume service, risk tier classifier, timeout checker) complete. -> - **Remaining:** JWT/OAuth auth, approval workflow gates. +> - **Remaining:** Approval workflow gates. ### 1.5 Configuration Philosophy @@ -2562,6 +2562,7 @@ The REST/WebSocket API is the **primary interface** for all consumers. The Web U ```text /api/v1/ ├── /health # Health check, readiness + ├── /auth # Authentication: setup, login, password change, me ├── /company # CRUD company config ├── /agents # List, hire, fire, modify agents ├── /departments # Department management @@ -2758,6 +2759,7 @@ Circular inheritance is detected via chain tracking and raises `TemplateInherita | **Docker API** | aiodocker | Async-native Docker API client for `DockerSandbox` backend | | **Tool Integration** | MCP SDK (`mcp`) | Industry standard for LLM-to-tool integration | | **Agent Comms** | A2A Protocol compatible | Future-proof inter-agent communication | +| **Authentication** | PyJWT + argon2-cffi | JWT (HMAC HS256/384/512) for session tokens, Argon2id for password hashing, SHA-256 for API key storage | | **Config Format** | YAML + Pydantic validation | Human-readable config with strict validation | | **CLI** | TBD (future, if needed) | Thin wrapper around the REST API for terminal use. May not be needed — interactive Scalar docs at `/docs/api` and `curl`/`httpie` may suffice | @@ -2980,7 +2982,7 @@ ai-company/ │ ├── persistence/ # Operational data persistence (§7.6) │ │ ├── __init__.py # Package exports │ │ ├── protocol.py # PersistenceBackend protocol (M5) -│ │ ├── repositories.py # Repository protocols: TaskRepository, CostRecordRepository, MessageRepository, ParkedContextRepository, AuditRepository +│ │ ├── repositories.py # Repository protocols: TaskRepository, CostRecordRepository, MessageRepository, ParkedContextRepository, AuditRepository, UserRepository, ApiKeyRepository │ │ ├── config.py # PersistenceConfig model (M5) │ │ ├── errors.py # Persistence error hierarchy (M5) │ │ ├── factory.py # create_backend() factory (M5) @@ -2991,7 +2993,8 @@ ai-company/ │ │ ├── hr_repositories.py # SQLite HR repositories (LifecycleEvent, TaskMetricRecord, CollaborationMetricRecord) │ │ ├── parked_context_repo.py # SQLiteParkedContextRepository (park/resume serialized agent state) │ │ ├── audit_repository.py # SQLiteAuditRepository (append-only audit entry persistence) -│ │ └── migrations.py # Schema migrations (user_version pragma) +│ │ ├── user_repo.py # SQLiteUserRepository + SQLiteApiKeyRepository +│ │ └── migrations.py # Schema migrations (user_version pragma, v1–v5) │ ├── observability/ # Structured logging & correlation │ │ ├── __init__.py # get_logger() entry point │ │ ├── _logger.py # Logger configuration @@ -3183,18 +3186,25 @@ ai-company/ │ ├── api/ # REST + WebSocket API (M6) │ │ ├── app.py # Litestar application factory, lifecycle hooks │ │ ├── approval_store.py # In-memory approval queue storage +│ │ ├── auth/ # JWT + API key authentication subsystem +│ │ │ ├── config.py # AuthConfig (frozen Pydantic, HMAC algorithm, exclude paths) +│ │ │ ├── controller.py # AuthController (setup, login, change-password, me) +│ │ │ ├── middleware.py # ApiAuthMiddleware (JWT-first, API key fallback) +│ │ │ ├── models.py # User, ApiKey, AuthenticatedUser, AuthMethod +│ │ │ ├── secret.py # JWT secret resolution (env var → persistence → auto-generate) +│ │ │ └── service.py # AuthService (Argon2id password hashing, JWT ops, API key hashing) │ │ ├── bus_bridge.py # Message-bus → WebSocket bridge │ │ ├── channels.py # WebSocket channel definitions │ │ ├── config.py # API configuration models (ServerConfig, CorsConfig) -│ │ ├── controllers/ # 14 class-based controllers + 1 WebSocket handler (15 route modules) +│ │ ├── controllers/ # 15 class-based controllers + 1 WebSocket handler (16 route modules) │ │ ├── dto.py # Request/response DTOs and envelopes -│ │ ├── errors.py # API error hierarchy (ApiError, NotFoundError, etc.) +│ │ ├── errors.py # API error hierarchy (ApiError, NotFoundError, UnauthorizedError, etc.) │ │ ├── exception_handlers.py # Litestar exception handler registration -│ │ ├── guards.py # Route guards — read/write access (stub auth, M7 real auth) -│ │ ├── middleware.py # Request logging middleware +│ │ ├── guards.py # Route guards — role-based read/write access control (HumanRole enum) +│ │ ├── middleware.py # Request logging, CSP middleware │ │ ├── pagination.py # Cursor-free offset/limit pagination │ │ ├── server.py # Uvicorn server runner -│ │ ├── state.py # Typed AppState container with service access +│ │ ├── state.py # Typed AppState container with service access (deferred auth init) │ │ └── ws_models.py # WebSocket event models (WsEvent, WsEventType) │ ├── cli/ # CLI interface (future, if needed) │ │ ├── __init__.py diff --git a/README.md b/README.md index db2750f090..e0a8074963 100644 --- a/README.md +++ b/README.md @@ -24,10 +24,11 @@ AI Company lets you spin up a virtual organization staffed entirely by AI agents - **Memory Interface (M5)** - Pluggable `MemoryBackend` protocol with capability discovery, shared knowledge protocol, domain models, config, factory, and context injection retrieval pipeline (ranking, token-budget formatting, non-inferable filtering). Shared organizational memory via `OrgMemoryBackend` protocol with hybrid prompt+retrieval backend. Memory consolidation/archival with pluggable strategies and retention enforcement - **Coordination Error Taxonomy (M5)** - Post-execution classification pipeline detecting logical contradictions, numerical drift, context omissions, and coordination failures - **Budget Enforcement (M5)** - `BudgetEnforcer` service with pre-flight checks, in-flight budget checking, auto-downgrade, configurable cost tiers, and quota/subscription tracking; `CostOptimizer` CFO service with anomaly detection, efficiency analysis, downgrade recommendations, and approval decisions; `ReportGenerator` for multi-dimensional spending reports -- **Litestar REST API (M6)** - 13 controllers + WebSocket handler covering company, agents, tasks, budget, approvals, analytics, messages, meetings, projects, departments, artifacts, providers, health, and WebSocket real-time feed +- **Litestar REST API (M6)** - 15 controllers + WebSocket handler covering company, agents, tasks, budget, approvals, analytics, messages, meetings, projects, departments, artifacts, providers, health, auth, and WebSocket real-time feed - **Human Approval Queue (M6)** - Approval submission, approve/reject with reason, list/filter by status, WebSocket notifications for approval events - **WebSocket Real-Time Feed (M6)** - Channel-based subscriptions (tasks, agents, budget, messages, system, approvals), per-channel payload filters, message-bus bridge -- **Route Guards (M6)** - Role-based read/write access control (stub auth for M6; real JWT/OAuth planned for M7) +- **Route Guards (M6)** - Role-based read/write access control with 5 human roles (CEO, Manager, Board Member, Pair Programmer, Observer) +- **JWT + API Key Authentication (M7)** - Mandatory auth middleware (JWT-first with API key fallback), Argon2id password hashing, first-run admin setup, password change flow, SHA-256 API key hashing, regex-based path exclusions - **HR Engine (M7)** - Hiring pipeline (request → generate candidate → approval → instantiate), onboarding checklists, offboarding pipeline (reassign → archive → notify → terminate), agent registry - **Performance Tracking (M7)** - Task metrics, CI-based quality scoring, behavioral collaboration scoring, Theil-Sen robust trend detection, multi-window rolling metric aggregation - **Progressive Trust (M7)** - 4 strategies (disabled/weighted/per-category/milestone) behind pluggable `TrustStrategy` protocol, trust level tracking, action permission evaluation @@ -38,12 +39,12 @@ AI Company lets you spin up a virtual organization staffed entirely by AI agents - **Memory Backend Adapter (M5)** - Memory protocols, retrieval pipeline, org memory, and consolidation are complete; initial Mem0 adapter backend ([ADR-001](docs/decisions/ADR-001-memory-layer.md)) pending; research backends (GraphRAG, Temporal KG) planned - **CLI Surface** - `cli/` package is placeholder-only -- **Security/Approval System (M7)** - Real authentication (JWT/OAuth) and approval workflow gates are planned +- **Security/Approval System (M7)** - SecOps agent with rule engine (soft-allow/hard-deny, fail-closed), audit log, output scanner, risk classifier, and ToolInvoker integration are implemented; progressive trust (4 strategies), promotion/demotion, autonomy levels (5 tiers with presets, resolver, change strategies) and approval timeout policies (wait-forever, auto-deny, tiered, escalation-chain with task park/resume) are implemented; JWT + API key authentication is implemented; approval workflow gates remain planned - **Advanced Product Surface** - web dashboard, external integrations ## Status -**M7: Security & Approval** partially complete — Docker sandbox, MCP bridge, code runner, SecOps agent, HR engine + performance tracking, progressive trust, promotion/demotion done; authentication/approval workflow gates remain. See [DESIGN_SPEC.md](DESIGN_SPEC.md) for the full high-level specification. +**M7: Security & Approval** partially complete — Docker sandbox, MCP bridge, code runner, SecOps agent, HR engine + performance tracking, progressive trust, promotion/demotion, JWT + API key authentication done; approval workflow gates remain. See [DESIGN_SPEC.md](DESIGN_SPEC.md) for the full high-level specification. ## Tech Stack diff --git a/docker/.env.example b/docker/.env.example index eae54725a1..7f336261e8 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -9,6 +9,14 @@ # API key for the LLM provider (required for agent execution) LLM_API_KEY= +# --- Authentication ---------------------------------------------------------- +# JWT signing secret (optional — auto-generated and persisted on first run). +# Set explicitly only for multi-instance deployments sharing a common secret. +# Must be >= 32 characters if set. +# Generate with: python -c "import secrets; print(secrets.token_urlsafe(48))" +# AI_COMPANY_JWT_SECRET= +# First-run: POST /api/v1/auth/setup to create admin account + # --- Application ------------------------------------------------------------- # Log level: debug, info, warning, error, critical AI_COMPANY_LOG_LEVEL=info diff --git a/pyproject.toml b/pyproject.toml index aa14594df4..3d8e36f09e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,12 +15,14 @@ classifiers = [ dependencies = [ "aiodocker==0.26.0", "aiosqlite==0.22.1", + "argon2-cffi==25.1.0", "jinja2==3.1.6", "jsonschema==4.26.0", "litellm==1.82.1", "litestar[standard,structlog,pydantic,brotli,prometheus]==2.21.1", "mcp==1.26.0", "pydantic==2.12.5", + "pyjwt[crypto]==2.11.0", "pyyaml==6.0.3", "structlog==25.5.0", ] diff --git a/src/ai_company/api/app.py b/src/ai_company/api/app.py index 793b5832fb..a3c896c39e 100644 --- a/src/ai_company/api/app.py +++ b/src/ai_company/api/app.py @@ -19,6 +19,10 @@ from ai_company import __version__ from ai_company.api.approval_store import ApprovalStore +from ai_company.api.auth.controller import require_password_changed +from ai_company.api.auth.middleware import create_auth_middleware_class +from ai_company.api.auth.secret import resolve_jwt_secret +from ai_company.api.auth.service import AuthService from ai_company.api.bus_bridge import MessageBusBridge from ai_company.api.channels import CHANNEL_APPROVALS, create_channels_plugin from ai_company.api.controllers import ALL_CONTROLLERS @@ -98,6 +102,7 @@ def _build_lifecycle( persistence: PersistenceBackend | None, message_bus: MessageBus | None, bridge: MessageBusBridge | None, + app_state: AppState, ) -> tuple[ Sequence[Callable[[], Awaitable[None]]], Sequence[Callable[[], Awaitable[None]]], @@ -110,7 +115,7 @@ def _build_lifecycle( async def on_startup() -> None: logger.info(API_APP_STARTUP, version=__version__) - await _safe_startup(persistence, message_bus, bridge) + await _safe_startup(persistence, message_bus, bridge, app_state) async def on_shutdown() -> None: logger.info(API_APP_SHUTDOWN, version=__version__) @@ -131,28 +136,69 @@ async def _cleanup_on_failure( try: await message_bus.stop() except Exception: - logger.error( + logger.exception( API_APP_STARTUP, error="Cleanup: failed to stop message bus", - exc_info=True, ) if started_persistence and persistence is not None: try: await persistence.disconnect() except Exception: - logger.error( + logger.exception( API_APP_STARTUP, error="Cleanup: failed to disconnect persistence", - exc_info=True, ) +async def _init_persistence( + persistence: PersistenceBackend, + app_state: AppState, +) -> None: + """Run migrations and resolve JWT secret on an already-connected backend. + + Must only be called after ``persistence.connect()`` has succeeded. + + Args: + persistence: Connected persistence backend. + app_state: Application state for auth service injection. + """ + try: + await persistence.migrate() + except Exception: + logger.exception( + API_APP_STARTUP, + error="Failed to run persistence migrations", + ) + raise + + # Resolve JWT secret after persistence is up + if app_state.has_auth_service: + logger.info( + API_APP_STARTUP, + note="Auth service already configured, skipping JWT secret resolution", + ) + else: + try: + secret = await resolve_jwt_secret(persistence) + auth_config = app_state.config.api.auth.with_secret( + secret, + ) + app_state.set_auth_service(AuthService(auth_config)) + except Exception: + logger.exception( + API_APP_STARTUP, + error="Failed to resolve JWT secret", + ) + raise + + async def _safe_startup( persistence: PersistenceBackend | None, message_bus: MessageBus | None, bridge: MessageBusBridge | None, + app_state: AppState, ) -> None: - """Connect persistence, start message bus and bridge. + """Connect persistence, resolve JWT secret, start message bus and bridge. Executes in order; on failure, cleans up already-started components in reverse order before re-raising. @@ -164,21 +210,23 @@ async def _safe_startup( try: await persistence.connect() except Exception: - logger.error( + logger.exception( API_APP_STARTUP, error="Failed to connect persistence", - exc_info=True, ) raise + # Mark connected immediately so cleanup can disconnect + # if migrate() or JWT resolution fails below. started_persistence = True + await _init_persistence(persistence, app_state) + if message_bus is not None: try: await message_bus.start() except Exception: - logger.error( + logger.exception( API_APP_STARTUP, error="Failed to start message bus", - exc_info=True, ) raise started_bus = True @@ -186,10 +234,9 @@ async def _safe_startup( try: await bridge.start() except Exception: - logger.error( + logger.exception( API_APP_STARTUP, error="Failed to start message bus bridge", - exc_info=True, ) raise except Exception: @@ -212,38 +259,36 @@ async def _safe_shutdown( try: await bridge.stop() except Exception: - logger.error( + logger.exception( API_APP_SHUTDOWN, error="Failed to stop message bus bridge", - exc_info=True, ) if message_bus is not None: try: await message_bus.stop() except Exception: - logger.error( + logger.exception( API_APP_SHUTDOWN, error="Failed to stop message bus", - exc_info=True, ) if persistence is not None: try: await persistence.disconnect() except Exception: - logger.error( + logger.exception( API_APP_SHUTDOWN, error="Failed to disconnect persistence", - exc_info=True, ) -def create_app( +def create_app( # noqa: PLR0913 *, config: RootConfig | None = None, persistence: PersistenceBackend | None = None, message_bus: MessageBus | None = None, cost_tracker: CostTracker | None = None, approval_store: ApprovalStore | None = None, + auth_service: AuthService | None = None, ) -> Litestar: """Create and configure the Litestar application. @@ -256,6 +301,7 @@ def create_app( message_bus: Internal message bus. cost_tracker: Cost tracking service. approval_store: Approval queue store. + auth_service: Pre-built auth service (for testing). Returns: Configured Litestar application. @@ -283,6 +329,7 @@ def create_app( message_bus=message_bus, cost_tracker=cost_tracker, approval_store=effective_approval_store, + auth_service=auth_service, startup_time=time.monotonic(), ) @@ -293,12 +340,14 @@ def create_app( api_router = Router( path=api_config.api_prefix, route_handlers=[*ALL_CONTROLLERS, ws_handler], + guards=[require_password_changed], ) startup, shutdown = _build_lifecycle( persistence, message_bus, bridge, + app_state, ) return Litestar( @@ -369,4 +418,24 @@ def _build_middleware(api_config: ApiConfig) -> list[Middleware]: rate_limit=(rl.time_unit, rl.max_requests), # type: ignore[arg-type] exclude=list(rl.exclude_paths), ) - return [CSPMiddleware, RequestLoggingMiddleware, rate_limit.middleware] + auth = api_config.auth + if auth.exclude_paths is None: + prefix = api_config.api_prefix + auth = auth.model_copy( + update={ + "exclude_paths": ( + f"^{prefix}/health$", + "^/docs", + "^/api$", + f"^{prefix}/auth/setup$", + f"^{prefix}/auth/login$", + ), + }, + ) + auth_middleware = create_auth_middleware_class(auth) + return [ + auth_middleware, + CSPMiddleware, + RequestLoggingMiddleware, + rate_limit.middleware, + ] diff --git a/src/ai_company/api/auth/__init__.py b/src/ai_company/api/auth/__init__.py new file mode 100644 index 0000000000..3bad8b316e --- /dev/null +++ b/src/ai_company/api/auth/__init__.py @@ -0,0 +1,14 @@ +"""Authentication and authorization for the API layer.""" + +from ai_company.api.auth.config import AuthConfig +from ai_company.api.auth.models import ApiKey, AuthenticatedUser, AuthMethod, User +from ai_company.api.auth.service import AuthService + +__all__ = [ + "ApiKey", + "AuthConfig", + "AuthMethod", + "AuthService", + "AuthenticatedUser", + "User", +] diff --git a/src/ai_company/api/auth/config.py b/src/ai_company/api/auth/config.py new file mode 100644 index 0000000000..88f9d58661 --- /dev/null +++ b/src/ai_company/api/auth/config.py @@ -0,0 +1,104 @@ +"""Authentication configuration.""" + +from typing import Literal, Self + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +MIN_SECRET_LENGTH = 32 + + +def _require_valid_secret(secret: str) -> None: + """Raise ``ValueError`` if *secret* is non-empty but too short. + + Args: + secret: JWT signing secret to validate. + + Raises: + ValueError: If *secret* is non-empty and shorter than + ``MIN_SECRET_LENGTH``. + """ + if secret and len(secret) < MIN_SECRET_LENGTH: + msg = ( + f"jwt_secret must be at least {MIN_SECRET_LENGTH} " + f"characters (got {len(secret)})" + ) + raise ValueError(msg) + + +class AuthConfig(BaseModel): + """JWT and authentication configuration. + + The ``jwt_secret`` is resolved at application startup via a + priority chain: + + 1. ``AI_COMPANY_JWT_SECRET`` environment variable (for multi-instance + deployments sharing a common secret). + 2. Previously persisted secret in the ``settings`` table. + 3. Auto-generate a new secret and persist it for future runs. + + At construction time the secret may be empty — it is populated + before the first request is served. + + Attributes: + jwt_secret: HMAC signing key (resolved at startup, repr-hidden). + jwt_algorithm: JWT signing algorithm (HMAC family only). + jwt_expiry_minutes: Token lifetime in minutes. + min_password_length: Minimum password length for setup/change. + exclude_paths: URL paths excluded from auth middleware. + """ + + model_config = ConfigDict(frozen=True) + + jwt_secret: str = Field( + default="", + repr=False, + description="JWT signing secret (resolved at startup)", + ) + jwt_algorithm: Literal["HS256", "HS384", "HS512"] = Field( + default="HS256", + description="JWT signing algorithm (HMAC family)", + ) + jwt_expiry_minutes: int = Field( + default=1440, + ge=1, + le=43200, + description="Token lifetime in minutes (default 24h)", + ) + min_password_length: int = Field( + default=12, + ge=8, + le=128, + description="Minimum password length for setup and password change", + ) + exclude_paths: tuple[str, ...] | None = Field( + default=None, + description=( + "Regex patterns for paths excluded from authentication. " + "When None (default), paths are auto-derived from the " + "API prefix (health, auth/setup, auth/login, docs, " + "scalar UI). " + "Use ^ to anchor at the start of the path and add $ when " + "an exact match (rather than a prefix match) is required." + ), + ) + + @model_validator(mode="after") + def _validate_secret_length(self) -> Self: + """Reject non-empty secrets shorter than the minimum.""" + _require_valid_secret(self.jwt_secret) + return self + + def with_secret(self, secret: str) -> AuthConfig: + """Return a copy with the JWT secret set. + + Args: + secret: Resolved JWT signing secret. + + Returns: + New ``AuthConfig`` with the secret populated. + + Raises: + ValueError: If the secret is too short. + """ + _require_valid_secret(secret) + return self.model_copy(update={"jwt_secret": secret}) diff --git a/src/ai_company/api/auth/controller.py b/src/ai_company/api/auth/controller.py new file mode 100644 index 0000000000..61b4b59df7 --- /dev/null +++ b/src/ai_company/api/auth/controller.py @@ -0,0 +1,413 @@ +"""Authentication controller — setup, login, password change, me.""" + +import uuid +from datetime import UTC, datetime +from typing import Any, Self + +from litestar import Controller, Request, Response, get, post +from litestar.connection import ASGIConnection # noqa: TC002 +from litestar.exceptions import PermissionDeniedException +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from ai_company.api.auth.config import AuthConfig +from ai_company.api.auth.models import AuthenticatedUser, User +from ai_company.api.auth.service import AuthService # noqa: TC001 +from ai_company.api.dto import ApiResponse +from ai_company.api.errors import ConflictError, UnauthorizedError +from ai_company.api.guards import HumanRole +from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.observability import get_logger +from ai_company.observability.events.api import ( + API_AUTH_FAILED, + API_AUTH_PASSWORD_CHANGED, + API_AUTH_SETUP_COMPLETE, + API_AUTH_TOKEN_ISSUED, +) + +logger = get_logger(__name__) + +# Derive from AuthConfig default to prevent silent divergence. +_MIN_PASSWORD_LENGTH: int = AuthConfig.model_fields["min_password_length"].default + +# Pre-computed Argon2id hash for constant-time rejection when the +# username doesn't exist — prevents timing-based username enumeration. +# The actual password is irrelevant; only the verification time matters. +_DUMMY_ARGON2_HASH = ( + "$argon2id$v=19$m=65536,t=3,p=4$" + "c2FsdHNhbHRzYWx0$" + "mB0bZKSNwOhSdxMQfsldT3qGmFyjVqbkntMkutMfdUs" +) + + +def _check_password_length(password: str) -> str: + """Validate that a password meets the minimum length requirement. + + Args: + password: Password to validate. + + Returns: + The password unchanged. + + Raises: + ValueError: If the password is too short. + """ + if len(password) < _MIN_PASSWORD_LENGTH: + msg = f"Password must be at least {_MIN_PASSWORD_LENGTH} characters" + raise ValueError(msg) + return password + + +# ── Request DTOs ────────────────────────────────────────────── + + +class SetupRequest(BaseModel): + """First-run admin account creation payload. + + Attributes: + username: Admin login username. + password: Admin password (min 12 chars). + """ + + model_config = ConfigDict(frozen=True) + + username: NotBlankStr = Field(max_length=128) + password: NotBlankStr = Field(max_length=128) + + @model_validator(mode="after") + def _validate_password_length(self) -> Self: + """Reject passwords shorter than the minimum.""" + _check_password_length(self.password) + return self + + +class LoginRequest(BaseModel): + """Login credentials payload. + + Attributes: + username: Login username. + password: Login password. + """ + + model_config = ConfigDict(frozen=True) + + username: NotBlankStr = Field(max_length=128) + password: NotBlankStr = Field(max_length=128) + + +class ChangePasswordRequest(BaseModel): + """Password change payload. + + Attributes: + current_password: Current password for verification. + new_password: New password (min 12 chars). + """ + + model_config = ConfigDict(frozen=True) + + current_password: NotBlankStr = Field(max_length=128) + new_password: NotBlankStr = Field(max_length=128) + + @model_validator(mode="after") + def _validate_password_length(self) -> Self: + """Reject new passwords shorter than the minimum.""" + _check_password_length(self.new_password) + return self + + +# ── Response DTOs ───────────────────────────────────────────── + + +class TokenResponse(BaseModel): + """JWT token response. + + Attributes: + token: Encoded JWT string. + expires_in: Token lifetime in seconds. + must_change_password: Whether password change is required. + """ + + model_config = ConfigDict(frozen=True) + + token: NotBlankStr + expires_in: int = Field(gt=0) + must_change_password: bool + + +class UserInfoResponse(BaseModel): + """Current user information. + + Attributes: + id: User ID. + username: Login username. + role: Access control role. + must_change_password: Whether password change is required. + """ + + model_config = ConfigDict(frozen=True) + + id: NotBlankStr + username: NotBlankStr + role: HumanRole + must_change_password: bool + + +_PWD_CHANGE_EXEMPT_SUFFIXES = ("/auth/change-password", "/auth/me") + +# ── Guards ──────────────────────────────────────────────────── + + +def require_password_changed( + connection: ASGIConnection, # type: ignore[type-arg] + _: object, +) -> None: + """Guard that blocks users who must change their password. + + Paths ending with ``/auth/change-password`` or ``/auth/me`` + are exempt so the user can actually change the password or + inspect their own profile. + + Args: + connection: The incoming connection. + _: Route handler (unused). + + Raises: + PermissionDeniedException: If password change is required + or the user object is present but not an + ``AuthenticatedUser``. + """ + path = str(connection.url.path) + if any(path.endswith(s) for s in _PWD_CHANGE_EXEMPT_SUFFIXES): + return + user = connection.scope.get("user") + if user is None: + return + if not isinstance(user, AuthenticatedUser): + logger.warning( + API_AUTH_FAILED, + reason="unexpected_user_type", + user_type=type(user).__qualname__, + path=path, + ) + raise PermissionDeniedException(detail="Invalid user session") + if user.must_change_password: + raise PermissionDeniedException(detail="Password change required") + + +# ── Controller ──────────────────────────────────────────────── + + +class AuthController(Controller): + """Authentication endpoints: setup, login, password change, me.""" + + path = "/auth" + tags = ("auth",) + + @post( + "/setup", + status_code=201, + summary="First-run admin setup", + ) + async def setup( + self, + data: SetupRequest, + request: Request[Any, Any, Any], + ) -> Response[ApiResponse[TokenResponse]]: + """Create the first admin account (CEO). + + Only available when no users exist. Returns 409 after + the first account is created. + """ + app_state = request.app.state["app_state"] + auth_service: AuthService = app_state.auth_service + persistence = app_state.persistence + + user_count = await persistence.users.count() + if user_count > 0: + logger.warning(API_AUTH_FAILED, reason="setup_already_completed") + msg = "Setup already completed" + raise ConflictError(msg) + + now = datetime.now(UTC) + password_hash = await auth_service.hash_password_async(data.password) + user = User( + id=str(uuid.uuid4()), + username=data.username, + password_hash=password_hash, + role=HumanRole.CEO, + must_change_password=True, + created_at=now, + updated_at=now, + ) + await persistence.users.save(user) + + # Race guard: undo if another setup completed concurrently + post_count = await persistence.users.count() + if post_count > 1: + await persistence.users.delete(user.id) + logger.warning(API_AUTH_FAILED, reason="setup_race_detected") + msg = "Setup already completed" + raise ConflictError(msg) + + token, expires_in = auth_service.create_token(user) + + logger.info( + API_AUTH_SETUP_COMPLETE, + user_id=user.id, + username=user.username, + ) + + return Response( + content=ApiResponse( + data=TokenResponse( + token=token, + expires_in=expires_in, + must_change_password=user.must_change_password, + ), + ), + status_code=201, + ) + + @post( + "/login", + status_code=200, + summary="Authenticate with credentials", + ) + async def login( + self, + data: LoginRequest, + request: Request[Any, Any, Any], + ) -> Response[ApiResponse[TokenResponse]]: + """Validate credentials and return a JWT.""" + app_state = request.app.state["app_state"] + auth_service: AuthService = app_state.auth_service + persistence = app_state.persistence + + user = await persistence.users.get_by_username(data.username) + if user is not None: + password_valid = await auth_service.verify_password_async( + data.password, user.password_hash + ) + else: + # Constant-time rejection: run verification against a + # dummy hash to prevent timing-based username enumeration. + await auth_service.verify_password_async(data.password, _DUMMY_ARGON2_HASH) + password_valid = False + + if not password_valid: + logger.warning( + API_AUTH_FAILED, + reason="invalid_credentials", + ) + msg = "Invalid credentials" + raise UnauthorizedError(msg) + + token, expires_in = auth_service.create_token(user) + + logger.info( + API_AUTH_TOKEN_ISSUED, + user_id=user.id, + username=user.username, + ) + + return Response( + content=ApiResponse( + data=TokenResponse( + token=token, + expires_in=expires_in, + must_change_password=user.must_change_password, + ), + ), + ) + + @post( + "/change-password", + status_code=200, + summary="Change current user password", + ) + async def change_password( + self, + data: ChangePasswordRequest, + request: Request[Any, Any, Any], + ) -> Response[ApiResponse[UserInfoResponse]]: + """Validate current password and set new one.""" + auth_user = request.scope.get("user") + if not isinstance(auth_user, AuthenticatedUser): + msg = "Authentication required" + raise UnauthorizedError(msg) + app_state = request.app.state["app_state"] + auth_service: AuthService = app_state.auth_service + persistence = app_state.persistence + + user = await persistence.users.get(auth_user.user_id) + if user is None: + logger.warning( + API_AUTH_FAILED, + reason="user_not_found_for_password_change", + user_id=auth_user.user_id, + ) + msg = "User not found" + raise UnauthorizedError(msg) + + if not await auth_service.verify_password_async( + data.current_password, user.password_hash + ): + logger.warning( + API_AUTH_FAILED, + reason="invalid_current_password", + user_id=user.id, + ) + msg = "Invalid current password" + raise UnauthorizedError(msg) + + now = datetime.now(UTC) + new_hash = await auth_service.hash_password_async(data.new_password) + updated_user = user.model_copy( + update={ + "password_hash": new_hash, + "must_change_password": False, + "updated_at": now, + } + ) + await persistence.users.save(updated_user) + + logger.info( + API_AUTH_PASSWORD_CHANGED, + user_id=user.id, + username=user.username, + ) + + return Response( + content=ApiResponse( + data=UserInfoResponse( + id=updated_user.id, + username=updated_user.username, + role=updated_user.role, + must_change_password=False, + ), + ), + ) + + @get( + "/me", + summary="Get current user info", + ) + async def me( + self, + request: Request[Any, Any, Any], + ) -> Response[ApiResponse[UserInfoResponse]]: + """Return information about the authenticated user.""" + auth_user = request.scope.get("user") + if not isinstance(auth_user, AuthenticatedUser): + msg = "Authentication required" + raise UnauthorizedError(msg) + + return Response( + content=ApiResponse( + data=UserInfoResponse( + id=auth_user.user_id, + username=auth_user.username, + role=auth_user.role, + must_change_password=auth_user.must_change_password, + ), + ), + ) diff --git a/src/ai_company/api/auth/middleware.py b/src/ai_company/api/auth/middleware.py new file mode 100644 index 0000000000..a3e20bc83e --- /dev/null +++ b/src/ai_company/api/auth/middleware.py @@ -0,0 +1,278 @@ +"""JWT + API key authentication middleware.""" + +import hashlib +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any + +import jwt +from litestar.exceptions import NotAuthorizedException +from litestar.middleware import ( + AbstractAuthenticationMiddleware, + AuthenticationResult, +) + +from ai_company.api.auth.models import AuthenticatedUser, AuthMethod +from ai_company.api.auth.service import AuthService +from ai_company.observability import get_logger +from ai_company.observability.events.api import ( + API_AUTH_FAILED, + API_AUTH_SUCCESS, +) + +if TYPE_CHECKING: + from litestar.connection import ASGIConnection + + from ai_company.api.auth.config import AuthConfig + from ai_company.api.state import AppState + +logger = get_logger(__name__) + +_BEARER_PARTS = 2 + + +class ApiAuthMiddleware(AbstractAuthenticationMiddleware): + """Authenticate requests via JWT or API key. + + Reads ``Authorization: Bearer `` from the request. + Tokens containing ``.`` are treated exclusively as JWTs. + Tokens without dots are tried as API keys via SHA-256 hash + lookup. + + Requires ``auth_service``, persistence backend on + ``app.state["app_state"]``. + """ + + async def authenticate_request( + self, + connection: ASGIConnection[Any, Any, Any, Any], + ) -> AuthenticationResult: + """Validate the Authorization header. + + Args: + connection: Incoming ASGI connection. + + Returns: + AuthenticationResult with AuthenticatedUser. + + Raises: + NotAuthorizedException: If authentication fails. + """ + auth_header = connection.headers.get("authorization") + if not auth_header: + logger.warning( + API_AUTH_FAILED, + reason="missing_header", + path=str(connection.url.path), + ) + raise NotAuthorizedException(detail="Missing Authorization header") + + token = _extract_bearer_token(auth_header) + if token is None: + logger.warning( + API_AUTH_FAILED, + reason="invalid_scheme", + path=str(connection.url.path), + ) + raise NotAuthorizedException(detail="Invalid authorization scheme") + + app_state = connection.app.state["app_state"] + auth_service: AuthService = app_state.auth_service + + # Try JWT for tokens with dots; API key otherwise + if "." in token: + user = await _try_jwt_auth(token, auth_service, app_state, connection) + if user is not None: + return AuthenticationResult(user=user, auth=token) + raise NotAuthorizedException(detail="Invalid JWT token") + + # API key (no dots in token) + user = await _try_api_key_auth(token, app_state, connection) + if user is not None: + return AuthenticationResult(user=user, auth=token) + + logger.warning( + API_AUTH_FAILED, + reason="invalid_credentials", + path=str(connection.url.path), + ) + raise NotAuthorizedException(detail="Invalid credentials") + + +def _extract_bearer_token(header: str) -> str | None: + """Extract token from ``Bearer `` header value.""" + parts = header.split(None, 1) + if len(parts) != _BEARER_PARTS or parts[0].lower() != "bearer": + return None + return parts[1] + + +async def _try_jwt_auth( + token: str, + auth_service: AuthService, + app_state: AppState, + connection: ASGIConnection[Any, Any, Any, Any], +) -> AuthenticatedUser | None: + """Attempt JWT authentication. + + Returns: + Authenticated user on success, or ``None`` if the token is + invalid, the ``sub`` claim is missing, or the user no longer + exists in the database. + """ + try: + claims = auth_service.decode_token(token) + except jwt.InvalidTokenError as exc: + logger.warning( + API_AUTH_FAILED, + reason="jwt_invalid", + error_type=type(exc).__qualname__, + error=str(exc), + path=str(connection.url.path), + ) + return None + + user_id = claims.get("sub") + if not user_id: + logger.warning( + API_AUTH_FAILED, + reason="jwt_missing_sub", + path=str(connection.url.path), + ) + return None + + persistence = app_state.persistence + db_user = await persistence.users.get(user_id) + if db_user is None: + logger.warning( + API_AUTH_FAILED, + reason="jwt_user_not_found", + user_id=user_id, + path=str(connection.url.path), + ) + return None + + expected_sig = hashlib.sha256( + db_user.password_hash.encode(), + ).hexdigest()[:16] + if claims.get("pwd_sig") != expected_sig: + logger.warning( + API_AUTH_FAILED, + reason="password_changed_since_token_issued", + user_id=user_id, + path=str(connection.url.path), + ) + return None + + authenticated = AuthenticatedUser( + user_id=db_user.id, + username=db_user.username, + role=db_user.role, + auth_method=AuthMethod.JWT, + must_change_password=db_user.must_change_password, + ) + logger.info( + API_AUTH_SUCCESS, + user_id=db_user.id, + username=db_user.username, + auth_method="jwt", + path=str(connection.url.path), + ) + return authenticated + + +async def _try_api_key_auth( + token: str, + app_state: AppState, + connection: ASGIConnection[Any, Any, Any, Any], +) -> AuthenticatedUser | None: + """Attempt API key authentication. + + Returns: + Authenticated user on success, or ``None`` if the key hash + is not found, the key is revoked or expired, or the owning + user no longer exists. + """ + key_hash = AuthService.hash_api_key(token) + persistence = app_state.persistence + api_key = await persistence.api_keys.get_by_hash(key_hash) + if api_key is None: + logger.debug( + API_AUTH_FAILED, + reason="api_key_not_found", + path=str(connection.url.path), + ) + return None + + if api_key.revoked: + logger.warning( + API_AUTH_FAILED, + reason="api_key_revoked", + key_name=api_key.name, + path=str(connection.url.path), + ) + return None + if api_key.expires_at is not None and api_key.expires_at < datetime.now(UTC): + logger.warning( + API_AUTH_FAILED, + reason="api_key_expired", + key_name=api_key.name, + path=str(connection.url.path), + ) + return None + + db_user = await persistence.users.get(api_key.user_id) + if db_user is None: + logger.error( + API_AUTH_FAILED, + reason="api_key_orphaned", + key_name=api_key.name, + user_id=api_key.user_id, + path=str(connection.url.path), + ) + return None + + authenticated = AuthenticatedUser( + user_id=db_user.id, + username=db_user.username, + role=api_key.role, + auth_method=AuthMethod.API_KEY, + must_change_password=db_user.must_change_password, + ) + logger.info( + API_AUTH_SUCCESS, + user_id=db_user.id, + username=db_user.username, + auth_method="api_key", + key_name=api_key.name, + path=str(connection.url.path), + ) + return authenticated + + +def create_auth_middleware_class( + auth_config: AuthConfig, +) -> type[ApiAuthMiddleware]: + """Create a middleware class with excluded paths baked in. + + Litestar's ``AbstractAuthenticationMiddleware.__init__`` takes + ``exclude`` as a parameter (default ``None``). We create a + subclass whose ``__init__`` forwards the configured exclude + list to ``super().__init__``. + + Args: + auth_config: Auth configuration with exclude_paths. + + Returns: + Middleware class ready for use in the Litestar middleware stack. + """ + exclude_paths = ( + list(auth_config.exclude_paths) if auth_config.exclude_paths else None + ) + + class ConfiguredAuthMiddleware(ApiAuthMiddleware): + """Auth middleware with pre-configured exclude paths.""" + + def __init__(self, app: Any) -> None: + super().__init__(app, exclude=exclude_paths) + + return ConfiguredAuthMiddleware diff --git a/src/ai_company/api/auth/models.py b/src/ai_company/api/auth/models.py new file mode 100644 index 0000000000..6c1c9f349a --- /dev/null +++ b/src/ai_company/api/auth/models.py @@ -0,0 +1,87 @@ +"""Authentication domain models.""" + +from enum import StrEnum + +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field + +from ai_company.api.guards import HumanRole # noqa: TC001 +from ai_company.core.types import NotBlankStr # noqa: TC001 + + +class AuthMethod(StrEnum): + """Authentication method used for a request.""" + + JWT = "jwt" + API_KEY = "api_key" + + +class User(BaseModel): + """Persisted user account. + + Attributes: + id: Unique user identifier (UUID). + username: Login username. + password_hash: Argon2id hash (excluded from repr). + role: Access control role. + must_change_password: Whether the user must change password. + created_at: Account creation timestamp. + updated_at: Last modification timestamp. + """ + + model_config = ConfigDict(frozen=True) + + id: NotBlankStr + username: NotBlankStr + password_hash: str = Field(repr=False) + role: HumanRole + must_change_password: bool = True + created_at: AwareDatetime + updated_at: AwareDatetime + + +class ApiKey(BaseModel): + """Persisted API key (hash-only storage). + + Attributes: + id: Unique key identifier (UUID). + key_hash: SHA-256 hex digest of the raw key. + name: Human-readable label. + role: Access control role. + user_id: Owner user ID. + created_at: Key creation timestamp (timezone-aware). + expires_at: Optional expiry timestamp (timezone-aware). + revoked: Whether the key has been revoked. + """ + + model_config = ConfigDict(frozen=True) + + id: NotBlankStr + key_hash: str = Field(repr=False) + name: NotBlankStr + role: HumanRole + user_id: NotBlankStr + created_at: AwareDatetime + expires_at: AwareDatetime | None = None + revoked: bool = False + + +class AuthenticatedUser(BaseModel): + """Lightweight identity attached to ``connection.user``. + + Populated by the auth middleware after successful authentication. + + Attributes: + user_id: User's unique identifier. + username: User's login name. + role: Access control role. + auth_method: How the user authenticated. + must_change_password: Whether forced password change is pending. + """ + + model_config = ConfigDict(frozen=True) + + user_id: NotBlankStr + username: NotBlankStr + role: HumanRole + auth_method: AuthMethod + must_change_password: bool = False diff --git a/src/ai_company/api/auth/secret.py b/src/ai_company/api/auth/secret.py new file mode 100644 index 0000000000..5e38a0f9c9 --- /dev/null +++ b/src/ai_company/api/auth/secret.py @@ -0,0 +1,75 @@ +"""JWT secret resolution — env var → persistence → auto-generate.""" + +import os +import secrets + +from ai_company.api.auth.config import MIN_SECRET_LENGTH +from ai_company.observability import get_logger +from ai_company.observability.events.api import API_APP_STARTUP +from ai_company.persistence.protocol import PersistenceBackend # noqa: TC001 + +logger = get_logger(__name__) + +_SETTING_KEY = "jwt_secret" +_SECRET_LENGTH = 48 # 64 URL-safe base64 chars + + +async def resolve_jwt_secret( + persistence: PersistenceBackend, +) -> str: + """Resolve the JWT signing secret using a priority chain. + + 1. ``AI_COMPANY_JWT_SECRET`` env var (for multi-instance deploys). + 2. Stored secret in persistence ``settings`` table. + 3. Auto-generate, persist, and return. + + Args: + persistence: Connected persistence backend. + + Returns: + JWT signing secret (>= 32 characters). + """ + # 1. Env var override (highest priority) + env_secret = os.environ.get("AI_COMPANY_JWT_SECRET", "").strip() + if env_secret: + if len(env_secret) < MIN_SECRET_LENGTH: + msg = ( + f"AI_COMPANY_JWT_SECRET must be at least " + f"{MIN_SECRET_LENGTH} characters (got {len(env_secret)})" + ) + logger.error(API_APP_STARTUP, error=msg) + raise ValueError(msg) + logger.info( + API_APP_STARTUP, + note="JWT secret loaded from AI_COMPANY_JWT_SECRET env var", + ) + return env_secret + + # 2. Check persistence + stored = await persistence.get_setting(_SETTING_KEY) + if stored: + stored = stored.strip() + if len(stored) < MIN_SECRET_LENGTH: + logger.warning( + API_APP_STARTUP, + note=( + "Stored JWT secret too short " + f"({len(stored)} < {MIN_SECRET_LENGTH}), " + "auto-generating replacement" + ), + ) + else: + logger.info( + API_APP_STARTUP, + note="JWT secret loaded from persistence", + ) + return stored + + # 3. Auto-generate and persist + generated = secrets.token_urlsafe(_SECRET_LENGTH) + await persistence.set_setting(_SETTING_KEY, generated) + logger.info( + API_APP_STARTUP, + note="JWT secret auto-generated and saved to persistence", + ) + return generated diff --git a/src/ai_company/api/auth/service.py b/src/ai_company/api/auth/service.py new file mode 100644 index 0000000000..9e57ceef1b --- /dev/null +++ b/src/ai_company/api/auth/service.py @@ -0,0 +1,195 @@ +"""Authentication service — password hashing, JWT ops, API key hashing.""" + +import asyncio +import hashlib +import secrets +from datetime import UTC, datetime, timedelta +from typing import TYPE_CHECKING, Any + +import argon2 +import jwt + +from ai_company.api.auth.models import User # noqa: TC001 +from ai_company.observability import get_logger +from ai_company.observability.events.api import API_AUTH_FAILED + +if TYPE_CHECKING: + from ai_company.api.auth.config import AuthConfig + +logger = get_logger(__name__) + +_hasher = argon2.PasswordHasher( + time_cost=3, + memory_cost=65536, + parallelism=4, + hash_len=32, + salt_len=16, +) + + +class AuthService: + """Immutable authentication operations. + + Args: + config: Authentication configuration (carries JWT secret). + """ + + def __init__(self, config: AuthConfig) -> None: + self._config = config + + def hash_password(self, password: str) -> str: + """Hash a password with Argon2id. + + Args: + password: Plaintext password. + + Returns: + Argon2id hash string. + """ + return _hasher.hash(password) + + def verify_password(self, password: str, password_hash: str) -> bool: + """Verify a password against an Argon2id hash. + + Args: + password: Plaintext password to check. + password_hash: Stored Argon2id hash. + + Returns: + ``True`` if the password matches. + """ + try: + return _hasher.verify(password_hash, password) + except argon2.exceptions.VerifyMismatchError: + return False + except argon2.exceptions.VerificationError: + logger.warning( + API_AUTH_FAILED, + reason="hash_verification_error", + exc_info=True, + ) + return False + except argon2.exceptions.InvalidHashError: + logger.error( + API_AUTH_FAILED, + reason="invalid_hash_data_corruption", + exc_info=True, + ) + return False + + async def hash_password_async(self, password: str) -> str: + """Hash a password with Argon2id in a thread executor. + + Offloads the CPU-intensive hashing to avoid blocking the + event loop. + + Args: + password: Plaintext password. + + Returns: + Argon2id hash string. + """ + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, self.hash_password, password) + + async def verify_password_async( + self, + password: str, + password_hash: str, + ) -> bool: + """Verify a password against an Argon2id hash in a thread executor. + + Offloads the CPU-intensive verification to avoid blocking the + event loop. + + Args: + password: Plaintext password to check. + password_hash: Stored Argon2id hash. + + Returns: + ``True`` if the password matches. + """ + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, self.verify_password, password, password_hash + ) + + def create_token(self, user: User) -> tuple[str, int]: + """Create a JWT for the given user. + + Args: + user: Authenticated user. + + Returns: + Tuple of (encoded JWT string, expiry seconds). + + Raises: + RuntimeError: If the JWT secret is empty. + """ + if not self._config.jwt_secret: + msg = "JWT secret not configured" + raise RuntimeError(msg) + now = datetime.now(UTC) + expiry_seconds = self._config.jwt_expiry_minutes * 60 + pwd_sig = hashlib.sha256( + user.password_hash.encode(), + ).hexdigest()[:16] + payload: dict[str, Any] = { + "sub": user.id, + "username": user.username, + "role": user.role.value, + "must_change_password": user.must_change_password, + "pwd_sig": pwd_sig, + "iat": now, + "exp": now + timedelta(seconds=expiry_seconds), + } + token = jwt.encode( + payload, + self._config.jwt_secret, + algorithm=self._config.jwt_algorithm, + ) + return token, expiry_seconds + + def decode_token(self, token: str) -> dict[str, Any]: + """Decode and validate a JWT. + + Args: + token: Encoded JWT string. + + Returns: + Decoded claims dictionary. + + Raises: + RuntimeError: If the JWT secret is empty. + jwt.InvalidTokenError: If the token is invalid or expired. + """ + if not self._config.jwt_secret: + msg = "JWT secret not configured" + raise RuntimeError(msg) + return jwt.decode( + token, + self._config.jwt_secret, + algorithms=[self._config.jwt_algorithm], + options={"require": ["exp", "iat", "sub"]}, + ) + + @staticmethod + def hash_api_key(raw_key: str) -> str: + """Compute SHA-256 hex digest of a raw API key. + + Args: + raw_key: The plaintext API key. + + Returns: + Lowercase hex digest. + """ + return hashlib.sha256(raw_key.encode()).hexdigest() + + @staticmethod + def generate_api_key() -> str: + """Generate a cryptographically secure API key. + + Returns: + URL-safe base64 string (43 chars). + """ + return secrets.token_urlsafe(32) diff --git a/src/ai_company/api/config.py b/src/ai_company/api/config.py index d3210999df..e72fd060f2 100644 --- a/src/ai_company/api/config.py +++ b/src/ai_company/api/config.py @@ -1,7 +1,8 @@ """API configuration models. -Frozen Pydantic models for CORS, rate limiting, server, and the -top-level ``ApiConfig`` that aggregates them all. +Frozen Pydantic models for CORS, rate limiting, server, +authentication, and the top-level ``ApiConfig`` that aggregates +them all. """ from enum import StrEnum @@ -9,6 +10,7 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator +from ai_company.api.auth.config import AuthConfig from ai_company.core.types import NotBlankStr # noqa: TC001 @@ -34,7 +36,7 @@ class CorsConfig(BaseModel): description="HTTP methods permitted in cross-origin requests", ) allow_headers: tuple[str, ...] = Field( - default=("Content-Type", "Authorization", "X-Human-Role"), + default=("Content-Type", "Authorization"), description="Headers permitted in cross-origin requests", ) allow_credentials: bool = Field( @@ -150,6 +152,7 @@ class ApiConfig(BaseModel): cors: CORS configuration. rate_limit: Rate limiting configuration. server: Uvicorn server configuration. + auth: Authentication configuration. api_prefix: URL prefix for all API routes. """ @@ -167,6 +170,10 @@ class ApiConfig(BaseModel): default_factory=ServerConfig, description="Uvicorn server configuration", ) + auth: AuthConfig = Field( + default_factory=AuthConfig, + description="Authentication configuration", + ) api_prefix: NotBlankStr = Field( default="/api/v1", description="URL prefix for all API routes", diff --git a/src/ai_company/api/controllers/__init__.py b/src/ai_company/api/controllers/__init__.py index 8343349809..575f6ee8b4 100644 --- a/src/ai_company/api/controllers/__init__.py +++ b/src/ai_company/api/controllers/__init__.py @@ -2,6 +2,7 @@ from litestar import Controller +from ai_company.api.auth.controller import AuthController from ai_company.api.controllers.agents import AgentController from ai_company.api.controllers.analytics import AnalyticsController from ai_company.api.controllers.approvals import ApprovalsController @@ -33,6 +34,7 @@ ProviderController, ApprovalsController, AutonomyController, + AuthController, ) __all__ = [ @@ -41,6 +43,7 @@ "AnalyticsController", "ApprovalsController", "ArtifactController", + "AuthController", "AutonomyController", "BudgetController", "CompanyController", diff --git a/src/ai_company/api/controllers/approvals.py b/src/ai_company/api/controllers/approvals.py index 2056a6d0f0..a3b91099b9 100644 --- a/src/ai_company/api/controllers/approvals.py +++ b/src/ai_company/api/controllers/approvals.py @@ -8,6 +8,7 @@ from litestar.channels import ChannelsPlugin from litestar.datastructures import State # noqa: TC002 +from ai_company.api.auth.models import AuthenticatedUser from ai_company.api.channels import CHANNEL_APPROVALS from ai_company.api.dto import ( ApiResponse, @@ -16,7 +17,7 @@ PaginatedResponse, RejectRequest, ) -from ai_company.api.errors import ConflictError, NotFoundError +from ai_company.api.errors import ConflictError, NotFoundError, UnauthorizedError from ai_company.api.guards import require_read_access, require_write_access from ai_company.api.pagination import PaginationLimit, PaginationOffset, paginate from ai_company.api.state import AppState # noqa: TC001 @@ -102,6 +103,46 @@ def _publish_approval_event( ) +def _resolve_decision( + request: Request[Any, Any, Any], + item: ApprovalItem, + approval_id: str, +) -> AuthenticatedUser: + """Validate that an approval item is pending and extract the auth user. + + Performs the shared pre-checks for approve/reject operations: + look up the authenticated user, and verify the item is still + in PENDING status. + + Args: + request: The incoming HTTP request. + item: The approval item to act on. + approval_id: Approval identifier (for log messages). + + Returns: + The authenticated user making the decision. + + Raises: + UnauthorizedError: If the user is missing from the request scope. + ConflictError: If the approval is not in PENDING status. + """ + if item.status != ApprovalStatus.PENDING: + msg = f"Approval {approval_id!r} is {item.status.value}, not pending" + logger.warning( + API_APPROVAL_CONFLICT, + approval_id=approval_id, + current_status=item.status.value, + ) + raise ConflictError(msg) + + auth_user = request.scope.get("user") + if not isinstance(auth_user, AuthenticatedUser): + msg = "Authentication required" + raise UnauthorizedError(msg) + + return auth_user + + class ApprovalsController(Controller): """Human approval queue — list, create, approve, reject.""" @@ -236,8 +277,8 @@ async def approve( ) -> ApiResponse[ApprovalItem]: """Approve a pending approval item. - The ``decided_by`` field is populated from the - ``X-Human-Role`` header. + The ``decided_by`` field is populated from the authenticated + user's username. Args: state: Application state. @@ -263,22 +304,13 @@ async def approve( ) raise NotFoundError(msg) - if item.status != ApprovalStatus.PENDING: - msg = f"Approval {approval_id!r} is {item.status.value}, not pending" - logger.warning( - API_APPROVAL_CONFLICT, - approval_id=approval_id, - current_status=item.status.value, - ) - raise ConflictError(msg) - - role = request.headers.get("x-human-role", "unknown") + auth_user = _resolve_decision(request, item, approval_id) now = datetime.now(UTC) updated = item.model_copy( update={ "status": ApprovalStatus.APPROVED, "decided_at": now, - "decided_by": role, + "decided_by": auth_user.username, "decision_reason": data.comment, }, ) @@ -301,7 +333,7 @@ async def approve( logger.info( API_APPROVAL_APPROVED, approval_id=approval_id, - decided_by=role, + decided_by=auth_user.username, ) return ApiResponse(data=updated) @@ -319,8 +351,8 @@ async def reject( ) -> ApiResponse[ApprovalItem]: """Reject a pending approval item. - The ``decided_by`` field is populated from the - ``X-Human-Role`` header. + The ``decided_by`` field is populated from the authenticated + user's username. Args: state: Application state. @@ -346,22 +378,13 @@ async def reject( ) raise NotFoundError(msg) - if item.status != ApprovalStatus.PENDING: - msg = f"Approval {approval_id!r} is {item.status.value}, not pending" - logger.warning( - API_APPROVAL_CONFLICT, - approval_id=approval_id, - current_status=item.status.value, - ) - raise ConflictError(msg) - - role = request.headers.get("x-human-role", "unknown") + auth_user = _resolve_decision(request, item, approval_id) now = datetime.now(UTC) updated = item.model_copy( update={ "status": ApprovalStatus.REJECTED, "decided_at": now, - "decided_by": role, + "decided_by": auth_user.username, "decision_reason": data.reason, }, ) @@ -384,6 +407,6 @@ async def reject( logger.info( API_APPROVAL_REJECTED, approval_id=approval_id, - decided_by=role, + decided_by=auth_user.username, ) return ApiResponse(data=updated) diff --git a/src/ai_company/api/errors.py b/src/ai_company/api/errors.py index f903adff73..2e9c86dc42 100644 --- a/src/ai_company/api/errors.py +++ b/src/ai_company/api/errors.py @@ -56,6 +56,15 @@ def __init__(self, message: str | None = None) -> None: super().__init__(message, status_code=403) +class UnauthorizedError(ApiError): + """Raised when authentication is required or invalid (401).""" + + default_message: str = "Authentication required" + + def __init__(self, message: str | None = None) -> None: + super().__init__(message, status_code=401) + + class ServiceUnavailableError(ApiError): """Raised when a required service is not configured (503).""" diff --git a/src/ai_company/api/exception_handlers.py b/src/ai_company/api/exception_handlers.py index 5c73144c37..8b222e2d45 100644 --- a/src/ai_company/api/exception_handlers.py +++ b/src/ai_company/api/exception_handlers.py @@ -9,6 +9,7 @@ from litestar import Request, Response from litestar.exceptions import ( + NotAuthorizedException, PermissionDeniedException, ValidationException, ) @@ -99,12 +100,15 @@ def handle_api_error( ) -> Response[ApiResponse[None]]: """Map ``ApiError`` subclasses to their declared status code.""" _log_error(request, exc, status=exc.status_code) - # Return the class-level default message, not the - # caller-interpolated string (which may contain internal IDs). - exc_cls = type(exc) - default_msg = getattr(exc_cls, "default_message", "Internal server error") + # For 5xx errors return the generic class-level default to avoid + # leaking internals. For 4xx client errors return the actual + # exception message — it was set by the controller and is user-safe. + if exc.status_code >= _SERVER_ERROR_THRESHOLD: + msg = type(exc).default_message + else: + msg = str(exc) or type(exc).default_message return Response( - content=ApiResponse[None](error=default_msg), + content=ApiResponse[None](error=msg), status_code=exc.status_code, ) @@ -129,8 +133,9 @@ def handle_permission_denied( ) -> Response[ApiResponse[None]]: """Map ``PermissionDeniedException`` to 403.""" _log_error(request, exc, status=403) + detail = exc.detail or "Forbidden" return Response( - content=ApiResponse[None](error="Forbidden"), + content=ApiResponse[None](error=detail), status_code=403, ) @@ -149,10 +154,24 @@ def handle_validation_error( ) +def handle_not_authorized( + request: Request[Any, Any, Any], + exc: NotAuthorizedException, +) -> Response[ApiResponse[None]]: + """Map ``NotAuthorizedException`` to 401.""" + _log_error(request, exc, status=401) + detail = exc.detail or "Authentication required" + return Response( + content=ApiResponse[None](error=detail), + status_code=401, + ) + + EXCEPTION_HANDLERS: dict[type[Exception], object] = { RecordNotFoundError: handle_record_not_found, DuplicateRecordError: handle_duplicate_record, PersistenceError: handle_persistence_error, + NotAuthorizedException: handle_not_authorized, PermissionDeniedException: handle_permission_denied, ValidationException: handle_validation_error, ApiError: handle_api_error, diff --git a/src/ai_company/api/guards.py b/src/ai_company/api/guards.py index 553e4e883d..e47f6683ff 100644 --- a/src/ai_company/api/guards.py +++ b/src/ai_company/api/guards.py @@ -1,16 +1,7 @@ """Route guards for access control. -.. warning:: **Security Stub (M6)** - - These guards check the ``X-Human-Role`` header, which is - **self-asserted by the caller** — there is no signature - verification, session token, or JWT. This is intentional for - the M6 milestone scope. **Real authentication and authorization - (pre-shared API key, JWT, or OAuth) will be implemented in M7 - (issue scope: security & HR).** - - Until M7, the API should only be exposed on trusted networks or - behind a reverse proxy that enforces authentication. +Guards read the authenticated user identity from ``connection.user`` +(populated by the auth middleware) and check role-based permissions. """ from enum import StrEnum @@ -45,11 +36,20 @@ class HumanRole(StrEnum): _READ_ROLES: frozenset[HumanRole] = _WRITE_ROLES | frozenset({HumanRole.OBSERVER}) -def _get_role(connection: ASGIConnection) -> str | None: # type: ignore[type-arg] - """Extract the human role from the request header.""" - value = connection.headers.get("x-human-role") - if value is not None: - return value.strip().lower() +def _get_role(connection: ASGIConnection) -> HumanRole | None: # type: ignore[type-arg] + """Extract the human role from the authenticated user.""" + user = connection.scope.get("user") + if user is not None and hasattr(user, "role"): + try: + return HumanRole(user.role) + except ValueError: + logger.warning( + API_GUARD_DENIED, + guard="_get_role", + invalid_role=str(user.role), + path=str(connection.url.path), + ) + return None return None @@ -59,7 +59,7 @@ def require_write_access( ) -> None: """Guard that allows only write-capable roles. - Checks the ``X-Human-Role`` header for ``ceo``, ``manager``, + Checks ``connection.user.role`` for ``ceo``, ``manager``, ``board_member``, or ``pair_programmer``. Args: @@ -86,7 +86,7 @@ def require_read_access( ) -> None: """Guard that allows all recognised roles. - Checks the ``X-Human-Role`` header for any valid role + Checks ``connection.user.role`` for any valid role including ``observer``. Args: diff --git a/src/ai_company/api/state.py b/src/ai_company/api/state.py index a7583548ad..2ffd80dab2 100644 --- a/src/ai_company/api/state.py +++ b/src/ai_company/api/state.py @@ -6,12 +6,13 @@ """ from ai_company.api.approval_store import ApprovalStore # noqa: TC001 +from ai_company.api.auth.service import AuthService # noqa: TC001 from ai_company.api.errors import ServiceUnavailableError from ai_company.budget.tracker import CostTracker # noqa: TC001 from ai_company.communication.bus_protocol import MessageBus # noqa: TC001 from ai_company.config.schema import RootConfig # noqa: TC001 from ai_company.observability import get_logger -from ai_company.observability.events.api import API_SERVICE_UNAVAILABLE +from ai_company.observability.events.api import API_APP_STARTUP, API_SERVICE_UNAVAILABLE from ai_company.persistence.protocol import PersistenceBackend # noqa: TC001 logger = get_logger(__name__) @@ -20,8 +21,9 @@ class AppState: """Typed application state container. - Service fields (``persistence``, ``message_bus``, ``cost_tracker``) - accept ``None`` at construction time for dev/test mode. Property + Service fields (``persistence``, ``message_bus``, ``cost_tracker``, + ``auth_service``) accept ``None`` at construction time for dev/test + mode. Property accessors raise ``ServiceUnavailableError`` (HTTP 503) when the service is not configured, producing a clear error instead of an opaque ``AttributeError``. @@ -33,6 +35,7 @@ class AppState: """ __slots__ = ( + "_auth_service", "_cost_tracker", "_message_bus", "_persistence", @@ -49,6 +52,7 @@ def __init__( # noqa: PLR0913 persistence: PersistenceBackend | None = None, message_bus: MessageBus | None = None, cost_tracker: CostTracker | None = None, + auth_service: AuthService | None = None, startup_time: float = 0.0, ) -> None: self.config = config @@ -56,40 +60,63 @@ def __init__( # noqa: PLR0913 self._persistence = persistence self._message_bus = message_bus self._cost_tracker = cost_tracker + self._auth_service = auth_service self.startup_time = startup_time + def _require_service[T](self, service: T | None, name: str) -> T: + """Return *service* or raise 503 if not configured. + + Args: + service: Service instance (``None`` when not configured). + name: Service name for logging and error message. + + Raises: + ServiceUnavailableError: If *service* is ``None``. + """ + if service is None: + logger.warning(API_SERVICE_UNAVAILABLE, service=name) + msg = f"{name.replace('_', ' ').title()} not configured" + raise ServiceUnavailableError(msg) + return service + @property def persistence(self) -> PersistenceBackend: """Return persistence backend or raise 503.""" - if self._persistence is None: - logger.warning( - API_SERVICE_UNAVAILABLE, - service="persistence", - ) - msg = "Persistence backend not configured" - raise ServiceUnavailableError(msg) - return self._persistence + return self._require_service(self._persistence, "persistence") @property def message_bus(self) -> MessageBus: """Return message bus or raise 503.""" - if self._message_bus is None: - logger.warning( - API_SERVICE_UNAVAILABLE, - service="message_bus", - ) - msg = "Message bus not configured" - raise ServiceUnavailableError(msg) - return self._message_bus + return self._require_service(self._message_bus, "message_bus") @property def cost_tracker(self) -> CostTracker: """Return cost tracker or raise 503.""" - if self._cost_tracker is None: - logger.warning( - API_SERVICE_UNAVAILABLE, - service="cost_tracker", - ) - msg = "Cost tracker not configured" - raise ServiceUnavailableError(msg) - return self._cost_tracker + return self._require_service(self._cost_tracker, "cost_tracker") + + @property + def auth_service(self) -> AuthService: + """Return auth service or raise 503.""" + return self._require_service(self._auth_service, "auth_service") + + @property + def has_auth_service(self) -> bool: + """Check whether the auth service is already configured.""" + return self._auth_service is not None + + def set_auth_service(self, service: AuthService) -> None: + """Set the auth service (deferred initialisation). + + Called once during startup after the JWT secret is resolved. + + Args: + service: Fully configured auth service. + + Raises: + RuntimeError: If the auth service was already configured. + """ + if self._auth_service is not None: + msg = "Auth service already configured" + logger.error(API_APP_STARTUP, error=msg) + raise RuntimeError(msg) + self._auth_service = service diff --git a/src/ai_company/observability/events/api.py b/src/ai_company/observability/events/api.py index 354955e5e6..8890cd7538 100644 --- a/src/ai_company/observability/events/api.py +++ b/src/ai_company/observability/events/api.py @@ -30,3 +30,8 @@ API_WS_TRANSPORT_ERROR: Final[str] = "api.ws.transport_error" API_WS_SEND_FAILED: Final[str] = "api.ws.send_failed" API_SERVICE_UNAVAILABLE: Final[str] = "api.service.unavailable" +API_AUTH_SUCCESS: Final[str] = "api.auth.success" +API_AUTH_FAILED: Final[str] = "api.auth.failed" +API_AUTH_TOKEN_ISSUED: Final[str] = "api.auth.token_issued" # noqa: S105 +API_AUTH_SETUP_COMPLETE: Final[str] = "api.auth.setup_complete" +API_AUTH_PASSWORD_CHANGED: Final[str] = "api.auth.password_changed" # noqa: S105 diff --git a/src/ai_company/observability/events/persistence.py b/src/ai_company/observability/events/persistence.py index 9c95516bd3..b2c048ac2a 100644 --- a/src/ai_company/observability/events/persistence.py +++ b/src/ai_company/observability/events/persistence.py @@ -123,3 +123,28 @@ PERSISTENCE_AUDIT_ENTRY_DESERIALIZE_FAILED: Final[str] = ( "persistence.audit_entry.deserialize_failed" ) + +PERSISTENCE_USER_SAVED: Final[str] = "persistence.user.saved" +PERSISTENCE_USER_SAVE_FAILED: Final[str] = "persistence.user.save_failed" +PERSISTENCE_USER_FETCHED: Final[str] = "persistence.user.fetched" +PERSISTENCE_USER_FETCH_FAILED: Final[str] = "persistence.user.fetch_failed" +PERSISTENCE_USER_LISTED: Final[str] = "persistence.user.listed" +PERSISTENCE_USER_LIST_FAILED: Final[str] = "persistence.user.list_failed" +PERSISTENCE_USER_COUNTED: Final[str] = "persistence.user.counted" +PERSISTENCE_USER_COUNT_FAILED: Final[str] = "persistence.user.count_failed" +PERSISTENCE_USER_DELETED: Final[str] = "persistence.user.deleted" +PERSISTENCE_USER_DELETE_FAILED: Final[str] = "persistence.user.delete_failed" + +PERSISTENCE_API_KEY_SAVED: Final[str] = "persistence.api_key.saved" +PERSISTENCE_API_KEY_SAVE_FAILED: Final[str] = "persistence.api_key.save_failed" +PERSISTENCE_API_KEY_FETCHED: Final[str] = "persistence.api_key.fetched" +PERSISTENCE_API_KEY_FETCH_FAILED: Final[str] = "persistence.api_key.fetch_failed" +PERSISTENCE_API_KEY_LISTED: Final[str] = "persistence.api_key.listed" +PERSISTENCE_API_KEY_LIST_FAILED: Final[str] = "persistence.api_key.list_failed" +PERSISTENCE_API_KEY_DELETED: Final[str] = "persistence.api_key.deleted" +PERSISTENCE_API_KEY_DELETE_FAILED: Final[str] = "persistence.api_key.delete_failed" + +PERSISTENCE_SETTING_FETCHED: Final[str] = "persistence.setting.fetched" +PERSISTENCE_SETTING_FETCH_FAILED: Final[str] = "persistence.setting.fetch_failed" +PERSISTENCE_SETTING_SAVED: Final[str] = "persistence.setting.saved" +PERSISTENCE_SETTING_SAVE_FAILED: Final[str] = "persistence.setting.save_failed" diff --git a/src/ai_company/persistence/protocol.py b/src/ai_company/persistence/protocol.py index 26871358ed..3cfcd03074 100644 --- a/src/ai_company/persistence/protocol.py +++ b/src/ai_company/persistence/protocol.py @@ -13,11 +13,13 @@ TaskMetricRepository, # noqa: TC001 ) from ai_company.persistence.repositories import ( + ApiKeyRepository, # noqa: TC001 AuditRepository, # noqa: TC001 CostRecordRepository, # noqa: TC001 MessageRepository, # noqa: TC001 ParkedContextRepository, # noqa: TC001 TaskRepository, # noqa: TC001 + UserRepository, # noqa: TC001 ) @@ -40,6 +42,8 @@ class PersistenceBackend(Protocol): collaboration_metrics: Repository for CollaborationMetricRecord persistence. parked_contexts: Repository for ParkedContext persistence. audit_entries: Repository for AuditEntry persistence. + users: Repository for User persistence. + api_keys: Repository for ApiKey persistence. """ async def connect(self) -> None: @@ -123,3 +127,41 @@ def parked_contexts(self) -> ParkedContextRepository: def audit_entries(self) -> AuditRepository: """Repository for AuditEntry persistence.""" ... + + @property + def users(self) -> UserRepository: + """Repository for User persistence.""" + ... + + @property + def api_keys(self) -> ApiKeyRepository: + """Repository for ApiKey persistence.""" + ... + + async def get_setting(self, key: NotBlankStr) -> str | None: + """Retrieve a setting value by key. + + Args: + key: Setting key. + + Returns: + The setting value, or ``None`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def set_setting(self, key: NotBlankStr, value: str) -> None: + """Store a setting value. + + Upserts — creates or updates the key. + + Args: + key: Setting key. + value: Setting value. + + Raises: + PersistenceError: If the operation fails. + """ + ... diff --git a/src/ai_company/persistence/repositories.py b/src/ai_company/persistence/repositories.py index 73fd914fd1..b03e7a8982 100644 --- a/src/ai_company/persistence/repositories.py +++ b/src/ai_company/persistence/repositories.py @@ -8,6 +8,7 @@ from pydantic import AwareDatetime # noqa: TC002 +from ai_company.api.auth.models import ApiKey, User # noqa: TC001 from ai_company.budget.cost_record import CostRecord # noqa: TC001 from ai_company.communication.message import Message # noqa: TC001 from ai_company.core.enums import ApprovalRiskLevel, TaskStatus # noqa: TC001 @@ -22,6 +23,7 @@ from ai_company.security.timeout.parked_context import ParkedContext # noqa: TC001 __all__ = [ + "ApiKeyRepository", "AuditRepository", "CollaborationMetricRepository", "CostRecordRepository", @@ -30,6 +32,7 @@ "ParkedContextRepository", "TaskMetricRepository", "TaskRepository", + "UserRepository", ] @@ -318,3 +321,155 @@ async def query( # noqa: PLR0913 *until* is earlier than *since*. """ ... + + +@runtime_checkable +class UserRepository(Protocol): + """CRUD interface for User persistence.""" + + async def save(self, user: User) -> None: + """Persist a user (insert or update). + + Args: + user: The user to persist. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get(self, user_id: NotBlankStr) -> User | None: + """Retrieve a user by ID. + + Args: + user_id: The user identifier. + + Returns: + The user, or ``None`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get_by_username(self, username: NotBlankStr) -> User | None: + """Retrieve a user by username. + + Args: + username: The login username. + + Returns: + The user, or ``None`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def list_users(self) -> tuple[User, ...]: + """List all users. + + Returns: + All users as a tuple. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def count(self) -> int: + """Count the number of users. + + Returns: + Total user count. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def delete(self, user_id: NotBlankStr) -> bool: + """Delete a user by ID. + + Args: + user_id: The user identifier. + + Returns: + ``True`` if deleted, ``False`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + +@runtime_checkable +class ApiKeyRepository(Protocol): + """CRUD interface for API key persistence.""" + + async def save(self, key: ApiKey) -> None: + """Persist an API key. + + Args: + key: The API key to persist. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get(self, key_id: NotBlankStr) -> ApiKey | None: + """Retrieve an API key by ID. + + Args: + key_id: The key identifier. + + Returns: + The API key, or ``None`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get_by_hash(self, key_hash: NotBlankStr) -> ApiKey | None: + """Retrieve an API key by its hash. + + Args: + key_hash: SHA-256 hex digest. + + Returns: + The API key, or ``None`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def list_by_user(self, user_id: NotBlankStr) -> tuple[ApiKey, ...]: + """List API keys belonging to a user. + + Args: + user_id: The owner user ID. + + Returns: + API keys for the user. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def delete(self, key_id: NotBlankStr) -> bool: + """Delete an API key by ID. + + Args: + key_id: The key identifier. + + Returns: + ``True`` if deleted, ``False`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... diff --git a/src/ai_company/persistence/sqlite/backend.py b/src/ai_company/persistence/sqlite/backend.py index 282ed4e7f6..9033eee3d8 100644 --- a/src/ai_company/persistence/sqlite/backend.py +++ b/src/ai_company/persistence/sqlite/backend.py @@ -19,8 +19,13 @@ PERSISTENCE_BACKEND_HEALTH_CHECK, PERSISTENCE_BACKEND_NOT_CONNECTED, PERSISTENCE_BACKEND_WAL_MODE_FAILED, + PERSISTENCE_SETTING_FETCH_FAILED, + PERSISTENCE_SETTING_SAVE_FAILED, +) +from ai_company.persistence.errors import ( + PersistenceConnectionError, + QueryError, ) -from ai_company.persistence.errors import PersistenceConnectionError from ai_company.persistence.sqlite.audit_repository import ( SQLiteAuditRepository, ) @@ -38,6 +43,10 @@ SQLiteMessageRepository, SQLiteTaskRepository, ) +from ai_company.persistence.sqlite.user_repo import ( + SQLiteApiKeyRepository, + SQLiteUserRepository, +) if TYPE_CHECKING: from ai_company.persistence.config import SQLiteConfig @@ -68,6 +77,8 @@ def __init__(self, config: SQLiteConfig) -> None: self._collaboration_metrics: SQLiteCollaborationMetricRepository | None = None self._parked_contexts: SQLiteParkedContextRepository | None = None self._audit_entries: SQLiteAuditRepository | None = None + self._users: SQLiteUserRepository | None = None + self._api_keys: SQLiteApiKeyRepository | None = None def _clear_state(self) -> None: """Reset connection and repository references to ``None``.""" @@ -80,6 +91,8 @@ def _clear_state(self) -> None: self._collaboration_metrics = None self._parked_contexts = None self._audit_entries = None + self._users = None + self._api_keys = None async def connect(self) -> None: """Open the SQLite database and configure WAL mode.""" @@ -96,6 +109,9 @@ async def connect(self) -> None: self._db = await aiosqlite.connect(self._config.path) self._db.row_factory = aiosqlite.Row + # Enable foreign key enforcement (off by default in SQLite). + await self._db.execute("PRAGMA foreign_keys = ON") + if self._config.wal_mode: await self._configure_wal() @@ -139,6 +155,8 @@ def _create_repositories(self) -> None: self._collaboration_metrics = SQLiteCollaborationMetricRepository(self._db) self._parked_contexts = SQLiteParkedContextRepository(self._db) self._audit_entries = SQLiteAuditRepository(self._db) + self._users = SQLiteUserRepository(self._db) + self._api_keys = SQLiteApiKeyRepository(self._db) async def _cleanup_failed_connect(self, exc: sqlite3.Error | OSError) -> None: """Log failure, close partial connection, and raise. @@ -320,3 +338,73 @@ def audit_entries(self) -> SQLiteAuditRepository: PersistenceConnectionError: If not connected. """ return self._require_connected(self._audit_entries, "audit_entries") + + @property + def users(self) -> SQLiteUserRepository: + """Repository for User persistence. + + Raises: + PersistenceConnectionError: If not connected. + """ + return self._require_connected(self._users, "users") + + @property + def api_keys(self) -> SQLiteApiKeyRepository: + """Repository for ApiKey persistence. + + Raises: + PersistenceConnectionError: If not connected. + """ + return self._require_connected(self._api_keys, "api_keys") + + async def get_setting(self, key: str) -> str | None: + """Retrieve a setting value by key. + + Raises: + PersistenceConnectionError: If not connected. + """ + if self._db is None: + msg = "Not connected — call connect() before accessing settings" + logger.warning(PERSISTENCE_BACKEND_NOT_CONNECTED, error=msg) + raise PersistenceConnectionError(msg) + try: + cursor = await self._db.execute( + "SELECT value FROM settings WHERE key = ?", (key,) + ) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to get setting {key!r}" + logger.exception( + PERSISTENCE_SETTING_FETCH_FAILED, + key=key, + error=str(exc), + ) + raise QueryError(msg) from exc + return str(row[0]) if row else None + + async def set_setting(self, key: str, value: str) -> None: + """Store a setting value (upsert). + + Raises: + PersistenceConnectionError: If not connected. + """ + if self._db is None: + msg = "Not connected — call connect() before accessing settings" + logger.warning(PERSISTENCE_BACKEND_NOT_CONNECTED, error=msg) + raise PersistenceConnectionError(msg) + try: + await self._db.execute( + """\ +INSERT INTO settings (key, value) VALUES (?, ?) +ON CONFLICT(key) DO UPDATE SET value=excluded.value""", + (key, value), + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to set setting {key!r}" + logger.exception( + PERSISTENCE_SETTING_SAVE_FAILED, + key=key, + error=str(exc), + ) + raise QueryError(msg) from exc diff --git a/src/ai_company/persistence/sqlite/migrations.py b/src/ai_company/persistence/sqlite/migrations.py index cc306d8086..66cb1bfbd6 100644 --- a/src/ai_company/persistence/sqlite/migrations.py +++ b/src/ai_company/persistence/sqlite/migrations.py @@ -23,7 +23,7 @@ logger = get_logger(__name__) # Current schema version — bump when adding new migrations. -SCHEMA_VERSION = 4 +SCHEMA_VERSION = 5 _V1_STATEMENTS: Sequence[str] = ( # ── Tasks ───────────────────────────────────────────── @@ -188,6 +188,39 @@ "CREATE INDEX IF NOT EXISTS idx_ae_risk_level ON audit_entries(risk_level)", ) +_V5_STATEMENTS: Sequence[str] = ( + # ── Settings (key-value store) ───────────────────────── + """\ +CREATE TABLE IF NOT EXISTS settings ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL +)""", + # ── Users ────────────────────────────────────────────── + """\ +CREATE TABLE IF NOT EXISTS users ( + id TEXT PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + password_hash TEXT NOT NULL, + role TEXT NOT NULL, + must_change_password INTEGER NOT NULL DEFAULT 1, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL +)""", + # ── API keys ─────────────────────────────────────────── + """\ +CREATE TABLE IF NOT EXISTS api_keys ( + id TEXT PRIMARY KEY, + key_hash TEXT NOT NULL UNIQUE, + name TEXT NOT NULL, + role TEXT NOT NULL, + user_id TEXT NOT NULL REFERENCES users(id), + created_at TEXT NOT NULL, + expires_at TEXT, + revoked INTEGER NOT NULL DEFAULT 0 +)""", + "CREATE INDEX IF NOT EXISTS idx_api_keys_user_id ON api_keys(user_id)", +) + _MigrateFn = Callable[[aiosqlite.Connection], Coroutine[Any, Any, None]] @@ -244,6 +277,12 @@ async def _apply_v4(db: aiosqlite.Connection) -> None: await db.execute(stmt) +async def _apply_v5(db: aiosqlite.Connection) -> None: + """Apply schema v5: settings, users, api_keys.""" + for stmt in _V5_STATEMENTS: + await db.execute(stmt) + + # Ordered list of (target_version, migration_function) pairs. Each migration # is applied when the current schema version is below its target version. _MIGRATIONS: list[tuple[int, _MigrateFn]] = [ @@ -251,6 +290,7 @@ async def _apply_v4(db: aiosqlite.Connection) -> None: (2, _apply_v2), (3, _apply_v3), (4, _apply_v4), + (5, _apply_v5), ] diff --git a/src/ai_company/persistence/sqlite/user_repo.py b/src/ai_company/persistence/sqlite/user_repo.py new file mode 100644 index 0000000000..a04f71f3c4 --- /dev/null +++ b/src/ai_company/persistence/sqlite/user_repo.py @@ -0,0 +1,500 @@ +"""SQLite repository implementations for User and ApiKey. + +Provides ``SQLiteUserRepository`` and ``SQLiteApiKeyRepository``, which +persist ``User`` and ``ApiKey`` domain models to SQLite via aiosqlite. +Both use upsert semantics for ``save`` operations. +""" + +import sqlite3 +from datetime import UTC, datetime + +import aiosqlite +from pydantic import ValidationError + +from ai_company.api.auth.models import ApiKey, User +from ai_company.api.guards import HumanRole +from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.observability import get_logger +from ai_company.observability.events.persistence import ( + PERSISTENCE_API_KEY_DELETE_FAILED, + PERSISTENCE_API_KEY_DELETED, + PERSISTENCE_API_KEY_FETCH_FAILED, + PERSISTENCE_API_KEY_FETCHED, + PERSISTENCE_API_KEY_LIST_FAILED, + PERSISTENCE_API_KEY_LISTED, + PERSISTENCE_API_KEY_SAVE_FAILED, + PERSISTENCE_API_KEY_SAVED, + PERSISTENCE_USER_COUNT_FAILED, + PERSISTENCE_USER_COUNTED, + PERSISTENCE_USER_DELETE_FAILED, + PERSISTENCE_USER_DELETED, + PERSISTENCE_USER_FETCH_FAILED, + PERSISTENCE_USER_FETCHED, + PERSISTENCE_USER_LIST_FAILED, + PERSISTENCE_USER_LISTED, + PERSISTENCE_USER_SAVE_FAILED, + PERSISTENCE_USER_SAVED, +) +from ai_company.persistence.errors import QueryError + +logger = get_logger(__name__) + + +def _row_to_user(row: aiosqlite.Row) -> User: + """Reconstruct a ``User`` from a database row. + + Converts SQLite-native types (integers, ISO strings) back into + the domain model's expected Python types. + + Args: + row: A single database row with user columns. + + Returns: + Validated ``User`` model instance. + """ + data = dict(row) + data["must_change_password"] = bool(data["must_change_password"]) + data["role"] = HumanRole(data["role"]) + data["created_at"] = datetime.fromisoformat(data["created_at"]) + data["updated_at"] = datetime.fromisoformat(data["updated_at"]) + return User.model_validate(data) + + +def _row_to_api_key(row: aiosqlite.Row) -> ApiKey: + """Reconstruct an ``ApiKey`` from a database row. + + Converts SQLite-native types (integers, ISO strings) back into + the domain model's expected Python types. + + Args: + row: A single database row with API key columns. + + Returns: + Validated ``ApiKey`` model instance. + """ + data = dict(row) + data["revoked"] = bool(data["revoked"]) + data["role"] = HumanRole(data["role"]) + data["created_at"] = datetime.fromisoformat(data["created_at"]) + if data["expires_at"] is not None: + data["expires_at"] = datetime.fromisoformat(data["expires_at"]) + return ApiKey.model_validate(data) + + +class SQLiteUserRepository: + """SQLite-backed user repository. + + Provides CRUD operations for ``User`` models using a shared + ``aiosqlite.Connection``. All write operations commit + immediately. + + Args: + db: An open aiosqlite connection with ``row_factory`` + set to ``aiosqlite.Row``. + """ + + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + async def save(self, user: User) -> None: + """Persist a user via upsert (insert or update on conflict). + + Args: + user: User model to persist. + + Raises: + QueryError: If the database operation fails. + """ + try: + await self._db.execute( + """\ +INSERT INTO users (id, username, password_hash, role, + must_change_password, created_at, updated_at) +VALUES (?, ?, ?, ?, ?, ?, ?) +ON CONFLICT(id) DO UPDATE SET + username=excluded.username, + password_hash=excluded.password_hash, + role=excluded.role, + must_change_password=excluded.must_change_password, + updated_at=excluded.updated_at""", + ( + user.id, + user.username, + user.password_hash, + user.role.value, + int(user.must_change_password), + user.created_at.astimezone(UTC).isoformat(), + user.updated_at.astimezone(UTC).isoformat(), + ), + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to save user {user.id!r}" + logger.exception( + PERSISTENCE_USER_SAVE_FAILED, + user_id=user.id, + error=str(exc), + ) + raise QueryError(msg) from exc + logger.info(PERSISTENCE_USER_SAVED, user_id=user.id) + + async def get(self, user_id: NotBlankStr) -> User | None: + """Retrieve a user by primary key. + + Args: + user_id: Unique user identifier. + + Returns: + The matching ``User``, or ``None`` if not found. + + Raises: + QueryError: If the database query or deserialization fails. + """ + try: + cursor = await self._db.execute( + "SELECT * FROM users WHERE id = ?", (user_id,) + ) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to fetch user {user_id!r}" + logger.exception( + PERSISTENCE_USER_FETCH_FAILED, + user_id=user_id, + error=str(exc), + ) + raise QueryError(msg) from exc + if row is None: + logger.debug(PERSISTENCE_USER_FETCHED, user_id=user_id, found=False) + return None + try: + user = _row_to_user(row) + except (ValueError, ValidationError) as exc: + msg = f"Failed to deserialize user {user_id!r}" + logger.exception( + PERSISTENCE_USER_FETCH_FAILED, + user_id=user_id, + error=str(exc), + ) + raise QueryError(msg) from exc + logger.debug(PERSISTENCE_USER_FETCHED, user_id=user_id, found=True) + return user + + async def get_by_username(self, username: NotBlankStr) -> User | None: + """Retrieve a user by their unique username. + + Args: + username: Login username to look up. + + Returns: + The matching ``User``, or ``None`` if not found. + + Raises: + QueryError: If the database query or deserialization fails. + """ + try: + cursor = await self._db.execute( + "SELECT * FROM users WHERE username = ?", (username,) + ) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to fetch user by username {username!r}" + logger.exception( + PERSISTENCE_USER_FETCH_FAILED, + username=username, + error=str(exc), + ) + raise QueryError(msg) from exc + if row is None: + return None + try: + return _row_to_user(row) + except (ValueError, ValidationError) as exc: + msg = f"Failed to deserialize user {username!r}" + logger.exception( + PERSISTENCE_USER_FETCH_FAILED, + username=username, + error=str(exc), + ) + raise QueryError(msg) from exc + + async def list_users(self) -> tuple[User, ...]: + """List all users ordered by creation date. + + Returns: + Tuple of all ``User`` records, oldest first. + + Raises: + QueryError: If the database query or deserialization fails. + """ + try: + cursor = await self._db.execute("SELECT * FROM users ORDER BY created_at") + rows = await cursor.fetchall() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = "Failed to list users" + logger.exception(PERSISTENCE_USER_LIST_FAILED, error=str(exc)) + raise QueryError(msg) from exc + try: + users = tuple(_row_to_user(row) for row in rows) + except (ValueError, ValidationError) as exc: + msg = "Failed to deserialize users" + logger.exception(PERSISTENCE_USER_LIST_FAILED, error=str(exc)) + raise QueryError(msg) from exc + logger.debug(PERSISTENCE_USER_LISTED, count=len(users)) + return users + + async def count(self) -> int: + """Return the total number of persisted users. + + Returns: + Non-negative integer count. + + Raises: + QueryError: If the database query fails. + """ + try: + cursor = await self._db.execute("SELECT COUNT(*) FROM users") + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = "Failed to count users" + logger.exception(PERSISTENCE_USER_COUNT_FAILED, error=str(exc)) + raise QueryError(msg) from exc + result = int(row[0]) if row else 0 + logger.debug(PERSISTENCE_USER_COUNTED, count=result) + return result + + async def delete(self, user_id: NotBlankStr) -> bool: + """Delete a user by primary key. + + Args: + user_id: Unique user identifier. + + Returns: + ``True`` if a row was deleted, ``False`` if not found. + + Raises: + QueryError: If the database operation fails. + """ + try: + cursor = await self._db.execute( + "DELETE FROM users WHERE id = ?", (user_id,) + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to delete user {user_id!r}" + logger.exception( + PERSISTENCE_USER_DELETE_FAILED, + user_id=user_id, + error=str(exc), + ) + raise QueryError(msg) from exc + deleted = cursor.rowcount > 0 + logger.info(PERSISTENCE_USER_DELETED, user_id=user_id, deleted=deleted) + return deleted + + +class SQLiteApiKeyRepository: + """SQLite-backed API key repository. + + Provides CRUD operations for ``ApiKey`` models using a shared + ``aiosqlite.Connection``. All write operations commit + immediately. + + Args: + db: An open aiosqlite connection with ``row_factory`` + set to ``aiosqlite.Row``. + """ + + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + async def save(self, key: ApiKey) -> None: + """Persist an API key via upsert (insert or update on conflict). + + Args: + key: API key model to persist. + + Raises: + QueryError: If the database operation fails. + """ + try: + await self._db.execute( + """\ +INSERT INTO api_keys (id, key_hash, name, role, user_id, + created_at, expires_at, revoked) +VALUES (?, ?, ?, ?, ?, ?, ?, ?) +ON CONFLICT(id) DO UPDATE SET + key_hash=excluded.key_hash, + name=excluded.name, + role=excluded.role, + user_id=excluded.user_id, + expires_at=excluded.expires_at, + revoked=excluded.revoked""", + ( + key.id, + key.key_hash, + key.name, + key.role.value, + key.user_id, + key.created_at.astimezone(UTC).isoformat(), + ( + key.expires_at.astimezone(UTC).isoformat() + if key.expires_at + else None + ), + int(key.revoked), + ), + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to save API key {key.id!r}" + logger.exception( + PERSISTENCE_API_KEY_SAVE_FAILED, + key_id=key.id, + error=str(exc), + ) + raise QueryError(msg) from exc + logger.info(PERSISTENCE_API_KEY_SAVED, key_id=key.id) + + async def get(self, key_id: NotBlankStr) -> ApiKey | None: + """Retrieve an API key by primary key. + + Args: + key_id: Unique key identifier. + + Returns: + The matching ``ApiKey``, or ``None`` if not found. + + Raises: + QueryError: If the database query or deserialization fails. + """ + try: + cursor = await self._db.execute( + "SELECT * FROM api_keys WHERE id = ?", (key_id,) + ) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to fetch API key {key_id!r}" + logger.exception( + PERSISTENCE_API_KEY_FETCH_FAILED, + key_id=key_id, + error=str(exc), + ) + raise QueryError(msg) from exc + if row is None: + logger.debug(PERSISTENCE_API_KEY_FETCHED, key_id=key_id, found=False) + return None + try: + key = _row_to_api_key(row) + except (ValueError, ValidationError) as exc: + msg = f"Failed to deserialize API key {key_id!r}" + logger.exception( + PERSISTENCE_API_KEY_FETCH_FAILED, + key_id=key_id, + error=str(exc), + ) + raise QueryError(msg) from exc + logger.debug(PERSISTENCE_API_KEY_FETCHED, key_id=key_id, found=True) + return key + + async def get_by_hash(self, key_hash: NotBlankStr) -> ApiKey | None: + """Retrieve an API key by its SHA-256 hash. + + Args: + key_hash: Hex-encoded SHA-256 digest of the raw key. + + Returns: + The matching ``ApiKey``, or ``None`` if not found. + + Raises: + QueryError: If the database query or deserialization fails. + """ + try: + cursor = await self._db.execute( + "SELECT * FROM api_keys WHERE key_hash = ?", + (key_hash,), + ) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = "Failed to fetch API key by hash" + logger.exception(PERSISTENCE_API_KEY_FETCH_FAILED, error=str(exc)) + raise QueryError(msg) from exc + if row is None: + return None + try: + return _row_to_api_key(row) + except (ValueError, ValidationError) as exc: + msg = "Failed to deserialize API key by hash" + logger.exception(PERSISTENCE_API_KEY_FETCH_FAILED, error=str(exc)) + raise QueryError(msg) from exc + + async def list_by_user(self, user_id: NotBlankStr) -> tuple[ApiKey, ...]: + """List all API keys belonging to a user, ordered by creation date. + + Args: + user_id: Owner user identifier. + + Returns: + Tuple of ``ApiKey`` records, oldest first. + + Raises: + QueryError: If the database query or deserialization fails. + """ + try: + cursor = await self._db.execute( + "SELECT * FROM api_keys WHERE user_id = ? ORDER BY created_at", + (user_id,), + ) + rows = await cursor.fetchall() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to list API keys for user {user_id!r}" + logger.exception( + PERSISTENCE_API_KEY_LIST_FAILED, + user_id=user_id, + error=str(exc), + ) + raise QueryError(msg) from exc + try: + keys = tuple(_row_to_api_key(row) for row in rows) + except (ValueError, ValidationError) as exc: + msg = f"Failed to deserialize API keys for user {user_id!r}" + logger.exception( + PERSISTENCE_API_KEY_LIST_FAILED, + user_id=user_id, + error=str(exc), + ) + raise QueryError(msg) from exc + logger.debug( + PERSISTENCE_API_KEY_LISTED, + user_id=user_id, + count=len(keys), + ) + return keys + + async def delete(self, key_id: NotBlankStr) -> bool: + """Delete an API key by primary key. + + Args: + key_id: Unique key identifier. + + Returns: + ``True`` if a row was deleted, ``False`` if not found. + + Raises: + QueryError: If the database operation fails. + """ + try: + cursor = await self._db.execute( + "DELETE FROM api_keys WHERE id = ?", (key_id,) + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to delete API key {key_id!r}" + logger.exception( + PERSISTENCE_API_KEY_DELETE_FAILED, + key_id=key_id, + error=str(exc), + ) + raise QueryError(msg) from exc + deleted = cursor.rowcount > 0 + logger.info(PERSISTENCE_API_KEY_DELETED, key_id=key_id, deleted=deleted) + return deleted diff --git a/tests/unit/api/auth/__init__.py b/tests/unit/api/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/api/auth/test_config.py b/tests/unit/api/auth/test_config.py new file mode 100644 index 0000000000..42025b465a --- /dev/null +++ b/tests/unit/api/auth/test_config.py @@ -0,0 +1,54 @@ +"""Tests for AuthConfig.""" + +import pytest + +from ai_company.api.auth.config import AuthConfig + + +@pytest.mark.unit +class TestAuthConfig: + def test_default_values(self) -> None: + config = AuthConfig() + assert config.jwt_secret == "" + assert config.jwt_algorithm == "HS256" + assert config.jwt_expiry_minutes == 1440 + assert config.exclude_paths is None + + def test_explicit_exclude_paths(self) -> None: + paths = ("^/health$", "^/docs") + config = AuthConfig(exclude_paths=paths) + assert config.exclude_paths == paths + + def test_with_secret_sets_secret(self) -> None: + config = AuthConfig() + updated = config.with_secret( + "a-very-long-secret-that-is-at-least-32-characters" + ) + assert updated.jwt_secret == "a-very-long-secret-that-is-at-least-32-characters" + + def test_with_secret_too_short_raises(self) -> None: + config = AuthConfig() + with pytest.raises(ValueError, match="at least 32"): + config.with_secret("short") + + def test_frozen(self) -> None: + config = AuthConfig() + with pytest.raises(Exception): # noqa: B017, PT011 + config.jwt_secret = "new" # type: ignore[misc] + + def test_original_unchanged_after_with_secret(self) -> None: + config = AuthConfig() + config.with_secret("a-very-long-secret-that-is-at-least-32-characters") + assert config.jwt_secret == "" + + def test_custom_expiry(self) -> None: + config = AuthConfig(jwt_expiry_minutes=60) + assert config.jwt_expiry_minutes == 60 + + def test_expiry_min_bound(self) -> None: + with pytest.raises(Exception): # noqa: B017, PT011 + AuthConfig(jwt_expiry_minutes=0) + + def test_expiry_max_bound(self) -> None: + with pytest.raises(Exception): # noqa: B017, PT011 + AuthConfig(jwt_expiry_minutes=50000) diff --git a/tests/unit/api/auth/test_controller.py b/tests/unit/api/auth/test_controller.py new file mode 100644 index 0000000000..abfc63d0d3 --- /dev/null +++ b/tests/unit/api/auth/test_controller.py @@ -0,0 +1,327 @@ +"""Tests for AuthController endpoints.""" + +from typing import Any + +import pytest +from litestar.testing import TestClient # noqa: TC002 + +from ai_company.api.guards import HumanRole +from tests.unit.api.conftest import make_auth_headers + + +@pytest.fixture +def bare_client(test_client: TestClient[Any]) -> TestClient[Any]: + """Test client with no default Authorization header.""" + test_client.headers.pop("authorization", None) + return test_client + + +@pytest.mark.unit +class TestSetup: + def test_setup_creates_admin(self, bare_client: TestClient[Any]) -> None: + app_state = bare_client.app.state["app_state"] + app_state.persistence._users._users.clear() + + response = bare_client.post( + "/api/v1/auth/setup", + json={ + "username": "newadmin", + "password": "super-secure-password-12", + }, + ) + assert response.status_code == 201 + data = response.json()["data"] + assert "token" in data + assert data["must_change_password"] is True + assert data["expires_in"] > 0 + + def test_setup_409_when_users_exist( + self, + bare_client: TestClient[Any], + ) -> None: + # Re-seed a user so the check fails + import uuid + from datetime import UTC, datetime + + from ai_company.api.auth.models import User + from ai_company.api.auth.service import AuthService # noqa: TC001 + from ai_company.api.guards import HumanRole + + app_state = bare_client.app.state["app_state"] + svc: AuthService = app_state.auth_service + now = datetime.now(UTC) + user = User( + id=str(uuid.uuid4()), + username="existing", + password_hash=svc.hash_password("test-password-12chars"), + role=HumanRole.CEO, + must_change_password=False, + created_at=now, + updated_at=now, + ) + app_state.persistence._users._users[user.id] = user + + response = bare_client.post( + "/api/v1/auth/setup", + json={ + "username": "admin2", + "password": "super-secure-password-12", + }, + ) + assert response.status_code == 409 + + def test_setup_short_password_rejected(self, bare_client: TestClient[Any]) -> None: + app_state = bare_client.app.state["app_state"] + app_state.persistence._users._users.clear() + + response = bare_client.post( + "/api/v1/auth/setup", + json={"username": "admin", "password": "short"}, + ) + assert response.status_code == 400 + + +@pytest.mark.unit +class TestLogin: + def test_login_valid_credentials(self, bare_client: TestClient[Any]) -> None: + app_state = bare_client.app.state["app_state"] + app_state.persistence._users._users.clear() + + bare_client.post( + "/api/v1/auth/setup", + json={ + "username": "loginuser", + "password": "super-secure-password-12", + }, + ) + + response = bare_client.post( + "/api/v1/auth/login", + json={ + "username": "loginuser", + "password": "super-secure-password-12", + }, + ) + assert response.status_code == 200 + data = response.json()["data"] + assert "token" in data + assert data["expires_in"] > 0 + + def test_login_wrong_password(self, bare_client: TestClient[Any]) -> None: + response = bare_client.post( + "/api/v1/auth/login", + json={ + "username": "test-ceo", + "password": "wrong-password-12345", + }, + ) + assert response.status_code == 401 + + def test_login_nonexistent_user(self, bare_client: TestClient[Any]) -> None: + response = bare_client.post( + "/api/v1/auth/login", + json={ + "username": "nonexistent", + "password": "any-password-12345", + }, + ) + assert response.status_code == 401 + + +@pytest.mark.unit +class TestChangePassword: + def test_change_password_success(self, bare_client: TestClient[Any]) -> None: + app_state = bare_client.app.state["app_state"] + app_state.persistence._users._users.clear() + + setup_resp = bare_client.post( + "/api/v1/auth/setup", + json={ + "username": "changepw", + "password": "old-password-12chars", + }, + ) + token = setup_resp.json()["data"]["token"] + + response = bare_client.post( + "/api/v1/auth/change-password", + json={ + "current_password": "old-password-12chars", + "new_password": "new-password-12chars", + }, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 200 + data = response.json()["data"] + assert data["must_change_password"] is False + + def test_change_password_wrong_current(self, test_client: TestClient[Any]) -> None: + response = test_client.post( + "/api/v1/auth/change-password", + json={ + "current_password": "wrong-current-pw-12", + "new_password": "new-password-12chars", + }, + headers=make_auth_headers("ceo"), + ) + assert response.status_code == 401 + + def test_change_password_requires_auth(self, bare_client: TestClient[Any]) -> None: + response = bare_client.post( + "/api/v1/auth/change-password", + json={ + "current_password": "old-password-12chars", + "new_password": "new-password-12chars", + }, + ) + assert response.status_code == 401 + + def test_change_password_short_new_password( + self, + bare_client: TestClient[Any], + ) -> None: + app_state = bare_client.app.state["app_state"] + app_state.persistence._users._users.clear() + + setup_resp = bare_client.post( + "/api/v1/auth/setup", + json={ + "username": "shortpw", + "password": "old-password-12chars", + }, + ) + token = setup_resp.json()["data"]["token"] + + response = bare_client.post( + "/api/v1/auth/change-password", + json={ + "current_password": "old-password-12chars", + "new_password": "short", + }, + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.status_code == 400 + + +@pytest.mark.unit +class TestMe: + def test_me_returns_user_info(self, test_client: TestClient[Any]) -> None: + response = test_client.get( + "/api/v1/auth/me", + headers=make_auth_headers("ceo"), + ) + assert response.status_code == 200 + data = response.json()["data"] + assert data["username"] == "test-ceo" + assert data["role"] == "ceo" + assert data["must_change_password"] is False + + def test_me_requires_auth(self, bare_client: TestClient[Any]) -> None: + response = bare_client.get("/api/v1/auth/me") + assert response.status_code == 401 + + +@pytest.mark.unit +class TestRequirePasswordChanged: + def test_blocks_user_with_must_change_password(self) -> None: + """Guard raises PermissionDeniedException for flagged users.""" + from unittest.mock import MagicMock + + from litestar.exceptions import PermissionDeniedException + + from ai_company.api.auth.controller import require_password_changed + from ai_company.api.auth.models import AuthenticatedUser, AuthMethod + + user = AuthenticatedUser( + user_id="u1", + username="admin", + role=HumanRole.CEO, + auth_method=AuthMethod.JWT, + must_change_password=True, + ) + connection = MagicMock() + connection.scope = {"user": user} + connection.url.path = "/api/v1/health" + + with pytest.raises(PermissionDeniedException): + require_password_changed(connection, None) + + def test_allows_user_without_flag(self) -> None: + """Guard passes when must_change_password is False.""" + from unittest.mock import MagicMock + + from ai_company.api.auth.controller import require_password_changed + from ai_company.api.auth.models import AuthenticatedUser, AuthMethod + + user = AuthenticatedUser( + user_id="u1", + username="admin", + role=HumanRole.CEO, + auth_method=AuthMethod.JWT, + must_change_password=False, + ) + connection = MagicMock() + connection.scope = {"user": user} + connection.url.path = "/api/v1/health" + + # Should not raise + require_password_changed(connection, None) + + def test_allows_when_no_user_in_scope(self) -> None: + """Guard passes when no user is in scope (pre-auth).""" + from unittest.mock import MagicMock + + from ai_company.api.auth.controller import require_password_changed + + connection = MagicMock() + connection.scope = {} + connection.url.path = "/api/v1/health" + + # Should not raise + require_password_changed(connection, None) + + @pytest.mark.parametrize( + "path", + [ + pytest.param("/api/v1/auth/change-password", id="change-password"), + pytest.param("/api/v1/auth/me", id="me"), + ], + ) + def test_exempts_paths_for_must_change_password_users( + self, + path: str, + ) -> None: + """Guard allows must_change_password users on exempt paths.""" + from unittest.mock import MagicMock + + from ai_company.api.auth.controller import require_password_changed + from ai_company.api.auth.models import AuthenticatedUser, AuthMethod + + user = AuthenticatedUser( + user_id="u1", + username="admin", + role=HumanRole.CEO, + auth_method=AuthMethod.JWT, + must_change_password=True, + ) + connection = MagicMock() + connection.scope = {"user": user} + connection.url.path = path + + # Should not raise — exempt path + require_password_changed(connection, None) + + def test_rejects_unknown_user_type(self) -> None: + """Guard raises PermissionDeniedException for non-AuthenticatedUser.""" + from unittest.mock import MagicMock + + from litestar.exceptions import PermissionDeniedException + + from ai_company.api.auth.controller import require_password_changed + + connection = MagicMock() + connection.scope = {"user": "not-an-auth-user"} + connection.url.path = "/api/v1/health" + + with pytest.raises(PermissionDeniedException): + require_password_changed(connection, None) diff --git a/tests/unit/api/auth/test_middleware.py b/tests/unit/api/auth/test_middleware.py new file mode 100644 index 0000000000..5a7c1ffc16 --- /dev/null +++ b/tests/unit/api/auth/test_middleware.py @@ -0,0 +1,313 @@ +"""Tests for ApiAuthMiddleware.""" + +from datetime import UTC, datetime + +import pytest +from litestar import Litestar, get +from litestar.testing import TestClient + +from ai_company.api.auth.config import AuthConfig +from ai_company.api.auth.middleware import create_auth_middleware_class +from ai_company.api.auth.models import ApiKey, User +from ai_company.api.auth.service import AuthService +from ai_company.api.guards import HumanRole +from tests.unit.api.conftest import FakePersistenceBackend + +_SECRET = "test-secret-that-is-at-least-32-characters-long" + + +def _make_auth_service() -> AuthService: + return AuthService(AuthConfig(jwt_secret=_SECRET)) + + +def _make_user(svc: AuthService) -> User: + now = datetime.now(UTC) + return User( + id="mw-user-001", + username="mw-admin", + password_hash=svc.hash_password("test-password-12chars"), + role=HumanRole.CEO, + must_change_password=False, + created_at=now, + updated_at=now, + ) + + +def _build_app( + *, + auth_service: AuthService, + persistence: FakePersistenceBackend, + exclude_paths: tuple[str, ...] = (), +) -> Litestar: + """Build a minimal Litestar app with auth middleware.""" + auth_config = AuthConfig( + jwt_secret=_SECRET, + exclude_paths=exclude_paths, + ) + + @get("/protected") + async def protected_route() -> dict[str, str]: + return {"status": "ok"} + + @get("/public") + async def public_route() -> dict[str, str]: + return {"status": "public"} + + middleware_cls = create_auth_middleware_class(auth_config) + + class _FakeState: + def __init__(self) -> None: + self.auth_service = auth_service + self.persistence = persistence + + app = Litestar( + route_handlers=[protected_route, public_route], + middleware=[middleware_cls], + ) + app.state["app_state"] = _FakeState() + return app + + +@pytest.mark.unit +class TestAuthMiddlewareJWT: + async def test_valid_jwt_authenticates(self) -> None: + svc = _make_auth_service() + user = _make_user(svc) + persistence = FakePersistenceBackend() + await persistence.connect() + await persistence.users.save(user) + + app = _build_app(auth_service=svc, persistence=persistence) + token, _ = svc.create_token(user) + + with TestClient(app) as client: + resp = client.get( + "/protected", + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 200 + + async def test_missing_header_returns_401(self) -> None: + svc = _make_auth_service() + persistence = FakePersistenceBackend() + await persistence.connect() + app = _build_app(auth_service=svc, persistence=persistence) + + with TestClient(app) as client: + resp = client.get("/protected") + assert resp.status_code == 401 + + async def test_invalid_scheme_returns_401(self) -> None: + svc = _make_auth_service() + persistence = FakePersistenceBackend() + await persistence.connect() + app = _build_app(auth_service=svc, persistence=persistence) + + with TestClient(app) as client: + resp = client.get( + "/protected", + headers={"Authorization": "Basic dXNlcjpwYXNz"}, + ) + assert resp.status_code == 401 + + async def test_invalid_jwt_returns_401(self) -> None: + svc = _make_auth_service() + persistence = FakePersistenceBackend() + await persistence.connect() + app = _build_app(auth_service=svc, persistence=persistence) + + with TestClient(app) as client: + resp = client.get( + "/protected", + headers={"Authorization": "Bearer bad.jwt.token"}, + ) + assert resp.status_code == 401 + + async def test_jwt_for_deleted_user_returns_401(self) -> None: + svc = _make_auth_service() + user = _make_user(svc) + persistence = FakePersistenceBackend() + await persistence.connect() + # Don't save user — simulate deleted user + token, _ = svc.create_token(user) + app = _build_app(auth_service=svc, persistence=persistence) + + with TestClient(app) as client: + resp = client.get( + "/protected", + headers={"Authorization": f"Bearer {token}"}, + ) + assert resp.status_code == 401 + + +@pytest.mark.unit +class TestAuthMiddlewareApiKey: + async def test_valid_api_key_authenticates(self) -> None: + svc = _make_auth_service() + user = _make_user(svc) + persistence = FakePersistenceBackend() + await persistence.connect() + await persistence.users.save(user) + + raw_key = AuthService.generate_api_key() + key_hash = AuthService.hash_api_key(raw_key) + now = datetime.now(UTC) + api_key = ApiKey( + id="key-001", + key_hash=key_hash, + name="test-key", + role=HumanRole.CEO, + user_id=user.id, + created_at=now, + ) + await persistence.api_keys.save(api_key) + + app = _build_app(auth_service=svc, persistence=persistence) + + with TestClient(app) as client: + resp = client.get( + "/protected", + headers={"Authorization": f"Bearer {raw_key}"}, + ) + assert resp.status_code == 200 + + async def test_revoked_api_key_returns_401(self) -> None: + svc = _make_auth_service() + user = _make_user(svc) + persistence = FakePersistenceBackend() + await persistence.connect() + await persistence.users.save(user) + + raw_key = AuthService.generate_api_key() + key_hash = AuthService.hash_api_key(raw_key) + now = datetime.now(UTC) + api_key = ApiKey( + id="key-002", + key_hash=key_hash, + name="revoked-key", + role=HumanRole.CEO, + user_id=user.id, + created_at=now, + revoked=True, + ) + await persistence.api_keys.save(api_key) + + app = _build_app(auth_service=svc, persistence=persistence) + + with TestClient(app) as client: + resp = client.get( + "/protected", + headers={"Authorization": f"Bearer {raw_key}"}, + ) + assert resp.status_code == 401 + + async def test_expired_api_key_returns_401(self) -> None: + from datetime import timedelta + + svc = _make_auth_service() + user = _make_user(svc) + persistence = FakePersistenceBackend() + await persistence.connect() + await persistence.users.save(user) + + raw_key = AuthService.generate_api_key() + key_hash = AuthService.hash_api_key(raw_key) + now = datetime.now(UTC) + api_key = ApiKey( + id="key-003", + key_hash=key_hash, + name="expired-key", + role=HumanRole.CEO, + user_id=user.id, + created_at=now - timedelta(days=2), + expires_at=now - timedelta(days=1), + ) + await persistence.api_keys.save(api_key) + + app = _build_app(auth_service=svc, persistence=persistence) + + with TestClient(app) as client: + resp = client.get( + "/protected", + headers={"Authorization": f"Bearer {raw_key}"}, + ) + assert resp.status_code == 401 + + +@pytest.mark.unit +class TestAuthMiddlewareApiKeyEdgeCases: + async def test_api_key_with_deleted_owner_returns_401(self) -> None: + svc = _make_auth_service() + user = _make_user(svc) + persistence = FakePersistenceBackend() + await persistence.connect() + # Save the user, create a key, then delete the user + await persistence.users.save(user) + + raw_key = AuthService.generate_api_key() + key_hash = AuthService.hash_api_key(raw_key) + now = datetime.now(UTC) + api_key = ApiKey( + id="key-orphan", + key_hash=key_hash, + name="orphaned-key", + role=HumanRole.CEO, + user_id=user.id, + created_at=now, + ) + await persistence.api_keys.save(api_key) + await persistence.users.delete(user.id) + + app = _build_app(auth_service=svc, persistence=persistence) + + with TestClient(app) as client: + resp = client.get( + "/protected", + headers={"Authorization": f"Bearer {raw_key}"}, + ) + assert resp.status_code == 401 + + +@pytest.mark.unit +class TestExtractBearerToken: + @pytest.mark.parametrize( + ("header", "expected"), + [ + pytest.param("Bearer mytoken123", "mytoken123", id="valid"), + pytest.param("bearer mytoken123", "mytoken123", id="lowercase"), + pytest.param("BEARER mytoken123", "mytoken123", id="uppercase"), + pytest.param("", None, id="empty"), + pytest.param("Bearer", None, id="no-token"), + pytest.param("Basic dXNlcjpwYXNz", None, id="wrong-scheme"), + pytest.param( + "Bearer token with spaces", + "token with spaces", + id="token-with-spaces", + ), + ], + ) + def test_extract_bearer_token( + self, + header: str, + expected: str | None, + ) -> None: + from ai_company.api.auth.middleware import _extract_bearer_token + + assert _extract_bearer_token(header) == expected + + +@pytest.mark.unit +class TestAuthMiddlewareExcludePaths: + async def test_excluded_path_skips_auth(self) -> None: + svc = _make_auth_service() + persistence = FakePersistenceBackend() + await persistence.connect() + app = _build_app( + auth_service=svc, + persistence=persistence, + exclude_paths=("/public",), + ) + + with TestClient(app) as client: + resp = client.get("/public") + assert resp.status_code == 200 diff --git a/tests/unit/api/auth/test_secret.py b/tests/unit/api/auth/test_secret.py new file mode 100644 index 0000000000..716a9993fd --- /dev/null +++ b/tests/unit/api/auth/test_secret.py @@ -0,0 +1,75 @@ +"""Tests for JWT secret resolution chain.""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from ai_company.api.auth.secret import resolve_jwt_secret + + +def _make_persistence(stored_secret: str | None = None) -> AsyncMock: + """Build a fake persistence backend for secret resolution tests.""" + persistence = AsyncMock() + persistence.get_setting = AsyncMock(return_value=stored_secret) + persistence.set_setting = AsyncMock() + return persistence + + +@pytest.mark.unit +class TestResolveJwtSecret: + async def test_env_var_takes_priority(self) -> None: + secret = "env-secret-that-is-at-least-32-characters!!" + persistence = _make_persistence(stored_secret="stored-secret-32-chars-long!!!!") + with patch.dict("os.environ", {"AI_COMPANY_JWT_SECRET": secret}): + result = await resolve_jwt_secret(persistence) + + assert result == secret + persistence.get_setting.assert_not_called() + + async def test_stored_secret_used_when_no_env_var(self) -> None: + stored = "stored-secret-that-is-at-least-32-chars!!" + persistence = _make_persistence(stored_secret=stored) + with patch.dict("os.environ", {}, clear=True): + result = await resolve_jwt_secret(persistence) + + assert result == stored + + async def test_generates_and_persists_when_nothing_stored(self) -> None: + persistence = _make_persistence(stored_secret=None) + with patch.dict("os.environ", {}, clear=True): + result = await resolve_jwt_secret(persistence) + + assert len(result) >= 32 + persistence.set_setting.assert_awaited_once_with("jwt_secret", result) + + async def test_env_var_too_short_raises(self) -> None: + persistence = _make_persistence() + with ( + patch.dict("os.environ", {"AI_COMPANY_JWT_SECRET": "short"}), + pytest.raises(ValueError, match="at least 32 characters"), + ): + await resolve_jwt_secret(persistence) + + async def test_env_var_whitespace_stripped(self) -> None: + secret = " env-secret-that-is-at-least-32-characters!! " + persistence = _make_persistence() + with patch.dict("os.environ", {"AI_COMPANY_JWT_SECRET": secret}): + result = await resolve_jwt_secret(persistence) + + assert result == secret.strip() + + async def test_empty_env_var_falls_through(self) -> None: + stored = "stored-secret-that-is-at-least-32-chars!!" + persistence = _make_persistence(stored_secret=stored) + with patch.dict("os.environ", {"AI_COMPANY_JWT_SECRET": ""}): + result = await resolve_jwt_secret(persistence) + + assert result == stored + + async def test_whitespace_only_env_var_falls_through(self) -> None: + stored = "stored-secret-that-is-at-least-32-chars!!" + persistence = _make_persistence(stored_secret=stored) + with patch.dict("os.environ", {"AI_COMPANY_JWT_SECRET": " "}): + result = await resolve_jwt_secret(persistence) + + assert result == stored diff --git a/tests/unit/api/auth/test_service.py b/tests/unit/api/auth/test_service.py new file mode 100644 index 0000000000..3f2dda24e8 --- /dev/null +++ b/tests/unit/api/auth/test_service.py @@ -0,0 +1,172 @@ +"""Tests for AuthService.""" + +from datetime import UTC, datetime, timedelta + +import pytest + +from ai_company.api.auth.config import AuthConfig +from ai_company.api.auth.models import User +from ai_company.api.auth.service import AuthService +from ai_company.api.guards import HumanRole + +_SECRET = "test-secret-that-is-at-least-32-characters-long" + + +def _make_service() -> AuthService: + return AuthService(AuthConfig(jwt_secret=_SECRET)) + + +def _make_user( + *, + role: HumanRole = HumanRole.CEO, + must_change_password: bool = False, +) -> User: + now = datetime.now(UTC) + svc = _make_service() + return User( + id="user-001", + username="admin", + password_hash=svc.hash_password("test-password-12chars"), + role=role, + must_change_password=must_change_password, + created_at=now, + updated_at=now, + ) + + +@pytest.mark.unit +class TestPasswordHashing: + def test_hash_and_verify(self) -> None: + svc = _make_service() + hashed = svc.hash_password("my-secret-password") + assert svc.verify_password("my-secret-password", hashed) + + def test_wrong_password_fails(self) -> None: + svc = _make_service() + hashed = svc.hash_password("correct-password") + assert not svc.verify_password("wrong-password", hashed) + + def test_hash_is_not_plaintext(self) -> None: + svc = _make_service() + hashed = svc.hash_password("my-secret-password") + assert hashed != "my-secret-password" + assert "$argon2" in hashed + + def test_different_hashes_for_same_password(self) -> None: + svc = _make_service() + h1 = svc.hash_password("same-password") + h2 = svc.hash_password("same-password") + # Different salts produce different hashes + assert h1 != h2 + + def test_verify_password_with_corrupted_hash(self) -> None: + svc = _make_service() + assert not svc.verify_password("my-password", "not-a-valid-argon2-hash") + + def test_verify_password_with_empty_hash(self) -> None: + svc = _make_service() + assert not svc.verify_password("my-password", "") + + +@pytest.mark.unit +class TestJWT: + def test_create_and_decode(self) -> None: + svc = _make_service() + user = _make_user() + token, expires_in = svc.create_token(user) + assert isinstance(token, str) + assert expires_in == 1440 * 60 + + claims = svc.decode_token(token) + assert claims["sub"] == "user-001" + assert claims["username"] == "admin" + assert claims["role"] == "ceo" + + def test_expired_token_raises(self) -> None: + import jwt + + config = AuthConfig(jwt_secret=_SECRET, jwt_expiry_minutes=1) + svc = AuthService(config) + user = _make_user() + _token, _ = svc.create_token(user) + + # Manually create an expired token + expired_payload = { + "sub": user.id, + "username": user.username, + "role": user.role.value, + "must_change_password": False, + "iat": datetime.now(UTC) - timedelta(hours=2), + "exp": datetime.now(UTC) - timedelta(hours=1), + } + expired_token = jwt.encode(expired_payload, _SECRET, algorithm="HS256") + with pytest.raises(jwt.ExpiredSignatureError): + svc.decode_token(expired_token) + + def test_invalid_signature_raises(self) -> None: + import jwt + + svc = _make_service() + user = _make_user() + token, _ = svc.create_token(user) + + # Decode with wrong secret + wrong_svc = AuthService( + AuthConfig(jwt_secret="wrong-secret-that-is-at-least-32-chars!!") + ) + with pytest.raises(jwt.InvalidSignatureError): + wrong_svc.decode_token(token) + + def test_must_change_password_in_claims(self) -> None: + svc = _make_service() + user = _make_user(must_change_password=True) + token, _ = svc.create_token(user) + claims = svc.decode_token(token) + assert claims["must_change_password"] is True + + def test_decode_token_missing_sub_claim(self) -> None: + import jwt as pyjwt + + svc = _make_service() + payload = { + "username": "admin", + "role": "ceo", + "iat": datetime.now(UTC), + "exp": datetime.now(UTC) + timedelta(hours=1), + } + token = pyjwt.encode(payload, _SECRET, algorithm="HS256") + with pytest.raises(pyjwt.MissingRequiredClaimError): + svc.decode_token(token) + + def test_create_token_empty_secret_raises(self) -> None: + svc = AuthService(AuthConfig()) + user = _make_user() + with pytest.raises(RuntimeError, match="JWT secret not configured"): + svc.create_token(user) + + def test_decode_token_empty_secret_raises(self) -> None: + svc = AuthService(AuthConfig()) + with pytest.raises(RuntimeError, match="JWT secret not configured"): + svc.decode_token("any.token.here") + + +@pytest.mark.unit +class TestApiKeyHashing: + def test_hash_deterministic(self) -> None: + h1 = AuthService.hash_api_key("my-key") + h2 = AuthService.hash_api_key("my-key") + assert h1 == h2 + + def test_different_keys_different_hashes(self) -> None: + h1 = AuthService.hash_api_key("key-one") + h2 = AuthService.hash_api_key("key-two") + assert h1 != h2 + + def test_generate_key_unique(self) -> None: + k1 = AuthService.generate_api_key() + k2 = AuthService.generate_api_key() + assert k1 != k2 + + def test_generate_key_length(self) -> None: + key = AuthService.generate_api_key() + assert len(key) > 30 diff --git a/tests/unit/api/conftest.py b/tests/unit/api/conftest.py index b43fc29181..32baeee0ce 100644 --- a/tests/unit/api/conftest.py +++ b/tests/unit/api/conftest.py @@ -1,6 +1,7 @@ """Shared fixtures for API unit tests.""" import asyncio +import uuid from datetime import UTC, datetime, timedelta from typing import Any @@ -9,6 +10,10 @@ from ai_company.api.app import create_app from ai_company.api.approval_store import ApprovalStore +from ai_company.api.auth.config import AuthConfig +from ai_company.api.auth.models import ApiKey, User +from ai_company.api.auth.service import AuthService +from ai_company.api.guards import HumanRole from ai_company.budget.cost_record import CostRecord # noqa: TC001 from ai_company.budget.tracker import CostTracker from ai_company.communication.channel import Channel # noqa: TC001 @@ -25,6 +30,12 @@ from ai_company.security.models import AuditEntry, AuditVerdictStr # noqa: TC001 from ai_company.security.timeout.parked_context import ParkedContext # noqa: TC001 +# ── Test auth constants ─────────────────────────────────────── + +_TEST_JWT_SECRET = "test-secret-that-is-at-least-32-characters-long" +_TEST_USER_ID = "test-user-001" +_TEST_USERNAME = "testadmin" + # ── Fake Repositories ──────────────────────────────────────────── @@ -262,6 +273,59 @@ async def query( # noqa: PLR0913 return tuple(results[:limit]) +class FakeUserRepository: + """In-memory user repository for tests.""" + + def __init__(self) -> None: + self._users: dict[str, User] = {} + + async def save(self, user: User) -> None: + self._users[user.id] = user + + async def get(self, user_id: str) -> User | None: + return self._users.get(user_id) + + async def get_by_username(self, username: str) -> User | None: + for user in self._users.values(): + if user.username == username: + return user + return None + + async def list_users(self) -> tuple[User, ...]: + return tuple(self._users.values()) + + async def count(self) -> int: + return len(self._users) + + async def delete(self, user_id: str) -> bool: + return self._users.pop(user_id, None) is not None + + +class FakeApiKeyRepository: + """In-memory API key repository for tests.""" + + def __init__(self) -> None: + self._keys: dict[str, ApiKey] = {} + + async def save(self, key: ApiKey) -> None: + self._keys[key.id] = key + + async def get(self, key_id: str) -> ApiKey | None: + return self._keys.get(key_id) + + async def get_by_hash(self, key_hash: str) -> ApiKey | None: + for key in self._keys.values(): + if key.key_hash == key_hash: + return key + return None + + async def list_by_user(self, user_id: str) -> tuple[ApiKey, ...]: + return tuple(k for k in self._keys.values() if k.user_id == user_id) + + async def delete(self, key_id: str) -> bool: + return self._keys.pop(key_id, None) is not None + + class FakePersistenceBackend: """In-memory persistence backend for tests.""" @@ -274,6 +338,9 @@ def __init__(self) -> None: self._collaboration_metrics = FakeCollaborationMetricRepository() self._parked_contexts = FakeParkedContextRepository() self._audit_entries = FakeAuditRepository() + self._users = FakeUserRepository() + self._api_keys = FakeApiKeyRepository() + self._settings: dict[str, str] = {} self._connected = False async def connect(self) -> None: @@ -328,6 +395,20 @@ def parked_contexts(self) -> FakeParkedContextRepository: def audit_entries(self) -> FakeAuditRepository: return self._audit_entries + @property + def users(self) -> FakeUserRepository: + return self._users + + @property + def api_keys(self) -> FakeApiKeyRepository: + return self._api_keys + + async def get_setting(self, key: str) -> str | None: + return self._settings.get(key) + + async def set_setting(self, key: str, value: str) -> None: + self._settings[key] = value + # ── Fake Message Bus ──────────────────────────────────────────── @@ -396,9 +477,104 @@ async def get_channel_history( return () +# ── Auth helpers ──────────────────────────────────────────────── + +# Cache password hashes by role so that make_auth_headers and +# _seed_test_users produce identical pwd_sig claims. +_TEST_PASSWORD_HASHES: dict[str, str] = {} + + +def _make_test_auth_config() -> AuthConfig: + """Create an AuthConfig with a test JWT secret.""" + return AuthConfig(jwt_secret=_TEST_JWT_SECRET) + + +def _make_test_auth_service() -> AuthService: + """Create an AuthService backed by test config.""" + return AuthService(_make_test_auth_config()) + + +def _get_test_password_hash( + role: str, + auth_service: AuthService, +) -> str: + """Return a cached password hash for the given role. + + On the first call for a role, hashes the test password and + caches the result so that ``make_auth_headers`` and + ``_seed_test_users`` produce tokens with matching ``pwd_sig`` + claims. + """ + if role not in _TEST_PASSWORD_HASHES: + _TEST_PASSWORD_HASHES[role] = auth_service.hash_password( + "test-password-12chars", + ) + return _TEST_PASSWORD_HASHES[role] + + +def _make_test_user( + *, + role: HumanRole = HumanRole.CEO, + must_change_password: bool = False, + user_id: str = _TEST_USER_ID, + username: str = _TEST_USERNAME, +) -> User: + """Create a test User with given role.""" + now = datetime.now(UTC) + auth_service = _make_test_auth_service() + return User( + id=user_id, + username=username, + password_hash=_get_test_password_hash(role.value, auth_service), + role=role, + must_change_password=must_change_password, + created_at=now, + updated_at=now, + ) + + +def make_auth_headers( + role: str = "ceo", + *, + must_change_password: bool = False, +) -> dict[str, str]: + """Build an Authorization header with a JWT for the given role. + + Uses deterministic user IDs matching ``_seed_test_users`` so + middleware user lookups succeed. The password hash is cached + per role to ensure the ``pwd_sig`` claim matches the seeded + user in persistence. + """ + auth_service = _make_test_auth_service() + # Must match the ID pattern in _seed_test_users + user_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"test-{role}")) + now = datetime.now(UTC) + user = User( + id=user_id, + username=f"test-{role}", + password_hash=_get_test_password_hash(role, auth_service), + role=HumanRole(role), + must_change_password=must_change_password, + created_at=now, + updated_at=now, + ) + token, _ = auth_service.create_token(user) + return {"Authorization": f"Bearer {token}"} + + # ── Fixtures ──────────────────────────────────────────────────── +@pytest.fixture +def auth_config() -> AuthConfig: + return _make_test_auth_config() + + +@pytest.fixture +def auth_service() -> AuthService: + return _make_test_auth_service() + + @pytest.fixture async def fake_persistence() -> FakePersistenceBackend: backend = FakePersistenceBackend() @@ -429,25 +605,65 @@ def root_config() -> RootConfig: @pytest.fixture -def test_client( +def test_client( # noqa: PLR0913 fake_persistence: FakePersistenceBackend, fake_message_bus: FakeMessageBus, cost_tracker: CostTracker, approval_store: ApprovalStore, root_config: RootConfig, + auth_service: AuthService, ) -> TestClient[Any]: + # Pre-seed users for each role so JWT sub claims resolve + _seed_test_users(fake_persistence, auth_service) + app = create_app( config=root_config, persistence=fake_persistence, message_bus=fake_message_bus, cost_tracker=cost_tracker, approval_store=approval_store, + auth_service=auth_service, ) client = TestClient(app) - client.headers["X-Human-Role"] = "observer" + # Default: CEO token (most tests need write access) + client.headers.update(make_auth_headers("ceo")) return client +def _seed_test_users( + backend: FakePersistenceBackend, + auth_service: AuthService, +) -> None: + """Pre-seed a user for each role so JWT validation succeeds. + + The middleware looks up the user by ``sub`` claim, so we + need matching users in the fake persistence for every role + that tests might use. Uses cached password hashes to ensure + ``pwd_sig`` claims match between seeded users and tokens + produced by ``make_auth_headers``. + + Assigns directly to the fake repository's internal dict + (avoiding async) so this helper works in both sync fixtures + and sync test functions. + """ + now = datetime.now(UTC) + for role in HumanRole: + user_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"test-{role.value}")) + user = User( + id=user_id, + username=f"test-{role.value}", + password_hash=_get_test_password_hash( + role.value, + auth_service, + ), + role=role, + must_change_password=False, + created_at=now, + updated_at=now, + ) + backend._users._users[user.id] = user + + def make_task( # noqa: PLR0913 *, task_id: str = "task-001", diff --git a/tests/unit/api/controllers/test_agents.py b/tests/unit/api/controllers/test_agents.py index 2e35e4d800..c6632b2282 100644 --- a/tests/unit/api/controllers/test_agents.py +++ b/tests/unit/api/controllers/test_agents.py @@ -6,9 +6,10 @@ from litestar.testing import TestClient from ai_company.config.schema import AgentConfig, RootConfig -from tests.unit.api.conftest import ( # noqa: TC001 +from tests.unit.api.conftest import ( FakeMessageBus, FakePersistenceBackend, + make_auth_headers, ) @@ -27,7 +28,9 @@ def test_list_agents_with_data( fake_message_bus: FakeMessageBus, ) -> None: from ai_company.api.app import create_app + from ai_company.api.auth.service import AuthService # noqa: TC001 from ai_company.budget.tracker import CostTracker + from tests.unit.api.conftest import _make_test_auth_service, _seed_test_users config = RootConfig( company_name="test", @@ -39,14 +42,17 @@ def test_list_agents_with_data( ), ), ) + auth_service: AuthService = _make_test_auth_service() + _seed_test_users(fake_persistence, auth_service) app = create_app( config=config, persistence=fake_persistence, message_bus=fake_message_bus, cost_tracker=CostTracker(), + auth_service=auth_service, ) with TestClient(app) as client: - client.headers["X-Human-Role"] = "observer" + client.headers.update(make_auth_headers("observer")) resp = client.get("/api/v1/agents") body = resp.json() assert body["pagination"]["total"] == 1 diff --git a/tests/unit/api/controllers/test_analytics.py b/tests/unit/api/controllers/test_analytics.py index a1462817fb..49fc03f29d 100644 --- a/tests/unit/api/controllers/test_analytics.py +++ b/tests/unit/api/controllers/test_analytics.py @@ -6,8 +6,9 @@ from litestar.testing import TestClient # noqa: TC002 from ai_company.core.enums import TaskStatus +from tests.unit.api.conftest import make_auth_headers -_HEADERS = {"X-Human-Role": "ceo"} +_HEADERS = make_auth_headers("ceo") @pytest.mark.unit @@ -27,6 +28,6 @@ def test_overview_empty(self, test_client: TestClient[Any]) -> None: def test_overview_requires_read_access(self, test_client: TestClient[Any]) -> None: resp = test_client.get( "/api/v1/analytics/overview", - headers={"X-Human-Role": "invalid"}, + headers={"Authorization": "Bearer invalid-token"}, ) - assert resp.status_code == 403 + assert resp.status_code == 401 diff --git a/tests/unit/api/controllers/test_approvals.py b/tests/unit/api/controllers/test_approvals.py index 025e80ee5b..29a986353e 100644 --- a/tests/unit/api/controllers/test_approvals.py +++ b/tests/unit/api/controllers/test_approvals.py @@ -9,11 +9,11 @@ from ai_company.api.approval_store import ApprovalStore # noqa: TC001 from ai_company.core.approval import ApprovalItem from ai_company.core.enums import ApprovalRiskLevel, ApprovalStatus -from tests.unit.api.conftest import make_approval +from tests.unit.api.conftest import make_approval, make_auth_headers _BASE = "/api/v1/approvals" -_WRITE_HEADERS = {"X-Human-Role": "ceo"} -_READ_HEADERS = {"X-Human-Role": "observer"} +_WRITE_HEADERS = make_auth_headers("ceo") +_READ_HEADERS = make_auth_headers("observer") def _create_payload( @@ -134,8 +134,8 @@ async def test_list_pagination( assert body["pagination"]["offset"] == 2 def test_list_blocks_no_role(self, test_client: TestClient[Any]) -> None: - resp = test_client.get(_BASE, headers={"X-Human-Role": "invalid"}) - assert resp.status_code == 403 + resp = test_client.get(_BASE, headers={"Authorization": "Bearer invalid-token"}) + assert resp.status_code == 401 @pytest.mark.unit @@ -164,16 +164,16 @@ def test_get_allows_observer(self, test_client: TestClient[Any]) -> None: # Observer should have read access (even if 404) resp = test_client.get( f"{_BASE}/nonexistent", - headers={"X-Human-Role": "observer"}, + headers=make_auth_headers("observer"), ) assert resp.status_code == 404 # 404 = authorized but not found def test_get_blocks_no_role(self, test_client: TestClient[Any]) -> None: resp = test_client.get( f"{_BASE}/whatever", - headers={"X-Human-Role": "invalid"}, + headers={"Authorization": "Bearer invalid-token"}, ) - assert resp.status_code == 403 + assert resp.status_code == 401 @pytest.mark.unit @@ -234,12 +234,13 @@ def test_create_blocks_observer(self, test_client: TestClient[Any]) -> None: ) assert resp.status_code == 403 - def test_create_blocks_no_role(self, test_client: TestClient[Any]) -> None: + def test_create_blocks_no_auth(self, test_client: TestClient[Any]) -> None: resp = test_client.post( _BASE, json=_create_payload(), + headers={"Authorization": "Bearer invalid-token"}, ) - assert resp.status_code == 403 + assert resp.status_code == 401 @pytest.mark.unit @@ -258,7 +259,7 @@ async def test_approve_pending( assert resp.status_code == 200 body = resp.json() assert body["data"]["status"] == "approved" - assert body["data"]["decided_by"] == "ceo" + assert body["data"]["decided_by"] == "test-ceo" assert body["data"]["decision_reason"] == "Looks good" async def test_approve_records_decided_by_from_header( @@ -270,10 +271,10 @@ async def test_approve_records_decided_by_from_header( resp = test_client.post( f"{_BASE}/approval-001/approve", json={}, - headers={"X-Human-Role": "manager"}, + headers=make_auth_headers("manager"), ) assert resp.status_code == 200 - assert resp.json()["data"]["decided_by"] == "manager" + assert resp.json()["data"]["decided_by"] == "test-manager" def test_approve_not_found(self, test_client: TestClient[Any]) -> None: resp = test_client.post( @@ -359,7 +360,7 @@ async def test_reject_pending( assert resp.status_code == 200 body = resp.json() assert body["data"]["status"] == "rejected" - assert body["data"]["decided_by"] == "ceo" + assert body["data"]["decided_by"] == "test-ceo" assert body["data"]["decision_reason"] == "Too risky" async def test_reject_requires_reason( diff --git a/tests/unit/api/controllers/test_autonomy.py b/tests/unit/api/controllers/test_autonomy.py index c460503db9..c84cde81d7 100644 --- a/tests/unit/api/controllers/test_autonomy.py +++ b/tests/unit/api/controllers/test_autonomy.py @@ -5,9 +5,11 @@ import pytest from litestar.testing import TestClient # noqa: TC002 +from tests.unit.api.conftest import make_auth_headers + _BASE = "/api/v1/agents" -_WRITE_HEADERS = {"X-Human-Role": "ceo"} -_READ_HEADERS = {"X-Human-Role": "observer"} +_WRITE_HEADERS = make_auth_headers("ceo") +_READ_HEADERS = make_auth_headers("observer") def _url(agent_id: str = "agent-001") -> str: @@ -29,8 +31,10 @@ def test_get_autonomy(self, test_client: TestClient[Any]) -> None: def test_get_autonomy_requires_read_access( self, test_client: TestClient[Any] ) -> None: - resp = test_client.get(_url(), headers={"X-Human-Role": "invalid"}) - assert resp.status_code == 403 + resp = test_client.get( + _url(), headers={"Authorization": "Bearer invalid-token"} + ) + assert resp.status_code == 401 @pytest.mark.unit diff --git a/tests/unit/api/controllers/test_budget.py b/tests/unit/api/controllers/test_budget.py index 16289e27d6..28c8e47a68 100644 --- a/tests/unit/api/controllers/test_budget.py +++ b/tests/unit/api/controllers/test_budget.py @@ -8,8 +8,9 @@ from ai_company.budget.cost_record import CostRecord from ai_company.budget.tracker import CostTracker # noqa: TC001 +from tests.unit.api.conftest import make_auth_headers -_HEADERS = {"X-Human-Role": "ceo"} +_HEADERS = make_auth_headers("ceo") @pytest.mark.unit @@ -74,6 +75,6 @@ async def test_agent_spending( def test_budget_requires_read_access(self, test_client: TestClient[Any]) -> None: resp = test_client.get( "/api/v1/budget/config", - headers={"X-Human-Role": "invalid"}, + headers={"Authorization": "Bearer invalid-token"}, ) - assert resp.status_code == 403 + assert resp.status_code == 401 diff --git a/tests/unit/api/controllers/test_company.py b/tests/unit/api/controllers/test_company.py index b447874a58..02a554f3fa 100644 --- a/tests/unit/api/controllers/test_company.py +++ b/tests/unit/api/controllers/test_company.py @@ -5,7 +5,9 @@ import pytest from litestar.testing import TestClient # noqa: TC002 -_HEADERS = {"X-Human-Role": "ceo"} +from tests.unit.api.conftest import make_auth_headers + +_HEADERS = make_auth_headers("ceo") @pytest.mark.unit @@ -27,6 +29,6 @@ def test_list_departments(self, test_client: TestClient[Any]) -> None: def test_company_requires_read_access(self, test_client: TestClient[Any]) -> None: resp = test_client.get( "/api/v1/company", - headers={"X-Human-Role": "invalid"}, + headers={"Authorization": "Bearer invalid-token"}, ) - assert resp.status_code == 403 + assert resp.status_code == 401 diff --git a/tests/unit/api/controllers/test_tasks.py b/tests/unit/api/controllers/test_tasks.py index 95a6895ffd..04755227d9 100644 --- a/tests/unit/api/controllers/test_tasks.py +++ b/tests/unit/api/controllers/test_tasks.py @@ -5,7 +5,7 @@ import pytest from litestar.testing import TestClient # noqa: TC002 -from tests.unit.api.conftest import FakePersistenceBackend, make_task +from tests.unit.api.conftest import FakePersistenceBackend, make_auth_headers, make_task @pytest.mark.unit @@ -77,7 +77,7 @@ def test_create_task(self, test_client: TestClient[Any]) -> None: "project": "proj-1", "created_by": "alice", }, - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert resp.status_code == 201 body = resp.json() @@ -93,14 +93,14 @@ def test_delete_task( fake_persistence.tasks._tasks[task.id] = task resp = test_client.delete( "/api/v1/tasks/task-001", - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert resp.status_code == 200 def test_delete_task_not_found(self, test_client: TestClient[Any]) -> None: resp = test_client.delete( "/api/v1/tasks/nonexistent", - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert resp.status_code == 404 @@ -117,7 +117,7 @@ def test_update_task( resp = test_client.patch( "/api/v1/tasks/task-001", json={"title": "Updated title"}, - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert resp.status_code == 200 assert resp.json()["data"]["title"] == "Updated title" @@ -126,7 +126,7 @@ def test_update_not_found(self, test_client: TestClient[Any]) -> None: resp = test_client.patch( "/api/v1/tasks/nonexistent", json={"title": "Nope"}, - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert resp.status_code == 404 @@ -134,7 +134,7 @@ def test_update_requires_write_role(self, test_client: TestClient[Any]) -> None: resp = test_client.patch( "/api/v1/tasks/task-001", json={"title": "Nope"}, - headers={"X-Human-Role": "observer"}, + headers=make_auth_headers("observer"), ) assert resp.status_code == 403 @@ -154,7 +154,7 @@ def test_transition_task( "target_status": "assigned", "assigned_to": "bob", }, - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert resp.status_code == 201 assert resp.json()["data"]["status"] == "assigned" @@ -169,7 +169,7 @@ def test_transition_invalid( resp = test_client.post( "/api/v1/tasks/task-001/transition", json={"target_status": "completed"}, - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert resp.status_code == 422 @@ -177,7 +177,7 @@ def test_transition_not_found(self, test_client: TestClient[Any]) -> None: resp = test_client.post( "/api/v1/tasks/nonexistent/transition", json={"target_status": "assigned"}, - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert resp.status_code == 404 @@ -185,6 +185,6 @@ def test_transition_requires_write_role(self, test_client: TestClient[Any]) -> N resp = test_client.post( "/api/v1/tasks/task-001/transition", json={"target_status": "assigned"}, - headers={"X-Human-Role": "observer"}, + headers=make_auth_headers("observer"), ) assert resp.status_code == 403 diff --git a/tests/unit/api/test_app.py b/tests/unit/api/test_app.py index 7d3b7c1688..20117107cb 100644 --- a/tests/unit/api/test_app.py +++ b/tests/unit/api/test_app.py @@ -35,9 +35,14 @@ def test_openapi_schema_accessible(self, test_client: TestClient[Any]) -> None: @pytest.mark.unit class TestAppLifecycle: - async def test_startup_partial_failure_cleanup(self) -> None: + async def test_startup_partial_failure_cleanup( + self, + root_config: Any, + ) -> None: """Persistence ok, bus fails → persistence cleaned up.""" from ai_company.api.app import _safe_startup + from ai_company.api.approval_store import ApprovalStore + from ai_company.api.state import AppState from tests.unit.api.conftest import ( FakeMessageBus, FakePersistenceBackend, @@ -45,6 +50,11 @@ async def test_startup_partial_failure_cleanup(self) -> None: persistence = FakePersistenceBackend() bus = FakeMessageBus() + app_state = AppState( + config=root_config, + approval_store=ApprovalStore(), + persistence=persistence, + ) async def failing_start() -> None: msg = "bus boom" @@ -53,7 +63,7 @@ async def failing_start() -> None: bus.start = failing_start # type: ignore[method-assign] with pytest.raises(RuntimeError, match="bus boom"): - await _safe_startup(persistence, bus, None) + await _safe_startup(persistence, bus, None, app_state) # Persistence should have been disconnected during cleanup assert not persistence.is_connected diff --git a/tests/unit/api/test_exception_handlers.py b/tests/unit/api/test_exception_handlers.py index e89a630a79..9a446afcf2 100644 --- a/tests/unit/api/test_exception_handlers.py +++ b/tests/unit/api/test_exception_handlers.py @@ -6,7 +6,13 @@ from litestar import Litestar, get from litestar.testing import TestClient -from ai_company.api.errors import ConflictError, ForbiddenError, NotFoundError +from ai_company.api.errors import ( + ApiValidationError, + ConflictError, + ForbiddenError, + NotFoundError, + UnauthorizedError, +) from ai_company.api.exception_handlers import EXCEPTION_HANDLERS from ai_company.persistence.errors import ( DuplicateRecordError, @@ -73,8 +79,8 @@ async def handler() -> None: resp = client.get("/test") assert resp.status_code == 404 body = resp.json() - # handle_api_error returns the default class message. - assert body["error"] == "Resource not found" + # 4xx errors return the actual exception message + assert body["error"] == "nope" def test_api_conflict_error_maps_to_409(self) -> None: @get("/test") @@ -86,7 +92,7 @@ async def handler() -> None: resp = client.get("/test") assert resp.status_code == 409 body = resp.json() - assert body["error"] == "Resource conflict" + assert body["error"] == "conflict" def test_api_forbidden_error_maps_to_403(self) -> None: @get("/test") @@ -98,7 +104,7 @@ async def handler() -> None: resp = client.get("/test") assert resp.status_code == 403 body = resp.json() - assert body["error"] == "Forbidden" + assert body["error"] == "denied" def test_value_error_falls_through_to_catch_all(self) -> None: @get("/test") @@ -122,3 +128,28 @@ async def handler() -> None: body = resp.json() assert body["success"] is False assert body["error"] == "Internal server error" + + def test_unauthorized_error_maps_to_401(self) -> None: + @get("/test") + async def handler() -> None: + msg = "Invalid credentials" + raise UnauthorizedError(msg) + + with TestClient(_make_app(handler)) as client: + resp = client.get("/test") + assert resp.status_code == 401 + body = resp.json() + # 4xx returns the actual message, not the generic default + assert body["error"] == "Invalid credentials" + + def test_validation_error_maps_to_422(self) -> None: + @get("/test") + async def handler() -> None: + msg = "Bad field" + raise ApiValidationError(msg) + + with TestClient(_make_app(handler)) as client: + resp = client.get("/test") + assert resp.status_code == 422 + body = resp.json() + assert body["error"] == "Bad field" diff --git a/tests/unit/api/test_guards.py b/tests/unit/api/test_guards.py index aa9e5b2c9b..4a1a543ae7 100644 --- a/tests/unit/api/test_guards.py +++ b/tests/unit/api/test_guards.py @@ -1,14 +1,28 @@ -"""Tests for route guards.""" +"""Tests for route guards with JWT-based authentication.""" from typing import Any import pytest from litestar.testing import TestClient # noqa: TC002 +from tests.unit.api.conftest import make_auth_headers + +# To test "no auth" we need a fresh client without default headers. +# The test_client fixture sets CEO headers. Passing headers={} to +# a request merges with session defaults — it does NOT clear them. +# Instead we create a bare_client fixture. + + +@pytest.fixture +def bare_client(test_client: TestClient[Any]) -> TestClient[Any]: + """Test client with no default Authorization header.""" + test_client.headers.pop("authorization", None) + return test_client + @pytest.mark.unit -class TestGuards: - def test_write_guard_allows_ceo(self, test_client: TestClient[Any]) -> None: +class TestWriteGuard: + def test_allows_ceo(self, test_client: TestClient[Any]) -> None: response = test_client.post( "/api/v1/tasks", json={ @@ -18,11 +32,11 @@ def test_write_guard_allows_ceo(self, test_client: TestClient[Any]) -> None: "project": "proj", "created_by": "alice", }, - headers={"X-Human-Role": "ceo"}, + headers=make_auth_headers("ceo"), ) assert response.status_code == 201 - def test_write_guard_blocks_observer(self, test_client: TestClient[Any]) -> None: + def test_allows_manager(self, test_client: TestClient[Any]) -> None: response = test_client.post( "/api/v1/tasks", json={ @@ -32,13 +46,11 @@ def test_write_guard_blocks_observer(self, test_client: TestClient[Any]) -> None "project": "proj", "created_by": "alice", }, - headers={"X-Human-Role": "observer"}, + headers=make_auth_headers("manager"), ) - assert response.status_code == 403 + assert response.status_code == 201 - def test_write_guard_blocks_missing_role( - self, test_client: TestClient[Any] - ) -> None: + def test_allows_board_member(self, test_client: TestClient[Any]) -> None: response = test_client.post( "/api/v1/tasks", json={ @@ -48,38 +60,52 @@ def test_write_guard_blocks_missing_role( "project": "proj", "created_by": "alice", }, + headers=make_auth_headers("board_member"), ) - assert response.status_code == 403 - - def test_case_insensitive_role(self, test_client: TestClient[Any]) -> None: - response = test_client.get( - "/api/v1/tasks", - headers={"X-Human-Role": "CEO"}, - ) - assert response.status_code == 200 + assert response.status_code == 201 - def test_whitespace_padded_role(self, test_client: TestClient[Any]) -> None: - response = test_client.get( + def test_allows_pair_programmer(self, test_client: TestClient[Any]) -> None: + response = test_client.post( "/api/v1/tasks", - headers={"X-Human-Role": " ceo "}, + json={ + "title": "Test", + "description": "Test desc", + "type": "development", + "project": "proj", + "created_by": "alice", + }, + headers=make_auth_headers("pair_programmer"), ) - assert response.status_code == 200 + assert response.status_code == 201 - def test_read_guard_allows_observer(self, test_client: TestClient[Any]) -> None: - response = test_client.get( + def test_blocks_observer(self, test_client: TestClient[Any]) -> None: + response = test_client.post( "/api/v1/tasks", - headers={"X-Human-Role": "observer"}, + json={ + "title": "Test", + "description": "Test desc", + "type": "development", + "project": "proj", + "created_by": "alice", + }, + headers=make_auth_headers("observer"), ) - assert response.status_code == 200 + assert response.status_code == 403 - def test_read_guard_blocks_missing_role(self, test_client: TestClient[Any]) -> None: - response = test_client.get( + def test_missing_auth_returns_401(self, bare_client: TestClient[Any]) -> None: + response = bare_client.post( "/api/v1/tasks", - headers={"X-Human-Role": "invalid"}, + json={ + "title": "Test", + "description": "Test desc", + "type": "development", + "project": "proj", + "created_by": "alice", + }, ) - assert response.status_code == 403 + assert response.status_code == 401 - def test_write_guard_allows_manager(self, test_client: TestClient[Any]) -> None: + def test_invalid_token_returns_401(self, test_client: TestClient[Any]) -> None: response = test_client.post( "/api/v1/tasks", json={ @@ -89,6 +115,34 @@ def test_write_guard_allows_manager(self, test_client: TestClient[Any]) -> None: "project": "proj", "created_by": "alice", }, - headers={"X-Human-Role": "manager"}, + headers={"Authorization": "Bearer invalid-token"}, ) - assert response.status_code == 201 + assert response.status_code == 401 + + +@pytest.mark.unit +class TestReadGuard: + def test_allows_observer(self, test_client: TestClient[Any]) -> None: + response = test_client.get( + "/api/v1/tasks", + headers=make_auth_headers("observer"), + ) + assert response.status_code == 200 + + def test_allows_ceo(self, test_client: TestClient[Any]) -> None: + response = test_client.get( + "/api/v1/tasks", + headers=make_auth_headers("ceo"), + ) + assert response.status_code == 200 + + def test_missing_auth_returns_401(self, bare_client: TestClient[Any]) -> None: + response = bare_client.get("/api/v1/tasks") + assert response.status_code == 401 + + def test_invalid_token_returns_401(self, test_client: TestClient[Any]) -> None: + response = test_client.get( + "/api/v1/tasks", + headers={"Authorization": "Bearer invalid-token"}, + ) + assert response.status_code == 401 diff --git a/tests/unit/api/test_state.py b/tests/unit/api/test_state.py index c48494ec20..792907cead 100644 --- a/tests/unit/api/test_state.py +++ b/tests/unit/api/test_state.py @@ -56,3 +56,37 @@ def test_cost_tracker_returns_when_set(self) -> None: tracker = CostTracker() state = _make_state(cost_tracker=tracker) assert state.cost_tracker is tracker + + def test_auth_service_raises_when_none(self) -> None: + state = _make_state(auth_service=None) + with pytest.raises(ServiceUnavailableError): + _ = state.auth_service + + def test_auth_service_returns_when_set(self) -> None: + from ai_company.api.auth.config import AuthConfig + from ai_company.api.auth.service import AuthService + + secret = "test-secret-that-is-at-least-32-characters-long" + svc = AuthService(AuthConfig(jwt_secret=secret)) + state = _make_state(auth_service=svc) + assert state.auth_service is svc + + def test_set_auth_service_succeeds_once(self) -> None: + from ai_company.api.auth.config import AuthConfig + from ai_company.api.auth.service import AuthService + + secret = "test-secret-that-is-at-least-32-characters-long" + svc = AuthService(AuthConfig(jwt_secret=secret)) + state = _make_state() + state.set_auth_service(svc) + assert state.auth_service is svc + + def test_set_auth_service_twice_raises(self) -> None: + from ai_company.api.auth.config import AuthConfig + from ai_company.api.auth.service import AuthService + + secret = "test-secret-that-is-at-least-32-characters-long" + svc = AuthService(AuthConfig(jwt_secret=secret)) + state = _make_state(auth_service=svc) + with pytest.raises(RuntimeError, match="already configured"): + state.set_auth_service(svc) diff --git a/tests/unit/config/conftest.py b/tests/unit/config/conftest.py index f05e50b05d..7300e5af2d 100644 --- a/tests/unit/config/conftest.py +++ b/tests/unit/config/conftest.py @@ -5,6 +5,7 @@ import pytest from polyfactory.factories.pydantic_factory import ModelFactory +from ai_company.api.config import ApiConfig from ai_company.budget.config import BudgetConfig from ai_company.budget.coordination_config import CoordinationMetricsConfig from ai_company.budget.cost_tiers import CostTiersConfig @@ -78,6 +79,7 @@ class RootConfigFactory(ModelFactory[RootConfig]): custom_roles = () providers: dict[str, ProviderConfig] = {} # noqa: RUF012 config = CompanyConfig() + api = ApiConfig() budget = BudgetConfig() communication = CommunicationConfig() routing = RoutingConfig() diff --git a/tests/unit/persistence/sqlite/test_migrations.py b/tests/unit/persistence/sqlite/test_migrations.py index f2d979ed87..154a0ba8a4 100644 --- a/tests/unit/persistence/sqlite/test_migrations.py +++ b/tests/unit/persistence/sqlite/test_migrations.py @@ -144,6 +144,35 @@ async def test_v4_creates_audit_entry_indexes( } assert expected.issubset(indexes) + @pytest.mark.parametrize( + "table_name", + ["users", "api_keys", "settings"], + ) + async def test_v5_creates_table( + self, + memory_db: aiosqlite.Connection, + table_name: str, + ) -> None: + """V5 migration creates the expected tables.""" + await run_migrations(memory_db) + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (table_name,), + ) + row = await cursor.fetchone() + assert row is not None + + async def test_v5_creates_user_indexes( + self, memory_db: aiosqlite.Connection + ) -> None: + await run_migrations(memory_db) + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='index' " + "AND name LIKE 'idx_%' AND name LIKE '%user%' ORDER BY name" + ) + indexes = {row[0] for row in await cursor.fetchall()} + assert "idx_api_keys_user_id" in indexes + async def test_migration_failure_raises_migration_error( self, memory_db: aiosqlite.Connection ) -> None: diff --git a/tests/unit/persistence/sqlite/test_user_repo.py b/tests/unit/persistence/sqlite/test_user_repo.py new file mode 100644 index 0000000000..dcbbd3b9dc --- /dev/null +++ b/tests/unit/persistence/sqlite/test_user_repo.py @@ -0,0 +1,199 @@ +"""Tests for SQLiteUserRepository and SQLiteApiKeyRepository.""" + +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +import aiosqlite +import pytest + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + +from ai_company.api.auth.models import ApiKey, User +from ai_company.api.guards import HumanRole +from ai_company.persistence.sqlite.migrations import run_migrations +from ai_company.persistence.sqlite.user_repo import ( + SQLiteApiKeyRepository, + SQLiteUserRepository, +) + + +@pytest.fixture +async def db() -> AsyncGenerator[aiosqlite.Connection]: + """Create an in-memory SQLite DB with schema applied.""" + conn = await aiosqlite.connect(":memory:") + conn.row_factory = aiosqlite.Row + await run_migrations(conn) + yield conn + await conn.close() + + +@pytest.fixture +def user_repo(db: aiosqlite.Connection) -> SQLiteUserRepository: + return SQLiteUserRepository(db) + + +@pytest.fixture +def api_key_repo(db: aiosqlite.Connection) -> SQLiteApiKeyRepository: + return SQLiteApiKeyRepository(db) + + +def _make_user( + *, + user_id: str = "user-001", + username: str = "admin", + role: HumanRole = HumanRole.CEO, +) -> User: + now = datetime.now(UTC) + return User( + id=user_id, + username=username, + password_hash="$argon2id$fake-hash", + role=role, + must_change_password=False, + created_at=now, + updated_at=now, + ) + + +@pytest.mark.unit +class TestSQLiteUserRepository: + async def test_save_and_get(self, user_repo: SQLiteUserRepository) -> None: + user = _make_user() + await user_repo.save(user) + fetched = await user_repo.get("user-001") + assert fetched is not None + assert fetched.id == "user-001" + assert fetched.username == "admin" + + async def test_get_nonexistent(self, user_repo: SQLiteUserRepository) -> None: + result = await user_repo.get("nonexistent") + assert result is None + + async def test_get_by_username(self, user_repo: SQLiteUserRepository) -> None: + user = _make_user() + await user_repo.save(user) + fetched = await user_repo.get_by_username("admin") + assert fetched is not None + assert fetched.id == "user-001" + + async def test_get_by_username_not_found( + self, user_repo: SQLiteUserRepository + ) -> None: + result = await user_repo.get_by_username("nope") + assert result is None + + async def test_list_users(self, user_repo: SQLiteUserRepository) -> None: + await user_repo.save(_make_user(user_id="u1", username="alice")) + await user_repo.save(_make_user(user_id="u2", username="bob")) + users = await user_repo.list_users() + assert len(users) == 2 + + async def test_count(self, user_repo: SQLiteUserRepository) -> None: + assert await user_repo.count() == 0 + await user_repo.save(_make_user()) + assert await user_repo.count() == 1 + + async def test_delete(self, user_repo: SQLiteUserRepository) -> None: + await user_repo.save(_make_user()) + deleted = await user_repo.delete("user-001") + assert deleted is True + assert await user_repo.get("user-001") is None + + async def test_delete_nonexistent(self, user_repo: SQLiteUserRepository) -> None: + deleted = await user_repo.delete("nope") + assert deleted is False + + async def test_upsert(self, user_repo: SQLiteUserRepository) -> None: + user = _make_user() + await user_repo.save(user) + updated = user.model_copy( + update={"username": "new-admin", "updated_at": datetime.now(UTC)} + ) + await user_repo.save(updated) + fetched = await user_repo.get("user-001") + assert fetched is not None + assert fetched.username == "new-admin" + assert await user_repo.count() == 1 + + +@pytest.mark.unit +class TestSQLiteApiKeyRepository: + async def test_save_and_get( + self, + api_key_repo: SQLiteApiKeyRepository, + user_repo: SQLiteUserRepository, + ) -> None: + await user_repo.save(_make_user()) + now = datetime.now(UTC) + key = ApiKey( + id="key-001", + key_hash="abc123hash", + name="test-key", + role=HumanRole.CEO, + user_id="user-001", + created_at=now, + ) + await api_key_repo.save(key) + fetched = await api_key_repo.get("key-001") + assert fetched is not None + assert fetched.name == "test-key" + + async def test_get_by_hash( + self, + api_key_repo: SQLiteApiKeyRepository, + user_repo: SQLiteUserRepository, + ) -> None: + await user_repo.save(_make_user()) + now = datetime.now(UTC) + key = ApiKey( + id="key-002", + key_hash="unique-hash", + name="hash-key", + role=HumanRole.CEO, + user_id="user-001", + created_at=now, + ) + await api_key_repo.save(key) + fetched = await api_key_repo.get_by_hash("unique-hash") + assert fetched is not None + assert fetched.id == "key-002" + + async def test_list_by_user( + self, + api_key_repo: SQLiteApiKeyRepository, + user_repo: SQLiteUserRepository, + ) -> None: + await user_repo.save(_make_user()) + now = datetime.now(UTC) + for i in range(3): + key = ApiKey( + id=f"key-{i}", + key_hash=f"hash-{i}", + name=f"key-{i}", + role=HumanRole.CEO, + user_id="user-001", + created_at=now, + ) + await api_key_repo.save(key) + keys = await api_key_repo.list_by_user("user-001") + assert len(keys) == 3 + + async def test_delete( + self, + api_key_repo: SQLiteApiKeyRepository, + user_repo: SQLiteUserRepository, + ) -> None: + await user_repo.save(_make_user()) + now = datetime.now(UTC) + key = ApiKey( + id="key-del", + key_hash="del-hash", + name="del-key", + role=HumanRole.CEO, + user_id="user-001", + created_at=now, + ) + await api_key_repo.save(key) + assert await api_key_repo.delete("key-del") is True + assert await api_key_repo.get("key-del") is None diff --git a/tests/unit/persistence/test_migrations_v2.py b/tests/unit/persistence/test_migrations_v2.py index 776a3dbe47..9e80beaf52 100644 --- a/tests/unit/persistence/test_migrations_v2.py +++ b/tests/unit/persistence/test_migrations_v2.py @@ -28,8 +28,8 @@ async def memory_db() -> AsyncGenerator[aiosqlite.Connection]: @pytest.mark.unit class TestSchemaMigrations: - async def test_schema_version_is_four(self) -> None: - assert SCHEMA_VERSION == 4 + async def test_schema_version_is_five(self) -> None: + assert SCHEMA_VERSION == 5 async def test_fresh_db_creates_all_v2_tables( self, memory_db: aiosqlite.Connection @@ -58,7 +58,7 @@ async def test_v1_to_v2_migration(self, memory_db: aiosqlite.Connection) -> None assert await get_user_version(memory_db) == 1 await run_migrations(memory_db) - assert await get_user_version(memory_db) == 4 + assert await get_user_version(memory_db) == 5 cursor = await memory_db.execute( "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" diff --git a/tests/unit/persistence/test_protocol.py b/tests/unit/persistence/test_protocol.py index 7b559bd05e..7b6db34524 100644 --- a/tests/unit/persistence/test_protocol.py +++ b/tests/unit/persistence/test_protocol.py @@ -12,16 +12,19 @@ ) from ai_company.persistence.protocol import PersistenceBackend from ai_company.persistence.repositories import ( + ApiKeyRepository, AuditRepository, CostRecordRepository, MessageRepository, ParkedContextRepository, TaskRepository, + UserRepository, ) if TYPE_CHECKING: from pydantic import AwareDatetime + from ai_company.api.auth.models import ApiKey, User from ai_company.budget.cost_record import CostRecord from ai_company.communication.message import Message from ai_company.core.enums import ApprovalRiskLevel, TaskStatus @@ -163,6 +166,43 @@ async def query( # noqa: PLR0913 return () +class _FakeUserRepository: + async def save(self, user: User) -> None: + pass + + async def get(self, user_id: str) -> User | None: + return None + + async def get_by_username(self, username: str) -> User | None: + return None + + async def list_users(self) -> tuple[User, ...]: + return () + + async def count(self) -> int: + return 0 + + async def delete(self, user_id: str) -> bool: + return False + + +class _FakeApiKeyRepository: + async def save(self, key: ApiKey) -> None: + pass + + async def get(self, key_id: str) -> ApiKey | None: + return None + + async def get_by_hash(self, key_hash: str) -> ApiKey | None: + return None + + async def list_by_user(self, user_id: str) -> tuple[ApiKey, ...]: + return () + + async def delete(self, key_id: str) -> bool: + return False + + class _FakeBackend: async def connect(self) -> None: pass @@ -216,6 +256,20 @@ def collaboration_metrics(self) -> _FakeCollaborationMetricRepository: def audit_entries(self) -> _FakeAuditRepository: return _FakeAuditRepository() + @property + def users(self) -> _FakeUserRepository: + return _FakeUserRepository() + + @property + def api_keys(self) -> _FakeApiKeyRepository: + return _FakeApiKeyRepository() + + async def get_setting(self, key: str) -> str | None: + return None + + async def set_setting(self, key: str, value: str) -> None: + pass + @pytest.mark.unit class TestProtocolCompliance: @@ -255,3 +309,9 @@ def test_fake_parked_context_repo_is_parked_context_repository( def test_fake_audit_repo_is_audit_repository(self) -> None: assert isinstance(_FakeAuditRepository(), AuditRepository) + + def test_fake_user_repo_is_user_repository(self) -> None: + assert isinstance(_FakeUserRepository(), UserRepository) + + def test_fake_api_key_repo_is_api_key_repository(self) -> None: + assert isinstance(_FakeApiKeyRepository(), ApiKeyRepository) diff --git a/uv.lock b/uv.lock index d12431845f..0ef0fd877a 100644 --- a/uv.lock +++ b/uv.lock @@ -8,12 +8,14 @@ source = { editable = "." } dependencies = [ { name = "aiodocker" }, { name = "aiosqlite" }, + { name = "argon2-cffi" }, { name = "jinja2" }, { name = "jsonschema" }, { name = "litellm" }, { name = "litestar", extra = ["brotli", "prometheus", "pydantic", "standard", "structlog"] }, { name = "mcp" }, { name = "pydantic" }, + { name = "pyjwt", extra = ["crypto"] }, { name = "pyyaml" }, { name = "structlog" }, ] @@ -50,12 +52,14 @@ test = [ requires-dist = [ { name = "aiodocker", specifier = "==0.26.0" }, { name = "aiosqlite", specifier = "==0.22.1" }, + { name = "argon2-cffi", specifier = "==25.1.0" }, { name = "jinja2", specifier = "==3.1.6" }, { name = "jsonschema", specifier = "==4.26.0" }, { name = "litellm", specifier = "==1.82.1" }, { name = "litestar", extras = ["brotli", "prometheus", "pydantic", "standard", "structlog"], specifier = "==2.21.1" }, { name = "mcp", specifier = "==1.26.0" }, { name = "pydantic", specifier = "==2.12.5" }, + { name = "pyjwt", extras = ["crypto"], specifier = "==2.11.0" }, { name = "pyyaml", specifier = "==6.0.3" }, { name = "structlog", specifier = "==25.5.0" }, ] @@ -220,6 +224,49 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/74/f5/9373290775639cb67a2fce7f629a1c240dce9f12fe927bc32b2736e16dfc/argcomplete-3.6.3-py3-none-any.whl", hash = "sha256:f5007b3a600ccac5d25bbce33089211dfd49eab4a7718da3f10e3082525a92ce", size = 43846, upload-time = "2025-10-20T03:33:33.021Z" }, ] +[[package]] +name = "argon2-cffi" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "argon2-cffi-bindings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/89/ce5af8a7d472a67cc819d5d998aa8c82c5d860608c4db9f46f1162d7dab9/argon2_cffi-25.1.0.tar.gz", hash = "sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1", size = 45706, upload-time = "2025-06-03T06:55:32.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/d3/a8b22fa575b297cd6e3e3b0155c7e25db170edf1c74783d6a31a2490b8d9/argon2_cffi-25.1.0-py3-none-any.whl", hash = "sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741", size = 14657, upload-time = "2025-06-03T06:55:30.804Z" }, +] + +[[package]] +name = "argon2-cffi-bindings" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5c/2d/db8af0df73c1cf454f71b2bbe5e356b8c1f8041c979f505b3d3186e520a9/argon2_cffi_bindings-25.1.0.tar.gz", hash = "sha256:b957f3e6ea4d55d820e40ff76f450952807013d361a65d7f28acc0acbf29229d", size = 1783441, upload-time = "2025-07-30T10:02:05.147Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/97/3c0a35f46e52108d4707c44b95cfe2afcafc50800b5450c197454569b776/argon2_cffi_bindings-25.1.0-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:3d3f05610594151994ca9ccb3c771115bdb4daef161976a266f0dd8aa9996b8f", size = 54393, upload-time = "2025-07-30T10:01:40.97Z" }, + { url = "https://files.pythonhosted.org/packages/9d/f4/98bbd6ee89febd4f212696f13c03ca302b8552e7dbf9c8efa11ea4a388c3/argon2_cffi_bindings-25.1.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:8b8efee945193e667a396cbc7b4fb7d357297d6234d30a489905d96caabde56b", size = 29328, upload-time = "2025-07-30T10:01:41.916Z" }, + { url = "https://files.pythonhosted.org/packages/43/24/90a01c0ef12ac91a6be05969f29944643bc1e5e461155ae6559befa8f00b/argon2_cffi_bindings-25.1.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:3c6702abc36bf3ccba3f802b799505def420a1b7039862014a65db3205967f5a", size = 31269, upload-time = "2025-07-30T10:01:42.716Z" }, + { url = "https://files.pythonhosted.org/packages/d4/d3/942aa10782b2697eee7af5e12eeff5ebb325ccfb86dd8abda54174e377e4/argon2_cffi_bindings-25.1.0-cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a1c70058c6ab1e352304ac7e3b52554daadacd8d453c1752e547c76e9c99ac44", size = 86558, upload-time = "2025-07-30T10:01:43.943Z" }, + { url = "https://files.pythonhosted.org/packages/0d/82/b484f702fec5536e71836fc2dbc8c5267b3f6e78d2d539b4eaa6f0db8bf8/argon2_cffi_bindings-25.1.0-cp314-cp314t-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e2fd3bfbff3c5d74fef31a722f729bf93500910db650c925c2d6ef879a7e51cb", size = 92364, upload-time = "2025-07-30T10:01:44.887Z" }, + { url = "https://files.pythonhosted.org/packages/c9/c1/a606ff83b3f1735f3759ad0f2cd9e038a0ad11a3de3b6c673aa41c24bb7b/argon2_cffi_bindings-25.1.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c4f9665de60b1b0e99bcd6be4f17d90339698ce954cfd8d9cf4f91c995165a92", size = 85637, upload-time = "2025-07-30T10:01:46.225Z" }, + { url = "https://files.pythonhosted.org/packages/44/b4/678503f12aceb0262f84fa201f6027ed77d71c5019ae03b399b97caa2f19/argon2_cffi_bindings-25.1.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ba92837e4a9aa6a508c8d2d7883ed5a8f6c308c89a4790e1e447a220deb79a85", size = 91934, upload-time = "2025-07-30T10:01:47.203Z" }, + { url = "https://files.pythonhosted.org/packages/f0/c7/f36bd08ef9bd9f0a9cff9428406651f5937ce27b6c5b07b92d41f91ae541/argon2_cffi_bindings-25.1.0-cp314-cp314t-win32.whl", hash = "sha256:84a461d4d84ae1295871329b346a97f68eade8c53b6ed9a7ca2d7467f3c8ff6f", size = 28158, upload-time = "2025-07-30T10:01:48.341Z" }, + { url = "https://files.pythonhosted.org/packages/b3/80/0106a7448abb24a2c467bf7d527fe5413b7fdfa4ad6d6a96a43a62ef3988/argon2_cffi_bindings-25.1.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b55aec3565b65f56455eebc9b9f34130440404f27fe21c3b375bf1ea4d8fbae6", size = 32597, upload-time = "2025-07-30T10:01:49.112Z" }, + { url = "https://files.pythonhosted.org/packages/05/b8/d663c9caea07e9180b2cb662772865230715cbd573ba3b5e81793d580316/argon2_cffi_bindings-25.1.0-cp314-cp314t-win_arm64.whl", hash = "sha256:87c33a52407e4c41f3b70a9c2d3f6056d88b10dad7695be708c5021673f55623", size = 28231, upload-time = "2025-07-30T10:01:49.92Z" }, + { url = "https://files.pythonhosted.org/packages/1d/57/96b8b9f93166147826da5f90376e784a10582dd39a393c99bb62cfcf52f0/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:aecba1723ae35330a008418a91ea6cfcedf6d31e5fbaa056a166462ff066d500", size = 54121, upload-time = "2025-07-30T10:01:50.815Z" }, + { url = "https://files.pythonhosted.org/packages/0a/08/a9bebdb2e0e602dde230bdde8021b29f71f7841bd54801bcfd514acb5dcf/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:2630b6240b495dfab90aebe159ff784d08ea999aa4b0d17efa734055a07d2f44", size = 29177, upload-time = "2025-07-30T10:01:51.681Z" }, + { url = "https://files.pythonhosted.org/packages/b6/02/d297943bcacf05e4f2a94ab6f462831dc20158614e5d067c35d4e63b9acb/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:7aef0c91e2c0fbca6fc68e7555aa60ef7008a739cbe045541e438373bc54d2b0", size = 31090, upload-time = "2025-07-30T10:01:53.184Z" }, + { url = "https://files.pythonhosted.org/packages/c1/93/44365f3d75053e53893ec6d733e4a5e3147502663554b4d864587c7828a7/argon2_cffi_bindings-25.1.0-cp39-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6", size = 81246, upload-time = "2025-07-30T10:01:54.145Z" }, + { url = "https://files.pythonhosted.org/packages/09/52/94108adfdd6e2ddf58be64f959a0b9c7d4ef2fa71086c38356d22dc501ea/argon2_cffi_bindings-25.1.0-cp39-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d3e924cfc503018a714f94a49a149fdc0b644eaead5d1f089330399134fa028a", size = 87126, upload-time = "2025-07-30T10:01:55.074Z" }, + { url = "https://files.pythonhosted.org/packages/72/70/7a2993a12b0ffa2a9271259b79cc616e2389ed1a4d93842fac5a1f923ffd/argon2_cffi_bindings-25.1.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:c87b72589133f0346a1cb8d5ecca4b933e3c9b64656c9d175270a000e73b288d", size = 80343, upload-time = "2025-07-30T10:01:56.007Z" }, + { url = "https://files.pythonhosted.org/packages/78/9a/4e5157d893ffc712b74dbd868c7f62365618266982b64accab26bab01edc/argon2_cffi_bindings-25.1.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99", size = 86777, upload-time = "2025-07-30T10:01:56.943Z" }, + { url = "https://files.pythonhosted.org/packages/74/cd/15777dfde1c29d96de7f18edf4cc94c385646852e7c7b0320aa91ccca583/argon2_cffi_bindings-25.1.0-cp39-abi3-win32.whl", hash = "sha256:473bcb5f82924b1becbb637b63303ec8d10e84c8d241119419897a26116515d2", size = 27180, upload-time = "2025-07-30T10:01:57.759Z" }, + { url = "https://files.pythonhosted.org/packages/e2/c6/a759ece8f1829d1f162261226fbfd2c6832b3ff7657384045286d2afa384/argon2_cffi_bindings-25.1.0-cp39-abi3-win_amd64.whl", hash = "sha256:a98cd7d17e9f7ce244c0803cad3c23a7d379c301ba618a5fa76a67d116618b98", size = 31715, upload-time = "2025-07-30T10:01:58.56Z" }, + { url = "https://files.pythonhosted.org/packages/42/b9/f8d6fa329ab25128b7e98fd83a3cb34d9db5b059a9847eddb840a0af45dd/argon2_cffi_bindings-25.1.0-cp39-abi3-win_arm64.whl", hash = "sha256:b0fdbcf513833809c882823f98dc2f931cf659d9a1429616ac3adebb49f5db94", size = 27149, upload-time = "2025-07-30T10:01:59.329Z" }, +] + [[package]] name = "attrs" version = "25.4.0"