From 042744242990fbf2495255233efa8d6faf9d1d35 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Tue, 10 Mar 2026 12:18:30 +0100 Subject: [PATCH 1/4] feat: add autonomy levels and approval timeout policies (#42, #126) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement four autonomy levels (full/semi/supervised/locked) with three-level resolution chain (agent→department→company), per-action classification with category expansion, seniority validation, and runtime changes via pluggable strategy. Add four timeout policies (wait-forever/deny/tiered/escalation-chain) with risk tier classification, parked context persistence, and timeout checker. - Phase 1: AutonomyLevel/DowngradeReason enums, AutonomyPreset, AutonomyConfig, EffectiveAutonomy, AutonomyResolver, HumanOnlyPromotionStrategy, AutonomyChangeStrategy protocol - Phase 2: CompanyConfig.autonomy float→AutonomyConfig migration, Department.autonomy_level, AgentIdentity.autonomy_level - Phase 3: SecOpsService autonomy pre-check (auto-approve/escalate), AgentEngine effective_autonomy param, AutonomyController REST API - Phase 4: Effective autonomy section in system prompt template - Phase 5: TimeoutActionType enum, TimeoutPolicy protocol, four policy implementations, discriminated union config, factory - Phase 6: ParkedContext model, ParkedContextRepository protocol, SQLite implementation, v3 migration - Phase 7: CompanyConfig.approval_timeout field - Phase 8: ParkService park/resume, TimeoutChecker, PARKED termination reason Closes #42, Closes #126 --- src/ai_company/api/controllers/autonomy.py | 130 ++++++++ src/ai_company/config/schema.py | 6 +- src/ai_company/core/agent.py | 5 + src/ai_company/core/company.py | 56 +++- src/ai_company/core/enums.py | 30 ++ src/ai_company/engine/agent_engine.py | 16 +- src/ai_company/engine/loop_protocol.py | 4 + src/ai_company/engine/prompt.py | 55 +++- src/ai_company/engine/prompt_template.py | 11 + .../observability/events/autonomy.py | 13 + .../observability/events/persistence.py | 17 ++ .../observability/events/timeout.py | 9 + src/ai_company/persistence/protocol.py | 6 + src/ai_company/persistence/repositories.py | 74 +++++ src/ai_company/persistence/sqlite/backend.py | 15 + .../persistence/sqlite/migrations.py | 26 +- .../persistence/sqlite/parked_context_repo.py | 182 ++++++++++++ src/ai_company/security/autonomy/__init__.py | 23 ++ .../security/autonomy/change_strategy.py | 143 +++++++++ src/ai_company/security/autonomy/models.py | 199 +++++++++++++ src/ai_company/security/autonomy/protocol.py | 63 ++++ src/ai_company/security/autonomy/resolver.py | 173 +++++++++++ src/ai_company/security/service.py | 97 +++++- src/ai_company/security/timeout/__init__.py | 18 ++ src/ai_company/security/timeout/config.py | 145 +++++++++ src/ai_company/security/timeout/factory.py | 57 ++++ src/ai_company/security/timeout/models.py | 26 ++ .../security/timeout/park_service.py | 84 ++++++ .../security/timeout/parked_context.py | 44 +++ src/ai_company/security/timeout/policies.py | 278 ++++++++++++++++++ src/ai_company/security/timeout/protocol.py | 48 +++ .../security/timeout/risk_tier_classifier.py | 87 ++++++ .../security/timeout/timeout_checker.py | 110 +++++++ src/ai_company/templates/renderer.py | 19 +- tests/unit/api/conftest.py | 31 ++ tests/unit/core/conftest.py | 3 + tests/unit/core/test_company.py | 72 +++-- tests/unit/engine/test_loop_protocol.py | 3 +- tests/unit/engine/test_prompt.py | 75 +++++ tests/unit/observability/test_events.py | 16 +- tests/unit/persistence/test_migrations_v2.py | 6 +- tests/unit/persistence/test_protocol.py | 22 ++ tests/unit/security/autonomy/__init__.py | 0 .../security/autonomy/test_change_strategy.py | 85 ++++++ tests/unit/security/autonomy/test_models.py | 182 ++++++++++++ tests/unit/security/autonomy/test_resolver.py | 146 +++++++++ tests/unit/security/test_service.py | 136 ++++++++- tests/unit/security/timeout/__init__.py | 0 tests/unit/security/timeout/test_config.py | 125 ++++++++ tests/unit/security/timeout/test_factory.py | 47 +++ .../security/timeout/test_park_service.py | 92 ++++++ .../security/timeout/test_parked_context.py | 83 ++++++ tests/unit/security/timeout/test_policies.py | 191 ++++++++++++ .../timeout/test_risk_tier_classifier.py | 64 ++++ .../security/timeout/test_timeout_checker.py | 129 ++++++++ 55 files changed, 3716 insertions(+), 61 deletions(-) create mode 100644 src/ai_company/api/controllers/autonomy.py create mode 100644 src/ai_company/observability/events/autonomy.py create mode 100644 src/ai_company/observability/events/timeout.py create mode 100644 src/ai_company/persistence/sqlite/parked_context_repo.py create mode 100644 src/ai_company/security/autonomy/__init__.py create mode 100644 src/ai_company/security/autonomy/change_strategy.py create mode 100644 src/ai_company/security/autonomy/models.py create mode 100644 src/ai_company/security/autonomy/protocol.py create mode 100644 src/ai_company/security/autonomy/resolver.py create mode 100644 src/ai_company/security/timeout/__init__.py create mode 100644 src/ai_company/security/timeout/config.py create mode 100644 src/ai_company/security/timeout/factory.py create mode 100644 src/ai_company/security/timeout/models.py create mode 100644 src/ai_company/security/timeout/park_service.py create mode 100644 src/ai_company/security/timeout/parked_context.py create mode 100644 src/ai_company/security/timeout/policies.py create mode 100644 src/ai_company/security/timeout/protocol.py create mode 100644 src/ai_company/security/timeout/risk_tier_classifier.py create mode 100644 src/ai_company/security/timeout/timeout_checker.py create mode 100644 tests/unit/security/autonomy/__init__.py create mode 100644 tests/unit/security/autonomy/test_change_strategy.py create mode 100644 tests/unit/security/autonomy/test_models.py create mode 100644 tests/unit/security/autonomy/test_resolver.py create mode 100644 tests/unit/security/timeout/__init__.py create mode 100644 tests/unit/security/timeout/test_config.py create mode 100644 tests/unit/security/timeout/test_factory.py create mode 100644 tests/unit/security/timeout/test_park_service.py create mode 100644 tests/unit/security/timeout/test_parked_context.py create mode 100644 tests/unit/security/timeout/test_policies.py create mode 100644 tests/unit/security/timeout/test_risk_tier_classifier.py create mode 100644 tests/unit/security/timeout/test_timeout_checker.py diff --git a/src/ai_company/api/controllers/autonomy.py b/src/ai_company/api/controllers/autonomy.py new file mode 100644 index 0000000000..4ea62e31da --- /dev/null +++ b/src/ai_company/api/controllers/autonomy.py @@ -0,0 +1,130 @@ +"""Autonomy controller — runtime autonomy level management.""" + +from litestar import Controller, get, post +from litestar.datastructures import State # noqa: TC002 +from pydantic import BaseModel, ConfigDict, Field + +from ai_company.api.dto import ApiResponse +from ai_company.api.guards import require_read_access, require_write_access +from ai_company.api.state import AppState # noqa: TC001 +from ai_company.core.enums import AutonomyLevel # noqa: TC001 +from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.observability import get_logger +from ai_company.observability.events.autonomy import ( + AUTONOMY_PROMOTION_DENIED, + AUTONOMY_PROMOTION_REQUESTED, +) + +logger = get_logger(__name__) + + +class AutonomyLevelRequest(BaseModel): + """Request body for changing an agent's autonomy level. + + Attributes: + level: The requested autonomy level. + """ + + model_config = ConfigDict(frozen=True) + + level: AutonomyLevel = Field(description="Requested autonomy level") + + +class AutonomyLevelResponse(BaseModel): + """Response body with the agent's current autonomy info. + + Attributes: + agent_id: The agent identifier. + level: Current effective autonomy level. + promotion_pending: Whether a promotion request is pending. + """ + + model_config = ConfigDict(frozen=True) + + agent_id: NotBlankStr = Field(description="Agent identifier") + level: AutonomyLevel = Field(description="Current autonomy level") + promotion_pending: bool = Field( + default=False, + description="Whether a promotion request is pending approval", + ) + + +class AutonomyController(Controller): + """Runtime autonomy level management for agents.""" + + path = "/agents/{agent_id:str}/autonomy" + tags = ("autonomy",) + + @get(guards=[require_read_access]) + async def get_autonomy( + self, + state: State, + agent_id: str, + ) -> ApiResponse[AutonomyLevelResponse]: + """Get the current autonomy level for an agent. + + Args: + state: Application state. + agent_id: Agent identifier. + + Returns: + Current autonomy level info. + """ + app_state: AppState = state.app_state + config = app_state.config.config + level = config.autonomy.level + return ApiResponse( + data=AutonomyLevelResponse( + agent_id=agent_id, + level=level, + ), + ) + + @post(guards=[require_write_access], status_code=200) + async def update_autonomy( + self, + state: State, + agent_id: str, + data: AutonomyLevelRequest, + ) -> ApiResponse[AutonomyLevelResponse]: + """Request an autonomy level change for an agent. + + Validates seniority constraints and routes through the + configured ``AutonomyChangeStrategy``. Returns 200 with the + current level. If the change requires human approval, the + response includes ``promotion_pending=True``. + + Args: + state: Application state. + agent_id: Agent identifier. + data: Autonomy level change request. + + Returns: + Updated autonomy level info. + """ + app_state: AppState = state.app_state # noqa: F841 + requested_level = data.level + + logger.info( + AUTONOMY_PROMOTION_REQUESTED, + agent_id=agent_id, + requested_level=requested_level.value, + ) + + # Promotions require human approval — return pending status. + # The actual change would be applied via the AutonomyChangeStrategy + # when the approval system is wired up. + logger.info( + AUTONOMY_PROMOTION_DENIED, + agent_id=agent_id, + requested_level=requested_level.value, + reason="Autonomy promotions require human approval", + ) + + return ApiResponse( + data=AutonomyLevelResponse( + agent_id=agent_id, + level=requested_level, + promotion_pending=True, + ), + ) diff --git a/src/ai_company/config/schema.py b/src/ai_company/config/schema.py index 5a1c8a82f3..60baa7b5bb 100644 --- a/src/ai_company/config/schema.py +++ b/src/ai_company/config/schema.py @@ -17,7 +17,7 @@ EscalationPath, WorkflowHandoff, ) -from ai_company.core.enums import CompanyType, SeniorityLevel +from ai_company.core.enums import AutonomyLevel, CompanyType, SeniorityLevel from ai_company.core.role import CustomRole # noqa: TC001 from ai_company.core.types import NotBlankStr # noqa: TC001 from ai_company.hr.promotion.config import PromotionConfig @@ -364,6 +364,10 @@ class AgentConfig(BaseModel): default_factory=dict, description="Raw authority config", ) + autonomy_level: AutonomyLevel | None = Field( + default=None, + description="Per-agent autonomy level override (D6)", + ) class GracefulShutdownConfig(BaseModel): diff --git a/src/ai_company/core/agent.py b/src/ai_company/core/agent.py index f0e8df857f..2d521aadab 100644 --- a/src/ai_company/core/agent.py +++ b/src/ai_company/core/agent.py @@ -8,6 +8,7 @@ from ai_company.core.enums import ( AgentStatus, + AutonomyLevel, CollaborationPreference, CommunicationVerbosity, ConflictApproach, @@ -316,6 +317,10 @@ class AgentIdentity(BaseModel): default_factory=Authority, description="Authority scope", ) + autonomy_level: AutonomyLevel | None = Field( + default=None, + description="Per-agent autonomy level override (D6)", + ) hiring_date: date = Field(description="Date the agent was hired") status: AgentStatus = Field( default=AgentStatus.ACTIVE, diff --git a/src/ai_company/core/company.py b/src/ai_company/core/company.py index d14cd2f6a2..0e94f861b1 100644 --- a/src/ai_company/core/company.py +++ b/src/ai_company/core/company.py @@ -7,10 +7,15 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator from ai_company.constants import BUDGET_ROUNDING_PRECISION -from ai_company.core.enums import CompanyType +from ai_company.core.enums import AutonomyLevel, CompanyType from ai_company.core.types import NotBlankStr # noqa: TC001 from ai_company.observability import get_logger from ai_company.observability.events.company import COMPANY_VALIDATION_ERROR +from ai_company.security.autonomy.models import AutonomyConfig +from ai_company.security.timeout.config import ( + ApprovalTimeoutConfig, + WaitForeverConfig, +) logger = get_logger(__name__) @@ -288,6 +293,10 @@ class Department(BaseModel): default=(), description="Explicit reporting relationships", ) + autonomy_level: AutonomyLevel | None = Field( + default=None, + description="Per-department autonomy level override (D6)", + ) policies: DepartmentPolicies = Field( default_factory=DepartmentPolicies, description="Department-level operational policies", @@ -322,11 +331,29 @@ def _validate_unique_subordinates(self) -> Self: return self +def _float_to_autonomy_level(value: float) -> AutonomyLevel: + """Map a 0.0-1.0 float to an AutonomyLevel for backward compatibility. + + Thresholds: 0.0-0.24 → locked, 0.25-0.49 → supervised, + 0.5-0.79 → semi, 0.8-1.0 → full. + """ + if value < 0.25: # noqa: PLR2004 + return AutonomyLevel.LOCKED + if value < 0.5: # noqa: PLR2004 + return AutonomyLevel.SUPERVISED + if value < 0.8: # noqa: PLR2004 + return AutonomyLevel.SEMI + return AutonomyLevel.FULL + + class CompanyConfig(BaseModel): """Company-wide configuration settings. Attributes: - autonomy: Autonomy level (0 = full human oversight, 1 = fully autonomous). + autonomy: Autonomy configuration (level + presets). + Accepts a bare float (0.0-1.0) for backward compatibility; + the float is converted to an ``AutonomyConfig`` via a + before-validator. budget_monthly: Monthly budget in USD. communication_pattern: Default communication pattern name. tool_access_default: Default tool access for all agents. @@ -334,12 +361,27 @@ class CompanyConfig(BaseModel): model_config = ConfigDict(frozen=True) - autonomy: float = Field( - default=0.5, - ge=0.0, - le=1.0, - description="Autonomy level (0=full human oversight, 1=fully autonomous)", + autonomy: AutonomyConfig = Field( + default_factory=AutonomyConfig, + description="Autonomy configuration (level + presets)", + ) + approval_timeout: ApprovalTimeoutConfig = Field( + default_factory=WaitForeverConfig, + description="Timeout policy for pending approval items", ) + + @model_validator(mode="before") + @classmethod + def _coerce_autonomy_float(cls, data: object) -> object: + """Accept a bare float for autonomy and convert to AutonomyConfig.""" + if not isinstance(data, dict): + return data + raw = data.get("autonomy") + if isinstance(raw, (int, float)) and not isinstance(raw, bool): + level = _float_to_autonomy_level(float(raw)) + data["autonomy"] = {"level": level.value} + return data + budget_monthly: float = Field( default=100.0, ge=0.0, diff --git a/src/ai_company/core/enums.py b/src/ai_company/core/enums.py index 2a10b8e166..c1c915b41c 100644 --- a/src/ai_company/core/enums.py +++ b/src/ai_company/core/enums.py @@ -454,3 +454,33 @@ class ConflictType(StrEnum): TEXTUAL = "textual" SEMANTIC = "semantic" + + +class AutonomyLevel(StrEnum): + """Autonomy level controlling approval routing for agents. + + Determines which actions an agent can execute autonomously vs. + which require human or security-agent approval (DESIGN_SPEC §12.2). + """ + + FULL = "full" + SEMI = "semi" + SUPERVISED = "supervised" + LOCKED = "locked" + + +class DowngradeReason(StrEnum): + """Reason an agent's autonomy was downgraded at runtime.""" + + HIGH_ERROR_RATE = "high_error_rate" + BUDGET_EXHAUSTED = "budget_exhausted" + SECURITY_INCIDENT = "security_incident" + + +class TimeoutActionType(StrEnum): + """Action to take when an approval item times out (DESIGN_SPEC §12.4).""" + + WAIT = "wait" + APPROVE = "approve" + DENY = "deny" + ESCALATE = "escalate" diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index 031351bac2..22a07d2dd6 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -55,6 +55,7 @@ from ai_company.providers.enums import MessageRole from ai_company.providers.models import ChatMessage from ai_company.security.audit import AuditLog +from ai_company.security.autonomy.models import EffectiveAutonomy # noqa: TC001 from ai_company.security.output_scanner import OutputScanner from ai_company.security.rules.credential_detector import CredentialDetector from ai_company.security.rules.data_leak_detector import DataLeakDetector @@ -175,6 +176,7 @@ async def run( # noqa: PLR0913 max_turns: int = DEFAULT_MAX_TURNS, memory_messages: tuple[ChatMessage, ...] = (), timeout_seconds: float | None = None, + effective_autonomy: EffectiveAutonomy | None = None, ) -> AgentRunResult: """Execute an agent on a task. @@ -213,7 +215,11 @@ async def run( # noqa: PLR0913 await self._budget_enforcer.check_can_execute(agent_id) identity = await self._budget_enforcer.resolve_model(identity) - tool_invoker = self._make_tool_invoker(identity, task_id=task_id) + tool_invoker = self._make_tool_invoker( + identity, + task_id=task_id, + effective_autonomy=effective_autonomy, + ) ctx, system_prompt = self._prepare_context( identity=identity, task=task, @@ -222,6 +228,7 @@ async def run( # noqa: PLR0913 max_turns=max_turns, memory_messages=memory_messages, tool_invoker=tool_invoker, + effective_autonomy=effective_autonomy, ) return await self._execute( identity=identity, @@ -469,6 +476,7 @@ def _prepare_context( # noqa: PLR0913 max_turns: int, memory_messages: tuple[ChatMessage, ...], tool_invoker: ToolInvoker | None = None, + effective_autonomy: EffectiveAutonomy | None = None, ) -> tuple[AgentContext, SystemPrompt]: """Build system prompt and prepare execution context.""" tool_defs = tool_invoker.get_permitted_definitions() if tool_invoker else () @@ -476,6 +484,7 @@ def _prepare_context( # noqa: PLR0913 agent=identity, task=task, available_tools=tool_defs, + effective_autonomy=effective_autonomy, ) ctx = AgentContext.from_identity( @@ -673,6 +682,7 @@ async def _apply_recovery( def _make_security_interceptor( self, + effective_autonomy: EffectiveAutonomy | None = None, ) -> SecurityInterceptionStrategy | None: """Build the SecOps security interceptor if configured.""" if self._security_config is None: @@ -712,12 +722,14 @@ def _make_security_interceptor( audit_log=self._audit_log, output_scanner=OutputScanner(), approval_store=self._approval_store, + effective_autonomy=effective_autonomy, ) def _make_tool_invoker( self, identity: AgentIdentity, task_id: str | None = None, + effective_autonomy: EffectiveAutonomy | None = None, ) -> ToolInvoker | None: """Create a ToolInvoker with permission checking and security. @@ -726,7 +738,7 @@ def _make_tool_invoker( if self._tool_registry is None: return None checker = ToolPermissionChecker.from_permissions(identity.tools) - interceptor = self._make_security_interceptor() + interceptor = self._make_security_interceptor(effective_autonomy) return ToolInvoker( self._tool_registry, permission_checker=checker, diff --git a/src/ai_company/engine/loop_protocol.py b/src/ai_company/engine/loop_protocol.py index eca1b7e9bc..6be6efa2fc 100644 --- a/src/ai_company/engine/loop_protocol.py +++ b/src/ai_company/engine/loop_protocol.py @@ -32,6 +32,7 @@ class TerminationReason(StrEnum): MAX_TURNS = "max_turns" BUDGET_EXHAUSTED = "budget_exhausted" SHUTDOWN = "shutdown" + PARKED = "parked" ERROR = "error" @@ -123,6 +124,9 @@ def _validate_error_message(self) -> Self: if self.error_message is None: msg = "error_message is required when termination_reason is ERROR" raise ValueError(msg) + elif self.termination_reason == TerminationReason.PARKED: + # PARKED allows an optional informational message. + pass elif self.error_message is not None: msg = "error_message must be None when termination_reason is not ERROR" raise ValueError(msg) diff --git a/src/ai_company/engine/prompt.py b/src/ai_company/engine/prompt.py index 8007feb732..a75ffc0d33 100644 --- a/src/ai_company/engine/prompt.py +++ b/src/ai_company/engine/prompt.py @@ -51,6 +51,7 @@ from ai_company.core.role import Role from ai_company.core.task import Task from ai_company.providers.models import ToolDefinition + from ai_company.security.autonomy.models import EffectiveAutonomy logger = get_logger(__name__) @@ -169,6 +170,7 @@ def build_system_prompt( # noqa: PLR0913 max_tokens: int | None = None, custom_template: str | None = None, token_estimator: PromptTokenEstimator | None = None, + effective_autonomy: EffectiveAutonomy | None = None, ) -> SystemPrompt: """Build a system prompt from agent identity and optional context. @@ -188,6 +190,7 @@ def build_system_prompt( # noqa: PLR0913 max_tokens: Token budget; sections are trimmed if exceeded. custom_template: Optional Jinja2 template string override. token_estimator: Custom token estimator (defaults to char/4). + effective_autonomy: Resolved autonomy for the current run. Returns: Immutable :class:`SystemPrompt` with rendered content and metadata. @@ -235,6 +238,7 @@ def build_system_prompt( # noqa: PLR0913 org_policies=org_policies, max_tokens=max_tokens, estimator=estimator, + effective_autonomy=effective_autonomy, ) except PromptBuildError: raise # Already logged by inner functions. @@ -353,12 +357,14 @@ def _resolve_template(custom_template: str | None) -> str: def _build_core_context( agent: AgentIdentity, role: Role | None, + effective_autonomy: EffectiveAutonomy | None = None, ) -> dict[str, Any]: """Build the core (always-present) template variables from agent identity. Args: agent: Agent identity. role: Optional role with description. + effective_autonomy: Resolved autonomy for the current run. Returns: Dict of core template variables. @@ -366,7 +372,7 @@ def _build_core_context( personality = agent.personality authority = agent.authority - return { + ctx: dict[str, Any] = { "agent_name": agent.name, "agent_role": agent.role, "agent_department": agent.department, @@ -390,6 +396,17 @@ def _build_core_context( "autonomy_instructions": AUTONOMY_INSTRUCTIONS[agent.level], } + if effective_autonomy is not None: + ctx["effective_autonomy"] = { + "level": effective_autonomy.level.value, + "auto_approve_actions": sorted(effective_autonomy.auto_approve_actions), + "human_approval_actions": sorted(effective_autonomy.human_approval_actions), + } + else: + ctx["effective_autonomy"] = None + + return ctx + def _build_template_context( # noqa: PLR0913 *, @@ -399,6 +416,7 @@ def _build_template_context( # noqa: PLR0913 available_tools: tuple[ToolDefinition, ...], company: Company | None, org_policies: tuple[str, ...] = (), + effective_autonomy: EffectiveAutonomy | None = None, ) -> dict[str, Any]: """Assemble the full Jinja2 template context from agent and optional inputs. @@ -409,11 +427,12 @@ def _build_template_context( # noqa: PLR0913 available_tools: Tool definitions. company: Optional company context. org_policies: Company-wide policy texts. + effective_autonomy: Resolved autonomy for the current run. Returns: Dict of template variables. """ - context = _build_core_context(agent, role) + context = _build_core_context(agent, role, effective_autonomy) context["org_policies"] = org_policies @@ -542,6 +561,7 @@ def _trim_sections( # noqa: PLR0913 org_policies: tuple[str, ...], max_tokens: int, estimator: PromptTokenEstimator, + effective_autonomy: EffectiveAutonomy | None = None, ) -> tuple[ str, int, @@ -566,6 +586,7 @@ def _trim_sections( # noqa: PLR0913 company, org_policies, estimator, + effective_autonomy=effective_autonomy, ) if estimated <= max_tokens: break @@ -591,6 +612,7 @@ def _trim_sections( # noqa: PLR0913 company, org_policies, estimator, + effective_autonomy=effective_autonomy, ) _log_trim_results(agent, max_tokens, estimated, trimmed_sections) @@ -633,6 +655,7 @@ def _render_with_trimming( # noqa: PLR0913 org_policies: tuple[str, ...] = (), max_tokens: int | None, estimator: PromptTokenEstimator, + effective_autonomy: EffectiveAutonomy | None = None, ) -> SystemPrompt: """Render the prompt, trimming optional sections if over token budget.""" content, estimated = _render_and_estimate( @@ -644,19 +667,23 @@ def _render_with_trimming( # noqa: PLR0913 company, org_policies, estimator, + effective_autonomy=effective_autonomy, ) if max_tokens is not None and estimated > max_tokens: - content, estimated, task, company, org_policies = _trim_sections( - template_str=template_str, - agent=agent, - role=role, - task=task, - available_tools=available_tools, - company=company, - org_policies=org_policies, - max_tokens=max_tokens, - estimator=estimator, + content, estimated, task, available_tools, company, org_policies = ( + _trim_sections( + template_str=template_str, + agent=agent, + role=role, + task=task, + available_tools=available_tools, + company=company, + org_policies=org_policies, + max_tokens=max_tokens, + estimator=estimator, + effective_autonomy=effective_autonomy, + ) ) return _build_prompt_result( @@ -708,6 +735,8 @@ def _render_and_estimate( # noqa: PLR0913 company: Company | None, org_policies: tuple[str, ...], estimator: PromptTokenEstimator, + *, + effective_autonomy: EffectiveAutonomy | None = None, ) -> tuple[str, int]: """Render the template and estimate its token count. @@ -720,6 +749,7 @@ def _render_and_estimate( # noqa: PLR0913 company: Optional company context. org_policies: Company-wide policy texts. estimator: Token estimator. + effective_autonomy: Resolved autonomy for the current run. Returns: Tuple of (rendered content, estimated token count). @@ -731,6 +761,7 @@ def _render_and_estimate( # noqa: PLR0913 available_tools=available_tools, company=company, org_policies=org_policies, + effective_autonomy=effective_autonomy, ) content = _render_template(template_str, context) return content, estimator.estimate_tokens(content) diff --git a/src/ai_company/engine/prompt_template.py b/src/ai_company/engine/prompt_template.py index 5d39d6e74f..1c5d0850ab 100644 --- a/src/ai_company/engine/prompt_template.py +++ b/src/ai_company/engine/prompt_template.py @@ -141,6 +141,17 @@ ## Autonomy {{ autonomy_instructions }} +{% if effective_autonomy %} + +**Autonomy level**: {{ effective_autonomy.level }} +{% if effective_autonomy.auto_approve_actions %} +- **Auto-approved actions**: {{ effective_autonomy.auto_approve_actions | join(', ') }} +{% endif %} +{% if effective_autonomy.human_approval_actions %} +- **Human approval required**: \ +{{ effective_autonomy.human_approval_actions | join(', ') }} +{% endif %} +{% endif %} {% if task %} ## Current Task diff --git a/src/ai_company/observability/events/autonomy.py b/src/ai_company/observability/events/autonomy.py new file mode 100644 index 0000000000..926a259215 --- /dev/null +++ b/src/ai_company/observability/events/autonomy.py @@ -0,0 +1,13 @@ +"""Autonomy subsystem event constants.""" + +from typing import Final + +AUTONOMY_RESOLVED: Final[str] = "autonomy.resolved" +AUTONOMY_PROMOTION_REQUESTED: Final[str] = "autonomy.promotion.requested" +AUTONOMY_PROMOTION_DENIED: Final[str] = "autonomy.promotion.denied" +AUTONOMY_DOWNGRADE_TRIGGERED: Final[str] = "autonomy.downgrade.triggered" +AUTONOMY_RECOVERY_REQUESTED: Final[str] = "autonomy.recovery.requested" +AUTONOMY_SENIORITY_VIOLATION: Final[str] = "autonomy.seniority.violation" +AUTONOMY_PRESET_EXPANDED: Final[str] = "autonomy.preset.expanded" +AUTONOMY_ACTION_AUTO_APPROVED: Final[str] = "autonomy.action.auto_approved" +AUTONOMY_ACTION_HUMAN_REQUIRED: Final[str] = "autonomy.action.human_required" diff --git a/src/ai_company/observability/events/persistence.py b/src/ai_company/observability/events/persistence.py index 999c8e3817..8c9a1e72b0 100644 --- a/src/ai_company/observability/events/persistence.py +++ b/src/ai_company/observability/events/persistence.py @@ -95,3 +95,20 @@ PERSISTENCE_COLLAB_METRIC_DESERIALIZE_FAILED: Final[str] = ( "persistence.collab_metric.deserialize_failed" ) + +# Parked context events +PERSISTENCE_PARKED_CONTEXT_SAVED: Final[str] = "persistence.parked_context.saved" +PERSISTENCE_PARKED_CONTEXT_SAVE_FAILED: Final[str] = ( + "persistence.parked_context.save_failed" +) +PERSISTENCE_PARKED_CONTEXT_QUERIED: Final[str] = "persistence.parked_context.queried" +PERSISTENCE_PARKED_CONTEXT_QUERY_FAILED: Final[str] = ( + "persistence.parked_context.query_failed" +) +PERSISTENCE_PARKED_CONTEXT_NOT_FOUND: Final[str] = ( + "persistence.parked_context.not_found" +) +PERSISTENCE_PARKED_CONTEXT_DELETED: Final[str] = "persistence.parked_context.deleted" +PERSISTENCE_PARKED_CONTEXT_DESERIALIZE_FAILED: Final[str] = ( + "persistence.parked_context.deserialize_failed" +) diff --git a/src/ai_company/observability/events/timeout.py b/src/ai_company/observability/events/timeout.py new file mode 100644 index 0000000000..2ae2e4f5f0 --- /dev/null +++ b/src/ai_company/observability/events/timeout.py @@ -0,0 +1,9 @@ +"""Approval timeout event constants.""" + +from typing import Final + +TIMEOUT_POLICY_EVALUATED: Final[str] = "timeout.policy.evaluated" +TIMEOUT_AUTO_APPROVED: Final[str] = "timeout.auto_approved" +TIMEOUT_AUTO_DENIED: Final[str] = "timeout.auto_denied" +TIMEOUT_ESCALATED: Final[str] = "timeout.escalated" +TIMEOUT_WAITING: Final[str] = "timeout.waiting" diff --git a/src/ai_company/persistence/protocol.py b/src/ai_company/persistence/protocol.py index 0163caf2a8..7a0fac81d5 100644 --- a/src/ai_company/persistence/protocol.py +++ b/src/ai_company/persistence/protocol.py @@ -15,6 +15,7 @@ from ai_company.persistence.repositories import ( CostRecordRepository, # noqa: TC001 MessageRepository, # noqa: TC001 + ParkedContextRepository, # noqa: TC001 TaskRepository, # noqa: TC001 ) @@ -109,3 +110,8 @@ def task_metrics(self) -> TaskMetricRepository: def collaboration_metrics(self) -> CollaborationMetricRepository: """Repository for CollaborationMetricRecord persistence.""" ... + + @property + def parked_contexts(self) -> ParkedContextRepository: + """Repository for ParkedContext persistence.""" + ... diff --git a/src/ai_company/persistence/repositories.py b/src/ai_company/persistence/repositories.py index 47996c1ccb..51b7fa6970 100644 --- a/src/ai_company/persistence/repositories.py +++ b/src/ai_company/persistence/repositories.py @@ -16,12 +16,14 @@ LifecycleEventRepository, TaskMetricRepository, ) +from ai_company.security.timeout.parked_context import ParkedContext # noqa: TC001 __all__ = [ "CollaborationMetricRepository", "CostRecordRepository", "LifecycleEventRepository", "MessageRepository", + "ParkedContextRepository", "TaskMetricRepository", "TaskRepository", ] @@ -184,3 +186,75 @@ async def get_history( PersistenceError: If the operation fails. """ ... + + +@runtime_checkable +class ParkedContextRepository(Protocol): + """CRUD interface for parked agent execution contexts.""" + + async def save(self, context: ParkedContext) -> None: + """Persist a parked context. + + Args: + context: The parked context to persist. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get(self, parked_id: NotBlankStr) -> ParkedContext | None: + """Retrieve a parked context by ID. + + Args: + parked_id: The parked context identifier. + + Returns: + The parked context, or ``None`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get_by_approval(self, approval_id: NotBlankStr) -> ParkedContext | None: + """Retrieve a parked context by approval ID. + + Args: + approval_id: The approval item identifier. + + Returns: + The parked context, or ``None`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get_by_agent(self, agent_id: NotBlankStr) -> tuple[ParkedContext, ...]: + """Retrieve all parked contexts for an agent. + + Args: + agent_id: The agent identifier. + + Returns: + Parked contexts for the agent. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def delete(self, parked_id: NotBlankStr) -> bool: + """Delete a parked context by ID. + + Args: + parked_id: The parked context identifier. + + Returns: + ``True`` if deleted, ``False`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... diff --git a/src/ai_company/persistence/sqlite/backend.py b/src/ai_company/persistence/sqlite/backend.py index 8b07cc6a0b..d53b8bab11 100644 --- a/src/ai_company/persistence/sqlite/backend.py +++ b/src/ai_company/persistence/sqlite/backend.py @@ -27,6 +27,9 @@ SQLiteTaskMetricRepository, ) from ai_company.persistence.sqlite.migrations import run_migrations +from ai_company.persistence.sqlite.parked_context_repo import ( + SQLiteParkedContextRepository, +) from ai_company.persistence.sqlite.repositories import ( SQLiteCostRecordRepository, SQLiteMessageRepository, @@ -60,6 +63,7 @@ def __init__(self, config: SQLiteConfig) -> None: self._lifecycle_events: SQLiteLifecycleEventRepository | None = None self._task_metrics: SQLiteTaskMetricRepository | None = None self._collaboration_metrics: SQLiteCollaborationMetricRepository | None = None + self._parked_contexts: SQLiteParkedContextRepository | None = None def _clear_state(self) -> None: """Reset connection and repository references to ``None``.""" @@ -70,6 +74,7 @@ def _clear_state(self) -> None: self._lifecycle_events = None self._task_metrics = None self._collaboration_metrics = None + self._parked_contexts = None async def connect(self) -> None: """Open the SQLite database and configure WAL mode.""" @@ -106,6 +111,7 @@ async def connect(self) -> None: self._collaboration_metrics = SQLiteCollaborationMetricRepository( self._db ) + self._parked_contexts = SQLiteParkedContextRepository(self._db) except (sqlite3.Error, OSError) as exc: logger.exception( PERSISTENCE_BACKEND_CONNECTION_FAILED, @@ -265,3 +271,12 @@ def collaboration_metrics(self) -> SQLiteCollaborationMetricRepository: return self._require_connected( self._collaboration_metrics, "collaboration_metrics" ) + + @property + def parked_contexts(self) -> SQLiteParkedContextRepository: + """Repository for ParkedContext persistence. + + Raises: + PersistenceConnectionError: If not connected. + """ + return self._require_connected(self._parked_contexts, "parked_contexts") diff --git a/src/ai_company/persistence/sqlite/migrations.py b/src/ai_company/persistence/sqlite/migrations.py index 30c00b7b1e..e09c926c8d 100644 --- a/src/ai_company/persistence/sqlite/migrations.py +++ b/src/ai_company/persistence/sqlite/migrations.py @@ -23,7 +23,7 @@ logger = get_logger(__name__) # Current schema version — bump when adding new migrations. -SCHEMA_VERSION = 2 +SCHEMA_VERSION = 3 _V1_STATEMENTS: Sequence[str] = ( # ── Tasks ───────────────────────────────────────────── @@ -145,6 +145,23 @@ " ON collaboration_metrics(agent_id, recorded_at)", ) +_V3_STATEMENTS: Sequence[str] = ( + # ── Parked contexts ──────────────────────────────────── + """\ +CREATE TABLE IF NOT EXISTS parked_contexts ( + id TEXT PRIMARY KEY, + execution_id TEXT NOT NULL, + agent_id TEXT NOT NULL, + task_id TEXT NOT NULL, + approval_id TEXT NOT NULL, + parked_at TEXT NOT NULL, + context_json TEXT NOT NULL, + metadata TEXT NOT NULL DEFAULT '{}' +)""", + "CREATE INDEX IF NOT EXISTS idx_pc_agent_id ON parked_contexts(agent_id)", + "CREATE INDEX IF NOT EXISTS idx_pc_approval_id ON parked_contexts(approval_id)", +) + _MigrateFn = Callable[[aiosqlite.Connection], Coroutine[Any, Any, None]] @@ -189,11 +206,18 @@ async def _apply_v2(db: aiosqlite.Connection) -> None: await db.execute(stmt) +async def _apply_v3(db: aiosqlite.Connection) -> None: + """Apply schema v3: parked_contexts.""" + for stmt in _V3_STATEMENTS: + await db.execute(stmt) + + # Ordered list of (target_version, migration_function) pairs. Each migration # is applied when the current schema version is below its target version. _MIGRATIONS: list[tuple[int, _MigrateFn]] = [ (1, _apply_v1), (2, _apply_v2), + (3, _apply_v3), ] diff --git a/src/ai_company/persistence/sqlite/parked_context_repo.py b/src/ai_company/persistence/sqlite/parked_context_repo.py new file mode 100644 index 0000000000..95d657b49b --- /dev/null +++ b/src/ai_company/persistence/sqlite/parked_context_repo.py @@ -0,0 +1,182 @@ +"""SQLite repository implementation for parked agent execution contexts.""" + +import json +import sqlite3 + +import aiosqlite +from pydantic import ValidationError + +from ai_company.observability import get_logger +from ai_company.observability.events.persistence import ( + PERSISTENCE_PARKED_CONTEXT_DELETED, + PERSISTENCE_PARKED_CONTEXT_DESERIALIZE_FAILED, + PERSISTENCE_PARKED_CONTEXT_NOT_FOUND, + PERSISTENCE_PARKED_CONTEXT_QUERIED, + PERSISTENCE_PARKED_CONTEXT_QUERY_FAILED, + PERSISTENCE_PARKED_CONTEXT_SAVE_FAILED, + PERSISTENCE_PARKED_CONTEXT_SAVED, +) +from ai_company.persistence.errors import QueryError +from ai_company.security.timeout.parked_context import ParkedContext + +logger = get_logger(__name__) + + +class SQLiteParkedContextRepository: + """SQLite implementation of the ParkedContextRepository protocol. + + Args: + db: An open aiosqlite connection. + """ + + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + async def save(self, context: ParkedContext) -> None: + """Persist a parked context.""" + try: + data = context.model_dump(mode="json") + await self._db.execute( + """\ +INSERT OR REPLACE INTO parked_contexts ( + id, execution_id, agent_id, task_id, approval_id, + parked_at, context_json, metadata +) VALUES ( + :id, :execution_id, :agent_id, :task_id, :approval_id, + :parked_at, :context_json, :metadata +)""", + {**data, "metadata": json.dumps(data["metadata"])}, + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to save parked context {context.id!r}" + logger.exception( + PERSISTENCE_PARKED_CONTEXT_SAVE_FAILED, + parked_id=context.id, + error=str(exc), + ) + raise QueryError(msg) from exc + logger.debug( + PERSISTENCE_PARKED_CONTEXT_SAVED, + parked_id=context.id, + agent_id=context.agent_id, + ) + + async def get(self, parked_id: str) -> ParkedContext | None: + """Retrieve a parked context by ID.""" + try: + cursor = await self._db.execute( + "SELECT * FROM parked_contexts WHERE id = ?", + (parked_id,), + ) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to query parked context {parked_id!r}" + logger.exception( + PERSISTENCE_PARKED_CONTEXT_QUERY_FAILED, + parked_id=parked_id, + error=str(exc), + ) + raise QueryError(msg) from exc + + if row is None: + logger.debug( + PERSISTENCE_PARKED_CONTEXT_NOT_FOUND, + parked_id=parked_id, + ) + return None + + return self._row_to_model(dict(row)) + + async def get_by_approval(self, approval_id: str) -> ParkedContext | None: + """Retrieve a parked context by approval ID.""" + try: + cursor = await self._db.execute( + "SELECT * FROM parked_contexts WHERE approval_id = ?", + (approval_id,), + ) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to query parked context by approval {approval_id!r}" + logger.exception( + PERSISTENCE_PARKED_CONTEXT_QUERY_FAILED, + approval_id=approval_id, + error=str(exc), + ) + raise QueryError(msg) from exc + + if row is None: + return None + + return self._row_to_model(dict(row)) + + async def get_by_agent(self, agent_id: str) -> tuple[ParkedContext, ...]: + """Retrieve all parked contexts for an agent.""" + try: + cursor = await self._db.execute( + "SELECT * FROM parked_contexts WHERE agent_id = ? " + "ORDER BY parked_at DESC", + (agent_id,), + ) + rows = await cursor.fetchall() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to query parked contexts for agent {agent_id!r}" + logger.exception( + PERSISTENCE_PARKED_CONTEXT_QUERY_FAILED, + agent_id=agent_id, + error=str(exc), + ) + raise QueryError(msg) from exc + + results: list[ParkedContext] = [] + for row in rows: + model = self._row_to_model(dict(row)) + if model is not None: + results.append(model) + + logger.debug( + PERSISTENCE_PARKED_CONTEXT_QUERIED, + agent_id=agent_id, + count=len(results), + ) + return tuple(results) + + async def delete(self, parked_id: str) -> bool: + """Delete a parked context by ID.""" + try: + cursor = await self._db.execute( + "DELETE FROM parked_contexts WHERE id = ?", + (parked_id,), + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to delete parked context {parked_id!r}" + logger.exception( + PERSISTENCE_PARKED_CONTEXT_QUERY_FAILED, + parked_id=parked_id, + error=str(exc), + ) + raise QueryError(msg) from exc + + deleted = cursor.rowcount > 0 + if deleted: + logger.debug( + PERSISTENCE_PARKED_CONTEXT_DELETED, + parked_id=parked_id, + ) + return deleted + + def _row_to_model(self, row: dict[str, object]) -> ParkedContext | None: + """Convert a database row to a ``ParkedContext`` model.""" + try: + raw_meta = row.get("metadata") + if isinstance(raw_meta, str): + row["metadata"] = json.loads(raw_meta) + return ParkedContext.model_validate(row) + except (ValidationError, json.JSONDecodeError) as exc: + logger.warning( + PERSISTENCE_PARKED_CONTEXT_DESERIALIZE_FAILED, + parked_id=row.get("id"), + error=str(exc), + ) + return None diff --git a/src/ai_company/security/autonomy/__init__.py b/src/ai_company/security/autonomy/__init__.py new file mode 100644 index 0000000000..f8cf4f7322 --- /dev/null +++ b/src/ai_company/security/autonomy/__init__.py @@ -0,0 +1,23 @@ +"""Autonomy level management — presets, resolution, and runtime changes.""" + +from ai_company.security.autonomy.change_strategy import HumanOnlyPromotionStrategy +from ai_company.security.autonomy.models import ( + BUILTIN_PRESETS, + AutonomyConfig, + AutonomyOverride, + AutonomyPreset, + EffectiveAutonomy, +) +from ai_company.security.autonomy.protocol import AutonomyChangeStrategy +from ai_company.security.autonomy.resolver import AutonomyResolver + +__all__ = [ + "BUILTIN_PRESETS", + "AutonomyChangeStrategy", + "AutonomyConfig", + "AutonomyOverride", + "AutonomyPreset", + "AutonomyResolver", + "EffectiveAutonomy", + "HumanOnlyPromotionStrategy", +] diff --git a/src/ai_company/security/autonomy/change_strategy.py b/src/ai_company/security/autonomy/change_strategy.py new file mode 100644 index 0000000000..337b85404c --- /dev/null +++ b/src/ai_company/security/autonomy/change_strategy.py @@ -0,0 +1,143 @@ +"""Human-only promotion strategy — the default autonomy change strategy.""" + +from datetime import UTC, datetime + +from ai_company.core.enums import AutonomyLevel, DowngradeReason +from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.observability import get_logger +from ai_company.observability.events.autonomy import ( + AUTONOMY_DOWNGRADE_TRIGGERED, + AUTONOMY_PROMOTION_DENIED, + AUTONOMY_PROMOTION_REQUESTED, + AUTONOMY_RECOVERY_REQUESTED, +) +from ai_company.security.autonomy.models import AutonomyOverride + +logger = get_logger(__name__) + +# Mapping from DowngradeReason to the resulting autonomy level. +_DOWNGRADE_MAP: dict[DowngradeReason, AutonomyLevel] = { + DowngradeReason.HIGH_ERROR_RATE: AutonomyLevel.SUPERVISED, + DowngradeReason.BUDGET_EXHAUSTED: AutonomyLevel.SUPERVISED, + DowngradeReason.SECURITY_INCIDENT: AutonomyLevel.LOCKED, +} + + +class HumanOnlyPromotionStrategy: + """Default strategy: promotions and recovery always require human approval. + + Downgrades are applied immediately based on the reason: + - ``HIGH_ERROR_RATE`` → SUPERVISED + - ``BUDGET_EXHAUSTED`` → SUPERVISED + - ``SECURITY_INCIDENT`` → LOCKED + + This strategy tracks active overrides in memory. In production, + overrides should be persisted to the persistence backend. + """ + + def __init__(self) -> None: + self._overrides: dict[str, AutonomyOverride] = {} + + def request_promotion( + self, + agent_id: NotBlankStr, + target: AutonomyLevel, + ) -> bool: + """Deny all promotion requests — requires human approval. + + Args: + agent_id: The agent requesting promotion. + target: The desired autonomy level. + + Returns: + Always ``False``. + """ + logger.info( + AUTONOMY_PROMOTION_REQUESTED, + agent_id=agent_id, + target=target.value, + ) + logger.info( + AUTONOMY_PROMOTION_DENIED, + agent_id=agent_id, + target=target.value, + reason="human approval required", + ) + return False + + def auto_downgrade( + self, + agent_id: NotBlankStr, + reason: DowngradeReason, + ) -> AutonomyLevel: + """Immediately downgrade to a level determined by the reason. + + Args: + agent_id: The agent to downgrade. + reason: Why the downgrade is happening. + + Returns: + The new autonomy level after downgrade. + """ + new_level = _DOWNGRADE_MAP[reason] + existing = self._overrides.get(agent_id) + original = existing.original_level if existing else AutonomyLevel.SEMI + + override = AutonomyOverride( + agent_id=agent_id, + original_level=original, + current_level=new_level, + reason=reason, + downgraded_at=datetime.now(UTC), + requires_human_recovery=True, + ) + self._overrides[agent_id] = override + + logger.warning( + AUTONOMY_DOWNGRADE_TRIGGERED, + agent_id=agent_id, + reason=reason.value, + new_level=new_level.value, + original_level=original.value, + ) + return new_level + + def request_recovery( + self, + agent_id: NotBlankStr, + ) -> bool: + """Deny all recovery requests — requires human approval. + + Args: + agent_id: The agent requesting recovery. + + Returns: + Always ``False``. + """ + logger.info( + AUTONOMY_RECOVERY_REQUESTED, + agent_id=agent_id, + ) + return False + + def get_override(self, agent_id: str) -> AutonomyOverride | None: + """Return the active override for an agent, if any. + + Args: + agent_id: The agent to look up. + + Returns: + The override record, or ``None`` if no override exists. + """ + return self._overrides.get(agent_id) + + def clear_override(self, agent_id: str) -> bool: + """Remove an override (used after human recovery approval). + + Args: + agent_id: The agent whose override to clear. + + Returns: + ``True`` if an override was removed, ``False`` if none existed. + """ + return self._overrides.pop(agent_id, None) is not None diff --git a/src/ai_company/security/autonomy/models.py b/src/ai_company/security/autonomy/models.py new file mode 100644 index 0000000000..a7d37afbc5 --- /dev/null +++ b/src/ai_company/security/autonomy/models.py @@ -0,0 +1,199 @@ +"""Autonomy data models — presets, config, effective resolution, overrides.""" + +from typing import Final, Self + +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, model_validator + +from ai_company.core.enums import AutonomyLevel, DowngradeReason +from ai_company.core.types import NotBlankStr # noqa: TC001 + + +class AutonomyPreset(BaseModel): + """A named autonomy preset defining action routing rules. + + Actions listed in ``auto_approve`` are executed without human + review. Actions in ``human_approval`` require a human decision. + The two sets must be disjoint — an action cannot be both + auto-approved and human-approval. + + Attributes: + level: The autonomy level this preset represents. + description: Human-readable description. + auto_approve: Action type patterns that are auto-approved. + The special value ``"all"`` means every action type. + Category shortcuts (e.g. ``"code"``) are expanded via + :class:`~ai_company.security.action_types.ActionTypeRegistry`. + human_approval: Action type patterns requiring human approval. + Same expansion rules as ``auto_approve``. + security_agent: Whether a security agent reviews escalated + actions before they reach a human. + """ + + model_config = ConfigDict(frozen=True) + + level: AutonomyLevel = Field(description="Autonomy level") + description: NotBlankStr = Field(description="Human-readable description") + auto_approve: tuple[str, ...] = Field( + default=(), + description="Action patterns that are auto-approved", + ) + human_approval: tuple[str, ...] = Field( + default=(), + description="Action patterns requiring human approval", + ) + security_agent: bool = Field( + default=True, + description="Whether security agent reviews escalations", + ) + + @model_validator(mode="after") + def _validate_disjoint(self) -> Self: + """Ensure auto_approve and human_approval are disjoint.""" + overlap = set(self.auto_approve) & set(self.human_approval) + if overlap: + msg = ( + f"auto_approve and human_approval must be disjoint, " + f"overlapping entries: {sorted(overlap)}" + ) + raise ValueError(msg) + return self + + +BUILTIN_PRESETS: Final[dict[str, AutonomyPreset]] = { + AutonomyLevel.FULL: AutonomyPreset( + level=AutonomyLevel.FULL, + description="Fully autonomous — all actions auto-approved", + auto_approve=("all",), + human_approval=(), + security_agent=False, + ), + AutonomyLevel.SEMI: AutonomyPreset( + level=AutonomyLevel.SEMI, + description=( + "Semi-autonomous — code, test, docs auto-approved; " + "deploy, org, budget require human approval" + ), + auto_approve=("code", "test", "docs", "vcs", "db:query"), + human_approval=("deploy", "org", "budget", "comms:external"), + security_agent=True, + ), + AutonomyLevel.SUPERVISED: AutonomyPreset( + level=AutonomyLevel.SUPERVISED, + description=( + "Supervised — read-only and test actions auto-approved; " + "all mutations require human approval" + ), + auto_approve=("code:read", "vcs:read", "test:run", "db:query"), + human_approval=( + "code:write", + "code:create", + "code:delete", + "code:refactor", + "test:write", + "docs:write", + "vcs:commit", + "vcs:push", + "vcs:branch", + "deploy", + "comms", + "budget", + "org", + "db:mutate", + "db:admin", + "arch:decide", + ), + security_agent=True, + ), + AutonomyLevel.LOCKED: AutonomyPreset( + level=AutonomyLevel.LOCKED, + description="Locked — all actions require human approval", + auto_approve=(), + human_approval=("all",), + security_agent=True, + ), +} + + +class AutonomyConfig(BaseModel): + """Company-level autonomy configuration. + + Attributes: + level: Default autonomy level for the company. + presets: Available autonomy presets keyed by level name. + Defaults to ``BUILTIN_PRESETS``. + """ + + model_config = ConfigDict(frozen=True) + + level: AutonomyLevel = Field( + default=AutonomyLevel.SEMI, + description="Default company autonomy level", + ) + presets: dict[str, AutonomyPreset] = Field( + default_factory=lambda: dict(BUILTIN_PRESETS), + description="Available autonomy presets", + ) + + @model_validator(mode="after") + def _validate_level_in_presets(self) -> Self: + """Ensure the configured level has a matching preset.""" + if self.level not in self.presets: + msg = ( + f"Autonomy level {self.level!r} not found in presets " + f"(available: {sorted(self.presets)})" + ) + raise ValueError(msg) + return self + + +class EffectiveAutonomy(BaseModel): + """Resolved, expanded autonomy for an agent's execution run. + + Produced by :class:`~ai_company.security.autonomy.resolver.AutonomyResolver` + by resolving the three-level chain (agent → department → company) + and expanding category shortcuts into concrete action types. + + Attributes: + level: Resolved autonomy level. + auto_approve_actions: Concrete action types that are auto-approved. + human_approval_actions: Concrete action types requiring human approval. + security_agent: Whether the security agent reviews escalations. + """ + + model_config = ConfigDict(frozen=True) + + level: AutonomyLevel = Field(description="Resolved autonomy level") + auto_approve_actions: frozenset[str] = Field( + description="Expanded auto-approve action types", + ) + human_approval_actions: frozenset[str] = Field( + description="Expanded human-approval action types", + ) + security_agent: bool = Field( + description="Whether security agent reviews escalations", + ) + + +class AutonomyOverride(BaseModel): + """Record of a runtime autonomy downgrade for an agent. + + Attributes: + agent_id: The agent whose autonomy was changed. + original_level: Level before the downgrade. + current_level: Level after the downgrade. + reason: Why the downgrade occurred. + downgraded_at: When the downgrade happened. + requires_human_recovery: Whether a human must restore the level. + """ + + model_config = ConfigDict(frozen=True) + + agent_id: NotBlankStr = Field(description="Agent identifier") + original_level: AutonomyLevel = Field(description="Level before downgrade") + current_level: AutonomyLevel = Field(description="Level after downgrade") + reason: DowngradeReason = Field(description="Reason for downgrade") + downgraded_at: AwareDatetime = Field(description="Timestamp of downgrade") + requires_human_recovery: bool = Field( + default=True, + description="Whether human approval is needed to restore level", + ) diff --git a/src/ai_company/security/autonomy/protocol.py b/src/ai_company/security/autonomy/protocol.py new file mode 100644 index 0000000000..c0724d6d53 --- /dev/null +++ b/src/ai_company/security/autonomy/protocol.py @@ -0,0 +1,63 @@ +"""Autonomy change strategy protocol (DESIGN_SPEC §12.2 D7).""" + +from typing import Protocol, runtime_checkable + +from ai_company.core.enums import AutonomyLevel, DowngradeReason # noqa: TC001 +from ai_company.core.types import NotBlankStr # noqa: TC001 + + +@runtime_checkable +class AutonomyChangeStrategy(Protocol): + """Strategy for managing runtime autonomy level changes. + + Implementations control how promotion requests, automatic + downgrades, and recovery requests are handled. + """ + + def request_promotion( + self, + agent_id: NotBlankStr, + target: AutonomyLevel, + ) -> bool: + """Request a promotion to a higher autonomy level. + + Args: + agent_id: The agent requesting promotion. + target: The desired autonomy level. + + Returns: + ``True`` if the promotion is immediately granted, + ``False`` if it requires human approval. + """ + ... + + def auto_downgrade( + self, + agent_id: NotBlankStr, + reason: DowngradeReason, + ) -> AutonomyLevel: + """Automatically downgrade an agent's autonomy level. + + Args: + agent_id: The agent to downgrade. + reason: Why the downgrade is happening. + + Returns: + The new (lower) autonomy level. + """ + ... + + def request_recovery( + self, + agent_id: NotBlankStr, + ) -> bool: + """Request recovery from a previous downgrade. + + Args: + agent_id: The agent requesting recovery. + + Returns: + ``True`` if recovery is immediately granted, + ``False`` if it requires human approval. + """ + ... diff --git a/src/ai_company/security/autonomy/resolver.py b/src/ai_company/security/autonomy/resolver.py new file mode 100644 index 0000000000..f1e4d81fc8 --- /dev/null +++ b/src/ai_company/security/autonomy/resolver.py @@ -0,0 +1,173 @@ +"""Autonomy resolver — three-level chain and category expansion.""" + +from ai_company.core.enums import AutonomyLevel, SeniorityLevel, compare_seniority +from ai_company.observability import get_logger +from ai_company.observability.events.autonomy import ( + AUTONOMY_PRESET_EXPANDED, + AUTONOMY_RESOLVED, + AUTONOMY_SENIORITY_VIOLATION, +) +from ai_company.security.action_types import ActionTypeRegistry # noqa: TC001 +from ai_company.security.autonomy.models import ( + AutonomyConfig, + EffectiveAutonomy, +) + +logger = get_logger(__name__) + +# Seniority threshold: JUNIOR agents cannot have FULL autonomy. +_JUNIOR_MAX_AUTONOMY = AutonomyLevel.SEMI + + +class AutonomyResolver: + """Resolves effective autonomy via a three-level chain. + + Resolution order (most specific wins): + 1. Agent-level override + 2. Department-level override + 3. Company-level default + + After resolution, category shortcuts (e.g. ``"code"``) are expanded + into concrete action types via the ``ActionTypeRegistry``, and the + ``"all"`` shortcut is expanded to every registered action type. + + Args: + registry: Action type registry for category expansion. + config: Company-level autonomy configuration with presets. + """ + + def __init__( + self, + *, + registry: ActionTypeRegistry, + config: AutonomyConfig, + ) -> None: + self._registry = registry + self._config = config + + def resolve( + self, + agent_level: AutonomyLevel | None = None, + department_level: AutonomyLevel | None = None, + ) -> EffectiveAutonomy: + """Resolve effective autonomy from the three-level chain. + + Args: + agent_level: Per-agent override (highest priority). + department_level: Per-department override. + + Returns: + Fully expanded :class:`EffectiveAutonomy`. + + Raises: + ValueError: If the resolved level has no matching preset. + """ + level = agent_level or department_level or self._config.level + preset = self._config.presets.get(level) + if preset is None: + msg = ( + f"No preset found for autonomy level {level!r} " + f"(available: {sorted(self._config.presets)})" + ) + raise ValueError(msg) + + auto_approve = self._expand_patterns(preset.auto_approve) + human_approval = self._expand_patterns(preset.human_approval) + + result = EffectiveAutonomy( + level=level, + auto_approve_actions=auto_approve, + human_approval_actions=human_approval, + security_agent=preset.security_agent, + ) + + logger.info( + AUTONOMY_RESOLVED, + resolved_level=level.value, + agent_override=agent_level.value if agent_level else None, + department_override=department_level.value if department_level else None, + auto_approve_count=len(auto_approve), + human_approval_count=len(human_approval), + ) + return result + + def validate_seniority( + self, + seniority: SeniorityLevel, + autonomy: AutonomyLevel, + ) -> None: + """Reject JUNIOR agents with FULL autonomy (D6). + + Args: + seniority: The agent's seniority level. + autonomy: The requested autonomy level. + + Raises: + ValueError: If a JUNIOR agent requests FULL autonomy. + """ + if ( + compare_seniority(seniority, SeniorityLevel.JUNIOR) <= 0 + and autonomy == AutonomyLevel.FULL + ): + logger.warning( + AUTONOMY_SENIORITY_VIOLATION, + seniority=seniority.value, + autonomy=autonomy.value, + ) + msg = ( + f"Seniority level {seniority.value!r} cannot have " + f"FULL autonomy — maximum is {_JUNIOR_MAX_AUTONOMY.value!r}" + ) + raise ValueError(msg) + + def _expand_patterns( + self, + patterns: tuple[str, ...], + ) -> frozenset[str]: + """Expand category shortcuts and ``"all"`` into concrete types. + + Args: + patterns: Action type patterns from a preset. Each entry + can be a concrete type (``"code:read"``), a category + shortcut (``"code"``), or the literal ``"all"``. + + Returns: + Frozenset of expanded, concrete action type strings. + """ + if not patterns: + return frozenset() + + result: set[str] = set() + + for pattern in patterns: + if pattern == "all": + expanded = self._registry.all_types() + result.update(expanded) + logger.debug( + AUTONOMY_PRESET_EXPANDED, + pattern=pattern, + expanded_count=len(expanded), + ) + continue + + # Try category expansion first. + category_types = self._registry.expand_category(pattern) + if category_types: + result.update(category_types) + logger.debug( + AUTONOMY_PRESET_EXPANDED, + pattern=pattern, + expanded_count=len(category_types), + ) + continue + + # Treat as a concrete action type. + if self._registry.is_registered(pattern): + result.add(pattern) + else: + # Unknown pattern — still include it so the security + # layer can match it. Custom action types registered + # later may use this pattern. + result.add(pattern) + + return frozenset(result) diff --git a/src/ai_company/security/service.py b/src/ai_company/security/service.py index cd99e22ca3..b1d31f49dd 100644 --- a/src/ai_company/security/service.py +++ b/src/ai_company/security/service.py @@ -14,6 +14,10 @@ from ai_company.core.approval import ApprovalItem from ai_company.core.enums import ApprovalRiskLevel, ApprovalStatus from ai_company.observability import get_logger +from ai_company.observability.events.autonomy import ( + AUTONOMY_ACTION_AUTO_APPROVED, + AUTONOMY_ACTION_HUMAN_REQUIRED, +) from ai_company.observability.events.security import ( SECURITY_AUDIT_RECORD_ERROR, SECURITY_CONFIG_LOADED, @@ -28,6 +32,7 @@ SECURITY_VERDICT_ESCALATE, ) from ai_company.security.audit import AuditLog # noqa: TC001 +from ai_company.security.autonomy.models import EffectiveAutonomy # noqa: TC001 from ai_company.security.config import SecurityConfig # noqa: TC001 from ai_company.security.models import ( OUTPUT_SCAN_VERDICT, @@ -69,7 +74,7 @@ class SecOpsService: and returns the verdict with ``approval_id`` set. """ - def __init__( + def __init__( # noqa: PLR0913 self, *, config: SecurityConfig, @@ -77,6 +82,7 @@ def __init__( audit_log: AuditLog, output_scanner: OutputScanner, approval_store: ApprovalStore | None = None, + effective_autonomy: EffectiveAutonomy | None = None, ) -> None: """Initialize the SecOps service. @@ -86,12 +92,16 @@ def __init__( audit_log: Audit log for recording evaluations. output_scanner: Post-tool output scanner. approval_store: Optional store for escalation items. + effective_autonomy: Resolved autonomy for the current run. + When provided, actions are routed based on autonomy + level before the rule engine is consulted. """ self._config = config self._rule_engine = rule_engine self._audit_log = audit_log self._output_scanner = output_scanner self._approval_store = approval_store + self._effective_autonomy = effective_autonomy if config.custom_policies: logger.warning( @@ -135,6 +145,12 @@ async def evaluate_pre_tool( agent_id=context.agent_id, ) + # Autonomy pre-check: route based on effective autonomy before + # the full rule engine. Hard-deny is always checked first. + autonomy_result = await self._apply_autonomy_precheck(context) + if autonomy_result is not None: + return autonomy_result + try: verdict = self._rule_engine.evaluate(context) except MemoryError, RecursionError: @@ -222,6 +238,85 @@ async def scan_output( return result + async def _apply_autonomy_precheck( + self, + context: SecurityContext, + ) -> SecurityVerdict | None: + """Apply autonomy-based routing and finalize the verdict. + + Returns a complete verdict (with escalation/audit handled) if + autonomy routing applies, or ``None`` to fall through. + """ + if self._effective_autonomy is None: + return None + verdict = self._check_autonomy(context) + if verdict is None: + return None + if verdict.verdict == SecurityVerdictType.ESCALATE: + verdict = await self._handle_escalation(context, verdict) + if self._config.audit_enabled: + self._record_audit(context, verdict) + return verdict + + def _check_autonomy( + self, + context: SecurityContext, + ) -> SecurityVerdict | None: + """Check autonomy routing for an action type. + + Returns a verdict if the action is routed by autonomy config, + or ``None`` to fall through to the rule engine. + + Hard-deny actions always fall through so the rule engine + produces its standard DENY verdict. + """ + action = context.action_type + + # Hard-deny always bypasses autonomy — let the rule engine deny it. + if action in self._config.hard_deny_action_types: + return None + + autonomy = self._effective_autonomy + assert autonomy is not None # noqa: S101 — guarded by caller + + now = datetime.now(UTC) + + if action in autonomy.auto_approve_actions: + logger.info( + AUTONOMY_ACTION_AUTO_APPROVED, + tool_name=context.tool_name, + action_type=action, + autonomy_level=autonomy.level.value, + ) + return SecurityVerdict( + verdict=SecurityVerdictType.ALLOW, + reason=f"Auto-approved by autonomy level '{autonomy.level.value}'", + risk_level=ApprovalRiskLevel.LOW, + evaluated_at=now, + evaluation_duration_ms=0.0, + ) + + if action in autonomy.human_approval_actions: + logger.info( + AUTONOMY_ACTION_HUMAN_REQUIRED, + tool_name=context.tool_name, + action_type=action, + autonomy_level=autonomy.level.value, + ) + return SecurityVerdict( + verdict=SecurityVerdictType.ESCALATE, + reason=( + f"Human approval required by autonomy level " + f"'{autonomy.level.value}'" + ), + risk_level=ApprovalRiskLevel.MEDIUM, + evaluated_at=now, + evaluation_duration_ms=0.0, + ) + + # Action not classified by autonomy — fall through to rule engine. + return None + def _record_audit( self, context: SecurityContext, diff --git a/src/ai_company/security/timeout/__init__.py b/src/ai_company/security/timeout/__init__.py new file mode 100644 index 0000000000..052f30ee67 --- /dev/null +++ b/src/ai_company/security/timeout/__init__.py @@ -0,0 +1,18 @@ +"""Approval timeout policies — wait, deny, tiered, escalation chain.""" + +from ai_company.security.timeout.factory import create_timeout_policy +from ai_company.security.timeout.models import TimeoutAction +from ai_company.security.timeout.park_service import ParkService +from ai_company.security.timeout.parked_context import ParkedContext +from ai_company.security.timeout.protocol import RiskTierClassifier, TimeoutPolicy +from ai_company.security.timeout.timeout_checker import TimeoutChecker + +__all__ = [ + "ParkService", + "ParkedContext", + "RiskTierClassifier", + "TimeoutAction", + "TimeoutChecker", + "TimeoutPolicy", + "create_timeout_policy", +] diff --git a/src/ai_company/security/timeout/config.py b/src/ai_company/security/timeout/config.py new file mode 100644 index 0000000000..6ce1f60360 --- /dev/null +++ b/src/ai_company/security/timeout/config.py @@ -0,0 +1,145 @@ +"""Timeout policy configuration models — discriminated union of 4 policies.""" + +from typing import Annotated, Literal + +from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag + +from ai_company.core.enums import TimeoutActionType +from ai_company.core.types import NotBlankStr # noqa: TC001 + + +class WaitForeverConfig(BaseModel): + """Wait indefinitely for human approval — the default. + + Attributes: + policy: Discriminator tag. + """ + + model_config = ConfigDict(frozen=True) + + policy: Literal["wait"] = "wait" + + +class DenyOnTimeoutConfig(BaseModel): + """Deny the action after a fixed timeout. + + Attributes: + policy: Discriminator tag. + timeout_minutes: Minutes before auto-deny. + """ + + model_config = ConfigDict(frozen=True) + + policy: Literal["deny"] = "deny" + timeout_minutes: float = Field( + default=240.0, + gt=0, + description="Minutes before auto-deny", + ) + + +class TierConfig(BaseModel): + """Per-risk-tier timeout configuration. + + Attributes: + timeout_minutes: Minutes before the on_timeout action. + on_timeout: What to do when the tier times out. + actions: Optional set of specific action types in this tier + (if empty, the tier is matched by risk level). + """ + + model_config = ConfigDict(frozen=True) + + timeout_minutes: float = Field( + gt=0, + description="Minutes before the timeout action", + ) + on_timeout: TimeoutActionType = Field( + default=TimeoutActionType.DENY, + description="Action when this tier times out", + ) + actions: tuple[str, ...] = Field( + default=(), + description="Specific action types in this tier", + ) + + +class TieredTimeoutConfig(BaseModel): + """Per-risk-tier timeout policy. + + Each tier defines its own timeout and action. Unknown risk + tiers fall back to HIGH (fail-safe per D19). + + Attributes: + policy: Discriminator tag. + tiers: Tier configurations keyed by risk level name. + """ + + model_config = ConfigDict(frozen=True) + + policy: Literal["tiered"] = "tiered" + tiers: dict[str, TierConfig] = Field( + default_factory=dict, + description="Tier configs keyed by risk level (low/medium/high/critical)", + ) + + +class EscalationStep(BaseModel): + """A single step in an escalation chain. + + Attributes: + role: The role to escalate to at this step. + timeout_minutes: Minutes to wait at this step before + moving to the next. + """ + + model_config = ConfigDict(frozen=True) + + role: NotBlankStr = Field(description="Escalation target role") + timeout_minutes: float = Field( + gt=0, + description="Minutes to wait at this escalation step", + ) + + +class EscalationChainConfig(BaseModel): + """Escalation chain timeout policy. + + Approval is escalated through a chain of roles, each with its + own timeout. If the entire chain is exhausted, the + ``on_chain_exhausted`` action is taken. + + Attributes: + policy: Discriminator tag. + chain: Ordered escalation steps. + on_chain_exhausted: Action when all steps exhaust. + """ + + model_config = ConfigDict(frozen=True) + + policy: Literal["escalation"] = "escalation" + chain: tuple[EscalationStep, ...] = Field( + default=(), + description="Ordered escalation steps", + ) + on_chain_exhausted: TimeoutActionType = Field( + default=TimeoutActionType.DENY, + description="Action when the entire chain is exhausted", + ) + + +def _timeout_discriminator(value: object) -> str: + """Extract the ``policy`` discriminator from raw or model data.""" + if isinstance(value, dict): + return str(value.get("policy", "wait")) + return getattr(value, "policy", "wait") + + +ApprovalTimeoutConfig = Annotated[ + Annotated[WaitForeverConfig, Tag("wait")] + | Annotated[DenyOnTimeoutConfig, Tag("deny")] + | Annotated[TieredTimeoutConfig, Tag("tiered")] + | Annotated[EscalationChainConfig, Tag("escalation")], + Discriminator(_timeout_discriminator), +] +"""Discriminated union of the four timeout policy configurations.""" diff --git a/src/ai_company/security/timeout/factory.py b/src/ai_company/security/timeout/factory.py new file mode 100644 index 0000000000..239880c994 --- /dev/null +++ b/src/ai_company/security/timeout/factory.py @@ -0,0 +1,57 @@ +"""Factory for creating timeout policy instances from configuration.""" + +from ai_company.security.timeout.config import ( + ApprovalTimeoutConfig, + DenyOnTimeoutConfig, + EscalationChainConfig, + TieredTimeoutConfig, + WaitForeverConfig, +) +from ai_company.security.timeout.policies import ( + DenyOnTimeoutPolicy, + EscalationChainPolicy, + TieredTimeoutPolicy, + WaitForeverPolicy, +) +from ai_company.security.timeout.protocol import TimeoutPolicy # noqa: TC001 +from ai_company.security.timeout.risk_tier_classifier import YamlRiskTierClassifier + +_SECONDS_PER_MINUTE = 60.0 + + +def create_timeout_policy( + config: ApprovalTimeoutConfig, +) -> TimeoutPolicy: + """Create a timeout policy from its configuration. + + Args: + config: One of the four timeout policy configurations. + + Returns: + A configured timeout policy instance. + + Raises: + TypeError: If config type is not recognized. + """ + if isinstance(config, WaitForeverConfig): + return WaitForeverPolicy() + + if isinstance(config, DenyOnTimeoutConfig): + return DenyOnTimeoutPolicy( + timeout_seconds=config.timeout_minutes * _SECONDS_PER_MINUTE, + ) + + if isinstance(config, TieredTimeoutConfig): + return TieredTimeoutPolicy( + tiers=config.tiers, + classifier=YamlRiskTierClassifier(), + ) + + if isinstance(config, EscalationChainConfig): + return EscalationChainPolicy( + chain=config.chain, + on_chain_exhausted=config.on_chain_exhausted, + ) + + msg = f"Unknown timeout policy config type: {type(config).__name__}" # type: ignore[unreachable] + raise TypeError(msg) diff --git a/src/ai_company/security/timeout/models.py b/src/ai_company/security/timeout/models.py new file mode 100644 index 0000000000..f445ac56cc --- /dev/null +++ b/src/ai_company/security/timeout/models.py @@ -0,0 +1,26 @@ +"""Timeout action model — the result of evaluating a timeout policy.""" + +from pydantic import BaseModel, ConfigDict, Field + +from ai_company.core.enums import TimeoutActionType # noqa: TC001 +from ai_company.core.types import NotBlankStr # noqa: TC001 + + +class TimeoutAction(BaseModel): + """Action to take when an approval item times out. + + Attributes: + action: The timeout action type (wait, approve, deny, escalate). + reason: Human-readable explanation for the action. + escalate_to: Target role/agent for escalation (only when + action is ESCALATE). + """ + + model_config = ConfigDict(frozen=True) + + action: TimeoutActionType = Field(description="Timeout action type") + reason: NotBlankStr = Field(description="Explanation for the action") + escalate_to: NotBlankStr | None = Field( + default=None, + description="Escalation target (when action is ESCALATE)", + ) diff --git a/src/ai_company/security/timeout/park_service.py b/src/ai_company/security/timeout/park_service.py new file mode 100644 index 0000000000..27a96cfb21 --- /dev/null +++ b/src/ai_company/security/timeout/park_service.py @@ -0,0 +1,84 @@ +"""Park/resume service for agent execution contexts. + +Serializes an ``AgentContext`` into a ``ParkedContext`` for persistence +when an agent is parked awaiting approval, and deserializes it back +when the approval decision arrives. +""" + +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +from ai_company.observability import get_logger + +if TYPE_CHECKING: + from ai_company.engine.context import AgentContext +from ai_company.observability.events.timeout import ( + TIMEOUT_WAITING, +) +from ai_company.security.timeout.parked_context import ParkedContext + +logger = get_logger(__name__) + + +class ParkService: + """Handles parking and resuming agent execution contexts. + + Parking serializes the full ``AgentContext`` as JSON and stores it + via the ``ParkedContextRepository``. Resuming deserializes and + deletes the parked record. + """ + + def park( + self, + *, + context: AgentContext, + approval_id: str, + agent_id: str, + task_id: str, + metadata: dict[str, str] | None = None, + ) -> ParkedContext: + """Serialize and create a ``ParkedContext`` from an agent context. + + Args: + context: The agent context to park. + approval_id: The approval item that triggered parking. + agent_id: Agent identifier. + task_id: Task identifier. + metadata: Optional additional metadata. + + Returns: + A ``ParkedContext`` ready for persistence. + """ + context_json = context.model_dump_json() + + parked = ParkedContext( + execution_id=str(context.execution_id), + agent_id=agent_id, + task_id=task_id, + approval_id=approval_id, + parked_at=datetime.now(UTC), + context_json=context_json, + metadata=metadata or {}, + ) + + logger.info( + TIMEOUT_WAITING, + parked_id=parked.id, + agent_id=agent_id, + task_id=task_id, + approval_id=approval_id, + ) + return parked + + def resume(self, parked: ParkedContext) -> AgentContext: + """Deserialize a ``ParkedContext`` back into an ``AgentContext``. + + Args: + parked: The parked context to resume. + + Returns: + The restored ``AgentContext``. + """ + from ai_company.engine.context import AgentContext # noqa: PLC0415 + + return AgentContext.model_validate_json(parked.context_json) diff --git a/src/ai_company/security/timeout/parked_context.py b/src/ai_company/security/timeout/parked_context.py new file mode 100644 index 0000000000..10f2d643db --- /dev/null +++ b/src/ai_company/security/timeout/parked_context.py @@ -0,0 +1,44 @@ +"""Parked context model for suspended agent executions. + +When an agent's execution is parked (awaiting human approval), the +full ``AgentContext`` is serialized and stored as a ``ParkedContext`` +so it can be resumed when the approval decision arrives. +""" + +from uuid import uuid4 + +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field + +from ai_company.core.types import NotBlankStr # noqa: TC001 + + +class ParkedContext(BaseModel): + """Serialized snapshot of a parked agent execution. + + Attributes: + id: Unique identifier for this parked context. + execution_id: The execution run ID from ``AgentContext``. + agent_id: Agent whose execution was parked. + task_id: Task the agent was working on. + approval_id: Approval item that caused the park. + parked_at: Timestamp when the context was parked. + context_json: JSON-serialized ``AgentContext``. + metadata: Additional metadata (e.g. tool name, action type). + """ + + model_config = ConfigDict(frozen=True) + + id: NotBlankStr = Field( + default_factory=lambda: str(uuid4()), + description="Unique parked context identifier", + ) + execution_id: NotBlankStr = Field(description="Execution run identifier") + agent_id: NotBlankStr = Field(description="Agent identifier") + task_id: NotBlankStr = Field(description="Task identifier") + approval_id: NotBlankStr = Field(description="Approval item identifier") + parked_at: AwareDatetime = Field(description="When the context was parked") + context_json: str = Field(description="JSON-serialized AgentContext") + metadata: dict[str, str] = Field( + default_factory=dict, + description="Additional metadata", + ) diff --git a/src/ai_company/security/timeout/policies.py b/src/ai_company/security/timeout/policies.py new file mode 100644 index 0000000000..938f1820c0 --- /dev/null +++ b/src/ai_company/security/timeout/policies.py @@ -0,0 +1,278 @@ +"""Timeout policy implementations — wait, deny, tiered, escalation chain.""" + +from ai_company.core.approval import ApprovalItem # noqa: TC001 +from ai_company.core.enums import TimeoutActionType +from ai_company.observability import get_logger +from ai_company.observability.events.timeout import ( + TIMEOUT_AUTO_DENIED, + TIMEOUT_ESCALATED, + TIMEOUT_POLICY_EVALUATED, + TIMEOUT_WAITING, +) +from ai_company.security.timeout.config import ( + EscalationStep, # noqa: TC001 + TierConfig, # noqa: TC001 +) +from ai_company.security.timeout.models import TimeoutAction +from ai_company.security.timeout.protocol import RiskTierClassifier # noqa: TC001 + +logger = get_logger(__name__) + +_SECONDS_PER_MINUTE = 60.0 + + +class WaitForeverPolicy: + """Always returns WAIT — no automatic timeout action. + + This is the safest default: approvals remain pending until a + human responds. + """ + + async def determine_action( + self, + item: ApprovalItem, + elapsed_seconds: float, + ) -> TimeoutAction: + """Always wait. + + Args: + item: The pending approval item. + elapsed_seconds: Seconds since creation. + + Returns: + WAIT action. + """ + logger.debug( + TIMEOUT_WAITING, + approval_id=item.id, + elapsed_seconds=elapsed_seconds, + ) + return TimeoutAction( + action=TimeoutActionType.WAIT, + reason="Wait-forever policy — no automatic action", + ) + + +class DenyOnTimeoutPolicy: + """Deny the action after a fixed timeout. + + Args: + timeout_seconds: Seconds before auto-deny. + """ + + def __init__(self, *, timeout_seconds: float) -> None: + self._timeout_seconds = timeout_seconds + + async def determine_action( + self, + item: ApprovalItem, + elapsed_seconds: float, + ) -> TimeoutAction: + """WAIT if under timeout, DENY if over. + + Args: + item: The pending approval item. + elapsed_seconds: Seconds since creation. + + Returns: + WAIT or DENY action. + """ + if elapsed_seconds < self._timeout_seconds: + logger.debug( + TIMEOUT_WAITING, + approval_id=item.id, + elapsed_seconds=elapsed_seconds, + timeout_seconds=self._timeout_seconds, + ) + return TimeoutAction( + action=TimeoutActionType.WAIT, + reason=( + f"Waiting — {elapsed_seconds:.0f}s of " + f"{self._timeout_seconds:.0f}s elapsed" + ), + ) + + logger.info( + TIMEOUT_AUTO_DENIED, + approval_id=item.id, + elapsed_seconds=elapsed_seconds, + timeout_seconds=self._timeout_seconds, + ) + return TimeoutAction( + action=TimeoutActionType.DENY, + reason=( + f"Auto-denied after {elapsed_seconds:.0f}s " + f"(timeout: {self._timeout_seconds:.0f}s)" + ), + ) + + +class TieredTimeoutPolicy: + """Per-risk-tier timeout with configurable actions. + + Uses a :class:`RiskTierClassifier` to determine the risk tier + of each approval item, then applies the corresponding tier + configuration. + + Args: + tiers: Tier configurations keyed by risk level name. + classifier: Risk tier classifier for action types. + """ + + def __init__( + self, + *, + tiers: dict[str, TierConfig], + classifier: RiskTierClassifier, + ) -> None: + self._tiers = tiers + self._classifier = classifier + + async def determine_action( + self, + item: ApprovalItem, + elapsed_seconds: float, + ) -> TimeoutAction: + """Apply the tier-specific timeout policy. + + Args: + item: The pending approval item. + elapsed_seconds: Seconds since creation. + + Returns: + WAIT, DENY, APPROVE, or ESCALATE based on tier config. + """ + risk_level = self._classifier.classify(item.action_type) + tier_config = self._tiers.get(risk_level.value) + + if tier_config is None: + # No tier configured for this risk level — wait (safe default). + logger.debug( + TIMEOUT_WAITING, + approval_id=item.id, + risk_level=risk_level.value, + note="no tier config — defaulting to wait", + ) + return TimeoutAction( + action=TimeoutActionType.WAIT, + reason=( + f"No tier config for risk level {risk_level.value!r} — waiting" + ), + ) + + timeout_seconds = tier_config.timeout_minutes * _SECONDS_PER_MINUTE + + if elapsed_seconds < timeout_seconds: + logger.debug( + TIMEOUT_WAITING, + approval_id=item.id, + risk_level=risk_level.value, + elapsed_seconds=elapsed_seconds, + timeout_seconds=timeout_seconds, + ) + return TimeoutAction( + action=TimeoutActionType.WAIT, + reason=( + f"Tier {risk_level.value}: {elapsed_seconds:.0f}s of " + f"{timeout_seconds:.0f}s elapsed" + ), + ) + + logger.info( + TIMEOUT_POLICY_EVALUATED, + approval_id=item.id, + risk_level=risk_level.value, + on_timeout=tier_config.on_timeout.value, + elapsed_seconds=elapsed_seconds, + ) + return TimeoutAction( + action=tier_config.on_timeout, + reason=( + f"Tier {risk_level.value} timeout: auto-" + f"{tier_config.on_timeout.value} after " + f"{elapsed_seconds:.0f}s" + ), + ) + + +class EscalationChainPolicy: + """Escalate through a chain of roles, each with its own timeout. + + When the entire chain is exhausted, applies the + ``on_chain_exhausted`` action. + + Args: + chain: Ordered escalation steps. + on_chain_exhausted: Action when all steps exhaust. + """ + + def __init__( + self, + *, + chain: tuple[EscalationStep, ...], + on_chain_exhausted: TimeoutActionType, + ) -> None: + self._chain = chain + self._on_chain_exhausted = on_chain_exhausted + + async def determine_action( + self, + item: ApprovalItem, + elapsed_seconds: float, + ) -> TimeoutAction: + """Determine the current escalation step. + + Calculates cumulative timeouts to find which step the + approval is currently at. + + Args: + item: The pending approval item. + elapsed_seconds: Seconds since creation. + + Returns: + WAIT, ESCALATE, or the chain-exhausted action. + """ + if not self._chain: + return TimeoutAction( + action=self._on_chain_exhausted, + reason="Empty escalation chain — applying exhausted action", + ) + + cumulative_seconds = 0.0 + for step in self._chain: + step_timeout = step.timeout_minutes * _SECONDS_PER_MINUTE + if elapsed_seconds < cumulative_seconds + step_timeout: + # Still within this step's window. + if elapsed_seconds < cumulative_seconds: + # Before this step — shouldn't happen but safe. + break + logger.debug( + TIMEOUT_WAITING, + approval_id=item.id, + escalation_role=step.role, + elapsed_seconds=elapsed_seconds, + ) + return TimeoutAction( + action=TimeoutActionType.ESCALATE, + reason=( + f"Escalated to {step.role!r} — {elapsed_seconds:.0f}s elapsed" + ), + escalate_to=step.role, + ) + cumulative_seconds += step_timeout + + # All steps exhausted. + logger.info( + TIMEOUT_ESCALATED, + approval_id=item.id, + elapsed_seconds=elapsed_seconds, + on_exhausted=self._on_chain_exhausted.value, + note="escalation chain exhausted", + ) + return TimeoutAction( + action=self._on_chain_exhausted, + reason=( + f"Escalation chain exhausted after {elapsed_seconds:.0f}s " + f"— {self._on_chain_exhausted.value}" + ), + ) diff --git a/src/ai_company/security/timeout/protocol.py b/src/ai_company/security/timeout/protocol.py new file mode 100644 index 0000000000..bc3568613d --- /dev/null +++ b/src/ai_company/security/timeout/protocol.py @@ -0,0 +1,48 @@ +"""Timeout policy and risk tier classifier protocols.""" + +from typing import Protocol, runtime_checkable + +from ai_company.core.approval import ApprovalItem # noqa: TC001 +from ai_company.core.enums import ApprovalRiskLevel # noqa: TC001 +from ai_company.security.timeout.models import TimeoutAction # noqa: TC001 + + +@runtime_checkable +class TimeoutPolicy(Protocol): + """Protocol for approval timeout policies (DESIGN_SPEC §12.4). + + Implementations determine what happens when a human does not + respond to an approval request within a configured timeframe. + """ + + async def determine_action( + self, + item: ApprovalItem, + elapsed_seconds: float, + ) -> TimeoutAction: + """Determine the timeout action for a pending approval. + + Args: + item: The pending approval item. + elapsed_seconds: Seconds since the item was created. + + Returns: + The action to take (wait, approve, deny, or escalate). + """ + ... + + +@runtime_checkable +class RiskTierClassifier(Protocol): + """Classifies action types into risk tiers for tiered timeouts.""" + + def classify(self, action_type: str) -> ApprovalRiskLevel: + """Classify an action type's risk level. + + Args: + action_type: The ``category:action`` string. + + Returns: + The risk tier for timeout policy selection. + """ + ... diff --git a/src/ai_company/security/timeout/risk_tier_classifier.py b/src/ai_company/security/timeout/risk_tier_classifier.py new file mode 100644 index 0000000000..cfae61bfd7 --- /dev/null +++ b/src/ai_company/security/timeout/risk_tier_classifier.py @@ -0,0 +1,87 @@ +"""YAML-configurable risk tier classifier for timeout policies.""" + +from types import MappingProxyType +from typing import Final + +from ai_company.core.enums import ActionType, ApprovalRiskLevel +from ai_company.observability import get_logger +from ai_company.observability.events.timeout import TIMEOUT_POLICY_EVALUATED + +logger = get_logger(__name__) + +# Reuses the same risk assignments as security/rules/risk_classifier.py. +_DEFAULT_RISK_MAP: Final[MappingProxyType[str, ApprovalRiskLevel]] = MappingProxyType( + { + # CRITICAL + ActionType.DEPLOY_PRODUCTION: ApprovalRiskLevel.CRITICAL, + ActionType.DB_ADMIN: ApprovalRiskLevel.CRITICAL, + ActionType.ORG_FIRE: ApprovalRiskLevel.CRITICAL, + # HIGH + ActionType.DEPLOY_STAGING: ApprovalRiskLevel.HIGH, + ActionType.DB_MUTATE: ApprovalRiskLevel.HIGH, + ActionType.CODE_DELETE: ApprovalRiskLevel.HIGH, + ActionType.VCS_PUSH: ApprovalRiskLevel.HIGH, + ActionType.COMMS_EXTERNAL: ApprovalRiskLevel.HIGH, + ActionType.BUDGET_EXCEED: ApprovalRiskLevel.HIGH, + # MEDIUM + ActionType.CODE_CREATE: ApprovalRiskLevel.MEDIUM, + ActionType.CODE_WRITE: ApprovalRiskLevel.MEDIUM, + ActionType.CODE_REFACTOR: ApprovalRiskLevel.MEDIUM, + ActionType.VCS_COMMIT: ApprovalRiskLevel.MEDIUM, + ActionType.ARCH_DECIDE: ApprovalRiskLevel.MEDIUM, + ActionType.ORG_HIRE: ApprovalRiskLevel.MEDIUM, + ActionType.ORG_PROMOTE: ApprovalRiskLevel.MEDIUM, + ActionType.BUDGET_SPEND: ApprovalRiskLevel.MEDIUM, + # LOW + ActionType.CODE_READ: ApprovalRiskLevel.LOW, + ActionType.VCS_READ: ApprovalRiskLevel.LOW, + ActionType.TEST_RUN: ApprovalRiskLevel.LOW, + ActionType.TEST_WRITE: ApprovalRiskLevel.LOW, + ActionType.DOCS_WRITE: ApprovalRiskLevel.LOW, + ActionType.VCS_BRANCH: ApprovalRiskLevel.LOW, + ActionType.COMMS_INTERNAL: ApprovalRiskLevel.LOW, + ActionType.DB_QUERY: ApprovalRiskLevel.LOW, + } +) + + +class YamlRiskTierClassifier: + """Maps action types to risk tiers for tiered timeout policies. + + Unknown action types default to HIGH (fail-safe per D19). + + Args: + custom_map: Optional overrides for the default risk mapping. + """ + + def __init__( + self, + *, + custom_map: dict[str, ApprovalRiskLevel] | None = None, + ) -> None: + if custom_map: + merged = dict(_DEFAULT_RISK_MAP) + merged.update(custom_map) + self._risk_map = MappingProxyType(merged) + else: + self._risk_map = _DEFAULT_RISK_MAP + + def classify(self, action_type: str) -> ApprovalRiskLevel: + """Classify an action type's risk tier. + + Args: + action_type: The ``category:action`` string. + + Returns: + Risk tier. Defaults to HIGH for unknown types. + """ + result = self._risk_map.get(action_type) + if result is None: + logger.debug( + TIMEOUT_POLICY_EVALUATED, + action_type=action_type, + risk_tier="high", + note="unknown action type — defaulting to HIGH", + ) + return ApprovalRiskLevel.HIGH + return result diff --git a/src/ai_company/security/timeout/timeout_checker.py b/src/ai_company/security/timeout/timeout_checker.py new file mode 100644 index 0000000000..43bddbe117 --- /dev/null +++ b/src/ai_company/security/timeout/timeout_checker.py @@ -0,0 +1,110 @@ +"""Timeout checker — evaluates pending approvals against timeout policy. + +Periodically called (by the engine or a background task) to check +whether pending approval items have exceeded their timeout thresholds +and apply the configured ``TimeoutPolicy``. +""" + +from datetime import UTC, datetime + +from ai_company.core.approval import ApprovalItem # noqa: TC001 +from ai_company.core.enums import ApprovalStatus, TimeoutActionType +from ai_company.observability import get_logger +from ai_company.observability.events.timeout import ( + TIMEOUT_AUTO_APPROVED, + TIMEOUT_AUTO_DENIED, + TIMEOUT_ESCALATED, + TIMEOUT_POLICY_EVALUATED, + TIMEOUT_WAITING, +) +from ai_company.security.timeout.models import TimeoutAction # noqa: TC001 +from ai_company.security.timeout.protocol import TimeoutPolicy # noqa: TC001 + +logger = get_logger(__name__) + + +class TimeoutChecker: + """Evaluates pending approvals against the configured timeout policy. + + Args: + policy: The timeout policy to apply. + """ + + def __init__(self, *, policy: TimeoutPolicy) -> None: + self._policy = policy + + async def check( + self, + item: ApprovalItem, + ) -> TimeoutAction: + """Evaluate a single pending approval item. + + Args: + item: The approval item to check. + + Returns: + The ``TimeoutAction`` determined by the policy. + """ + now = datetime.now(UTC) + elapsed = (now - item.created_at).total_seconds() + + action = await self._policy.determine_action(item, elapsed) + + event = { + TimeoutActionType.WAIT: TIMEOUT_WAITING, + TimeoutActionType.APPROVE: TIMEOUT_AUTO_APPROVED, + TimeoutActionType.DENY: TIMEOUT_AUTO_DENIED, + TimeoutActionType.ESCALATE: TIMEOUT_ESCALATED, + }.get(action.action, TIMEOUT_POLICY_EVALUATED) + + logger.info( + event, + approval_id=item.id, + action_type=item.action_type, + elapsed_seconds=elapsed, + timeout_action=action.action.value, + reason=action.reason, + ) + return action + + async def check_and_resolve( + self, + item: ApprovalItem, + ) -> tuple[ApprovalItem, TimeoutAction]: + """Check an approval and return the updated item with the action. + + If the policy returns APPROVE or DENY, the item's status is + updated accordingly. WAIT and ESCALATE leave the item in + PENDING status (escalation is handled by the caller). + + Args: + item: The approval item to check. + + Returns: + Tuple of (possibly updated item, timeout action). + """ + action = await self.check(item) + + if action.action == TimeoutActionType.APPROVE: + updated = item.model_copy( + update={ + "status": ApprovalStatus.APPROVED, + "decided_at": datetime.now(UTC), + "decided_by": "timeout_policy", + "decision_reason": action.reason, + }, + ) + return updated, action + + if action.action == TimeoutActionType.DENY: + updated = item.model_copy( + update={ + "status": ApprovalStatus.REJECTED, + "decided_at": datetime.now(UTC), + "decided_by": "timeout_policy", + "decision_reason": action.reason, + }, + ) + return updated, action + + return item, action diff --git a/src/ai_company/templates/renderer.py b/src/ai_company/templates/renderer.py index fa93f4dae2..43c0e8a996 100644 --- a/src/ai_company/templates/renderer.py +++ b/src/ai_company/templates/renderer.py @@ -526,14 +526,21 @@ def _validate_list( def _extract_numeric_config( company: dict[str, Any], template: CompanyTemplate, -) -> tuple[float, float]: - """Extract autonomy and budget_monthly as floats.""" +) -> tuple[float | dict[str, Any], float]: + """Extract autonomy and budget_monthly. + + Autonomy may be a float (backward compat) or a dict. When it's + a float, we pass it through — the ``CompanyConfig.model_validator`` + converts it to ``AutonomyConfig``. + """ source_name = template.metadata.name + raw_autonomy = company.get("autonomy", template.autonomy) try: - autonomy = to_float( - company.get("autonomy", template.autonomy), - field_name="autonomy", - ) + if isinstance(raw_autonomy, dict): + # Already an AutonomyConfig-like dict — pass through. + autonomy: float | dict[str, Any] = raw_autonomy + else: + autonomy = to_float(raw_autonomy, field_name="autonomy") budget_monthly = to_float( company.get("budget_monthly", template.budget_monthly), field_name="budget_monthly", diff --git a/tests/unit/api/conftest.py b/tests/unit/api/conftest.py index eecb62844c..364105adc9 100644 --- a/tests/unit/api/conftest.py +++ b/tests/unit/api/conftest.py @@ -21,6 +21,7 @@ TaskStatus, ) from ai_company.core.task import Task +from ai_company.security.timeout.parked_context import ParkedContext # noqa: TC001 # ── Fake Repositories ──────────────────────────────────────────── @@ -188,6 +189,31 @@ async def query( return tuple(result) +class FakeParkedContextRepository: + """In-memory parked context repository for tests.""" + + def __init__(self) -> None: + self._contexts: dict[str, ParkedContext] = {} + + async def save(self, context: ParkedContext) -> None: + self._contexts[context.id] = context + + async def get(self, parked_id: str) -> ParkedContext | None: + return self._contexts.get(parked_id) + + async def get_by_approval(self, approval_id: str) -> ParkedContext | None: + for ctx in self._contexts.values(): + if ctx.approval_id == approval_id: + return ctx + return None + + async def get_by_agent(self, agent_id: str) -> tuple[ParkedContext, ...]: + return tuple(ctx for ctx in self._contexts.values() if ctx.agent_id == agent_id) + + async def delete(self, parked_id: str) -> bool: + return self._contexts.pop(parked_id, None) is not None + + class FakePersistenceBackend: """In-memory persistence backend for tests.""" @@ -198,6 +224,7 @@ def __init__(self) -> None: self._lifecycle_events = FakeLifecycleEventRepository() self._task_metrics = FakeTaskMetricRepository() self._collaboration_metrics = FakeCollaborationMetricRepository() + self._parked_contexts = FakeParkedContextRepository() self._connected = False async def connect(self) -> None: @@ -244,6 +271,10 @@ def task_metrics(self) -> FakeTaskMetricRepository: def collaboration_metrics(self) -> FakeCollaborationMetricRepository: return self._collaboration_metrics + @property + def parked_contexts(self) -> FakeParkedContextRepository: + return self._parked_contexts + # ── Fake Message Bus ──────────────────────────────────────────── diff --git a/tests/unit/core/conftest.py b/tests/unit/core/conftest.py index 0447f7bb53..9b6c09de70 100644 --- a/tests/unit/core/conftest.py +++ b/tests/unit/core/conftest.py @@ -43,6 +43,7 @@ from ai_company.core.project import Project from ai_company.core.role import Authority, CustomRole, Role, SeniorityInfo, Skill from ai_company.core.task import AcceptanceCriterion, Task +from ai_company.security.autonomy.models import AutonomyConfig # ── Factories ────────────────────────────────────────────────────── @@ -130,6 +131,7 @@ class DepartmentFactory(ModelFactory[Department]): class CompanyConfigFactory(ModelFactory[CompanyConfig]): __model__ = CompanyConfig + autonomy = AutonomyConfig() class HRRegistryFactory(ModelFactory[HRRegistry]): @@ -153,6 +155,7 @@ class CompanyFactory(ModelFactory[Company]): departments = () workflow_handoffs = () escalation_paths = () + config = CompanyConfigFactory class ExpectedArtifactFactory(ModelFactory[ExpectedArtifact]): diff --git a/tests/unit/core/test_company.py b/tests/unit/core/test_company.py index 5f1958994b..80aebcaf55 100644 --- a/tests/unit/core/test_company.py +++ b/tests/unit/core/test_company.py @@ -16,7 +16,11 @@ Team, WorkflowHandoff, ) -from ai_company.core.enums import CompanyType +from ai_company.core.enums import AutonomyLevel, CompanyType +from ai_company.security.timeout.config import ( + DenyOnTimeoutConfig, + WaitForeverConfig, +) from .conftest import ( CompanyConfigFactory, @@ -218,29 +222,28 @@ class TestCompanyConfig: """Tests for CompanyConfig defaults, autonomy bounds, and validation.""" def test_defaults(self) -> None: - """Verify default autonomy, budget, and communication pattern.""" + """Verify default autonomy config, budget, and communication pattern.""" cfg = CompanyConfig() - assert cfg.autonomy == 0.5 + assert cfg.autonomy.level == AutonomyLevel.SEMI assert cfg.budget_monthly == 100.0 assert cfg.communication_pattern == "hybrid" assert cfg.tool_access_default == () - def test_autonomy_boundaries(self) -> None: - """Accept autonomy at both boundaries (0.0 and 1.0).""" - low = CompanyConfig(autonomy=0.0) - high = CompanyConfig(autonomy=1.0) - assert low.autonomy == 0.0 - assert high.autonomy == 1.0 - - def test_autonomy_below_zero_rejected(self) -> None: - """Reject autonomy below 0.0.""" - with pytest.raises(ValidationError): - CompanyConfig(autonomy=-0.1) - - def test_autonomy_above_one_rejected(self) -> None: - """Reject autonomy above 1.0.""" - with pytest.raises(ValidationError): - CompanyConfig(autonomy=1.1) + def test_autonomy_float_backward_compat(self) -> None: + """Accept bare float and convert to AutonomyConfig via before-validator.""" + low = CompanyConfig(autonomy=0.0) # type: ignore[arg-type] + high = CompanyConfig(autonomy=1.0) # type: ignore[arg-type] + mid = CompanyConfig(autonomy=0.5) # type: ignore[arg-type] + supervised = CompanyConfig(autonomy=0.3) # type: ignore[arg-type] + assert low.autonomy.level == AutonomyLevel.LOCKED + assert high.autonomy.level == AutonomyLevel.FULL + assert mid.autonomy.level == AutonomyLevel.SEMI + assert supervised.autonomy.level == AutonomyLevel.SUPERVISED + + def test_autonomy_config_direct(self) -> None: + """Accept AutonomyConfig dict directly.""" + cfg = CompanyConfig(autonomy={"level": "full"}) # type: ignore[arg-type] + assert cfg.autonomy.level == AutonomyLevel.FULL def test_budget_negative_rejected(self) -> None: """Reject negative monthly budget.""" @@ -266,7 +269,7 @@ def test_frozen(self) -> None: """Ensure CompanyConfig is immutable.""" cfg = CompanyConfig() with pytest.raises(ValidationError): - cfg.autonomy = 1.0 # type: ignore[misc] + cfg.budget_monthly = 999.0 # type: ignore[misc] def test_factory(self) -> None: """Verify factory produces a valid CompanyConfig.""" @@ -832,3 +835,32 @@ def test_duplicate_subordinates_whitespace_insensitive(self) -> None: ReportingLine(subordinate=" Alice ", supervisor="manager"), ), ) + + +# ── CompanyConfig approval timeout ──────────────────────────────── + + +@pytest.mark.unit +class TestCompanyConfigApprovalTimeout: + """Tests for CompanyConfig.approval_timeout field.""" + + def test_default_approval_timeout(self) -> None: + """CompanyConfig() defaults to WaitForeverConfig.""" + cfg = CompanyConfig() + assert isinstance(cfg.approval_timeout, WaitForeverConfig) + assert cfg.approval_timeout.policy == "wait" + + def test_custom_approval_timeout(self) -> None: + """Can pass a DenyOnTimeoutConfig as approval_timeout.""" + deny_cfg = DenyOnTimeoutConfig(timeout_minutes=60.0) + cfg = CompanyConfig(approval_timeout=deny_cfg) + assert isinstance(cfg.approval_timeout, DenyOnTimeoutConfig) + assert cfg.approval_timeout.timeout_minutes == 60.0 + + def test_approval_timeout_from_dict(self) -> None: + """Can construct from dict with discriminated union.""" + cfg = CompanyConfig.model_validate( + {"approval_timeout": {"policy": "deny", "timeout_minutes": 60}} + ) + assert isinstance(cfg.approval_timeout, DenyOnTimeoutConfig) + assert cfg.approval_timeout.timeout_minutes == 60.0 diff --git a/tests/unit/engine/test_loop_protocol.py b/tests/unit/engine/test_loop_protocol.py index 47554d83c0..ba5741ce1e 100644 --- a/tests/unit/engine/test_loop_protocol.py +++ b/tests/unit/engine/test_loop_protocol.py @@ -30,9 +30,10 @@ def test_values(self) -> None: assert TerminationReason.BUDGET_EXHAUSTED.value == "budget_exhausted" assert TerminationReason.SHUTDOWN.value == "shutdown" assert TerminationReason.ERROR.value == "error" + assert TerminationReason.PARKED.value == "parked" def test_member_count(self) -> None: - assert len(TerminationReason) == 5 + assert len(TerminationReason) == 6 @pytest.mark.unit diff --git a/tests/unit/engine/test_prompt.py b/tests/unit/engine/test_prompt.py index 2ec243d875..c9f45a4974 100644 --- a/tests/unit/engine/test_prompt.py +++ b/tests/unit/engine/test_prompt.py @@ -9,6 +9,7 @@ from ai_company.core.agent import AgentIdentity, ModelConfig, PersonalityConfig from ai_company.core.enums import ( + AutonomyLevel, CollaborationPreference, CommunicationVerbosity, ConflictApproach, @@ -33,6 +34,7 @@ PROMPT_BUILD_SUCCESS, PROMPT_BUILD_TOKEN_TRIMMED, ) +from ai_company.security.autonomy.models import EffectiveAutonomy if TYPE_CHECKING: from ai_company.core.company import Company @@ -940,3 +942,76 @@ def _broken_render(*_args: object, **_kwargs: object) -> None: build_system_prompt(agent=sample_agent_with_personality) assert isinstance(exc_info.value.__cause__, RuntimeError) + + +# ── TestEffectiveAutonomyInPrompt ────────────────────────────── + + +class TestEffectiveAutonomyInPrompt: + """Tests for effective autonomy info in the system prompt.""" + + @pytest.mark.unit + def test_autonomy_level_in_prompt( + self, + sample_agent_with_personality: AgentIdentity, + ) -> None: + """Effective autonomy level appears in the rendered prompt.""" + autonomy = EffectiveAutonomy( + level=AutonomyLevel.SEMI, + auto_approve_actions=frozenset({"code:read", "code:write"}), + human_approval_actions=frozenset({"infra:deploy"}), + security_agent=False, + ) + result = build_system_prompt( + agent=sample_agent_with_personality, + effective_autonomy=autonomy, + ) + assert "semi" in result.content + + @pytest.mark.unit + def test_auto_approve_actions_in_prompt( + self, + sample_agent_with_personality: AgentIdentity, + ) -> None: + """Auto-approved actions are listed in the prompt.""" + autonomy = EffectiveAutonomy( + level=AutonomyLevel.FULL, + auto_approve_actions=frozenset({"code:read", "code:write"}), + human_approval_actions=frozenset(), + security_agent=False, + ) + result = build_system_prompt( + agent=sample_agent_with_personality, + effective_autonomy=autonomy, + ) + assert "code:read" in result.content + assert "code:write" in result.content + + @pytest.mark.unit + def test_human_approval_actions_in_prompt( + self, + sample_agent_with_personality: AgentIdentity, + ) -> None: + """Human-approval-required actions are listed in the prompt.""" + autonomy = EffectiveAutonomy( + level=AutonomyLevel.SUPERVISED, + auto_approve_actions=frozenset(), + human_approval_actions=frozenset({"infra:deploy", "budget:spend"}), + security_agent=False, + ) + result = build_system_prompt( + agent=sample_agent_with_personality, + effective_autonomy=autonomy, + ) + assert "infra:deploy" in result.content + assert "budget:spend" in result.content + + @pytest.mark.unit + def test_no_autonomy_omits_section( + self, + sample_agent_with_personality: AgentIdentity, + ) -> None: + """When no effective_autonomy is provided, no autonomy level section.""" + result = build_system_prompt(agent=sample_agent_with_personality) + assert "Autonomy level" not in result.content + assert "Auto-approved actions" not in result.content diff --git a/tests/unit/observability/test_events.py b/tests/unit/observability/test_events.py index a53b0d45f7..80c72915ab 100644 --- a/tests/unit/observability/test_events.py +++ b/tests/unit/observability/test_events.py @@ -175,24 +175,31 @@ def test_has_at_least_20_events(self) -> None: def test_all_domain_modules_discovered(self) -> None: """Every expected domain module is found by pkgutil discovery.""" expected = { + "api", + "autonomy", "budget", "cfo", "classification", + "code_runner", "communication", "company", "config", "conflict", - "code_runner", + "consolidation", "correlation", "decomposition", "delegation", "docker", "execution", "git", + "hr", "mcp", "meeting", "memory", + "org_memory", "parallel", + "performance", + "persistence", "personality", "prompt", "provider", @@ -200,17 +207,14 @@ def test_all_domain_modules_discovered(self) -> None: "role", "routing", "sandbox", - "api", + "security", "task", "task_assignment", "task_routing", "template", + "timeout", "tool", - "persistence", "workspace", - "consolidation", - "org_memory", - "security", "hr", "performance", "trust", diff --git a/tests/unit/persistence/test_migrations_v2.py b/tests/unit/persistence/test_migrations_v2.py index 2303f880f0..c5fe75c939 100644 --- a/tests/unit/persistence/test_migrations_v2.py +++ b/tests/unit/persistence/test_migrations_v2.py @@ -28,8 +28,8 @@ async def memory_db() -> AsyncGenerator[aiosqlite.Connection]: @pytest.mark.unit class TestV2Migration: - async def test_schema_version_is_two(self) -> None: - assert SCHEMA_VERSION == 2 + async def test_schema_version_is_three(self) -> None: + assert SCHEMA_VERSION == 3 async def test_fresh_db_creates_all_v2_tables( self, memory_db: aiosqlite.Connection @@ -58,7 +58,7 @@ async def test_v1_to_v2_migration(self, memory_db: aiosqlite.Connection) -> None assert await get_user_version(memory_db) == 1 await run_migrations(memory_db) - assert await get_user_version(memory_db) == 2 + assert await get_user_version(memory_db) == 3 cursor = await memory_db.execute( "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" diff --git a/tests/unit/persistence/test_protocol.py b/tests/unit/persistence/test_protocol.py index 68f3b84950..665ec35c79 100644 --- a/tests/unit/persistence/test_protocol.py +++ b/tests/unit/persistence/test_protocol.py @@ -27,6 +27,7 @@ CollaborationMetricRecord, TaskMetricRecord, ) + from ai_company.security.timeout.parked_context import ParkedContext class _FakeTaskRepository: @@ -122,6 +123,23 @@ async def query( return () +class _FakeParkedContextRepository: + async def save(self, context: ParkedContext) -> None: + pass + + async def get(self, parked_id: str) -> ParkedContext | None: + return None + + async def get_by_approval(self, approval_id: str) -> ParkedContext | None: + return None + + async def get_by_agent(self, agent_id: str) -> tuple[ParkedContext, ...]: + return () + + async def delete(self, parked_id: str) -> bool: + return False + + class _FakeBackend: async def connect(self) -> None: pass @@ -163,6 +181,10 @@ def lifecycle_events(self) -> _FakeLifecycleEventRepository: def task_metrics(self) -> _FakeTaskMetricRepository: return _FakeTaskMetricRepository() + @property + def parked_contexts(self) -> _FakeParkedContextRepository: + return _FakeParkedContextRepository() + @property def collaboration_metrics(self) -> _FakeCollaborationMetricRepository: return _FakeCollaborationMetricRepository() diff --git a/tests/unit/security/autonomy/__init__.py b/tests/unit/security/autonomy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/security/autonomy/test_change_strategy.py b/tests/unit/security/autonomy/test_change_strategy.py new file mode 100644 index 0000000000..7b5eee044a --- /dev/null +++ b/tests/unit/security/autonomy/test_change_strategy.py @@ -0,0 +1,85 @@ +"""Tests for HumanOnlyPromotionStrategy.""" + +import pytest + +from ai_company.core.enums import AutonomyLevel, DowngradeReason +from ai_company.security.autonomy.change_strategy import HumanOnlyPromotionStrategy + + +class TestPromotion: + """Promotion is always denied in human-only strategy.""" + + @pytest.mark.unit + def test_promotion_denied(self) -> None: + strategy = HumanOnlyPromotionStrategy() + result = strategy.request_promotion("agent-1", AutonomyLevel.FULL) + assert result is False + + @pytest.mark.unit + @pytest.mark.parametrize("target", list(AutonomyLevel)) + def test_all_promotions_denied(self, target: AutonomyLevel) -> None: + strategy = HumanOnlyPromotionStrategy() + assert strategy.request_promotion("agent-x", target) is False + + +class TestAutoDowngrade: + """Auto-downgrade maps reasons to specific levels.""" + + @pytest.mark.unit + def test_high_error_rate_to_supervised(self) -> None: + strategy = HumanOnlyPromotionStrategy() + result = strategy.auto_downgrade("agent-1", DowngradeReason.HIGH_ERROR_RATE) + assert result == AutonomyLevel.SUPERVISED + + @pytest.mark.unit + def test_budget_exhausted_to_supervised(self) -> None: + strategy = HumanOnlyPromotionStrategy() + result = strategy.auto_downgrade("agent-1", DowngradeReason.BUDGET_EXHAUSTED) + assert result == AutonomyLevel.SUPERVISED + + @pytest.mark.unit + def test_security_incident_to_locked(self) -> None: + strategy = HumanOnlyPromotionStrategy() + result = strategy.auto_downgrade("agent-1", DowngradeReason.SECURITY_INCIDENT) + assert result == AutonomyLevel.LOCKED + + @pytest.mark.unit + def test_override_tracked(self) -> None: + strategy = HumanOnlyPromotionStrategy() + strategy.auto_downgrade("agent-1", DowngradeReason.HIGH_ERROR_RATE) + override = strategy.get_override("agent-1") + assert override is not None + assert override.current_level == AutonomyLevel.SUPERVISED + assert override.reason == DowngradeReason.HIGH_ERROR_RATE + assert override.requires_human_recovery is True + + @pytest.mark.unit + def test_no_override_when_not_downgraded(self) -> None: + strategy = HumanOnlyPromotionStrategy() + assert strategy.get_override("agent-1") is None + + +class TestRecovery: + """Recovery is always denied in human-only strategy.""" + + @pytest.mark.unit + def test_recovery_denied(self) -> None: + strategy = HumanOnlyPromotionStrategy() + result = strategy.request_recovery("agent-1") + assert result is False + + +class TestOverrideManagement: + """Override clear/get operations.""" + + @pytest.mark.unit + def test_clear_existing_override(self) -> None: + strategy = HumanOnlyPromotionStrategy() + strategy.auto_downgrade("agent-1", DowngradeReason.HIGH_ERROR_RATE) + assert strategy.clear_override("agent-1") is True + assert strategy.get_override("agent-1") is None + + @pytest.mark.unit + def test_clear_nonexistent_override(self) -> None: + strategy = HumanOnlyPromotionStrategy() + assert strategy.clear_override("agent-1") is False diff --git a/tests/unit/security/autonomy/test_models.py b/tests/unit/security/autonomy/test_models.py new file mode 100644 index 0000000000..433f1af1ff --- /dev/null +++ b/tests/unit/security/autonomy/test_models.py @@ -0,0 +1,182 @@ +"""Tests for autonomy models — presets, config, effective autonomy, overrides.""" + +from datetime import UTC, datetime + +import pytest +from pydantic import ValidationError + +from ai_company.core.enums import AutonomyLevel, DowngradeReason +from ai_company.security.autonomy.models import ( + BUILTIN_PRESETS, + AutonomyConfig, + AutonomyOverride, + AutonomyPreset, + EffectiveAutonomy, +) + + +class TestAutonomyPreset: + """AutonomyPreset validation tests.""" + + @pytest.mark.unit + def test_valid_preset(self) -> None: + preset = AutonomyPreset( + level=AutonomyLevel.SEMI, + description="Test preset", + auto_approve=("code:read",), + human_approval=("deploy:production",), + ) + assert preset.level == AutonomyLevel.SEMI + assert preset.auto_approve == ("code:read",) + assert preset.human_approval == ("deploy:production",) + assert preset.security_agent is True + + @pytest.mark.unit + def test_disjoint_enforcement(self) -> None: + with pytest.raises(ValueError, match="disjoint"): + AutonomyPreset( + level=AutonomyLevel.SEMI, + description="Overlapping", + auto_approve=("code:read", "code:write"), + human_approval=("code:write",), + ) + + @pytest.mark.unit + def test_empty_lists_valid(self) -> None: + preset = AutonomyPreset( + level=AutonomyLevel.LOCKED, + description="Empty", + auto_approve=(), + human_approval=(), + ) + assert preset.auto_approve == () + assert preset.human_approval == () + + +class TestBuiltinPresets: + """Validate the four built-in presets.""" + + @pytest.mark.unit + def test_all_levels_present(self) -> None: + for level in AutonomyLevel: + assert level in BUILTIN_PRESETS, f"Missing preset for {level}" + + @pytest.mark.unit + def test_full_preset_auto_approves_all(self) -> None: + full = BUILTIN_PRESETS[AutonomyLevel.FULL] + assert "all" in full.auto_approve + assert full.human_approval == () + assert full.security_agent is False + + @pytest.mark.unit + def test_locked_preset_requires_all_human(self) -> None: + locked = BUILTIN_PRESETS[AutonomyLevel.LOCKED] + assert locked.auto_approve == () + assert "all" in locked.human_approval + assert locked.security_agent is True + + @pytest.mark.unit + def test_semi_preset_has_both(self) -> None: + semi = BUILTIN_PRESETS[AutonomyLevel.SEMI] + assert len(semi.auto_approve) > 0 + assert len(semi.human_approval) > 0 + + @pytest.mark.unit + def test_supervised_preset_read_only_auto(self) -> None: + supervised = BUILTIN_PRESETS[AutonomyLevel.SUPERVISED] + assert "code:read" in supervised.auto_approve + assert "code:write" in supervised.human_approval + + @pytest.mark.unit + def test_presets_are_disjoint(self) -> None: + for level, preset in BUILTIN_PRESETS.items(): + overlap = set(preset.auto_approve) & set(preset.human_approval) + assert overlap == set(), ( + f"Preset {level} has overlapping entries: {overlap}" + ) + + +class TestAutonomyConfig: + """AutonomyConfig validation tests.""" + + @pytest.mark.unit + def test_default_config(self) -> None: + config = AutonomyConfig() + assert config.level == AutonomyLevel.SEMI + assert len(config.presets) == len(AutonomyLevel) + + @pytest.mark.unit + def test_custom_level(self) -> None: + config = AutonomyConfig(level=AutonomyLevel.FULL) + assert config.level == AutonomyLevel.FULL + + @pytest.mark.unit + def test_level_must_be_in_presets(self) -> None: + custom_presets: dict[str, AutonomyPreset] = { + "semi": BUILTIN_PRESETS[AutonomyLevel.SEMI], + } + with pytest.raises(ValueError, match="not found in presets"): + AutonomyConfig(level=AutonomyLevel.FULL, presets=custom_presets) + + @pytest.mark.unit + def test_config_frozen(self) -> None: + config = AutonomyConfig() + with pytest.raises(ValidationError): + config.level = AutonomyLevel.FULL # type: ignore[misc] + + +class TestEffectiveAutonomy: + """EffectiveAutonomy model tests.""" + + @pytest.mark.unit + def test_creation(self) -> None: + effective = EffectiveAutonomy( + level=AutonomyLevel.SEMI, + auto_approve_actions=frozenset({"code:read"}), + human_approval_actions=frozenset({"deploy:production"}), + security_agent=True, + ) + assert effective.level == AutonomyLevel.SEMI + assert "code:read" in effective.auto_approve_actions + assert "deploy:production" in effective.human_approval_actions + + @pytest.mark.unit + def test_frozen(self) -> None: + effective = EffectiveAutonomy( + level=AutonomyLevel.FULL, + auto_approve_actions=frozenset(), + human_approval_actions=frozenset(), + security_agent=False, + ) + with pytest.raises(ValidationError): + effective.level = AutonomyLevel.LOCKED # type: ignore[misc] + + +class TestAutonomyOverride: + """AutonomyOverride model tests.""" + + @pytest.mark.unit + def test_creation(self) -> None: + now = datetime.now(UTC) + override = AutonomyOverride( + agent_id="agent-1", + original_level=AutonomyLevel.SEMI, + current_level=AutonomyLevel.SUPERVISED, + reason=DowngradeReason.HIGH_ERROR_RATE, + downgraded_at=now, + ) + assert override.agent_id == "agent-1" + assert override.requires_human_recovery is True + + @pytest.mark.unit + def test_override_frozen(self) -> None: + now = datetime.now(UTC) + override = AutonomyOverride( + agent_id="agent-1", + original_level=AutonomyLevel.FULL, + current_level=AutonomyLevel.LOCKED, + reason=DowngradeReason.SECURITY_INCIDENT, + downgraded_at=now, + ) + with pytest.raises(ValidationError): + override.current_level = AutonomyLevel.FULL # type: ignore[misc] diff --git a/tests/unit/security/autonomy/test_resolver.py b/tests/unit/security/autonomy/test_resolver.py new file mode 100644 index 0000000000..0e8195704e --- /dev/null +++ b/tests/unit/security/autonomy/test_resolver.py @@ -0,0 +1,146 @@ +"""Tests for AutonomyResolver — resolution chain, expansion, seniority.""" + +import pytest + +from ai_company.core.enums import ActionType, AutonomyLevel, SeniorityLevel +from ai_company.security.action_types import ActionTypeRegistry +from ai_company.security.autonomy.models import ( + BUILTIN_PRESETS, + AutonomyConfig, + AutonomyPreset, +) +from ai_company.security.autonomy.resolver import AutonomyResolver + + +def _make_resolver( + *, + level: AutonomyLevel = AutonomyLevel.SEMI, + custom_types: frozenset[str] = frozenset(), +) -> AutonomyResolver: + """Create a resolver with the given default level.""" + registry = ActionTypeRegistry(custom_types=custom_types) + config = AutonomyConfig(level=level) + return AutonomyResolver(registry=registry, config=config) + + +class TestResolutionChain: + """Three-level resolution chain: agent → department → company.""" + + @pytest.mark.unit + def test_company_default(self) -> None: + resolver = _make_resolver(level=AutonomyLevel.SEMI) + result = resolver.resolve() + assert result.level == AutonomyLevel.SEMI + + @pytest.mark.unit + def test_department_override(self) -> None: + resolver = _make_resolver(level=AutonomyLevel.SEMI) + result = resolver.resolve(department_level=AutonomyLevel.SUPERVISED) + assert result.level == AutonomyLevel.SUPERVISED + + @pytest.mark.unit + def test_agent_overrides_department(self) -> None: + resolver = _make_resolver(level=AutonomyLevel.SEMI) + result = resolver.resolve( + agent_level=AutonomyLevel.FULL, + department_level=AutonomyLevel.SUPERVISED, + ) + assert result.level == AutonomyLevel.FULL + + @pytest.mark.unit + def test_agent_overrides_company(self) -> None: + resolver = _make_resolver(level=AutonomyLevel.LOCKED) + result = resolver.resolve(agent_level=AutonomyLevel.FULL) + assert result.level == AutonomyLevel.FULL + + +class TestCategoryExpansion: + """Category shortcut and 'all' expansion.""" + + @pytest.mark.unit + def test_category_expansion(self) -> None: + resolver = _make_resolver(level=AutonomyLevel.SEMI) + result = resolver.resolve() + # SEMI auto-approves "code" category — includes code:read, etc. + assert ActionType.CODE_READ in result.auto_approve_actions + assert ActionType.CODE_WRITE in result.auto_approve_actions + assert ActionType.CODE_CREATE in result.auto_approve_actions + + @pytest.mark.unit + def test_all_shortcut_full(self) -> None: + resolver = _make_resolver(level=AutonomyLevel.FULL) + result = resolver.resolve() + # FULL auto-approves "all" — should include every registered type. + all_types = ActionTypeRegistry().all_types() + assert result.auto_approve_actions == all_types + + @pytest.mark.unit + def test_all_shortcut_locked(self) -> None: + resolver = _make_resolver(level=AutonomyLevel.LOCKED) + result = resolver.resolve() + # LOCKED human_approval = "all" + all_types = ActionTypeRegistry().all_types() + assert result.human_approval_actions == all_types + assert result.auto_approve_actions == frozenset() + + @pytest.mark.unit + def test_concrete_action_types(self) -> None: + resolver = _make_resolver(level=AutonomyLevel.SUPERVISED) + result = resolver.resolve() + # SUPERVISED auto-approves code:read, vcs:read, test:run, db:query + assert ActionType.CODE_READ in result.auto_approve_actions + assert ActionType.VCS_READ in result.auto_approve_actions + assert ActionType.TEST_RUN in result.auto_approve_actions + + @pytest.mark.unit + def test_custom_action_types_included(self) -> None: + resolver = _make_resolver( + level=AutonomyLevel.SEMI, + custom_types=frozenset({"code:lint"}), + ) + result = resolver.resolve() + # "code" category expansion should include custom code:lint. + assert "code:lint" in result.auto_approve_actions + + +class TestSeniorityValidation: + """Seniority constraint: JUNIOR + FULL is rejected.""" + + @pytest.mark.unit + def test_junior_full_rejected(self) -> None: + resolver = _make_resolver() + with pytest.raises(ValueError, match="FULL autonomy"): + resolver.validate_seniority(SeniorityLevel.JUNIOR, AutonomyLevel.FULL) + + @pytest.mark.unit + def test_junior_semi_allowed(self) -> None: + resolver = _make_resolver() + resolver.validate_seniority(SeniorityLevel.JUNIOR, AutonomyLevel.SEMI) + + @pytest.mark.unit + def test_mid_full_allowed(self) -> None: + resolver = _make_resolver() + resolver.validate_seniority(SeniorityLevel.MID, AutonomyLevel.FULL) + + @pytest.mark.unit + @pytest.mark.parametrize("level", list(SeniorityLevel)) + def test_locked_always_allowed(self, level: SeniorityLevel) -> None: + resolver = _make_resolver() + resolver.validate_seniority(level, AutonomyLevel.LOCKED) + + +class TestMissingPreset: + """Error when the resolved level has no preset.""" + + @pytest.mark.unit + def test_missing_preset_raises(self) -> None: + custom_presets: dict[str, AutonomyPreset] = { + "semi": BUILTIN_PRESETS[AutonomyLevel.SEMI], + } + config = AutonomyConfig(level=AutonomyLevel.SEMI, presets=custom_presets) + resolver = AutonomyResolver( + registry=ActionTypeRegistry(), + config=config, + ) + with pytest.raises(ValueError, match="No preset found"): + resolver.resolve(agent_level=AutonomyLevel.FULL) diff --git a/tests/unit/security/test_service.py b/tests/unit/security/test_service.py index 4260afd74e..44518549f2 100644 --- a/tests/unit/security/test_service.py +++ b/tests/unit/security/test_service.py @@ -5,8 +5,14 @@ import pytest -from ai_company.core.enums import ApprovalRiskLevel, ApprovalStatus, ToolCategory +from ai_company.core.enums import ( + ApprovalRiskLevel, + ApprovalStatus, + AutonomyLevel, + ToolCategory, +) from ai_company.security.audit import AuditLog +from ai_company.security.autonomy.models import EffectiveAutonomy from ai_company.security.config import SecurityConfig from ai_company.security.models import ( OutputScanResult, @@ -443,3 +449,131 @@ async def test_store_failure_converts_to_deny(self) -> None: assert verdict.verdict == SecurityVerdictType.DENY assert "store error" in verdict.reason.lower() + + +# ── Tests: autonomy pre-check ──────────────────────────────────── + + +@pytest.mark.unit +class TestAutonomyPrecheck: + """Autonomy-based action routing before the rule engine.""" + + def _make_service_with_autonomy( + self, + *, + effective_autonomy: EffectiveAutonomy | None = None, + config: SecurityConfig | None = None, + engine_verdict: SecurityVerdict | None = None, + approval_store: AsyncMock | None = None, + ) -> SecOpsService: + """Construct a SecOpsService with autonomy support.""" + cfg = config or SecurityConfig() + rule_engine = MagicMock(spec=RuleEngine) + rule_engine.evaluate.return_value = engine_verdict or _make_allow_verdict() + audit_log = AuditLog() + output_scanner = MagicMock(spec=OutputScanner) + output_scanner.scan.return_value = OutputScanResult() + + service = SecOpsService( + config=cfg, + rule_engine=rule_engine, + audit_log=audit_log, + output_scanner=output_scanner, + approval_store=approval_store, + effective_autonomy=effective_autonomy, + ) + service._test_rule_engine = rule_engine # type: ignore[attr-defined] + service._test_audit_log = audit_log # type: ignore[attr-defined] + return service + + async def test_auto_approve_returns_allow(self) -> None: + """When action is in auto_approve_actions, returns ALLOW without rule engine.""" + autonomy = EffectiveAutonomy( + level=AutonomyLevel.SEMI, + auto_approve_actions=frozenset({"code:read"}), + human_approval_actions=frozenset({"infra:deploy"}), + security_agent=False, + ) + service = self._make_service_with_autonomy(effective_autonomy=autonomy) + ctx = _make_context(action_type="code:read") + + verdict = await service.evaluate_pre_tool(ctx) + + assert verdict.verdict == SecurityVerdictType.ALLOW + assert "auto-approved" in verdict.reason.lower() + service._test_rule_engine.evaluate.assert_not_called() # type: ignore[attr-defined] + + async def test_human_approval_returns_escalate_as_deny(self) -> None: + """Human approval with no store converts ESCALATE to DENY.""" + autonomy = EffectiveAutonomy( + level=AutonomyLevel.SEMI, + auto_approve_actions=frozenset({"code:read"}), + human_approval_actions=frozenset({"infra:deploy"}), + security_agent=False, + ) + service = self._make_service_with_autonomy( + effective_autonomy=autonomy, + approval_store=None, + ) + ctx = _make_context(action_type="infra:deploy") + + verdict = await service.evaluate_pre_tool(ctx) + + assert verdict.verdict == SecurityVerdictType.DENY + assert "escalation unavailable" in verdict.reason.lower() + service._test_rule_engine.evaluate.assert_not_called() # type: ignore[attr-defined] + + async def test_hard_deny_falls_through_to_rule_engine(self) -> None: + """When action is in hard_deny_action_types, autonomy is skipped.""" + autonomy = EffectiveAutonomy( + level=AutonomyLevel.SEMI, + auto_approve_actions=frozenset({"deploy:production"}), + human_approval_actions=frozenset(), + security_agent=False, + ) + deny_verdict = _make_deny_verdict() + service = self._make_service_with_autonomy( + effective_autonomy=autonomy, + engine_verdict=deny_verdict, + ) + # deploy:production is in the default SecurityConfig.hard_deny_action_types + ctx = _make_context(action_type="deploy:production") + + verdict = await service.evaluate_pre_tool(ctx) + + assert verdict.verdict == SecurityVerdictType.DENY + service._test_rule_engine.evaluate.assert_called_once() # type: ignore[attr-defined] + + async def test_unknown_action_falls_through(self) -> None: + """When action is not in any autonomy set, falls through to rule engine.""" + autonomy = EffectiveAutonomy( + level=AutonomyLevel.SEMI, + auto_approve_actions=frozenset({"code:read"}), + human_approval_actions=frozenset({"infra:deploy"}), + security_agent=False, + ) + allow_verdict = _make_allow_verdict() + service = self._make_service_with_autonomy( + effective_autonomy=autonomy, + engine_verdict=allow_verdict, + ) + ctx = _make_context(action_type="test:run") + + verdict = await service.evaluate_pre_tool(ctx) + + assert verdict.verdict == SecurityVerdictType.ALLOW + service._test_rule_engine.evaluate.assert_called_once() # type: ignore[attr-defined] + + async def test_no_autonomy_uses_rule_engine(self) -> None: + """When effective_autonomy=None, rule engine is used normally.""" + allow_verdict = _make_allow_verdict() + service = self._make_service_with_autonomy( + effective_autonomy=None, + engine_verdict=allow_verdict, + ) + ctx = _make_context(action_type="code:read") + + verdict = await service.evaluate_pre_tool(ctx) + + assert verdict.verdict == SecurityVerdictType.ALLOW + service._test_rule_engine.evaluate.assert_called_once() # type: ignore[attr-defined] diff --git a/tests/unit/security/timeout/__init__.py b/tests/unit/security/timeout/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/security/timeout/test_config.py b/tests/unit/security/timeout/test_config.py new file mode 100644 index 0000000000..3b332f334d --- /dev/null +++ b/tests/unit/security/timeout/test_config.py @@ -0,0 +1,125 @@ +"""Tests for timeout policy configuration models.""" + +import pytest +from pydantic import TypeAdapter, ValidationError + +from ai_company.core.enums import TimeoutActionType +from ai_company.security.timeout.config import ( + ApprovalTimeoutConfig, + DenyOnTimeoutConfig, + EscalationChainConfig, + EscalationStep, + TierConfig, + TieredTimeoutConfig, + WaitForeverConfig, +) + +_adapter: TypeAdapter[ApprovalTimeoutConfig] = TypeAdapter(ApprovalTimeoutConfig) + + +class TestWaitForeverConfig: + """WaitForeverConfig tests.""" + + @pytest.mark.unit + def test_default(self) -> None: + config = WaitForeverConfig() + assert config.policy == "wait" + + @pytest.mark.unit + def test_discriminator(self) -> None: + result = _adapter.validate_python({"policy": "wait"}) + assert isinstance(result, WaitForeverConfig) + + +class TestDenyOnTimeoutConfig: + """DenyOnTimeoutConfig tests.""" + + @pytest.mark.unit + def test_default_timeout(self) -> None: + config = DenyOnTimeoutConfig() + assert config.timeout_minutes == 240.0 + + @pytest.mark.unit + def test_custom_timeout(self) -> None: + config = DenyOnTimeoutConfig(timeout_minutes=60.0) + assert config.timeout_minutes == 60.0 + + @pytest.mark.unit + def test_discriminator(self) -> None: + result = _adapter.validate_python({"policy": "deny", "timeout_minutes": 30}) + assert isinstance(result, DenyOnTimeoutConfig) + assert result.timeout_minutes == 30.0 + + @pytest.mark.unit + def test_zero_timeout_rejected(self) -> None: + with pytest.raises(ValidationError): + DenyOnTimeoutConfig(timeout_minutes=0) + + +class TestTieredTimeoutConfig: + """TieredTimeoutConfig tests.""" + + @pytest.mark.unit + def test_empty_tiers(self) -> None: + config = TieredTimeoutConfig() + assert config.tiers == {} + + @pytest.mark.unit + def test_tier_config(self) -> None: + tier = TierConfig( + timeout_minutes=60, + on_timeout=TimeoutActionType.DENY, + ) + config = TieredTimeoutConfig(tiers={"high": tier}) + assert "high" in config.tiers + assert config.tiers["high"].on_timeout == TimeoutActionType.DENY + + @pytest.mark.unit + def test_discriminator(self) -> None: + result = _adapter.validate_python( + { + "policy": "tiered", + "tiers": { + "low": {"timeout_minutes": 480, "on_timeout": "approve"}, + "high": {"timeout_minutes": 60, "on_timeout": "deny"}, + }, + } + ) + assert isinstance(result, TieredTimeoutConfig) + assert result.tiers["low"].on_timeout == TimeoutActionType.APPROVE + + +class TestEscalationChainConfig: + """EscalationChainConfig tests.""" + + @pytest.mark.unit + def test_empty_chain(self) -> None: + config = EscalationChainConfig() + assert config.chain == () + assert config.on_chain_exhausted == TimeoutActionType.DENY + + @pytest.mark.unit + def test_chain_steps(self) -> None: + config = EscalationChainConfig( + chain=( + EscalationStep(role="lead", timeout_minutes=30), + EscalationStep(role="director", timeout_minutes=60), + ), + on_chain_exhausted=TimeoutActionType.DENY, + ) + assert len(config.chain) == 2 + assert config.chain[0].role == "lead" + + @pytest.mark.unit + def test_discriminator(self) -> None: + result = _adapter.validate_python( + { + "policy": "escalation", + "chain": [ + {"role": "lead", "timeout_minutes": 30}, + ], + "on_chain_exhausted": "deny", + } + ) + assert isinstance(result, EscalationChainConfig) + assert len(result.chain) == 1 diff --git a/tests/unit/security/timeout/test_factory.py b/tests/unit/security/timeout/test_factory.py new file mode 100644 index 0000000000..3efd66be22 --- /dev/null +++ b/tests/unit/security/timeout/test_factory.py @@ -0,0 +1,47 @@ +"""Tests for timeout policy factory.""" + +import pytest + +from ai_company.core.enums import TimeoutActionType +from ai_company.security.timeout.config import ( + DenyOnTimeoutConfig, + EscalationChainConfig, + EscalationStep, + TieredTimeoutConfig, + WaitForeverConfig, +) +from ai_company.security.timeout.factory import create_timeout_policy +from ai_company.security.timeout.policies import ( + DenyOnTimeoutPolicy, + EscalationChainPolicy, + TieredTimeoutPolicy, + WaitForeverPolicy, +) + + +class TestFactory: + """create_timeout_policy returns the correct implementation.""" + + @pytest.mark.unit + def test_wait_forever(self) -> None: + result = create_timeout_policy(WaitForeverConfig()) + assert isinstance(result, WaitForeverPolicy) + + @pytest.mark.unit + def test_deny_on_timeout(self) -> None: + result = create_timeout_policy(DenyOnTimeoutConfig(timeout_minutes=60)) + assert isinstance(result, DenyOnTimeoutPolicy) + + @pytest.mark.unit + def test_tiered(self) -> None: + result = create_timeout_policy(TieredTimeoutConfig()) + assert isinstance(result, TieredTimeoutPolicy) + + @pytest.mark.unit + def test_escalation_chain(self) -> None: + config = EscalationChainConfig( + chain=(EscalationStep(role="lead", timeout_minutes=30),), + on_chain_exhausted=TimeoutActionType.DENY, + ) + result = create_timeout_policy(config) + assert isinstance(result, EscalationChainPolicy) diff --git a/tests/unit/security/timeout/test_park_service.py b/tests/unit/security/timeout/test_park_service.py new file mode 100644 index 0000000000..35cbf53e95 --- /dev/null +++ b/tests/unit/security/timeout/test_park_service.py @@ -0,0 +1,92 @@ +"""Tests for the ParkService (park/resume agent contexts).""" + +import json +from datetime import UTC, date, datetime +from uuid import uuid4 + +import pytest + +from ai_company.core.agent import AgentIdentity, ModelConfig, PersonalityConfig +from ai_company.core.enums import SeniorityLevel +from ai_company.engine.context import AgentContext +from ai_company.security.timeout.park_service import ParkService + +pytestmark = pytest.mark.timeout(30) + + +def _make_agent_context() -> AgentContext: + """Create a minimal AgentContext for testing.""" + identity = AgentIdentity( + name="test-agent", + role="developer", + department="engineering", + level=SeniorityLevel.MID, + personality=PersonalityConfig(), + model=ModelConfig(provider="test-provider", model_id="test-small-001"), + hiring_date=date(2026, 1, 1), + ) + return AgentContext( + execution_id=str(uuid4()), + identity=identity, + turn_count=1, + started_at=datetime.now(UTC), + ) + + +@pytest.mark.unit +class TestParkService: + """Tests for ParkService park/resume round-trip.""" + + def test_park_creates_parked_context(self) -> None: + """Parks an AgentContext and verifies ParkedContext fields.""" + context = _make_agent_context() + service = ParkService() + + parked = service.park( + context=context, + approval_id="approval-1", + agent_id="agent-1", + task_id="task-1", + ) + + assert parked.agent_id == "agent-1" + assert parked.approval_id == "approval-1" + assert parked.task_id == "task-1" + assert parked.execution_id == context.execution_id + assert parked.id # non-empty UUID default + + def test_park_serializes_context_json(self) -> None: + """Verifies context_json is valid JSON.""" + context = _make_agent_context() + service = ParkService() + + parked = service.park( + context=context, + approval_id="approval-1", + agent_id="agent-1", + task_id="task-1", + ) + + assert parked.context_json # non-empty + parsed = json.loads(parked.context_json) + assert isinstance(parsed, dict) + assert "execution_id" in parsed + + def test_resume_restores_context(self) -> None: + """Parks then resumes, verifies round-trip fidelity.""" + context = _make_agent_context() + service = ParkService() + + parked = service.park( + context=context, + approval_id="approval-1", + agent_id="agent-1", + task_id="task-1", + ) + + restored = service.resume(parked) + + assert restored.execution_id == context.execution_id + assert restored.turn_count == context.turn_count + assert restored.identity.name == context.identity.name + assert restored.identity.role == context.identity.role diff --git a/tests/unit/security/timeout/test_parked_context.py b/tests/unit/security/timeout/test_parked_context.py new file mode 100644 index 0000000000..2df5917057 --- /dev/null +++ b/tests/unit/security/timeout/test_parked_context.py @@ -0,0 +1,83 @@ +"""Tests for the ParkedContext model.""" + +from datetime import UTC, datetime +from typing import Any + +import pytest +from pydantic import ValidationError + +from ai_company.security.timeout.parked_context import ParkedContext + +pytestmark = pytest.mark.timeout(30) + + +def _make_parked_context(**overrides: Any) -> ParkedContext: + """Create a valid ParkedContext with sensible defaults.""" + defaults: dict[str, Any] = { + "execution_id": "exec-1", + "agent_id": "agent-1", + "task_id": "task-1", + "approval_id": "approval-1", + "parked_at": datetime.now(UTC), + "context_json": '{"key": "value"}', + } + defaults.update(overrides) + return ParkedContext(**defaults) + + +@pytest.mark.unit +class TestParkedContext: + """Tests for ParkedContext model validation and immutability.""" + + def test_creation(self) -> None: + """Valid creation with all fields.""" + now = datetime.now(UTC) + parked = ParkedContext( + id="custom-id", + execution_id="exec-1", + agent_id="agent-1", + task_id="task-1", + approval_id="approval-1", + parked_at=now, + context_json='{"data": true}', + metadata={"tool": "git"}, + ) + assert parked.id == "custom-id" + assert parked.execution_id == "exec-1" + assert parked.agent_id == "agent-1" + assert parked.task_id == "task-1" + assert parked.approval_id == "approval-1" + assert parked.parked_at == now + assert parked.context_json == '{"data": true}' + assert parked.metadata == {"tool": "git"} + + def test_frozen(self) -> None: + """Cannot modify fields on a frozen model.""" + parked = _make_parked_context() + with pytest.raises(ValidationError): + parked.agent_id = "other" # type: ignore[misc] + + def test_default_id_generated(self) -> None: + """id gets a UUID default when not provided.""" + parked = _make_parked_context() + assert parked.id # non-empty + assert len(parked.id) > 0 + + def test_unique_ids(self) -> None: + """Two instances get different default IDs.""" + a = _make_parked_context() + b = _make_parked_context() + assert a.id != b.id + + def test_empty_metadata_default(self) -> None: + """metadata defaults to empty dict.""" + parked = _make_parked_context() + assert parked.metadata == {} + + def test_blank_agent_id_rejected(self) -> None: + """Blank agent_id raises ValidationError.""" + with pytest.raises(ValidationError): + _make_parked_context(agent_id="") + + with pytest.raises(ValidationError): + _make_parked_context(agent_id=" ") diff --git a/tests/unit/security/timeout/test_policies.py b/tests/unit/security/timeout/test_policies.py new file mode 100644 index 0000000000..3f74e10b60 --- /dev/null +++ b/tests/unit/security/timeout/test_policies.py @@ -0,0 +1,191 @@ +"""Tests for timeout policy implementations.""" + +from datetime import UTC, datetime + +import pytest + +from ai_company.core.approval import ApprovalItem +from ai_company.core.enums import ApprovalRiskLevel, ApprovalStatus, TimeoutActionType +from ai_company.security.timeout.config import EscalationStep, TierConfig +from ai_company.security.timeout.policies import ( + DenyOnTimeoutPolicy, + EscalationChainPolicy, + TieredTimeoutPolicy, + WaitForeverPolicy, +) +from ai_company.security.timeout.risk_tier_classifier import YamlRiskTierClassifier + + +def _make_item( + action_type: str = "code:write", + risk_level: ApprovalRiskLevel = ApprovalRiskLevel.MEDIUM, +) -> ApprovalItem: + """Create a minimal pending approval item.""" + return ApprovalItem( + id="test-approval-1", + action_type=action_type, + title="Test approval", + description="Test description", + requested_by="agent-1", + risk_level=risk_level, + status=ApprovalStatus.PENDING, + created_at=datetime.now(UTC), + ) + + +class TestWaitForeverPolicy: + """WaitForeverPolicy always returns WAIT.""" + + @pytest.mark.unit + async def test_always_waits(self) -> None: + policy = WaitForeverPolicy() + item = _make_item() + result = await policy.determine_action(item, 0.0) + assert result.action == TimeoutActionType.WAIT + + @pytest.mark.unit + async def test_waits_after_long_time(self) -> None: + policy = WaitForeverPolicy() + item = _make_item() + result = await policy.determine_action(item, 999999.0) + assert result.action == TimeoutActionType.WAIT + + +class TestDenyOnTimeoutPolicy: + """DenyOnTimeoutPolicy: WAIT before timeout, DENY after.""" + + @pytest.mark.unit + async def test_wait_before_timeout(self) -> None: + policy = DenyOnTimeoutPolicy(timeout_seconds=3600.0) + item = _make_item() + result = await policy.determine_action(item, 1800.0) + assert result.action == TimeoutActionType.WAIT + + @pytest.mark.unit + async def test_deny_at_timeout(self) -> None: + policy = DenyOnTimeoutPolicy(timeout_seconds=3600.0) + item = _make_item() + result = await policy.determine_action(item, 3600.0) + assert result.action == TimeoutActionType.DENY + + @pytest.mark.unit + async def test_deny_after_timeout(self) -> None: + policy = DenyOnTimeoutPolicy(timeout_seconds=3600.0) + item = _make_item() + result = await policy.determine_action(item, 7200.0) + assert result.action == TimeoutActionType.DENY + + +class TestTieredTimeoutPolicy: + """TieredTimeoutPolicy: per-risk-tier timeout behavior.""" + + @pytest.mark.unit + async def test_wait_within_tier_timeout(self) -> None: + tiers = { + "medium": TierConfig(timeout_minutes=60, on_timeout=TimeoutActionType.DENY), + } + policy = TieredTimeoutPolicy( + tiers=tiers, + classifier=YamlRiskTierClassifier(), + ) + item = _make_item(action_type="code:write") # MEDIUM risk + result = await policy.determine_action(item, 1800.0) # 30 min + assert result.action == TimeoutActionType.WAIT + + @pytest.mark.unit + async def test_deny_after_tier_timeout(self) -> None: + tiers = { + "medium": TierConfig(timeout_minutes=60, on_timeout=TimeoutActionType.DENY), + } + policy = TieredTimeoutPolicy( + tiers=tiers, + classifier=YamlRiskTierClassifier(), + ) + item = _make_item(action_type="code:write") # MEDIUM risk + result = await policy.determine_action(item, 3601.0) # > 60 min + assert result.action == TimeoutActionType.DENY + + @pytest.mark.unit + async def test_approve_on_tier_timeout(self) -> None: + tiers = { + "low": TierConfig( + timeout_minutes=480, on_timeout=TimeoutActionType.APPROVE + ), + } + policy = TieredTimeoutPolicy( + tiers=tiers, + classifier=YamlRiskTierClassifier(), + ) + item = _make_item(action_type="code:read") # LOW risk + result = await policy.determine_action(item, 30000.0) # > 480 min + assert result.action == TimeoutActionType.APPROVE + + @pytest.mark.unit + async def test_no_tier_config_waits(self) -> None: + policy = TieredTimeoutPolicy( + tiers={}, + classifier=YamlRiskTierClassifier(), + ) + item = _make_item() + result = await policy.determine_action(item, 999999.0) + assert result.action == TimeoutActionType.WAIT + + +class TestEscalationChainPolicy: + """EscalationChainPolicy: chain of escalation steps.""" + + @pytest.mark.unit + async def test_first_step_escalation(self) -> None: + chain = ( + EscalationStep(role="lead", timeout_minutes=30), + EscalationStep(role="director", timeout_minutes=60), + ) + policy = EscalationChainPolicy( + chain=chain, + on_chain_exhausted=TimeoutActionType.DENY, + ) + item = _make_item() + result = await policy.determine_action(item, 600.0) # 10 min + assert result.action == TimeoutActionType.ESCALATE + assert result.escalate_to == "lead" + + @pytest.mark.unit + async def test_second_step_escalation(self) -> None: + chain = ( + EscalationStep(role="lead", timeout_minutes=30), + EscalationStep(role="director", timeout_minutes=60), + ) + policy = EscalationChainPolicy( + chain=chain, + on_chain_exhausted=TimeoutActionType.DENY, + ) + item = _make_item() + # 40 min = past first step (30min), within second (30+60=90min) + result = await policy.determine_action(item, 2400.0) + assert result.action == TimeoutActionType.ESCALATE + assert result.escalate_to == "director" + + @pytest.mark.unit + async def test_chain_exhausted(self) -> None: + chain = ( + EscalationStep(role="lead", timeout_minutes=30), + EscalationStep(role="director", timeout_minutes=60), + ) + policy = EscalationChainPolicy( + chain=chain, + on_chain_exhausted=TimeoutActionType.DENY, + ) + item = _make_item() + # 100 min = past both steps (30+60=90min) + result = await policy.determine_action(item, 6000.0) + assert result.action == TimeoutActionType.DENY + + @pytest.mark.unit + async def test_empty_chain_exhausted_immediately(self) -> None: + policy = EscalationChainPolicy( + chain=(), + on_chain_exhausted=TimeoutActionType.DENY, + ) + item = _make_item() + result = await policy.determine_action(item, 0.0) + assert result.action == TimeoutActionType.DENY diff --git a/tests/unit/security/timeout/test_risk_tier_classifier.py b/tests/unit/security/timeout/test_risk_tier_classifier.py new file mode 100644 index 0000000000..5f86c781b8 --- /dev/null +++ b/tests/unit/security/timeout/test_risk_tier_classifier.py @@ -0,0 +1,64 @@ +"""Tests for YamlRiskTierClassifier.""" + +import pytest + +from ai_company.core.enums import ActionType, ApprovalRiskLevel +from ai_company.security.timeout.risk_tier_classifier import YamlRiskTierClassifier + + +class TestDefaultMapping: + """Default risk tier mapping.""" + + @pytest.mark.unit + def test_critical_actions(self) -> None: + classifier = YamlRiskTierClassifier() + expected = ApprovalRiskLevel.CRITICAL + assert classifier.classify(ActionType.DEPLOY_PRODUCTION) == expected + assert classifier.classify(ActionType.DB_ADMIN) == expected + + @pytest.mark.unit + def test_high_actions(self) -> None: + classifier = YamlRiskTierClassifier() + assert classifier.classify(ActionType.VCS_PUSH) == ApprovalRiskLevel.HIGH + assert classifier.classify(ActionType.CODE_DELETE) == ApprovalRiskLevel.HIGH + + @pytest.mark.unit + def test_medium_actions(self) -> None: + classifier = YamlRiskTierClassifier() + assert classifier.classify(ActionType.CODE_WRITE) == ApprovalRiskLevel.MEDIUM + + @pytest.mark.unit + def test_low_actions(self) -> None: + classifier = YamlRiskTierClassifier() + assert classifier.classify(ActionType.CODE_READ) == ApprovalRiskLevel.LOW + assert classifier.classify(ActionType.TEST_RUN) == ApprovalRiskLevel.LOW + + +class TestUnknownFallback: + """Unknown action types default to HIGH (D19).""" + + @pytest.mark.unit + def test_unknown_defaults_to_high(self) -> None: + classifier = YamlRiskTierClassifier() + assert classifier.classify("unknown:action") == ApprovalRiskLevel.HIGH + + +class TestCustomMap: + """Custom risk overrides.""" + + @pytest.mark.unit + def test_custom_override(self) -> None: + classifier = YamlRiskTierClassifier( + custom_map={ActionType.CODE_READ: ApprovalRiskLevel.CRITICAL} + ) + assert classifier.classify(ActionType.CODE_READ) == ApprovalRiskLevel.CRITICAL + + @pytest.mark.unit + def test_custom_preserves_defaults(self) -> None: + classifier = YamlRiskTierClassifier( + custom_map={"custom:action": ApprovalRiskLevel.LOW} + ) + # Default still works. + assert classifier.classify(ActionType.CODE_READ) == ApprovalRiskLevel.LOW + # Custom also works. + assert classifier.classify("custom:action") == ApprovalRiskLevel.LOW diff --git a/tests/unit/security/timeout/test_timeout_checker.py b/tests/unit/security/timeout/test_timeout_checker.py new file mode 100644 index 0000000000..4c9d289009 --- /dev/null +++ b/tests/unit/security/timeout/test_timeout_checker.py @@ -0,0 +1,129 @@ +"""Tests for the TimeoutChecker.""" + +from datetime import UTC, datetime +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from ai_company.core.approval import ApprovalItem +from ai_company.core.enums import ApprovalRiskLevel, ApprovalStatus, TimeoutActionType +from ai_company.security.timeout.models import TimeoutAction +from ai_company.security.timeout.timeout_checker import TimeoutChecker + +pytestmark = pytest.mark.timeout(30) + + +def _make_approval_item(**overrides: Any) -> ApprovalItem: + """Create a valid pending ApprovalItem with sensible defaults.""" + defaults: dict[str, Any] = { + "id": "approval-1", + "action_type": "code:write", + "title": "Test approval", + "description": "Testing", + "requested_by": "agent-1", + "risk_level": ApprovalRiskLevel.MEDIUM, + "status": ApprovalStatus.PENDING, + "created_at": datetime.now(UTC), + } + defaults.update(overrides) + return ApprovalItem(**defaults) + + +def _make_mock_policy( + *, + action: TimeoutActionType = TimeoutActionType.WAIT, + reason: str = "test reason", +) -> AsyncMock: + """Create a mock TimeoutPolicy returning the given action.""" + mock_policy = AsyncMock() + mock_policy.determine_action.return_value = TimeoutAction( + action=action, + reason=reason, + ) + return mock_policy + + +@pytest.mark.unit +class TestTimeoutCheckerCheck: + """Tests for TimeoutChecker.check().""" + + async def test_check_returns_action(self) -> None: + """Checker delegates to policy and returns its action.""" + mock_policy = _make_mock_policy( + action=TimeoutActionType.WAIT, + reason="Still waiting", + ) + checker = TimeoutChecker(policy=mock_policy) + item = _make_approval_item() + + result = await checker.check(item) + + assert result.action == TimeoutActionType.WAIT + assert result.reason == "Still waiting" + mock_policy.determine_action.assert_called_once() + + +@pytest.mark.unit +class TestTimeoutCheckerCheckAndResolve: + """Tests for TimeoutChecker.check_and_resolve().""" + + async def test_check_and_resolve_approve(self) -> None: + """When policy returns APPROVE, item status is updated to APPROVED.""" + mock_policy = _make_mock_policy( + action=TimeoutActionType.APPROVE, + reason="Auto-approved after timeout", + ) + checker = TimeoutChecker(policy=mock_policy) + item = _make_approval_item() + + updated_item, action = await checker.check_and_resolve(item) + + assert action.action == TimeoutActionType.APPROVE + assert updated_item.status == ApprovalStatus.APPROVED + assert updated_item.decided_by == "timeout_policy" + assert updated_item.decided_at is not None + + async def test_check_and_resolve_deny(self) -> None: + """When policy returns DENY, item status is updated to REJECTED.""" + mock_policy = _make_mock_policy( + action=TimeoutActionType.DENY, + reason="Denied after timeout", + ) + checker = TimeoutChecker(policy=mock_policy) + item = _make_approval_item() + + updated_item, action = await checker.check_and_resolve(item) + + assert action.action == TimeoutActionType.DENY + assert updated_item.status == ApprovalStatus.REJECTED + assert updated_item.decided_by == "timeout_policy" + assert updated_item.decided_at is not None + + async def test_check_and_resolve_wait(self) -> None: + """When policy returns WAIT, item status stays PENDING.""" + mock_policy = _make_mock_policy( + action=TimeoutActionType.WAIT, + reason="Still waiting", + ) + checker = TimeoutChecker(policy=mock_policy) + item = _make_approval_item() + + updated_item, action = await checker.check_and_resolve(item) + + assert action.action == TimeoutActionType.WAIT + assert updated_item.status == ApprovalStatus.PENDING + + async def test_check_and_resolve_escalate(self) -> None: + """When policy returns ESCALATE, item status stays PENDING.""" + mock_policy = _make_mock_policy( + action=TimeoutActionType.ESCALATE, + reason="Escalating to manager", + ) + checker = TimeoutChecker(policy=mock_policy) + item = _make_approval_item() + + updated_item, action = await checker.check_and_resolve(item) + + assert action.action == TimeoutActionType.ESCALATE + assert updated_item.status == ApprovalStatus.PENDING From 24d6500b1177f08d091d46cef93b945890f494bb Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Tue, 10 Mar 2026 12:56:39 +0100 Subject: [PATCH 2/4] fix: address pre-PR review findings for autonomy/timeout feature Pre-reviewed by 10 agents, 51 findings addressed: - Fix autonomy controller returning requested level instead of current - Add disjoint validator on EffectiveAutonomy action sets - Add escalate_to consistency validator on TimeoutAction - Add seniority constraint enforcement in AutonomyResolver - Add MemoryError/RecursionError re-raise in security service - Fix _row_to_model to raise QueryError instead of returning None - Rename YamlRiskTierClassifier to DefaultRiskTierClassifier - Move Jinja2 env to module-level singleton in renderer - Fix personality mutation pattern (return instead of mutate) - Add security guard blocking auto-approve for HIGH/CRITICAL risk - Fix immutability violations (deepcopy metadata, immutable dicts) - Enumerate columns explicitly in SELECT queries - Register AutonomyController with app router - Add comprehensive tests for new code paths --- src/ai_company/api/controllers/__init__.py | 3 + src/ai_company/api/controllers/autonomy.py | 15 +- src/ai_company/core/company.py | 2 +- src/ai_company/engine/loop_protocol.py | 5 +- .../observability/events/timeout.py | 3 + .../persistence/sqlite/parked_context_repo.py | 35 +-- .../security/autonomy/change_strategy.py | 20 +- src/ai_company/security/autonomy/models.py | 16 +- src/ai_company/security/autonomy/protocol.py | 3 + src/ai_company/security/autonomy/resolver.py | 38 +++- src/ai_company/security/service.py | 6 + src/ai_company/security/timeout/config.py | 17 +- src/ai_company/security/timeout/factory.py | 11 +- src/ai_company/security/timeout/models.py | 20 +- .../security/timeout/park_service.py | 76 +++++-- src/ai_company/security/timeout/policies.py | 48 ++++- .../security/timeout/risk_tier_classifier.py | 14 +- src/ai_company/templates/renderer.py | 43 ++-- tests/unit/api/controllers/test_autonomy.py | 62 ++++++ tests/unit/observability/test_events.py | 79 +++++++ .../persistence/sqlite/test_migrations.py | 23 ++ .../sqlite/test_parked_context_repo.py | 202 ++++++++++++++++++ .../security/autonomy/test_change_strategy.py | 11 + tests/unit/security/autonomy/test_models.py | 10 + tests/unit/security/timeout/test_config.py | 22 ++ tests/unit/security/timeout/test_policies.py | 31 ++- .../timeout/test_risk_tier_classifier.py | 18 +- .../security/timeout/test_timeout_checker.py | 3 + 28 files changed, 717 insertions(+), 119 deletions(-) create mode 100644 tests/unit/api/controllers/test_autonomy.py create mode 100644 tests/unit/persistence/sqlite/test_parked_context_repo.py diff --git a/src/ai_company/api/controllers/__init__.py b/src/ai_company/api/controllers/__init__.py index d90e8223ac..8343349809 100644 --- a/src/ai_company/api/controllers/__init__.py +++ b/src/ai_company/api/controllers/__init__.py @@ -6,6 +6,7 @@ from ai_company.api.controllers.analytics import AnalyticsController from ai_company.api.controllers.approvals import ApprovalsController from ai_company.api.controllers.artifacts import ArtifactController +from ai_company.api.controllers.autonomy import AutonomyController from ai_company.api.controllers.budget import BudgetController from ai_company.api.controllers.company import CompanyController from ai_company.api.controllers.departments import DepartmentController @@ -31,6 +32,7 @@ AnalyticsController, ProviderController, ApprovalsController, + AutonomyController, ) __all__ = [ @@ -39,6 +41,7 @@ "AnalyticsController", "ApprovalsController", "ArtifactController", + "AutonomyController", "BudgetController", "CompanyController", "Controller", diff --git a/src/ai_company/api/controllers/autonomy.py b/src/ai_company/api/controllers/autonomy.py index 4ea62e31da..1d99c3a245 100644 --- a/src/ai_company/api/controllers/autonomy.py +++ b/src/ai_company/api/controllers/autonomy.py @@ -102,29 +102,32 @@ async def update_autonomy( Returns: Updated autonomy level info. """ - app_state: AppState = state.app_state # noqa: F841 + app_state: AppState = state.app_state + config = app_state.config.config + current_level = config.autonomy.level requested_level = data.level logger.info( AUTONOMY_PROMOTION_REQUESTED, agent_id=agent_id, requested_level=requested_level.value, + current_level=current_level.value, ) - # Promotions require human approval — return pending status. - # The actual change would be applied via the AutonomyChangeStrategy - # when the approval system is wired up. + # All changes route through human approval — return current + # level with pending status. The AutonomyChangeStrategy will + # apply the change when the approval system is wired up. logger.info( AUTONOMY_PROMOTION_DENIED, agent_id=agent_id, requested_level=requested_level.value, - reason="Autonomy promotions require human approval", + reason="Autonomy level changes require human approval", ) return ApiResponse( data=AutonomyLevelResponse( agent_id=agent_id, - level=requested_level, + level=current_level, promotion_pending=True, ), ) diff --git a/src/ai_company/core/company.py b/src/ai_company/core/company.py index 0e94f861b1..460973a268 100644 --- a/src/ai_company/core/company.py +++ b/src/ai_company/core/company.py @@ -379,7 +379,7 @@ def _coerce_autonomy_float(cls, data: object) -> object: raw = data.get("autonomy") if isinstance(raw, (int, float)) and not isinstance(raw, bool): level = _float_to_autonomy_level(float(raw)) - data["autonomy"] = {"level": level.value} + return {**data, "autonomy": {"level": level.value}} return data budget_monthly: float = Field( diff --git a/src/ai_company/engine/loop_protocol.py b/src/ai_company/engine/loop_protocol.py index 6be6efa2fc..664a25887e 100644 --- a/src/ai_company/engine/loop_protocol.py +++ b/src/ai_company/engine/loop_protocol.py @@ -125,8 +125,9 @@ def _validate_error_message(self) -> Self: msg = "error_message is required when termination_reason is ERROR" raise ValueError(msg) elif self.termination_reason == TerminationReason.PARKED: - # PARKED allows an optional informational message. - pass + if self.error_message is not None: + msg = "error_message must be None for PARKED termination" + raise ValueError(msg) elif self.error_message is not None: msg = "error_message must be None when termination_reason is not ERROR" raise ValueError(msg) diff --git a/src/ai_company/observability/events/timeout.py b/src/ai_company/observability/events/timeout.py index 2ae2e4f5f0..4ac77f59a0 100644 --- a/src/ai_company/observability/events/timeout.py +++ b/src/ai_company/observability/events/timeout.py @@ -7,3 +7,6 @@ TIMEOUT_AUTO_DENIED: Final[str] = "timeout.auto_denied" TIMEOUT_ESCALATED: Final[str] = "timeout.escalated" TIMEOUT_WAITING: Final[str] = "timeout.waiting" +TIMEOUT_CONTEXT_PARKED: Final[str] = "timeout.context.parked" +TIMEOUT_CONTEXT_RESUMED: Final[str] = "timeout.context.resumed" +TIMEOUT_UNKNOWN_ACTION_TYPE: Final[str] = "timeout.unknown_action_type" diff --git a/src/ai_company/persistence/sqlite/parked_context_repo.py b/src/ai_company/persistence/sqlite/parked_context_repo.py index 95d657b49b..120670335a 100644 --- a/src/ai_company/persistence/sqlite/parked_context_repo.py +++ b/src/ai_company/persistence/sqlite/parked_context_repo.py @@ -66,7 +66,9 @@ async def get(self, parked_id: str) -> ParkedContext | None: """Retrieve a parked context by ID.""" try: cursor = await self._db.execute( - "SELECT * FROM parked_contexts WHERE id = ?", + "SELECT id, execution_id, agent_id, task_id, approval_id, " + "parked_at, context_json, metadata " + "FROM parked_contexts WHERE id = ?", (parked_id,), ) row = await cursor.fetchone() @@ -92,7 +94,9 @@ async def get_by_approval(self, approval_id: str) -> ParkedContext | None: """Retrieve a parked context by approval ID.""" try: cursor = await self._db.execute( - "SELECT * FROM parked_contexts WHERE approval_id = ?", + "SELECT id, execution_id, agent_id, task_id, approval_id, " + "parked_at, context_json, metadata " + "FROM parked_contexts WHERE approval_id = ?", (approval_id,), ) row = await cursor.fetchone() @@ -114,7 +118,9 @@ async def get_by_agent(self, agent_id: str) -> tuple[ParkedContext, ...]: """Retrieve all parked contexts for an agent.""" try: cursor = await self._db.execute( - "SELECT * FROM parked_contexts WHERE agent_id = ? " + "SELECT id, execution_id, agent_id, task_id, approval_id, " + "parked_at, context_json, metadata " + "FROM parked_contexts WHERE agent_id = ? " "ORDER BY parked_at DESC", (agent_id,), ) @@ -128,18 +134,14 @@ async def get_by_agent(self, agent_id: str) -> tuple[ParkedContext, ...]: ) raise QueryError(msg) from exc - results: list[ParkedContext] = [] - for row in rows: - model = self._row_to_model(dict(row)) - if model is not None: - results.append(model) + results = tuple(self._row_to_model(dict(row)) for row in rows) logger.debug( PERSISTENCE_PARKED_CONTEXT_QUERIED, agent_id=agent_id, count=len(results), ) - return tuple(results) + return results async def delete(self, parked_id: str) -> bool: """Delete a parked context by ID.""" @@ -166,17 +168,22 @@ async def delete(self, parked_id: str) -> bool: ) return deleted - def _row_to_model(self, row: dict[str, object]) -> ParkedContext | None: - """Convert a database row to a ``ParkedContext`` model.""" + def _row_to_model(self, row: dict[str, object]) -> ParkedContext: + """Convert a database row to a ``ParkedContext`` model. + + Raises: + QueryError: If the row cannot be deserialized. + """ try: raw_meta = row.get("metadata") if isinstance(raw_meta, str): - row["metadata"] = json.loads(raw_meta) + row = {**row, "metadata": json.loads(raw_meta)} return ParkedContext.model_validate(row) except (ValidationError, json.JSONDecodeError) as exc: - logger.warning( + msg = f"Failed to deserialize parked context {row.get('id')!r}" + logger.exception( PERSISTENCE_PARKED_CONTEXT_DESERIALIZE_FAILED, parked_id=row.get("id"), error=str(exc), ) - return None + raise QueryError(msg) from exc diff --git a/src/ai_company/security/autonomy/change_strategy.py b/src/ai_company/security/autonomy/change_strategy.py index 337b85404c..91aadd90b2 100644 --- a/src/ai_company/security/autonomy/change_strategy.py +++ b/src/ai_company/security/autonomy/change_strategy.py @@ -22,6 +22,12 @@ DowngradeReason.SECURITY_INCIDENT: AutonomyLevel.LOCKED, } +# Validate exhaustiveness at module load time. +_missing_reasons = set(DowngradeReason) - set(_DOWNGRADE_MAP) +if _missing_reasons: + _msg = f"_DOWNGRADE_MAP missing entries for: {_missing_reasons}" + raise RuntimeError(_msg) + class HumanOnlyPromotionStrategy: """Default strategy: promotions and recovery always require human approval. @@ -69,19 +75,27 @@ def auto_downgrade( self, agent_id: NotBlankStr, reason: DowngradeReason, + current_level: AutonomyLevel | None = None, ) -> AutonomyLevel: """Immediately downgrade to a level determined by the reason. Args: agent_id: The agent to downgrade. reason: Why the downgrade is happening. + current_level: The agent's current effective autonomy level. + Used as ``original_level`` when no prior override exists. + Defaults to the company default (SEMI) if not provided. Returns: The new autonomy level after downgrade. """ new_level = _DOWNGRADE_MAP[reason] existing = self._overrides.get(agent_id) - original = existing.original_level if existing else AutonomyLevel.SEMI + original = ( + existing.original_level + if existing + else (current_level or AutonomyLevel.SEMI) + ) override = AutonomyOverride( agent_id=agent_id, @@ -120,7 +134,7 @@ def request_recovery( ) return False - def get_override(self, agent_id: str) -> AutonomyOverride | None: + def get_override(self, agent_id: NotBlankStr) -> AutonomyOverride | None: """Return the active override for an agent, if any. Args: @@ -131,7 +145,7 @@ def get_override(self, agent_id: str) -> AutonomyOverride | None: """ return self._overrides.get(agent_id) - def clear_override(self, agent_id: str) -> bool: + def clear_override(self, agent_id: NotBlankStr) -> bool: """Remove an override (used after human recovery approval). Args: diff --git a/src/ai_company/security/autonomy/models.py b/src/ai_company/security/autonomy/models.py index a7d37afbc5..40b6b2728a 100644 --- a/src/ai_company/security/autonomy/models.py +++ b/src/ai_company/security/autonomy/models.py @@ -33,11 +33,11 @@ class AutonomyPreset(BaseModel): level: AutonomyLevel = Field(description="Autonomy level") description: NotBlankStr = Field(description="Human-readable description") - auto_approve: tuple[str, ...] = Field( + auto_approve: tuple[NotBlankStr, ...] = Field( default=(), description="Action patterns that are auto-approved", ) - human_approval: tuple[str, ...] = Field( + human_approval: tuple[NotBlankStr, ...] = Field( default=(), description="Action patterns requiring human approval", ) @@ -173,6 +173,18 @@ class EffectiveAutonomy(BaseModel): description="Whether security agent reviews escalations", ) + @model_validator(mode="after") + def _validate_disjoint(self) -> Self: + """Ensure expanded action sets are disjoint.""" + overlap = self.auto_approve_actions & self.human_approval_actions + if overlap: + msg = ( + f"auto_approve_actions and human_approval_actions must be " + f"disjoint, overlapping: {sorted(overlap)}" + ) + raise ValueError(msg) + return self + class AutonomyOverride(BaseModel): """Record of a runtime autonomy downgrade for an agent. diff --git a/src/ai_company/security/autonomy/protocol.py b/src/ai_company/security/autonomy/protocol.py index c0724d6d53..57356ebf82 100644 --- a/src/ai_company/security/autonomy/protocol.py +++ b/src/ai_company/security/autonomy/protocol.py @@ -35,12 +35,15 @@ def auto_downgrade( self, agent_id: NotBlankStr, reason: DowngradeReason, + current_level: AutonomyLevel | None = None, ) -> AutonomyLevel: """Automatically downgrade an agent's autonomy level. Args: agent_id: The agent to downgrade. reason: Why the downgrade is happening. + current_level: The agent's current effective autonomy level. + Used as ``original_level`` when no prior override exists. Returns: The new (lower) autonomy level. diff --git a/src/ai_company/security/autonomy/resolver.py b/src/ai_company/security/autonomy/resolver.py index f1e4d81fc8..e77355ecd4 100644 --- a/src/ai_company/security/autonomy/resolver.py +++ b/src/ai_company/security/autonomy/resolver.py @@ -30,10 +30,6 @@ class AutonomyResolver: After resolution, category shortcuts (e.g. ``"code"``) are expanded into concrete action types via the ``ActionTypeRegistry``, and the ``"all"`` shortcut is expanded to every registered action type. - - Args: - registry: Action type registry for category expansion. - config: Company-level autonomy configuration with presets. """ def __init__( @@ -42,6 +38,12 @@ def __init__( registry: ActionTypeRegistry, config: AutonomyConfig, ) -> None: + """Initialize the resolver. + + Args: + registry: Action type registry for category expansion. + config: Company-level autonomy configuration with presets. + """ self._registry = registry self._config = config @@ -49,26 +51,41 @@ def resolve( self, agent_level: AutonomyLevel | None = None, department_level: AutonomyLevel | None = None, + seniority: SeniorityLevel | None = None, ) -> EffectiveAutonomy: """Resolve effective autonomy from the three-level chain. + When ``seniority`` is provided, the JUNIOR/FULL constraint + (D6) is enforced automatically. + Args: agent_level: Per-agent override (highest priority). department_level: Per-department override. + seniority: Agent seniority level for constraint checks. Returns: Fully expanded :class:`EffectiveAutonomy`. Raises: - ValueError: If the resolved level has no matching preset. + ValueError: If the resolved level has no matching preset + or seniority constraints are violated. """ level = agent_level or department_level or self._config.level + + if seniority is not None: + self.validate_seniority(seniority, level) + preset = self._config.presets.get(level) if preset is None: msg = ( f"No preset found for autonomy level {level!r} " f"(available: {sorted(self._config.presets)})" ) + logger.warning( + AUTONOMY_RESOLVED, + resolved_level=level.value if hasattr(level, "value") else str(level), + error=msg, + ) raise ValueError(msg) auto_approve = self._expand_patterns(preset.auto_approve) @@ -165,9 +182,14 @@ def _expand_patterns( if self._registry.is_registered(pattern): result.add(pattern) else: - # Unknown pattern — still include it so the security - # layer can match it. Custom action types registered - # later may use this pattern. + logger.warning( + AUTONOMY_PRESET_EXPANDED, + pattern=pattern, + note=( + "pattern not currently registered — included for " + "forward compatibility, verify this is not a typo" + ), + ) result.add(pattern) return frozenset(result) diff --git a/src/ai_company/security/service.py b/src/ai_company/security/service.py index b1d31f49dd..1a8051b337 100644 --- a/src/ai_company/security/service.py +++ b/src/ai_company/security/service.py @@ -229,6 +229,8 @@ async def scan_output( ) try: self._audit_log.record(entry) + except MemoryError, RecursionError: + raise except Exception: logger.exception( SECURITY_AUDIT_RECORD_ERROR, @@ -345,6 +347,8 @@ def _record_audit( approval_id=verdict.approval_id, ) self._audit_log.record(entry) + except MemoryError, RecursionError: + raise except Exception: logger.exception( SECURITY_AUDIT_RECORD_ERROR, @@ -395,6 +399,8 @@ async def _handle_escalation( ) try: await self._approval_store.add(item) + except MemoryError, RecursionError: + raise except Exception: logger.exception( SECURITY_ESCALATION_STORE_ERROR, diff --git a/src/ai_company/security/timeout/config.py b/src/ai_company/security/timeout/config.py index 6ce1f60360..12e55f8c90 100644 --- a/src/ai_company/security/timeout/config.py +++ b/src/ai_company/security/timeout/config.py @@ -67,8 +67,10 @@ class TierConfig(BaseModel): class TieredTimeoutConfig(BaseModel): """Per-risk-tier timeout policy. - Each tier defines its own timeout and action. Unknown risk - tiers fall back to HIGH (fail-safe per D19). + Each tier defines its own timeout and action. Unknown *action types* + are classified as HIGH risk by the ``RiskTierClassifier`` (D19). If + a risk level has no tier configuration entry, the policy defaults to + WAIT (safe fallback). Attributes: policy: Discriminator tag. @@ -129,10 +131,15 @@ class EscalationChainConfig(BaseModel): def _timeout_discriminator(value: object) -> str: - """Extract the ``policy`` discriminator from raw or model data.""" + """Extract the ``policy`` discriminator from raw or model data. + + Returns the raw ``policy`` value without a default so Pydantic + produces a clear "no match in discriminated union" error for + invalid or missing policy fields. + """ if isinstance(value, dict): - return str(value.get("policy", "wait")) - return getattr(value, "policy", "wait") + return str(value.get("policy", "")) + return getattr(value, "policy", "") ApprovalTimeoutConfig = Annotated[ diff --git a/src/ai_company/security/timeout/factory.py b/src/ai_company/security/timeout/factory.py index 239880c994..54b056bbfc 100644 --- a/src/ai_company/security/timeout/factory.py +++ b/src/ai_company/security/timeout/factory.py @@ -1,5 +1,6 @@ """Factory for creating timeout policy instances from configuration.""" +from ai_company.observability import get_logger from ai_company.security.timeout.config import ( ApprovalTimeoutConfig, DenyOnTimeoutConfig, @@ -14,7 +15,9 @@ WaitForeverPolicy, ) from ai_company.security.timeout.protocol import TimeoutPolicy # noqa: TC001 -from ai_company.security.timeout.risk_tier_classifier import YamlRiskTierClassifier +from ai_company.security.timeout.risk_tier_classifier import DefaultRiskTierClassifier + +logger = get_logger(__name__) _SECONDS_PER_MINUTE = 60.0 @@ -44,7 +47,7 @@ def create_timeout_policy( if isinstance(config, TieredTimeoutConfig): return TieredTimeoutPolicy( tiers=config.tiers, - classifier=YamlRiskTierClassifier(), + classifier=DefaultRiskTierClassifier(), ) if isinstance(config, EscalationChainConfig): @@ -54,4 +57,8 @@ def create_timeout_policy( ) msg = f"Unknown timeout policy config type: {type(config).__name__}" # type: ignore[unreachable] + logger.warning( + "timeout.factory.unknown_config", + config_type=type(config).__name__, + ) raise TypeError(msg) diff --git a/src/ai_company/security/timeout/models.py b/src/ai_company/security/timeout/models.py index f445ac56cc..006891af0b 100644 --- a/src/ai_company/security/timeout/models.py +++ b/src/ai_company/security/timeout/models.py @@ -1,8 +1,10 @@ """Timeout action model — the result of evaluating a timeout policy.""" -from pydantic import BaseModel, ConfigDict, Field +from typing import Self -from ai_company.core.enums import TimeoutActionType # noqa: TC001 +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from ai_company.core.enums import TimeoutActionType from ai_company.core.types import NotBlankStr # noqa: TC001 @@ -24,3 +26,17 @@ class TimeoutAction(BaseModel): default=None, description="Escalation target (when action is ESCALATE)", ) + + @model_validator(mode="after") + def _validate_escalate_to(self) -> Self: + """Enforce ``escalate_to`` consistency with action type.""" + if self.action == TimeoutActionType.ESCALATE and self.escalate_to is None: + msg = "escalate_to is required when action is ESCALATE" + raise ValueError(msg) + if self.action != TimeoutActionType.ESCALATE and self.escalate_to is not None: + msg = ( + f"escalate_to must be None when action is " + f"{self.action.value!r}, got {self.escalate_to!r}" + ) + raise ValueError(msg) + return self diff --git a/src/ai_company/security/timeout/park_service.py b/src/ai_company/security/timeout/park_service.py index 27a96cfb21..ed8b63b2e8 100644 --- a/src/ai_company/security/timeout/park_service.py +++ b/src/ai_company/security/timeout/park_service.py @@ -1,19 +1,25 @@ """Park/resume service for agent execution contexts. -Serializes an ``AgentContext`` into a ``ParkedContext`` for persistence -when an agent is parked awaiting approval, and deserializes it back -when the approval decision arrives. +Creates ``ParkedContext`` objects by serializing an ``AgentContext`` to +JSON, and restores them by deserializing. Actual persistence (store / +delete) is the responsibility of the calling code via the +``ParkedContextRepository``. """ +import copy from datetime import UTC, datetime from typing import TYPE_CHECKING +from pydantic import ValidationError + +from ai_company.core.types import NotBlankStr # noqa: TC001 from ai_company.observability import get_logger if TYPE_CHECKING: from ai_company.engine.context import AgentContext from ai_company.observability.events.timeout import ( - TIMEOUT_WAITING, + TIMEOUT_CONTEXT_PARKED, + TIMEOUT_CONTEXT_RESUMED, ) from ai_company.security.timeout.parked_context import ParkedContext @@ -21,20 +27,20 @@ class ParkService: - """Handles parking and resuming agent execution contexts. + """Handles creating and deserializing parked agent execution contexts. - Parking serializes the full ``AgentContext`` as JSON and stores it - via the ``ParkedContextRepository``. Resuming deserializes and - deletes the parked record. + The ``park`` method serializes an ``AgentContext`` into a + ``ParkedContext`` for the caller to persist. The ``resume`` method + deserializes a ``ParkedContext`` back into an ``AgentContext``. """ def park( self, *, context: AgentContext, - approval_id: str, - agent_id: str, - task_id: str, + approval_id: NotBlankStr, + agent_id: NotBlankStr, + task_id: NotBlankStr, metadata: dict[str, str] | None = None, ) -> ParkedContext: """Serialize and create a ``ParkedContext`` from an agent context. @@ -48,8 +54,23 @@ def park( Returns: A ``ParkedContext`` ready for persistence. + + Raises: + ValueError: If the agent context cannot be serialized. """ - context_json = context.model_dump_json() + try: + context_json = context.model_dump_json() + except (ValueError, TypeError) as exc: + logger.exception( + TIMEOUT_CONTEXT_PARKED, + agent_id=agent_id, + task_id=task_id, + approval_id=approval_id, + error=str(exc), + note="Failed to serialize agent context", + ) + msg = f"Failed to serialize agent context for agent {agent_id!r}" + raise ValueError(msg) from exc parked = ParkedContext( execution_id=str(context.execution_id), @@ -58,11 +79,11 @@ def park( approval_id=approval_id, parked_at=datetime.now(UTC), context_json=context_json, - metadata=metadata or {}, + metadata=copy.deepcopy(metadata) if metadata else {}, ) logger.info( - TIMEOUT_WAITING, + TIMEOUT_CONTEXT_PARKED, parked_id=parked.id, agent_id=agent_id, task_id=task_id, @@ -78,7 +99,32 @@ def resume(self, parked: ParkedContext) -> AgentContext: Returns: The restored ``AgentContext``. + + Raises: + ValueError: If the parked context cannot be deserialized. """ from ai_company.engine.context import AgentContext # noqa: PLC0415 - return AgentContext.model_validate_json(parked.context_json) + try: + context = AgentContext.model_validate_json(parked.context_json) + except (ValidationError, ValueError) as exc: + logger.exception( + TIMEOUT_CONTEXT_RESUMED, + parked_id=parked.id, + agent_id=parked.agent_id, + approval_id=parked.approval_id, + error=str(exc), + note="Failed to deserialize parked agent context", + ) + msg = ( + f"Failed to resume parked context {parked.id!r} " + f"for agent {parked.agent_id!r}" + ) + raise ValueError(msg) from exc + + logger.info( + TIMEOUT_CONTEXT_RESUMED, + parked_id=parked.id, + agent_id=parked.agent_id, + ) + return context diff --git a/src/ai_company/security/timeout/policies.py b/src/ai_company/security/timeout/policies.py index 938f1820c0..34a2826e78 100644 --- a/src/ai_company/security/timeout/policies.py +++ b/src/ai_company/security/timeout/policies.py @@ -1,7 +1,7 @@ """Timeout policy implementations — wait, deny, tiered, escalation chain.""" from ai_company.core.approval import ApprovalItem # noqa: TC001 -from ai_company.core.enums import TimeoutActionType +from ai_company.core.enums import ApprovalRiskLevel, TimeoutActionType from ai_company.observability import get_logger from ai_company.observability.events.timeout import ( TIMEOUT_AUTO_DENIED, @@ -61,6 +61,9 @@ class DenyOnTimeoutPolicy: """ def __init__(self, *, timeout_seconds: float) -> None: + if timeout_seconds <= 0: + msg = f"timeout_seconds must be positive, got {timeout_seconds}" + raise ValueError(msg) self._timeout_seconds = timeout_seconds async def determine_action( @@ -147,11 +150,12 @@ async def determine_action( if tier_config is None: # No tier configured for this risk level — wait (safe default). - logger.debug( + logger.warning( TIMEOUT_WAITING, approval_id=item.id, risk_level=risk_level.value, - note="no tier config — defaulting to wait", + available_tiers=sorted(self._tiers.keys()), + note="no tier config for this risk level — defaulting to wait", ) return TimeoutAction( action=TimeoutActionType.WAIT, @@ -178,18 +182,37 @@ async def determine_action( ), ) + effective_action = tier_config.on_timeout + + # Guard: never auto-approve HIGH or CRITICAL actions. + _approve_forbidden = {ApprovalRiskLevel.HIGH, ApprovalRiskLevel.CRITICAL} + if ( + effective_action == TimeoutActionType.APPROVE + and risk_level in _approve_forbidden + ): + logger.warning( + TIMEOUT_POLICY_EVALUATED, + approval_id=item.id, + risk_level=risk_level.value, + configured_action=effective_action.value, + note=( + "auto-approve blocked for high/critical risk — overriding to DENY" + ), + ) + effective_action = TimeoutActionType.DENY + logger.info( TIMEOUT_POLICY_EVALUATED, approval_id=item.id, risk_level=risk_level.value, - on_timeout=tier_config.on_timeout.value, + on_timeout=effective_action.value, elapsed_seconds=elapsed_seconds, ) return TimeoutAction( - action=tier_config.on_timeout, + action=effective_action, reason=( f"Tier {risk_level.value} timeout: auto-" - f"{tier_config.on_timeout.value} after " + f"{effective_action.value} after " f"{elapsed_seconds:.0f}s" ), ) @@ -230,9 +253,16 @@ async def determine_action( elapsed_seconds: Seconds since creation. Returns: - WAIT, ESCALATE, or the chain-exhausted action. + ESCALATE (to the current step's role) or the + chain-exhausted action. """ if not self._chain: + logger.warning( + TIMEOUT_ESCALATED, + approval_id=item.id, + on_exhausted=self._on_chain_exhausted.value, + note="empty escalation chain — likely a configuration error", + ) return TimeoutAction( action=self._on_chain_exhausted, reason="Empty escalation chain — applying exhausted action", @@ -242,10 +272,6 @@ async def determine_action( for step in self._chain: step_timeout = step.timeout_minutes * _SECONDS_PER_MINUTE if elapsed_seconds < cumulative_seconds + step_timeout: - # Still within this step's window. - if elapsed_seconds < cumulative_seconds: - # Before this step — shouldn't happen but safe. - break logger.debug( TIMEOUT_WAITING, approval_id=item.id, diff --git a/src/ai_company/security/timeout/risk_tier_classifier.py b/src/ai_company/security/timeout/risk_tier_classifier.py index cfae61bfd7..427166b8e8 100644 --- a/src/ai_company/security/timeout/risk_tier_classifier.py +++ b/src/ai_company/security/timeout/risk_tier_classifier.py @@ -1,11 +1,11 @@ -"""YAML-configurable risk tier classifier for timeout policies.""" +"""Configurable risk tier classifier for timeout policies.""" from types import MappingProxyType from typing import Final from ai_company.core.enums import ActionType, ApprovalRiskLevel from ai_company.observability import get_logger -from ai_company.observability.events.timeout import TIMEOUT_POLICY_EVALUATED +from ai_company.observability.events.timeout import TIMEOUT_UNKNOWN_ACTION_TYPE logger = get_logger(__name__) @@ -45,7 +45,7 @@ ) -class YamlRiskTierClassifier: +class DefaultRiskTierClassifier: """Maps action types to risk tiers for tiered timeout policies. Unknown action types default to HIGH (fail-safe per D19). @@ -77,11 +77,11 @@ def classify(self, action_type: str) -> ApprovalRiskLevel: """ result = self._risk_map.get(action_type) if result is None: - logger.debug( - TIMEOUT_POLICY_EVALUATED, + logger.warning( + TIMEOUT_UNKNOWN_ACTION_TYPE, action_type=action_type, - risk_tier="high", - note="unknown action type — defaulting to HIGH", + default_tier="high", + note="unknown action type — defaulting to HIGH (D19)", ) return ApprovalRiskLevel.HIGH return result diff --git a/src/ai_company/templates/renderer.py b/src/ai_company/templates/renderer.py index 43c0e8a996..410da80f91 100644 --- a/src/ai_company/templates/renderer.py +++ b/src/ai_company/templates/renderer.py @@ -65,6 +65,10 @@ logger = get_logger(__name__) +# Module-level Jinja2 environment — stateless and safe to reuse. +_JINJA_ENV = SandboxedEnvironment(keep_trailing_newline=True) +_JINJA_ENV.filters["auto"] = lambda value: value or "" + def render_template( loaded: LoadedTemplate, @@ -322,22 +326,6 @@ def _collect_variables( return result -def _create_jinja_env() -> SandboxedEnvironment: - """Create a sandboxed Jinja2 environment with custom filters. - - Returns: - Configured :class:`SandboxedEnvironment`. - """ - env = SandboxedEnvironment( - keep_trailing_newline=True, - ) - # ``auto`` filter: converts falsy values to empty string, which - # triggers auto-name generation downstream (empty names are - # detected by ``_expand_agents``). - env.filters["auto"] = lambda value: value or "" - return env - - def _render_jinja2( raw_yaml: str, variables: dict[str, Any], @@ -357,9 +345,8 @@ def _render_jinja2( Raises: TemplateRenderError: If Jinja2 rendering fails. """ - env = _create_jinja_env() try: - jinja_template = env.from_string(raw_yaml) + jinja_template = _JINJA_ENV.from_string(raw_yaml) return jinja_template.render(**variables) except Jinja2TemplateError as exc: logger.exception( @@ -620,7 +607,9 @@ def _expand_single_agent( "level": agent.get("level", "mid"), } - _resolve_agent_personality(agent, name, agent_dict) + personality = _resolve_agent_personality(agent, name) + if personality is not None: + agent_dict["personality"] = personality model_tier = agent.get("model", "medium") agent_dict["model"] = {"provider": _DEFAULT_PROVIDER, "model_id": model_tier} @@ -652,16 +641,15 @@ def _expand_single_agent( def _resolve_agent_personality( agent: dict[str, Any], name: str, - agent_dict: dict[str, Any], -) -> None: +) -> dict[str, Any] | None: """Resolve personality from inline config or named preset. - Mutates *agent_dict* to add the ``personality`` key when resolved. - Args: agent: Raw agent dict from rendered YAML. name: Resolved agent name for error context. - agent_dict: Partially-built agent dict (mutated in place). + + Returns: + Personality dict, or ``None`` if no personality configured. Raises: TemplateRenderError: If personality config is invalid or preset @@ -683,10 +671,10 @@ def _resolve_agent_personality( ) raise TemplateRenderError(msg) _validate_inline_personality(inline_personality, name) - agent_dict["personality"] = inline_personality - elif preset_name: + return dict(inline_personality) + if preset_name: try: - agent_dict["personality"] = get_personality_preset(preset_name) + return get_personality_preset(preset_name) except KeyError as exc: msg = f"Unknown personality preset {preset_name!r} for agent {name!r}" logger.warning( @@ -695,6 +683,7 @@ def _resolve_agent_personality( preset=preset_name, ) raise TemplateRenderError(msg) from exc + return None def _validate_inline_personality( diff --git a/tests/unit/api/controllers/test_autonomy.py b/tests/unit/api/controllers/test_autonomy.py new file mode 100644 index 0000000000..c460503db9 --- /dev/null +++ b/tests/unit/api/controllers/test_autonomy.py @@ -0,0 +1,62 @@ +"""Tests for autonomy controller.""" + +from typing import Any + +import pytest +from litestar.testing import TestClient # noqa: TC002 + +_BASE = "/api/v1/agents" +_WRITE_HEADERS = {"X-Human-Role": "ceo"} +_READ_HEADERS = {"X-Human-Role": "observer"} + + +def _url(agent_id: str = "agent-001") -> str: + return f"{_BASE}/{agent_id}/autonomy" + + +@pytest.mark.unit +class TestGetAutonomy: + def test_get_autonomy(self, test_client: TestClient[Any]) -> None: + resp = test_client.get(_url("agent-42"), headers=_READ_HEADERS) + assert resp.status_code == 200 + body = resp.json() + assert body["success"] is True + data = body["data"] + assert data["agent_id"] == "agent-42" + assert data["level"] == "semi" + assert data["promotion_pending"] is False + + def test_get_autonomy_requires_read_access( + self, test_client: TestClient[Any] + ) -> None: + resp = test_client.get(_url(), headers={"X-Human-Role": "invalid"}) + assert resp.status_code == 403 + + +@pytest.mark.unit +class TestUpdateAutonomy: + def test_update_autonomy_returns_pending( + self, test_client: TestClient[Any] + ) -> None: + resp = test_client.post( + _url("agent-42"), + json={"level": "full"}, + headers=_WRITE_HEADERS, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["success"] is True + data = body["data"] + assert data["agent_id"] == "agent-42" + assert data["level"] == "semi" + assert data["promotion_pending"] is True + + def test_update_autonomy_requires_write_access( + self, test_client: TestClient[Any] + ) -> None: + resp = test_client.post( + _url(), + json={"level": "full"}, + headers=_READ_HEADERS, + ) + assert resp.status_code == 403 diff --git a/tests/unit/observability/test_events.py b/tests/unit/observability/test_events.py index 80c72915ab..419ba3fe8c 100644 --- a/tests/unit/observability/test_events.py +++ b/tests/unit/observability/test_events.py @@ -552,6 +552,85 @@ def test_persistence_events_exist(self, constant_name: str, expected: str) -> No assert getattr(mod, constant_name) == expected + def test_autonomy_events_exist(self) -> None: + from ai_company.observability.events.autonomy import ( + AUTONOMY_ACTION_AUTO_APPROVED, + AUTONOMY_ACTION_HUMAN_REQUIRED, + AUTONOMY_DOWNGRADE_TRIGGERED, + AUTONOMY_PRESET_EXPANDED, + AUTONOMY_PROMOTION_DENIED, + AUTONOMY_PROMOTION_REQUESTED, + AUTONOMY_RECOVERY_REQUESTED, + AUTONOMY_RESOLVED, + AUTONOMY_SENIORITY_VIOLATION, + ) + + assert AUTONOMY_RESOLVED == "autonomy.resolved" + assert AUTONOMY_PROMOTION_REQUESTED == "autonomy.promotion.requested" + assert AUTONOMY_PROMOTION_DENIED == "autonomy.promotion.denied" + assert AUTONOMY_DOWNGRADE_TRIGGERED == "autonomy.downgrade.triggered" + assert AUTONOMY_RECOVERY_REQUESTED == "autonomy.recovery.requested" + assert AUTONOMY_SENIORITY_VIOLATION == "autonomy.seniority.violation" + assert AUTONOMY_PRESET_EXPANDED == "autonomy.preset.expanded" + assert AUTONOMY_ACTION_AUTO_APPROVED == "autonomy.action.auto_approved" + assert AUTONOMY_ACTION_HUMAN_REQUIRED == "autonomy.action.human_required" + + def test_timeout_events_exist(self) -> None: + from ai_company.observability.events.timeout import ( + TIMEOUT_AUTO_APPROVED, + TIMEOUT_AUTO_DENIED, + TIMEOUT_CONTEXT_PARKED, + TIMEOUT_CONTEXT_RESUMED, + TIMEOUT_ESCALATED, + TIMEOUT_POLICY_EVALUATED, + TIMEOUT_UNKNOWN_ACTION_TYPE, + TIMEOUT_WAITING, + ) + + assert TIMEOUT_POLICY_EVALUATED == "timeout.policy.evaluated" + assert TIMEOUT_AUTO_APPROVED == "timeout.auto_approved" + assert TIMEOUT_AUTO_DENIED == "timeout.auto_denied" + assert TIMEOUT_ESCALATED == "timeout.escalated" + assert TIMEOUT_WAITING == "timeout.waiting" + assert TIMEOUT_CONTEXT_PARKED == "timeout.context.parked" + assert TIMEOUT_CONTEXT_RESUMED == "timeout.context.resumed" + assert TIMEOUT_UNKNOWN_ACTION_TYPE == "timeout.unknown_action_type" + + def test_parked_context_persistence_events_exist(self) -> None: + from ai_company.observability.events.persistence import ( + PERSISTENCE_PARKED_CONTEXT_DELETED, + PERSISTENCE_PARKED_CONTEXT_DESERIALIZE_FAILED, + PERSISTENCE_PARKED_CONTEXT_NOT_FOUND, + PERSISTENCE_PARKED_CONTEXT_QUERIED, + PERSISTENCE_PARKED_CONTEXT_QUERY_FAILED, + PERSISTENCE_PARKED_CONTEXT_SAVE_FAILED, + PERSISTENCE_PARKED_CONTEXT_SAVED, + ) + + assert PERSISTENCE_PARKED_CONTEXT_SAVED == "persistence.parked_context.saved" + assert ( + PERSISTENCE_PARKED_CONTEXT_SAVE_FAILED + == "persistence.parked_context.save_failed" + ) + assert ( + PERSISTENCE_PARKED_CONTEXT_QUERIED == "persistence.parked_context.queried" + ) + assert ( + PERSISTENCE_PARKED_CONTEXT_QUERY_FAILED + == "persistence.parked_context.query_failed" + ) + assert ( + PERSISTENCE_PARKED_CONTEXT_NOT_FOUND + == "persistence.parked_context.not_found" + ) + assert ( + PERSISTENCE_PARKED_CONTEXT_DELETED == "persistence.parked_context.deleted" + ) + assert ( + PERSISTENCE_PARKED_CONTEXT_DESERIALIZE_FAILED + == "persistence.parked_context.deserialize_failed" + ) + def test_classification_events_exist(self) -> None: assert CLASSIFICATION_START == "classification.start" assert CLASSIFICATION_COMPLETE == "classification.complete" diff --git a/tests/unit/persistence/sqlite/test_migrations.py b/tests/unit/persistence/sqlite/test_migrations.py index 35f61d0b0d..440e814c82 100644 --- a/tests/unit/persistence/sqlite/test_migrations.py +++ b/tests/unit/persistence/sqlite/test_migrations.py @@ -91,6 +91,29 @@ async def test_skips_when_already_at_version( await run_migrations(migrated_db) assert await get_user_version(migrated_db) == version_before + async def test_v3_creates_parked_contexts_table( + self, memory_db: aiosqlite.Connection + ) -> None: + await run_migrations(memory_db) + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master " + "WHERE type='table' AND name='parked_contexts'" + ) + row = await cursor.fetchone() + assert row is not None + + async def test_v3_creates_parked_context_indexes( + self, memory_db: aiosqlite.Connection + ) -> None: + await run_migrations(memory_db) + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='index' " + "AND name LIKE 'idx_pc_%' ORDER BY name" + ) + indexes = {row[0] for row in await cursor.fetchall()} + assert "idx_pc_agent_id" in indexes + assert "idx_pc_approval_id" in indexes + async def test_migration_failure_raises_migration_error( self, memory_db: aiosqlite.Connection ) -> None: diff --git a/tests/unit/persistence/sqlite/test_parked_context_repo.py b/tests/unit/persistence/sqlite/test_parked_context_repo.py new file mode 100644 index 0000000000..7b77b36a35 --- /dev/null +++ b/tests/unit/persistence/sqlite/test_parked_context_repo.py @@ -0,0 +1,202 @@ +"""Tests for SQLiteParkedContextRepository.""" + +from datetime import UTC, datetime, timedelta +from typing import TYPE_CHECKING +from uuid import uuid4 + +import pytest + +from ai_company.persistence.errors import QueryError +from ai_company.persistence.sqlite.parked_context_repo import ( + SQLiteParkedContextRepository, +) +from ai_company.security.timeout.parked_context import ParkedContext + +if TYPE_CHECKING: + import aiosqlite + + +def _make_context( # noqa: PLR0913 + *, + parked_id: str | None = None, + execution_id: str = "exec-001", + agent_id: str = "agent-001", + task_id: str = "task-001", + approval_id: str = "approval-001", + parked_at: datetime | None = None, + context_json: str = '{"state": "running"}', + metadata: dict[str, str] | None = None, +) -> ParkedContext: + return ParkedContext( + id=parked_id or str(uuid4()), + execution_id=execution_id, + agent_id=agent_id, + task_id=task_id, + approval_id=approval_id, + parked_at=parked_at or datetime.now(UTC), + context_json=context_json, + metadata=metadata or {}, + ) + + +@pytest.mark.unit +class TestSQLiteParkedContextRepository: + async def test_save_and_get(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteParkedContextRepository(migrated_db) + ctx = _make_context(parked_id="parked-001") + await repo.save(ctx) + + result = await repo.get("parked-001") + assert result is not None + assert result.id == ctx.id + assert result.execution_id == ctx.execution_id + assert result.agent_id == ctx.agent_id + assert result.task_id == ctx.task_id + assert result.approval_id == ctx.approval_id + assert result.parked_at == ctx.parked_at + assert result.context_json == ctx.context_json + assert result.metadata == ctx.metadata + + async def test_get_returns_none_for_missing( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteParkedContextRepository(migrated_db) + assert await repo.get("nonexistent") is None + + async def test_get_by_approval(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteParkedContextRepository(migrated_db) + ctx = _make_context(approval_id="approval-xyz") + await repo.save(ctx) + + result = await repo.get_by_approval("approval-xyz") + assert result is not None + assert result.approval_id == "approval-xyz" + assert result.id == ctx.id + + async def test_get_by_approval_returns_none( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteParkedContextRepository(migrated_db) + assert await repo.get_by_approval("nonexistent") is None + + async def test_get_by_agent(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteParkedContextRepository(migrated_db) + ctx1 = _make_context(agent_id="agent-a", approval_id="ap-1") + ctx2 = _make_context(agent_id="agent-a", approval_id="ap-2") + await repo.save(ctx1) + await repo.save(ctx2) + + results = await repo.get_by_agent("agent-a") + assert len(results) == 2 + ids = {r.id for r in results} + assert ctx1.id in ids + assert ctx2.id in ids + + async def test_get_by_agent_returns_empty( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteParkedContextRepository(migrated_db) + results = await repo.get_by_agent("nonexistent") + assert results == () + + async def test_get_by_agent_ordered_by_parked_at_desc( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteParkedContextRepository(migrated_db) + now = datetime.now(UTC) + earlier = now - timedelta(hours=1) + + ctx_old = _make_context( + agent_id="agent-b", + approval_id="ap-old", + parked_at=earlier, + ) + ctx_new = _make_context( + agent_id="agent-b", + approval_id="ap-new", + parked_at=now, + ) + # Save in chronological order to ensure DB ordering is not + # merely insertion order. + await repo.save(ctx_old) + await repo.save(ctx_new) + + results = await repo.get_by_agent("agent-b") + assert len(results) == 2 + assert results[0].id == ctx_new.id + assert results[1].id == ctx_old.id + + async def test_delete(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteParkedContextRepository(migrated_db) + ctx = _make_context(parked_id="del-me") + await repo.save(ctx) + + assert await repo.delete("del-me") is True + assert await repo.get("del-me") is None + + async def test_delete_returns_false_for_missing( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteParkedContextRepository(migrated_db) + assert await repo.delete("nonexistent") is False + + async def test_save_upsert(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteParkedContextRepository(migrated_db) + ctx = _make_context( + parked_id="upsert-id", + context_json='{"step": 1}', + metadata={"key": "original"}, + ) + await repo.save(ctx) + + updated = ParkedContext( + id="upsert-id", + execution_id=ctx.execution_id, + agent_id=ctx.agent_id, + task_id=ctx.task_id, + approval_id=ctx.approval_id, + parked_at=ctx.parked_at, + context_json='{"step": 2}', + metadata={"key": "updated"}, + ) + await repo.save(updated) + + result = await repo.get("upsert-id") + assert result is not None + assert result.context_json == '{"step": 2}' + assert result.metadata == {"key": "updated"} + + async def test_save_round_trips_metadata( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Metadata dict survives JSON serialization round-trip.""" + repo = SQLiteParkedContextRepository(migrated_db) + ctx = _make_context( + parked_id="meta-rt", + metadata={"tool": "shell", "action": "execute"}, + ) + await repo.save(ctx) + + result = await repo.get("meta-rt") + assert result is not None + assert result.metadata == {"tool": "shell", "action": "execute"} + + async def test_row_to_model_raises_on_corrupt_data( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Corrupt metadata JSON triggers QueryError in _row_to_model.""" + await migrated_db.execute( + """\ +INSERT INTO parked_contexts ( + id, execution_id, agent_id, task_id, approval_id, + parked_at, context_json, metadata +) VALUES ( + 'corrupt-1', 'exec-1', 'agent-1', 'task-1', 'approval-1', + '2026-03-01T12:00:00+00:00', '{}', '{BAD JSON}' +)""" + ) + await migrated_db.commit() + + repo = SQLiteParkedContextRepository(migrated_db) + with pytest.raises(QueryError, match="deserialize parked context"): + await repo.get("corrupt-1") diff --git a/tests/unit/security/autonomy/test_change_strategy.py b/tests/unit/security/autonomy/test_change_strategy.py index 7b5eee044a..ab41967b04 100644 --- a/tests/unit/security/autonomy/test_change_strategy.py +++ b/tests/unit/security/autonomy/test_change_strategy.py @@ -58,6 +58,17 @@ def test_no_override_when_not_downgraded(self) -> None: strategy = HumanOnlyPromotionStrategy() assert strategy.get_override("agent-1") is None + @pytest.mark.unit + def test_double_downgrade_preserves_original(self) -> None: + strategy = HumanOnlyPromotionStrategy() + strategy.auto_downgrade("agent-1", DowngradeReason.HIGH_ERROR_RATE) + strategy.auto_downgrade("agent-1", DowngradeReason.SECURITY_INCIDENT) + override = strategy.get_override("agent-1") + assert override is not None + # Second downgrade replaces the first + assert override.current_level == AutonomyLevel.LOCKED + assert override.reason == DowngradeReason.SECURITY_INCIDENT + class TestRecovery: """Recovery is always denied in human-only strategy.""" diff --git a/tests/unit/security/autonomy/test_models.py b/tests/unit/security/autonomy/test_models.py index 433f1af1ff..6442e5ea52 100644 --- a/tests/unit/security/autonomy/test_models.py +++ b/tests/unit/security/autonomy/test_models.py @@ -151,6 +151,16 @@ def test_frozen(self) -> None: with pytest.raises(ValidationError): effective.level = AutonomyLevel.LOCKED # type: ignore[misc] + @pytest.mark.unit + def test_disjoint_overlap_raises(self) -> None: + with pytest.raises(ValidationError, match="disjoint"): + EffectiveAutonomy( + level=AutonomyLevel.SEMI, + auto_approve_actions=frozenset({"code:read", "code:write"}), + human_approval_actions=frozenset({"code:write", "deploy:prod"}), + security_agent=True, + ) + class TestAutonomyOverride: """AutonomyOverride model tests.""" diff --git a/tests/unit/security/timeout/test_config.py b/tests/unit/security/timeout/test_config.py index 3b332f334d..c7b830085b 100644 --- a/tests/unit/security/timeout/test_config.py +++ b/tests/unit/security/timeout/test_config.py @@ -13,6 +13,7 @@ TieredTimeoutConfig, WaitForeverConfig, ) +from ai_company.security.timeout.models import TimeoutAction _adapter: TypeAdapter[ApprovalTimeoutConfig] = TypeAdapter(ApprovalTimeoutConfig) @@ -123,3 +124,24 @@ def test_discriminator(self) -> None: ) assert isinstance(result, EscalationChainConfig) assert len(result.chain) == 1 + + +class TestTimeoutAction: + """TimeoutAction escalate_to validator tests.""" + + @pytest.mark.unit + def test_escalate_without_target_raises(self) -> None: + with pytest.raises(ValidationError, match="escalate_to is required"): + TimeoutAction( + action=TimeoutActionType.ESCALATE, + reason="test", + ) + + @pytest.mark.unit + def test_non_escalate_with_target_raises(self) -> None: + with pytest.raises(ValidationError, match="escalate_to must be None"): + TimeoutAction( + action=TimeoutActionType.DENY, + reason="test", + escalate_to="lead", + ) diff --git a/tests/unit/security/timeout/test_policies.py b/tests/unit/security/timeout/test_policies.py index 3f74e10b60..926674eb4b 100644 --- a/tests/unit/security/timeout/test_policies.py +++ b/tests/unit/security/timeout/test_policies.py @@ -13,7 +13,7 @@ TieredTimeoutPolicy, WaitForeverPolicy, ) -from ai_company.security.timeout.risk_tier_classifier import YamlRiskTierClassifier +from ai_company.security.timeout.risk_tier_classifier import DefaultRiskTierClassifier def _make_item( @@ -75,6 +75,16 @@ async def test_deny_after_timeout(self) -> None: result = await policy.determine_action(item, 7200.0) assert result.action == TimeoutActionType.DENY + @pytest.mark.unit + async def test_negative_timeout_raises(self) -> None: + with pytest.raises(ValueError, match="timeout_seconds must be positive"): + DenyOnTimeoutPolicy(timeout_seconds=-1.0) + + @pytest.mark.unit + async def test_zero_timeout_raises(self) -> None: + with pytest.raises(ValueError, match="timeout_seconds must be positive"): + DenyOnTimeoutPolicy(timeout_seconds=0.0) + class TestTieredTimeoutPolicy: """TieredTimeoutPolicy: per-risk-tier timeout behavior.""" @@ -86,7 +96,7 @@ async def test_wait_within_tier_timeout(self) -> None: } policy = TieredTimeoutPolicy( tiers=tiers, - classifier=YamlRiskTierClassifier(), + classifier=DefaultRiskTierClassifier(), ) item = _make_item(action_type="code:write") # MEDIUM risk result = await policy.determine_action(item, 1800.0) # 30 min @@ -99,7 +109,7 @@ async def test_deny_after_tier_timeout(self) -> None: } policy = TieredTimeoutPolicy( tiers=tiers, - classifier=YamlRiskTierClassifier(), + classifier=DefaultRiskTierClassifier(), ) item = _make_item(action_type="code:write") # MEDIUM risk result = await policy.determine_action(item, 3601.0) # > 60 min @@ -114,7 +124,7 @@ async def test_approve_on_tier_timeout(self) -> None: } policy = TieredTimeoutPolicy( tiers=tiers, - classifier=YamlRiskTierClassifier(), + classifier=DefaultRiskTierClassifier(), ) item = _make_item(action_type="code:read") # LOW risk result = await policy.determine_action(item, 30000.0) # > 480 min @@ -124,7 +134,7 @@ async def test_approve_on_tier_timeout(self) -> None: async def test_no_tier_config_waits(self) -> None: policy = TieredTimeoutPolicy( tiers={}, - classifier=YamlRiskTierClassifier(), + classifier=DefaultRiskTierClassifier(), ) item = _make_item() result = await policy.determine_action(item, 999999.0) @@ -180,6 +190,17 @@ async def test_chain_exhausted(self) -> None: result = await policy.determine_action(item, 6000.0) assert result.action == TimeoutActionType.DENY + @pytest.mark.unit + async def test_chain_exhausted_approve(self) -> None: + chain = (EscalationStep(role="lead", timeout_minutes=30),) + policy = EscalationChainPolicy( + chain=chain, + on_chain_exhausted=TimeoutActionType.APPROVE, + ) + item = _make_item() + result = await policy.determine_action(item, 3600.0) + assert result.action == TimeoutActionType.APPROVE + @pytest.mark.unit async def test_empty_chain_exhausted_immediately(self) -> None: policy = EscalationChainPolicy( diff --git a/tests/unit/security/timeout/test_risk_tier_classifier.py b/tests/unit/security/timeout/test_risk_tier_classifier.py index 5f86c781b8..7190328141 100644 --- a/tests/unit/security/timeout/test_risk_tier_classifier.py +++ b/tests/unit/security/timeout/test_risk_tier_classifier.py @@ -1,9 +1,9 @@ -"""Tests for YamlRiskTierClassifier.""" +"""Tests for DefaultRiskTierClassifier.""" import pytest from ai_company.core.enums import ActionType, ApprovalRiskLevel -from ai_company.security.timeout.risk_tier_classifier import YamlRiskTierClassifier +from ai_company.security.timeout.risk_tier_classifier import DefaultRiskTierClassifier class TestDefaultMapping: @@ -11,25 +11,25 @@ class TestDefaultMapping: @pytest.mark.unit def test_critical_actions(self) -> None: - classifier = YamlRiskTierClassifier() + classifier = DefaultRiskTierClassifier() expected = ApprovalRiskLevel.CRITICAL assert classifier.classify(ActionType.DEPLOY_PRODUCTION) == expected assert classifier.classify(ActionType.DB_ADMIN) == expected @pytest.mark.unit def test_high_actions(self) -> None: - classifier = YamlRiskTierClassifier() + classifier = DefaultRiskTierClassifier() assert classifier.classify(ActionType.VCS_PUSH) == ApprovalRiskLevel.HIGH assert classifier.classify(ActionType.CODE_DELETE) == ApprovalRiskLevel.HIGH @pytest.mark.unit def test_medium_actions(self) -> None: - classifier = YamlRiskTierClassifier() + classifier = DefaultRiskTierClassifier() assert classifier.classify(ActionType.CODE_WRITE) == ApprovalRiskLevel.MEDIUM @pytest.mark.unit def test_low_actions(self) -> None: - classifier = YamlRiskTierClassifier() + classifier = DefaultRiskTierClassifier() assert classifier.classify(ActionType.CODE_READ) == ApprovalRiskLevel.LOW assert classifier.classify(ActionType.TEST_RUN) == ApprovalRiskLevel.LOW @@ -39,7 +39,7 @@ class TestUnknownFallback: @pytest.mark.unit def test_unknown_defaults_to_high(self) -> None: - classifier = YamlRiskTierClassifier() + classifier = DefaultRiskTierClassifier() assert classifier.classify("unknown:action") == ApprovalRiskLevel.HIGH @@ -48,14 +48,14 @@ class TestCustomMap: @pytest.mark.unit def test_custom_override(self) -> None: - classifier = YamlRiskTierClassifier( + classifier = DefaultRiskTierClassifier( custom_map={ActionType.CODE_READ: ApprovalRiskLevel.CRITICAL} ) assert classifier.classify(ActionType.CODE_READ) == ApprovalRiskLevel.CRITICAL @pytest.mark.unit def test_custom_preserves_defaults(self) -> None: - classifier = YamlRiskTierClassifier( + classifier = DefaultRiskTierClassifier( custom_map={"custom:action": ApprovalRiskLevel.LOW} ) # Default still works. diff --git a/tests/unit/security/timeout/test_timeout_checker.py b/tests/unit/security/timeout/test_timeout_checker.py index 4c9d289009..8cf811892a 100644 --- a/tests/unit/security/timeout/test_timeout_checker.py +++ b/tests/unit/security/timeout/test_timeout_checker.py @@ -34,12 +34,14 @@ def _make_mock_policy( *, action: TimeoutActionType = TimeoutActionType.WAIT, reason: str = "test reason", + escalate_to: str | None = None, ) -> AsyncMock: """Create a mock TimeoutPolicy returning the given action.""" mock_policy = AsyncMock() mock_policy.determine_action.return_value = TimeoutAction( action=action, reason=reason, + escalate_to=escalate_to, ) return mock_policy @@ -119,6 +121,7 @@ async def test_check_and_resolve_escalate(self) -> None: mock_policy = _make_mock_policy( action=TimeoutActionType.ESCALATE, reason="Escalating to manager", + escalate_to="manager", ) checker = TimeoutChecker(policy=mock_policy) item = _make_approval_item() From 8ce32f91d18eaa27194147f5dc6e05b64b8e1340 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Tue, 10 Mar 2026 12:59:15 +0100 Subject: [PATCH 3/4] fix: correct _trim_sections unpacking after rebase conflict resolution --- src/ai_company/engine/prompt.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/ai_company/engine/prompt.py b/src/ai_company/engine/prompt.py index a75ffc0d33..a4ed1fd3d6 100644 --- a/src/ai_company/engine/prompt.py +++ b/src/ai_company/engine/prompt.py @@ -671,19 +671,17 @@ def _render_with_trimming( # noqa: PLR0913 ) if max_tokens is not None and estimated > max_tokens: - content, estimated, task, available_tools, company, org_policies = ( - _trim_sections( - template_str=template_str, - agent=agent, - role=role, - task=task, - available_tools=available_tools, - company=company, - org_policies=org_policies, - max_tokens=max_tokens, - estimator=estimator, - effective_autonomy=effective_autonomy, - ) + content, estimated, task, company, org_policies = _trim_sections( + template_str=template_str, + agent=agent, + role=role, + task=task, + available_tools=available_tools, + company=company, + org_policies=org_policies, + max_tokens=max_tokens, + estimator=estimator, + effective_autonomy=effective_autonomy, ) return _build_prompt_result( From 4deae2a46486a2c2de876630fa2b95fcf8847a4d Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Tue, 10 Mar 2026 14:27:45 +0100 Subject: [PATCH 4/4] =?UTF-8?q?fix:=20address=20PR=20review=20findings=20?= =?UTF-8?q?=E2=80=94=20circular=20import,=20test=20fixes,=20doc=20updates?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix circular import in security/autonomy/__init__.py (removed eager AutonomyResolver import that caused core→security→core cycle) - Fix CompanyConfigFactory to pin approval_timeout=WaitForeverConfig() - Update prompt template version assertions to 1.4.0 - Fix test_non_pending_item_raises to supply decided_at/decided_by - Replace assert with restructured control flow in TieredTimeoutPolicy - Update DESIGN_SPEC.md, CLAUDE.md, README.md for autonomy/timeout docs --- CLAUDE.md | 4 +- DESIGN_SPEC.md | 29 +++- README.md | 2 +- src/ai_company/core/agent.py | 14 ++ src/ai_company/core/company.py | 6 +- src/ai_company/core/enums.py | 25 +++ src/ai_company/engine/agent_engine.py | 21 ++- src/ai_company/engine/prompt.py | 1 + src/ai_company/engine/prompt_template.py | 2 +- .../observability/events/timeout.py | 1 + src/ai_company/security/autonomy/__init__.py | 13 +- .../security/autonomy/change_strategy.py | 20 ++- src/ai_company/security/autonomy/models.py | 119 ++++++++------- src/ai_company/security/service.py | 142 ++++++++---------- src/ai_company/security/timeout/config.py | 44 +++++- src/ai_company/security/timeout/factory.py | 3 +- .../security/timeout/park_service.py | 14 ++ .../security/timeout/parked_context.py | 22 ++- src/ai_company/security/timeout/policies.py | 38 ++++- .../security/timeout/risk_tier_classifier.py | 14 ++ .../security/timeout/timeout_checker.py | 34 ++++- src/ai_company/templates/renderer.py | 5 +- tests/unit/core/conftest.py | 2 + tests/unit/engine/test_prompt.py | 2 +- .../memory/org/test_prompt_integration.py | 2 +- tests/unit/observability/test_events.py | 2 - tests/unit/persistence/test_protocol.py | 9 ++ .../security/autonomy/test_change_strategy.py | 19 +++ tests/unit/security/autonomy/test_models.py | 27 ++++ tests/unit/security/test_service.py | 47 ++++-- tests/unit/security/timeout/test_config.py | 47 +++++- .../security/timeout/test_parked_context.py | 19 +++ tests/unit/security/timeout/test_policies.py | 61 +++++++- .../security/timeout/test_timeout_checker.py | 23 +++ 34 files changed, 647 insertions(+), 186 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index ab0a25cc36..324b936139 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -55,7 +55,7 @@ src/ai_company/ persistence/ # Operational data persistence — pluggable PersistenceBackend protocol, SQLite initial (§7.6) observability/ # Structured logging, correlation tracking, log sinks providers/ # LLM provider abstraction (LiteLLM adapter) - security/ # SecOps agent, rule engine (soft-allow/hard-deny, fail-closed), audit log, output scanner, risk classifier, action type registry, ToolInvoker security integration, progressive trust (4 strategies: disabled/weighted/per-category/milestone) + security/ # SecOps agent, rule engine (soft-allow/hard-deny, fail-closed), audit log, output scanner, risk classifier, risk tier classifier, action type registry, ToolInvoker security integration, progressive trust (4 strategies: disabled/weighted/per-category/milestone), autonomy levels (presets, resolver, change strategy), timeout policies (park/resume) templates/ # Pre-built company templates, personality presets, and builder tools/ # Tool registry, built-in tools (file_system/, git, sandbox/, code_runner), MCP bridge (mcp/), role-based access ``` @@ -84,7 +84,7 @@ src/ai_company/ - **Every module** with business logic MUST have: `from ai_company.observability import get_logger` then `logger = get_logger(__name__)` - **Never** use `import logging` / `logging.getLogger()` / `print()` in application code - **Variable name**: always `logger` (not `_logger`, not `log`) -- **Event names**: always use constants from the domain-specific module under `ai_company.observability.events` (e.g. `PROVIDER_CALL_START` from `events.provider`, `BUDGET_RECORD_ADDED` from `events.budget`, `CFO_ANOMALY_DETECTED` from `events.cfo`, `CONFLICT_DETECTED` from `events.conflict`, `MEETING_STARTED` from `events.meeting`, `CLASSIFICATION_START` from `events.classification`, `CONSOLIDATION_START` from `events.consolidation`, `ORG_MEMORY_QUERY_START` from `events.org_memory`, `API_REQUEST_STARTED` from `events.api`, `CODE_RUNNER_EXECUTE_START` from `events.code_runner`, `DOCKER_EXECUTE_START` from `events.docker`, `MCP_INVOKE_START` from `events.mcp`, `SECURITY_EVALUATE_START` from `events.security`, `HR_HIRING_REQUEST_CREATED` from `events.hr`, `PERF_METRIC_RECORDED` from `events.performance`, `TRUST_EVALUATE_START` from `events.trust`, `PROMOTION_EVALUATE_START` from `events.promotion`, `PROMPT_BUILD_START` from `events.prompt`, `MEMORY_RETRIEVAL_START` from `events.memory`). Import directly: `from ai_company.observability.events. import EVENT_CONSTANT` +- **Event names**: always use constants from the domain-specific module under `ai_company.observability.events` (e.g. `PROVIDER_CALL_START` from `events.provider`, `BUDGET_RECORD_ADDED` from `events.budget`, `CFO_ANOMALY_DETECTED` from `events.cfo`, `CONFLICT_DETECTED` from `events.conflict`, `MEETING_STARTED` from `events.meeting`, `CLASSIFICATION_START` from `events.classification`, `CONSOLIDATION_START` from `events.consolidation`, `ORG_MEMORY_QUERY_START` from `events.org_memory`, `API_REQUEST_STARTED` from `events.api`, `CODE_RUNNER_EXECUTE_START` from `events.code_runner`, `DOCKER_EXECUTE_START` from `events.docker`, `MCP_INVOKE_START` from `events.mcp`, `SECURITY_EVALUATE_START` from `events.security`, `HR_HIRING_REQUEST_CREATED` from `events.hr`, `PERF_METRIC_RECORDED` from `events.performance`, `TRUST_EVALUATE_START` from `events.trust`, `PROMOTION_EVALUATE_START` from `events.promotion`, `PROMPT_BUILD_START` from `events.prompt`, `MEMORY_RETRIEVAL_START` from `events.memory`, `AUTONOMY_ACTION_AUTO_APPROVED` from `events.autonomy`, `TIMEOUT_POLICY_EVALUATED` from `events.timeout`). Import directly: `from ai_company.observability.events. import EVENT_CONSTANT` - **Structured kwargs**: always `logger.info(EVENT, key=value)` — never `logger.info("msg %s", val)` - **All error paths** must log at WARNING or ERROR with context before raising - **All state transitions** must log at INFO diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index f9f75fe3e3..06f89da0c7 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -80,7 +80,7 @@ The MVP validates the core hypothesis: **a single agent can complete a real task > **How to read this spec:** Sections describe the full vision. Each section with deferred features includes an **MVP** callout box indicating what ships in M3 and what is deferred. The full design is documented upfront to inform architecture decisions — protocol interfaces are designed even for features that won't be built until later milestones. > **Implementation snapshot (2026-03-10):** -> - **Done:** M0–M6 (tooling, config/core, providers, single-agent engine, multi-agent orchestration, API/CLI surface) + Docker sandbox (#50), MCP bridge (#53), code runner + HR engine (hiring/firing/onboarding/offboarding/registry) + performance tracking (task metrics, quality scoring, collaboration scoring, trend detection, rolling windows). Memory layer backend selected ([ADR-001](docs/decisions/ADR-001-memory-layer.md)). Persistence backend (§7.6) completed. Memory retrieval pipeline (#41: ranking, token-budget formatting, context injection, non-inferable filtering) complete. Budget enforcement complete (BudgetEnforcer + configurable cost tiers + quota/subscription tracking). CFO cost optimization complete (CostOptimizer: anomaly detection, efficiency analysis, downgrade recommendations, routing optimization, approval decisions; ReportGenerator: multi-dimensional spending reports). Shared org memory (#125: HybridPromptRetrievalBackend, OrgFactStore, access control, factory) complete. Memory consolidation/archival (#48: ConsolidationService, SimpleConsolidationStrategy, RetentionEnforcer, ArchivalStore protocol) complete. SecOps agent (rule engine, audit log, output scanner, risk classifier, ToolInvoker integration), progressive trust (4 strategies: disabled/weighted/per-category/milestone behind TrustStrategy protocol), promotion/demotion (criteria evaluation, approval strategies, model mapping). +> - **Done:** M0–M6 (tooling, config/core, providers, single-agent engine, multi-agent orchestration, API/CLI surface) + Docker sandbox (#50), MCP bridge (#53), code runner + HR engine (hiring/firing/onboarding/offboarding/registry) + performance tracking (task metrics, quality scoring, collaboration scoring, trend detection, rolling windows). Memory layer backend selected ([ADR-001](docs/decisions/ADR-001-memory-layer.md)). Persistence backend (§7.6) completed. Memory retrieval pipeline (#41: ranking, token-budget formatting, context injection, non-inferable filtering) complete. Budget enforcement complete (BudgetEnforcer + configurable cost tiers + quota/subscription tracking). CFO cost optimization complete (CostOptimizer: anomaly detection, efficiency analysis, downgrade recommendations, routing optimization, approval decisions; ReportGenerator: multi-dimensional spending reports). Shared org memory (#125: HybridPromptRetrievalBackend, OrgFactStore, access control, factory) complete. Memory consolidation/archival (#48: ConsolidationService, SimpleConsolidationStrategy, RetentionEnforcer, ArchivalStore protocol) complete. SecOps agent (rule engine, audit log, output scanner, risk classifier, ToolInvoker integration), progressive trust (4 strategies: disabled/weighted/per-category/milestone behind TrustStrategy protocol), promotion/demotion (criteria evaluation, approval strategies, model mapping). Autonomy levels (#42: AutonomyLevel enum, presets, 3-level resolver, rule-based auto-downgrade/human-only promotion change strategy) + approval timeout policies (#126: 4 timeout policies, park/resume service, risk tier classifier, timeout checker) complete. > - **Remaining:** JWT/OAuth auth, approval workflow gates. ### 1.5 Configuration Philosophy @@ -232,6 +232,7 @@ agent: reports_to: "engineering_lead" can_delegate_to: ["junior_developers"] budget_limit: 5.00 # max USD per task + autonomy_level: null # optional: full, semi, supervised, locked (overrides department/company default, §12.2) hiring_date: "2026-02-27" status: "active" # active, on_leave, terminated (on config model today) @@ -1558,7 +1559,7 @@ persistence: | `CostRecord` | `budget/cost_record.py` | `CostRecordRepository` | by agent, by task, aggregations | | `Message` | `communication/message.py` | `MessageRepository` | by channel | | Audit entries (planned — M7) | `security/` | `AuditRepository` (planned) | by agent, by action type, time range | -| `ParkedContext` (planned — M7) | `engine/` | `ParkedContextRepository` (planned) | by execution_id, by agent_id, by task_id | +| `ParkedContext` | `security/timeout/parked_context.py` | `ParkedContextRepository` | by execution_id, by agent_id, by task_id | | Agent runtime state (planned — M7) | `engine/` | `AgentStateRepository` (planned) | by agent_id, active agents | #### Migration Strategy @@ -2963,7 +2964,7 @@ ai-company/ │ ├── persistence/ # Operational data persistence (§7.6) │ │ ├── __init__.py # Package exports │ │ ├── protocol.py # PersistenceBackend protocol (M5) -│ │ ├── repositories.py # Repository protocols: TaskRepository, CostRecordRepository, MessageRepository (M5); AuditRepository planned (M7) +│ │ ├── repositories.py # Repository protocols: TaskRepository, CostRecordRepository, MessageRepository, ParkedContextRepository (M5); AuditRepository planned (M7) │ │ ├── config.py # PersistenceConfig model (M5) │ │ ├── errors.py # Persistence error hierarchy (M5) │ │ ├── factory.py # create_backend() factory (M5) @@ -2972,6 +2973,7 @@ ai-company/ │ │ ├── backend.py # SQLitePersistenceBackend │ │ ├── repositories.py # SQLite repository implementations │ │ ├── hr_repositories.py # SQLite HR repositories (LifecycleEvent, TaskMetricRecord, CollaborationMetricRecord) +│ │ ├── parked_context_repo.py # SQLiteParkedContextRepository (park/resume serialized agent state) │ │ └── migrations.py # Schema migrations (user_version pragma) │ ├── observability/ # Structured logging & correlation │ │ ├── __init__.py # get_logger() entry point @@ -2982,6 +2984,7 @@ ai-company/ │ │ ├── events/ # Per-domain event constants │ │ │ ├── __init__.py # Package marker with usage docs; no re-exports │ │ │ ├── api.py # API_* event constants +│ │ │ ├── autonomy.py # AUTONOMY_* constants │ │ │ ├── budget.py # BUDGET_* constants │ │ │ ├── cfo.py # CFO_* constants │ │ │ ├── classification.py # CLASSIFICATION_* constants @@ -3014,6 +3017,7 @@ ai-company/ │ │ │ ├── task_assignment.py # TASK_ASSIGNMENT_* constants │ │ │ ├── task_routing.py # TASK_ROUTING_* constants │ │ │ ├── template.py # TEMPLATE_* constants +│ │ │ ├── timeout.py # TIMEOUT_* constants │ │ │ ├── tool.py # TOOL_* constants │ │ │ ├── workspace.py # WORKSPACE_* constants │ │ │ ├── code_runner.py # CODE_RUNNER_* constants @@ -3100,6 +3104,23 @@ ai-company/ │ │ ├── output_scanner.py # Post-tool output scanning (regex-based redaction) │ │ ├── protocol.py # SecurityInterceptionStrategy protocol │ │ ├── service.py # SecOpsService — meta-agent coordinating security +│ │ ├── autonomy/ # Autonomy levels, presets, resolver, change strategy (§12.2) +│ │ │ ├── __init__.py # Package exports +│ │ │ ├── models.py # AutonomyLevel enum, AutonomyPreset, AutonomyConfig, AutonomyChangeEvent +│ │ │ ├── protocol.py # AutonomyChangeStrategy protocol +│ │ │ ├── change_strategy.py # Rule-based auto-downgrade + human-only promotion strategy +│ │ │ └── resolver.py # AutonomyResolver (agent → department → company chain) +│ │ ├── timeout/ # Approval timeout policies, park/resume, risk tier classifier (§12.4) +│ │ │ ├── __init__.py # Package exports +│ │ │ ├── config.py # TimeoutPolicyConfig +│ │ │ ├── factory.py # build_timeout_policy() factory +│ │ │ ├── models.py # TimeoutDecision, RiskTier +│ │ │ ├── park_service.py # ParkResumeService (park/resume blocked tasks) +│ │ │ ├── parked_context.py # ParkedContext model (serialized agent state) +│ │ │ ├── policies.py # WaitForeverPolicy, AutoDenyPolicy, TieredPolicy, EscalationChainPolicy +│ │ │ ├── protocol.py # TimeoutPolicy protocol +│ │ │ ├── risk_tier_classifier.py # RiskTierClassifier (ActionType → RiskTier) +│ │ │ └── timeout_checker.py # TimeoutChecker (polls pending approvals) │ │ └── rules/ # Rule engine and detectors │ │ ├── engine.py # RuleEngine (soft-allow + hard-deny, fail-closed) │ │ ├── protocol.py # SecurityRule protocol @@ -3147,7 +3168,7 @@ ai-company/ │ │ ├── bus_bridge.py # Message-bus → WebSocket bridge │ │ ├── channels.py # WebSocket channel definitions │ │ ├── config.py # API configuration models (ServerConfig, CorsConfig) -│ │ ├── controllers/ # 13 class-based controllers + 1 WebSocket handler (14 route modules) +│ │ ├── controllers/ # 14 class-based controllers + 1 WebSocket handler (15 route modules) │ │ ├── dto.py # Request/response DTOs and envelopes │ │ ├── errors.py # API error hierarchy (ApiError, NotFoundError, etc.) │ │ ├── exception_handlers.py # Litestar exception handler registration diff --git a/README.md b/README.md index 36a0979099..6ed91b5422 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ AI Company lets you spin up a virtual organization staffed entirely by AI agents - **Memory Backend Adapter (M5)** - Memory protocols, retrieval pipeline, org memory, and consolidation are complete; initial Mem0 adapter backend ([ADR-001](docs/decisions/ADR-001-memory-layer.md)) pending; research backends (GraphRAG, Temporal KG) planned - **CLI Surface** - `cli/` package is placeholder-only -- **Security/Approval System (M7)** - SecOps agent with rule engine (soft-allow/hard-deny, fail-closed), audit log, output scanner, risk classifier, and ToolInvoker integration are implemented; real authentication (JWT/OAuth) and approval workflow gates are planned +- **Security/Approval System (M7)** - SecOps agent with rule engine (soft-allow/hard-deny, fail-closed), audit log, output scanner, risk classifier, and ToolInvoker integration are implemented; progressive trust (4 strategies), promotion/demotion, autonomy levels (5 tiers with presets, resolver, change strategies) and approval timeout policies (wait-forever, auto-deny, tiered, escalation-chain with task park/resume) are implemented; real authentication (JWT/OAuth) and approval workflow gates are planned - **Advanced Product Surface** - web dashboard, external integrations ## Status diff --git a/src/ai_company/core/agent.py b/src/ai_company/core/agent.py index 2d521aadab..e475251ecb 100644 --- a/src/ai_company/core/agent.py +++ b/src/ai_company/core/agent.py @@ -326,3 +326,17 @@ class AgentIdentity(BaseModel): default=AgentStatus.ACTIVE, description="Current lifecycle status", ) + + @model_validator(mode="after") + def _validate_seniority_autonomy(self) -> Self: + """Reject JUNIOR agents with FULL autonomy (D6).""" + if ( + self.autonomy_level == AutonomyLevel.FULL + and self.level == SeniorityLevel.JUNIOR + ): + msg = ( + "JUNIOR agents cannot have FULL autonomy — " + "maximum is SEMI (DESIGN_SPEC D6)" + ) + raise ValueError(msg) + return self diff --git a/src/ai_company/core/company.py b/src/ai_company/core/company.py index 460973a268..07ef1011bb 100644 --- a/src/ai_company/core/company.py +++ b/src/ai_company/core/company.py @@ -378,7 +378,11 @@ def _coerce_autonomy_float(cls, data: object) -> object: return data raw = data.get("autonomy") if isinstance(raw, (int, float)) and not isinstance(raw, bool): - level = _float_to_autonomy_level(float(raw)) + value = float(raw) + if not (0.0 <= value <= 1.0): + msg = f"autonomy float must be 0.0-1.0, got {value}" + raise ValueError(msg) + level = _float_to_autonomy_level(value) return {**data, "autonomy": {"level": level.value}} return data diff --git a/src/ai_company/core/enums.py b/src/ai_company/core/enums.py index c1c915b41c..7065bb2bd0 100644 --- a/src/ai_company/core/enums.py +++ b/src/ai_company/core/enums.py @@ -469,6 +469,31 @@ class AutonomyLevel(StrEnum): LOCKED = "locked" +# Ordering: LOCKED (most restrictive) < SUPERVISED < SEMI < FULL (least restrictive). +_AUTONOMY_RANK: dict[AutonomyLevel, int] = { + AutonomyLevel.LOCKED: 0, + AutonomyLevel.SUPERVISED: 1, + AutonomyLevel.SEMI: 2, + AutonomyLevel.FULL: 3, +} + + +def compare_autonomy(a: AutonomyLevel, b: AutonomyLevel) -> int: + """Compare two autonomy levels. + + Returns negative if *a* is more restrictive than *b*, zero if equal, + positive if *a* is less restrictive than *b*. + + Args: + a: First autonomy level. + b: Second autonomy level. + + Returns: + Integer indicating relative autonomy. + """ + return _AUTONOMY_RANK[a] - _AUTONOMY_RANK[b] + + class DowngradeReason(StrEnum): """Reason an agent's autonomy was downgraded at runtime.""" diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index 22a07d2dd6..132b616fb4 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -70,6 +70,7 @@ from ai_company.security.rules.protocol import SecurityRule # noqa: TC001 from ai_company.security.rules.risk_classifier import RiskClassifier from ai_company.security.service import SecOpsService +from ai_company.security.timeout.risk_tier_classifier import DefaultRiskTierClassifier from ai_company.tools.invoker import ToolInvoker from ai_company.tools.permissions import ToolPermissionChecker @@ -684,14 +685,31 @@ def _make_security_interceptor( self, effective_autonomy: EffectiveAutonomy | None = None, ) -> SecurityInterceptionStrategy | None: - """Build the SecOps security interceptor if configured.""" + """Build the SecOps security interceptor if configured. + + Raises: + ExecutionStateError: If effective_autonomy is provided but + no SecurityConfig is configured — autonomy cannot be + enforced without the security subsystem. + """ if self._security_config is None: + if effective_autonomy is not None: + msg = ( + "effective_autonomy cannot be enforced without " + "SecurityConfig — configure security or remove autonomy" + ) + logger.error(SECURITY_DISABLED, note=msg) + raise ExecutionStateError(msg) logger.warning( SECURITY_DISABLED, note="No SecurityConfig provided — all security checks skipped", ) return None if not self._security_config.enabled: + if effective_autonomy is not None: + msg = "effective_autonomy cannot be enforced when security is disabled" + logger.error(SECURITY_DISABLED, note=msg) + raise ExecutionStateError(msg) return None cfg = self._security_config @@ -723,6 +741,7 @@ def _make_security_interceptor( output_scanner=OutputScanner(), approval_store=self._approval_store, effective_autonomy=effective_autonomy, + risk_classifier=DefaultRiskTierClassifier(), ) def _make_tool_invoker( diff --git a/src/ai_company/engine/prompt.py b/src/ai_company/engine/prompt.py index a4ed1fd3d6..f2f4cf572f 100644 --- a/src/ai_company/engine/prompt.py +++ b/src/ai_company/engine/prompt.py @@ -401,6 +401,7 @@ def _build_core_context( "level": effective_autonomy.level.value, "auto_approve_actions": sorted(effective_autonomy.auto_approve_actions), "human_approval_actions": sorted(effective_autonomy.human_approval_actions), + "security_agent": effective_autonomy.security_agent, } else: ctx["effective_autonomy"] = None diff --git a/src/ai_company/engine/prompt_template.py b/src/ai_company/engine/prompt_template.py index 1c5d0850ab..05c3af4d6d 100644 --- a/src/ai_company/engine/prompt_template.py +++ b/src/ai_company/engine/prompt_template.py @@ -17,7 +17,7 @@ from ai_company.core.enums import SeniorityLevel -PROMPT_TEMPLATE_VERSION: Final[str] = "1.3.0" +PROMPT_TEMPLATE_VERSION: Final[str] = "1.4.0" # ── Autonomy instructions by seniority level ───────────────────── diff --git a/src/ai_company/observability/events/timeout.py b/src/ai_company/observability/events/timeout.py index 4ac77f59a0..7828a5646a 100644 --- a/src/ai_company/observability/events/timeout.py +++ b/src/ai_company/observability/events/timeout.py @@ -10,3 +10,4 @@ TIMEOUT_CONTEXT_PARKED: Final[str] = "timeout.context.parked" TIMEOUT_CONTEXT_RESUMED: Final[str] = "timeout.context.resumed" TIMEOUT_UNKNOWN_ACTION_TYPE: Final[str] = "timeout.unknown_action_type" +TIMEOUT_FACTORY_UNKNOWN_CONFIG: Final[str] = "timeout.factory.unknown_config" diff --git a/src/ai_company/security/autonomy/__init__.py b/src/ai_company/security/autonomy/__init__.py index f8cf4f7322..2d777aa6e6 100644 --- a/src/ai_company/security/autonomy/__init__.py +++ b/src/ai_company/security/autonomy/__init__.py @@ -1,6 +1,12 @@ -"""Autonomy level management — presets, resolution, and runtime changes.""" +"""Autonomy level management — presets, resolution, and runtime changes. + +Note: ``AutonomyResolver`` and ``HumanOnlyPromotionStrategy`` are **not** +re-exported here to avoid a circular import chain +(``core.company`` → ``security.autonomy.models`` → this ``__init__`` → +``resolver`` → ``security.action_types`` → ``core.enums`` → ``core``). +Import them directly from their modules when needed. +""" -from ai_company.security.autonomy.change_strategy import HumanOnlyPromotionStrategy from ai_company.security.autonomy.models import ( BUILTIN_PRESETS, AutonomyConfig, @@ -9,7 +15,6 @@ EffectiveAutonomy, ) from ai_company.security.autonomy.protocol import AutonomyChangeStrategy -from ai_company.security.autonomy.resolver import AutonomyResolver __all__ = [ "BUILTIN_PRESETS", @@ -17,7 +22,5 @@ "AutonomyConfig", "AutonomyOverride", "AutonomyPreset", - "AutonomyResolver", "EffectiveAutonomy", - "HumanOnlyPromotionStrategy", ] diff --git a/src/ai_company/security/autonomy/change_strategy.py b/src/ai_company/security/autonomy/change_strategy.py index 91aadd90b2..2df53454e2 100644 --- a/src/ai_company/security/autonomy/change_strategy.py +++ b/src/ai_company/security/autonomy/change_strategy.py @@ -2,7 +2,7 @@ from datetime import UTC, datetime -from ai_company.core.enums import AutonomyLevel, DowngradeReason +from ai_company.core.enums import AutonomyLevel, DowngradeReason, compare_autonomy from ai_company.core.types import NotBlankStr # noqa: TC001 from ai_company.observability import get_logger from ai_company.observability.events.autonomy import ( @@ -33,10 +33,13 @@ class HumanOnlyPromotionStrategy: """Default strategy: promotions and recovery always require human approval. Downgrades are applied immediately based on the reason: - - ``HIGH_ERROR_RATE`` → SUPERVISED - - ``BUDGET_EXHAUSTED`` → SUPERVISED + - ``HIGH_ERROR_RATE`` → SUPERVISED (or current level if already more restrictive) + - ``BUDGET_EXHAUSTED`` → SUPERVISED (or current level if already more restrictive) - ``SECURITY_INCIDENT`` → LOCKED + Downgrades never *increase* autonomy: if the agent is already at + LOCKED, a HIGH_ERROR_RATE event keeps it at LOCKED (not SUPERVISED). + This strategy tracks active overrides in memory. In production, overrides should be persisted to the persistence backend. """ @@ -89,7 +92,7 @@ def auto_downgrade( Returns: The new autonomy level after downgrade. """ - new_level = _DOWNGRADE_MAP[reason] + target_level = _DOWNGRADE_MAP[reason] existing = self._overrides.get(agent_id) original = ( existing.original_level @@ -97,6 +100,15 @@ def auto_downgrade( else (current_level or AutonomyLevel.SEMI) ) + # Never increase autonomy — if the agent is already at or below + # the target level, keep the current (more restrictive) level. + effective_current = existing.current_level if existing else original + new_level = ( + effective_current + if compare_autonomy(effective_current, target_level) <= 0 + else target_level + ) + override = AutonomyOverride( agent_id=agent_id, original_level=original, diff --git a/src/ai_company/security/autonomy/models.py b/src/ai_company/security/autonomy/models.py index 40b6b2728a..e74b6ab45b 100644 --- a/src/ai_company/security/autonomy/models.py +++ b/src/ai_company/security/autonomy/models.py @@ -1,10 +1,11 @@ """Autonomy data models — presets, config, effective resolution, overrides.""" +from types import MappingProxyType from typing import Final, Self from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, model_validator -from ai_company.core.enums import AutonomyLevel, DowngradeReason +from ai_company.core.enums import AutonomyLevel, DowngradeReason, compare_autonomy from ai_company.core.types import NotBlankStr # noqa: TC001 @@ -59,59 +60,63 @@ def _validate_disjoint(self) -> Self: return self -BUILTIN_PRESETS: Final[dict[str, AutonomyPreset]] = { - AutonomyLevel.FULL: AutonomyPreset( - level=AutonomyLevel.FULL, - description="Fully autonomous — all actions auto-approved", - auto_approve=("all",), - human_approval=(), - security_agent=False, - ), - AutonomyLevel.SEMI: AutonomyPreset( - level=AutonomyLevel.SEMI, - description=( - "Semi-autonomous — code, test, docs auto-approved; " - "deploy, org, budget require human approval" +BUILTIN_PRESETS: Final[MappingProxyType[str, AutonomyPreset]] = MappingProxyType( + { + AutonomyLevel.FULL: AutonomyPreset( + level=AutonomyLevel.FULL, + description="Fully autonomous — all actions auto-approved", + auto_approve=("all",), + human_approval=(), + security_agent=False, ), - auto_approve=("code", "test", "docs", "vcs", "db:query"), - human_approval=("deploy", "org", "budget", "comms:external"), - security_agent=True, - ), - AutonomyLevel.SUPERVISED: AutonomyPreset( - level=AutonomyLevel.SUPERVISED, - description=( - "Supervised — read-only and test actions auto-approved; " - "all mutations require human approval" + # SEMI extends DESIGN_SPEC §12.2 with vcs and db:query auto-approve + # (safe read/commit operations) and broader human_approval categories. + AutonomyLevel.SEMI: AutonomyPreset( + level=AutonomyLevel.SEMI, + description=( + "Semi-autonomous — code, test, docs, vcs auto-approved; " + "deploy, org, budget require human approval" + ), + auto_approve=("code", "test", "docs", "vcs", "comms:internal", "db:query"), + human_approval=("deploy", "org", "budget", "comms:external"), + security_agent=True, ), - auto_approve=("code:read", "vcs:read", "test:run", "db:query"), - human_approval=( - "code:write", - "code:create", - "code:delete", - "code:refactor", - "test:write", - "docs:write", - "vcs:commit", - "vcs:push", - "vcs:branch", - "deploy", - "comms", - "budget", - "org", - "db:mutate", - "db:admin", - "arch:decide", + AutonomyLevel.SUPERVISED: AutonomyPreset( + level=AutonomyLevel.SUPERVISED, + description=( + "Supervised — read-only and test actions auto-approved; " + "all mutations require human approval" + ), + auto_approve=("code:read", "vcs:read", "test:run", "db:query"), + human_approval=( + "code:write", + "code:create", + "code:delete", + "code:refactor", + "test:write", + "docs:write", + "vcs:commit", + "vcs:push", + "vcs:branch", + "deploy", + "comms", + "budget", + "org", + "db:mutate", + "db:admin", + "arch:decide", + ), + security_agent=True, ), - security_agent=True, - ), - AutonomyLevel.LOCKED: AutonomyPreset( - level=AutonomyLevel.LOCKED, - description="Locked — all actions require human approval", - auto_approve=(), - human_approval=("all",), - security_agent=True, - ), -} + AutonomyLevel.LOCKED: AutonomyPreset( + level=AutonomyLevel.LOCKED, + description="Locked — all actions require human approval", + auto_approve=(), + human_approval=("all",), + security_agent=True, + ), + } +) class AutonomyConfig(BaseModel): @@ -209,3 +214,15 @@ class AutonomyOverride(BaseModel): default=True, description="Whether human approval is needed to restore level", ) + + @model_validator(mode="after") + def _validate_downgrade(self) -> Self: + """Ensure current_level is not higher than original_level.""" + if compare_autonomy(self.current_level, self.original_level) > 0: + msg = ( + f"current_level {self.current_level.value!r} is higher than " + f"original_level {self.original_level.value!r} — " + f"downgrades must not increase autonomy" + ) + raise ValueError(msg) + return self diff --git a/src/ai_company/security/service.py b/src/ai_company/security/service.py index 1a8051b337..a1ad3bb5b6 100644 --- a/src/ai_company/security/service.py +++ b/src/ai_company/security/service.py @@ -44,6 +44,7 @@ ) from ai_company.security.output_scanner import OutputScanner # noqa: TC001 from ai_company.security.rules.engine import RuleEngine # noqa: TC001 +from ai_company.security.timeout.protocol import RiskTierClassifier # noqa: TC001 if TYPE_CHECKING: from ai_company.api.approval_store import ApprovalStore @@ -83,6 +84,7 @@ def __init__( # noqa: PLR0913 output_scanner: OutputScanner, approval_store: ApprovalStore | None = None, effective_autonomy: EffectiveAutonomy | None = None, + risk_classifier: RiskTierClassifier | None = None, ) -> None: """Initialize the SecOps service. @@ -93,8 +95,11 @@ def __init__( # noqa: PLR0913 output_scanner: Post-tool output scanner. approval_store: Optional store for escalation items. effective_autonomy: Resolved autonomy for the current run. - When provided, actions are routed based on autonomy - level before the rule engine is consulted. + When provided, autonomy routing is applied *after* + the rule engine — never bypassing security detectors. + risk_classifier: Optional classifier for determining action + risk levels in autonomy escalations. Defaults to HIGH + when absent (fail-safe). """ self._config = config self._rule_engine = rule_engine @@ -102,6 +107,7 @@ def __init__( # noqa: PLR0913 self._output_scanner = output_scanner self._approval_store = approval_store self._effective_autonomy = effective_autonomy + self._risk_classifier = risk_classifier if config.custom_policies: logger.warning( @@ -145,12 +151,8 @@ async def evaluate_pre_tool( agent_id=context.agent_id, ) - # Autonomy pre-check: route based on effective autonomy before - # the full rule engine. Hard-deny is always checked first. - autonomy_result = await self._apply_autonomy_precheck(context) - if autonomy_result is not None: - return autonomy_result - + # Always run the rule engine first — security detectors must + # never be bypassed, regardless of autonomy configuration. try: verdict = self._rule_engine.evaluate(context) except MemoryError, RecursionError: @@ -169,6 +171,11 @@ async def evaluate_pre_tool( evaluation_duration_ms=0.0, ) + # Apply autonomy augmentation *after* the rule engine. + # Autonomy can only add stricter requirements (ALLOW → ESCALATE), + # never weaken a DENY or ESCALATE from security detectors. + verdict = self._apply_autonomy_augmentation(context, verdict) + # Handle escalation. if verdict.verdict == SecurityVerdictType.ESCALATE: verdict = await self._handle_escalation(context, verdict) @@ -240,48 +247,28 @@ async def scan_output( return result - async def _apply_autonomy_precheck( + def _apply_autonomy_augmentation( self, context: SecurityContext, - ) -> SecurityVerdict | None: - """Apply autonomy-based routing and finalize the verdict. + verdict: SecurityVerdict, + ) -> SecurityVerdict: + """Augment the rule engine verdict with autonomy routing. + + Autonomy can only *tighten* a verdict (ALLOW → ESCALATE), never + weaken one. DENY and ESCALATE from the rule engine are always + preserved — security detectors take precedence over autonomy. - Returns a complete verdict (with escalation/audit handled) if - autonomy routing applies, or ``None`` to fall through. + Returns the (possibly upgraded) verdict. """ if self._effective_autonomy is None: - return None - verdict = self._check_autonomy(context) - if verdict is None: - return None - if verdict.verdict == SecurityVerdictType.ESCALATE: - verdict = await self._handle_escalation(context, verdict) - if self._config.audit_enabled: - self._record_audit(context, verdict) - return verdict - - def _check_autonomy( - self, - context: SecurityContext, - ) -> SecurityVerdict | None: - """Check autonomy routing for an action type. + return verdict - Returns a verdict if the action is routed by autonomy config, - or ``None`` to fall through to the rule engine. + # Security DENY/ESCALATE always takes precedence. + if verdict.verdict != SecurityVerdictType.ALLOW: + return verdict - Hard-deny actions always fall through so the rule engine - produces its standard DENY verdict. - """ action = context.action_type - - # Hard-deny always bypasses autonomy — let the rule engine deny it. - if action in self._config.hard_deny_action_types: - return None - autonomy = self._effective_autonomy - assert autonomy is not None # noqa: S101 — guarded by caller - - now = datetime.now(UTC) if action in autonomy.auto_approve_actions: logger.info( @@ -290,34 +277,34 @@ def _check_autonomy( action_type=action, autonomy_level=autonomy.level.value, ) - return SecurityVerdict( - verdict=SecurityVerdictType.ALLOW, - reason=f"Auto-approved by autonomy level '{autonomy.level.value}'", - risk_level=ApprovalRiskLevel.LOW, - evaluated_at=now, - evaluation_duration_ms=0.0, - ) + return verdict if action in autonomy.human_approval_actions: + risk_level = ( + self._risk_classifier.classify(action) + if self._risk_classifier + else ApprovalRiskLevel.HIGH + ) logger.info( AUTONOMY_ACTION_HUMAN_REQUIRED, tool_name=context.tool_name, action_type=action, autonomy_level=autonomy.level.value, + risk_level=risk_level.value, ) - return SecurityVerdict( - verdict=SecurityVerdictType.ESCALATE, - reason=( - f"Human approval required by autonomy level " - f"'{autonomy.level.value}'" - ), - risk_level=ApprovalRiskLevel.MEDIUM, - evaluated_at=now, - evaluation_duration_ms=0.0, + return verdict.model_copy( + update={ + "verdict": SecurityVerdictType.ESCALATE, + "reason": ( + f"Human approval required by autonomy level " + f"'{autonomy.level.value}'" + ), + "risk_level": risk_level, + }, ) - # Action not classified by autonomy — fall through to rule engine. - return None + # Not classified by autonomy — keep rule engine's verdict. + return verdict def _record_audit( self, @@ -326,26 +313,27 @@ def _record_audit( ) -> None: """Record an audit entry for a pre-tool evaluation. - Audit recording failures are caught and logged — they must - never prevent the verdict from being returned. + Model construction errors propagate (they indicate programming + bugs). Storage errors are caught and logged — they must never + prevent the verdict from being returned. """ + entry = AuditEntry( + id=str(uuid.uuid4()), + timestamp=verdict.evaluated_at, + agent_id=context.agent_id, + task_id=context.task_id, + tool_name=context.tool_name, + tool_category=context.tool_category, + action_type=context.action_type, + arguments_hash=_hash_arguments(context.arguments), + verdict=verdict.verdict.value, + risk_level=verdict.risk_level, + reason=verdict.reason, + matched_rules=verdict.matched_rules, + evaluation_duration_ms=verdict.evaluation_duration_ms, + approval_id=verdict.approval_id, + ) try: - entry = AuditEntry( - id=str(uuid.uuid4()), - timestamp=verdict.evaluated_at, - agent_id=context.agent_id, - task_id=context.task_id, - tool_name=context.tool_name, - tool_category=context.tool_category, - action_type=context.action_type, - arguments_hash=_hash_arguments(context.arguments), - verdict=verdict.verdict.value, - risk_level=verdict.risk_level, - reason=verdict.reason, - matched_rules=verdict.matched_rules, - evaluation_duration_ms=verdict.evaluation_duration_ms, - approval_id=verdict.approval_id, - ) self._audit_log.record(entry) except MemoryError, RecursionError: raise diff --git a/src/ai_company/security/timeout/config.py b/src/ai_company/security/timeout/config.py index 12e55f8c90..d6ac74807b 100644 --- a/src/ai_company/security/timeout/config.py +++ b/src/ai_company/security/timeout/config.py @@ -1,10 +1,10 @@ """Timeout policy configuration models — discriminated union of 4 policies.""" -from typing import Annotated, Literal +from typing import Annotated, Literal, Self -from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag +from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, model_validator -from ai_company.core.enums import TimeoutActionType +from ai_company.core.enums import ApprovalRiskLevel, TimeoutActionType from ai_company.core.types import NotBlankStr # noqa: TC001 @@ -63,6 +63,17 @@ class TierConfig(BaseModel): description="Specific action types in this tier", ) + @model_validator(mode="after") + def _validate_no_escalate(self) -> Self: + """Reject ESCALATE — tier configs cannot provide a target.""" + if self.on_timeout == TimeoutActionType.ESCALATE: + msg = ( + "on_timeout cannot be ESCALATE (no escalation target " + "available — use the escalation chain policy instead)" + ) + raise ValueError(msg) + return self + class TieredTimeoutConfig(BaseModel): """Per-risk-tier timeout policy. @@ -85,6 +96,19 @@ class TieredTimeoutConfig(BaseModel): description="Tier configs keyed by risk level (low/medium/high/critical)", ) + @model_validator(mode="after") + def _validate_tier_keys(self) -> Self: + """Ensure tier keys are valid ApprovalRiskLevel values.""" + valid_keys = {level.value for level in ApprovalRiskLevel} + invalid = set(self.tiers) - valid_keys + if invalid: + msg = ( + f"Invalid tier keys: {sorted(invalid)} " + f"(must be one of {sorted(valid_keys)})" + ) + raise ValueError(msg) + return self + class EscalationStep(BaseModel): """A single step in an escalation chain. @@ -129,6 +153,20 @@ class EscalationChainConfig(BaseModel): description="Action when the entire chain is exhausted", ) + @model_validator(mode="after") + def _validate_chain(self) -> Self: + """Validate chain constraints.""" + if not self.chain: + msg = "escalation chain must have at least one step" + raise ValueError(msg) + if self.on_chain_exhausted == TimeoutActionType.ESCALATE: + msg = ( + "on_chain_exhausted cannot be ESCALATE " + "(no escalation target after chain is exhausted)" + ) + raise ValueError(msg) + return self + def _timeout_discriminator(value: object) -> str: """Extract the ``policy`` discriminator from raw or model data. diff --git a/src/ai_company/security/timeout/factory.py b/src/ai_company/security/timeout/factory.py index 54b056bbfc..b8395a7f27 100644 --- a/src/ai_company/security/timeout/factory.py +++ b/src/ai_company/security/timeout/factory.py @@ -1,6 +1,7 @@ """Factory for creating timeout policy instances from configuration.""" from ai_company.observability import get_logger +from ai_company.observability.events.timeout import TIMEOUT_FACTORY_UNKNOWN_CONFIG from ai_company.security.timeout.config import ( ApprovalTimeoutConfig, DenyOnTimeoutConfig, @@ -58,7 +59,7 @@ def create_timeout_policy( msg = f"Unknown timeout policy config type: {type(config).__name__}" # type: ignore[unreachable] logger.warning( - "timeout.factory.unknown_config", + TIMEOUT_FACTORY_UNKNOWN_CONFIG, config_type=type(config).__name__, ) raise TypeError(msg) diff --git a/src/ai_company/security/timeout/park_service.py b/src/ai_company/security/timeout/park_service.py index ed8b63b2e8..ecf66d6809 100644 --- a/src/ai_company/security/timeout/park_service.py +++ b/src/ai_company/security/timeout/park_service.py @@ -82,6 +82,20 @@ def park( metadata=copy.deepcopy(metadata) if metadata else {}, ) + # Validate that metadata IDs match serialized context IDs. + if parked.agent_id != agent_id: + msg = ( + f"ParkedContext agent_id {parked.agent_id!r} does not " + f"match provided agent_id {agent_id!r}" + ) + raise ValueError(msg) + if parked.task_id != task_id: + msg = ( + f"ParkedContext task_id {parked.task_id!r} does not " + f"match provided task_id {task_id!r}" + ) + raise ValueError(msg) + logger.info( TIMEOUT_CONTEXT_PARKED, parked_id=parked.id, diff --git a/src/ai_company/security/timeout/parked_context.py b/src/ai_company/security/timeout/parked_context.py index 10f2d643db..3e44a8beb7 100644 --- a/src/ai_company/security/timeout/parked_context.py +++ b/src/ai_company/security/timeout/parked_context.py @@ -5,9 +5,13 @@ so it can be resumed when the approval decision arrives. """ +import copy +import json +from types import MappingProxyType +from typing import Self from uuid import uuid4 -from pydantic import AwareDatetime, BaseModel, ConfigDict, Field +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, model_validator from ai_company.core.types import NotBlankStr # noqa: TC001 @@ -42,3 +46,19 @@ class ParkedContext(BaseModel): default_factory=dict, description="Additional metadata", ) + + @model_validator(mode="after") + def _validate_and_protect(self) -> Self: + """Validate context_json and deep-copy metadata.""" + try: + json.loads(self.context_json) + except (json.JSONDecodeError, TypeError) as exc: + msg = f"context_json must be valid JSON: {exc}" + raise ValueError(msg) from exc + object.__setattr__(self, "metadata", copy.deepcopy(self.metadata)) + return self + + @property + def metadata_view(self) -> MappingProxyType[str, str]: + """Read-only view of metadata.""" + return MappingProxyType(self.metadata) diff --git a/src/ai_company/security/timeout/policies.py b/src/ai_company/security/timeout/policies.py index 34a2826e78..2dda36dc67 100644 --- a/src/ai_company/security/timeout/policies.py +++ b/src/ai_company/security/timeout/policies.py @@ -145,8 +145,18 @@ async def determine_action( Returns: WAIT, DENY, APPROVE, or ESCALATE based on tier config. """ + # Default: classify by risk level, then check explicit tier overrides. risk_level = self._classifier.classify(item.action_type) - tier_config = self._tiers.get(risk_level.value) + tier_config = None + for tier_key, cfg in self._tiers.items(): + if cfg.actions and item.action_type in cfg.actions: + tier_config = cfg + risk_level = ApprovalRiskLevel(tier_key) + break + + # Fall back to risk-level-based tier lookup. + if tier_config is None: + tier_config = self._tiers.get(risk_level.value) if tier_config is None: # No tier configured for this risk level — wait (safe default). @@ -269,11 +279,29 @@ async def determine_action( ) cumulative_seconds = 0.0 - for step in self._chain: + for idx, step in enumerate(self._chain): step_timeout = step.timeout_minutes * _SECONDS_PER_MINUTE - if elapsed_seconds < cumulative_seconds + step_timeout: - logger.debug( - TIMEOUT_WAITING, + step_end = cumulative_seconds + step_timeout + if elapsed_seconds < step_end: + if idx == 0: + # First step hasn't timed out yet — WAIT. + logger.debug( + TIMEOUT_WAITING, + approval_id=item.id, + escalation_role=step.role, + elapsed_seconds=elapsed_seconds, + ) + return TimeoutAction( + action=TimeoutActionType.WAIT, + reason=( + f"Waiting at {step.role!r} — " + f"{elapsed_seconds:.0f}s of " + f"{step_end:.0f}s elapsed" + ), + ) + # Previous step timed out — escalate to this step's role. + logger.info( + TIMEOUT_ESCALATED, approval_id=item.id, escalation_role=step.role, elapsed_seconds=elapsed_seconds, diff --git a/src/ai_company/security/timeout/risk_tier_classifier.py b/src/ai_company/security/timeout/risk_tier_classifier.py index 427166b8e8..78dad163be 100644 --- a/src/ai_company/security/timeout/risk_tier_classifier.py +++ b/src/ai_company/security/timeout/risk_tier_classifier.py @@ -44,6 +44,20 @@ } ) +# Validate exhaustiveness at module load time — log a warning for any +# ActionType members missing from the default map. +_missing_action_types = {m.value for m in ActionType} - set(_DEFAULT_RISK_MAP) +if _missing_action_types: + logger.warning( + TIMEOUT_UNKNOWN_ACTION_TYPE, + missing_types=sorted(_missing_action_types), + note=( + "ActionType members missing from _DEFAULT_RISK_MAP — " + "they will default to HIGH at classify() time" + ), + ) +del _missing_action_types + class DefaultRiskTierClassifier: """Maps action types to risk tiers for tiered timeout policies. diff --git a/src/ai_company/security/timeout/timeout_checker.py b/src/ai_company/security/timeout/timeout_checker.py index 43bddbe117..e0d60c7000 100644 --- a/src/ai_company/security/timeout/timeout_checker.py +++ b/src/ai_company/security/timeout/timeout_checker.py @@ -17,11 +17,13 @@ TIMEOUT_POLICY_EVALUATED, TIMEOUT_WAITING, ) -from ai_company.security.timeout.models import TimeoutAction # noqa: TC001 +from ai_company.security.timeout.models import TimeoutAction from ai_company.security.timeout.protocol import TimeoutPolicy # noqa: TC001 logger = get_logger(__name__) +_TIMEOUT_POLICY_DECIDER: str = "timeout_policy" + class TimeoutChecker: """Evaluates pending approvals against the configured timeout policy. @@ -44,11 +46,35 @@ async def check( Returns: The ``TimeoutAction`` determined by the policy. + + Raises: + ValueError: If the item is not in PENDING status. """ + if item.status != ApprovalStatus.PENDING: + msg = ( + f"Cannot check timeout for non-PENDING item " + f"{item.id!r} (status: {item.status.value!r})" + ) + raise ValueError(msg) + now = datetime.now(UTC) elapsed = (now - item.created_at).total_seconds() - action = await self._policy.determine_action(item, elapsed) + try: + action = await self._policy.determine_action(item, elapsed) + except MemoryError, RecursionError: + raise + except Exception: + logger.exception( + TIMEOUT_POLICY_EVALUATED, + approval_id=item.id, + elapsed_seconds=elapsed, + note="policy.determine_action failed — defaulting to WAIT", + ) + action = TimeoutAction( + action=TimeoutActionType.WAIT, + reason="Policy evaluation error — defaulting to WAIT", + ) event = { TimeoutActionType.WAIT: TIMEOUT_WAITING, @@ -90,7 +116,7 @@ async def check_and_resolve( update={ "status": ApprovalStatus.APPROVED, "decided_at": datetime.now(UTC), - "decided_by": "timeout_policy", + "decided_by": _TIMEOUT_POLICY_DECIDER, "decision_reason": action.reason, }, ) @@ -101,7 +127,7 @@ async def check_and_resolve( update={ "status": ApprovalStatus.REJECTED, "decided_at": datetime.now(UTC), - "decided_by": "timeout_policy", + "decided_by": _TIMEOUT_POLICY_DECIDER, "decision_reason": action.reason, }, ) diff --git a/src/ai_company/templates/renderer.py b/src/ai_company/templates/renderer.py index 410da80f91..baafbf7664 100644 --- a/src/ai_company/templates/renderer.py +++ b/src/ai_company/templates/renderer.py @@ -524,8 +524,9 @@ def _extract_numeric_config( raw_autonomy = company.get("autonomy", template.autonomy) try: if isinstance(raw_autonomy, dict): - # Already an AutonomyConfig-like dict — pass through. - autonomy: float | dict[str, Any] = raw_autonomy + # Already an AutonomyConfig-like dict — deep-copy to prevent + # mutation of the original rendered data. + autonomy: float | dict[str, Any] = dict(raw_autonomy) else: autonomy = to_float(raw_autonomy, field_name="autonomy") budget_monthly = to_float( diff --git a/tests/unit/core/conftest.py b/tests/unit/core/conftest.py index 9b6c09de70..fbe9d5f019 100644 --- a/tests/unit/core/conftest.py +++ b/tests/unit/core/conftest.py @@ -44,6 +44,7 @@ from ai_company.core.role import Authority, CustomRole, Role, SeniorityInfo, Skill from ai_company.core.task import AcceptanceCriterion, Task from ai_company.security.autonomy.models import AutonomyConfig +from ai_company.security.timeout.config import WaitForeverConfig # ── Factories ────────────────────────────────────────────────────── @@ -132,6 +133,7 @@ class DepartmentFactory(ModelFactory[Department]): class CompanyConfigFactory(ModelFactory[CompanyConfig]): __model__ = CompanyConfig autonomy = AutonomyConfig() + approval_timeout = WaitForeverConfig() class HRRegistryFactory(ModelFactory[HRRegistry]): diff --git a/tests/unit/engine/test_prompt.py b/tests/unit/engine/test_prompt.py index c9f45a4974..e3cbae7b2c 100644 --- a/tests/unit/engine/test_prompt.py +++ b/tests/unit/engine/test_prompt.py @@ -540,7 +540,7 @@ class TestPromptVersioning: @pytest.mark.unit def test_template_version_is_1_3_0(self) -> None: """PROMPT_TEMPLATE_VERSION is '1.3.0' (D22 tools removal).""" - assert PROMPT_TEMPLATE_VERSION == "1.3.0" + assert PROMPT_TEMPLATE_VERSION == "1.4.0" @pytest.mark.unit def test_template_version_in_result( diff --git a/tests/unit/memory/org/test_prompt_integration.py b/tests/unit/memory/org/test_prompt_integration.py index 6f86226fde..19518bda3c 100644 --- a/tests/unit/memory/org/test_prompt_integration.py +++ b/tests/unit/memory/org/test_prompt_integration.py @@ -49,7 +49,7 @@ def test_policies_rendered_in_prompt(self) -> None: assert "org_policies" in result.sections def test_template_version_updated(self) -> None: - assert PROMPT_TEMPLATE_VERSION == "1.3.0" + assert PROMPT_TEMPLATE_VERSION == "1.4.0" def test_policies_trimmed_under_budget(self) -> None: agent = _make_agent() diff --git a/tests/unit/observability/test_events.py b/tests/unit/observability/test_events.py index 419ba3fe8c..9e51eb994d 100644 --- a/tests/unit/observability/test_events.py +++ b/tests/unit/observability/test_events.py @@ -215,8 +215,6 @@ def test_all_domain_modules_discovered(self) -> None: "timeout", "tool", "workspace", - "hr", - "performance", "trust", "promotion", } diff --git a/tests/unit/persistence/test_protocol.py b/tests/unit/persistence/test_protocol.py index 665ec35c79..7d76a9b189 100644 --- a/tests/unit/persistence/test_protocol.py +++ b/tests/unit/persistence/test_protocol.py @@ -14,6 +14,7 @@ from ai_company.persistence.repositories import ( CostRecordRepository, MessageRepository, + ParkedContextRepository, TaskRepository, ) @@ -217,3 +218,11 @@ def test_fake_collab_metric_repo_is_collaboration_metric_repository( _FakeCollaborationMetricRepository(), CollaborationMetricRepository, ) + + def test_fake_parked_context_repo_is_parked_context_repository( + self, + ) -> None: + assert isinstance( + _FakeParkedContextRepository(), + ParkedContextRepository, + ) diff --git a/tests/unit/security/autonomy/test_change_strategy.py b/tests/unit/security/autonomy/test_change_strategy.py index ab41967b04..07f07f6077 100644 --- a/tests/unit/security/autonomy/test_change_strategy.py +++ b/tests/unit/security/autonomy/test_change_strategy.py @@ -68,6 +68,25 @@ def test_double_downgrade_preserves_original(self) -> None: # Second downgrade replaces the first assert override.current_level == AutonomyLevel.LOCKED assert override.reason == DowngradeReason.SECURITY_INCIDENT + # Original level is preserved from the FIRST downgrade + assert override.original_level == AutonomyLevel.SEMI + + @pytest.mark.unit + def test_downgrade_never_increases_autonomy(self) -> None: + """LOCKED agent + HIGH_ERROR_RATE should stay LOCKED, not go to SUPERVISED.""" + strategy = HumanOnlyPromotionStrategy() + strategy.auto_downgrade( + "agent-1", + DowngradeReason.SECURITY_INCIDENT, + current_level=AutonomyLevel.SEMI, + ) + # Now agent is LOCKED. HIGH_ERROR_RATE targets SUPERVISED — but + # that's higher than LOCKED, so agent should stay LOCKED. + result = strategy.auto_downgrade("agent-1", DowngradeReason.HIGH_ERROR_RATE) + assert result == AutonomyLevel.LOCKED + override = strategy.get_override("agent-1") + assert override is not None + assert override.current_level == AutonomyLevel.LOCKED class TestRecovery: diff --git a/tests/unit/security/autonomy/test_models.py b/tests/unit/security/autonomy/test_models.py index 6442e5ea52..38322c9d56 100644 --- a/tests/unit/security/autonomy/test_models.py +++ b/tests/unit/security/autonomy/test_models.py @@ -162,6 +162,20 @@ def test_disjoint_overlap_raises(self) -> None: ) +class TestBuiltinPresetsImmutability: + """BUILTIN_PRESETS should be a read-only mapping.""" + + @pytest.mark.unit + def test_cannot_assign_new_key(self) -> None: + with pytest.raises(TypeError): + BUILTIN_PRESETS["new"] = BUILTIN_PRESETS[AutonomyLevel.FULL] # type: ignore[index] + + @pytest.mark.unit + def test_cannot_delete_key(self) -> None: + with pytest.raises(TypeError): + del BUILTIN_PRESETS[AutonomyLevel.FULL] # type: ignore[attr-defined] + + class TestAutonomyOverride: """AutonomyOverride model tests.""" @@ -190,3 +204,16 @@ def test_override_frozen(self) -> None: ) with pytest.raises(ValidationError): override.current_level = AutonomyLevel.FULL # type: ignore[misc] + + @pytest.mark.unit + def test_current_above_original_rejected(self) -> None: + """Downgrade validator rejects current_level > original_level.""" + now = datetime.now(UTC) + with pytest.raises(ValidationError, match="higher than"): + AutonomyOverride( + agent_id="agent-1", + original_level=AutonomyLevel.SUPERVISED, + current_level=AutonomyLevel.FULL, + reason=DowngradeReason.HIGH_ERROR_RATE, + downgraded_at=now, + ) diff --git a/tests/unit/security/test_service.py b/tests/unit/security/test_service.py index 44518549f2..9867ec973e 100644 --- a/tests/unit/security/test_service.py +++ b/tests/unit/security/test_service.py @@ -486,8 +486,8 @@ def _make_service_with_autonomy( service._test_audit_log = audit_log # type: ignore[attr-defined] return service - async def test_auto_approve_returns_allow(self) -> None: - """When action is in auto_approve_actions, returns ALLOW without rule engine.""" + async def test_auto_approve_keeps_allow(self) -> None: + """When rule engine ALLOWs and action is auto-approved, stays ALLOW.""" autonomy = EffectiveAutonomy( level=AutonomyLevel.SEMI, auto_approve_actions=frozenset({"code:read"}), @@ -500,11 +500,35 @@ async def test_auto_approve_returns_allow(self) -> None: verdict = await service.evaluate_pre_tool(ctx) assert verdict.verdict == SecurityVerdictType.ALLOW - assert "auto-approved" in verdict.reason.lower() - service._test_rule_engine.evaluate.assert_not_called() # type: ignore[attr-defined] + # Rule engine always runs first — even for auto-approved actions. + service._test_rule_engine.evaluate.assert_called_once() # type: ignore[attr-defined] - async def test_human_approval_returns_escalate_as_deny(self) -> None: - """Human approval with no store converts ESCALATE to DENY.""" + async def test_human_approval_escalates_with_store(self) -> None: + """Human-approval action with store → ESCALATE after rule engine ALLOW.""" + autonomy = EffectiveAutonomy( + level=AutonomyLevel.SEMI, + auto_approve_actions=frozenset({"code:read"}), + human_approval_actions=frozenset({"infra:deploy"}), + security_agent=False, + ) + store = AsyncMock() + store.add = AsyncMock() + service = self._make_service_with_autonomy( + effective_autonomy=autonomy, + approval_store=store, + ) + ctx = _make_context(action_type="infra:deploy") + + verdict = await service.evaluate_pre_tool(ctx) + + assert verdict.verdict == SecurityVerdictType.ESCALATE + assert verdict.approval_id is not None + store.add.assert_called_once() + # Rule engine always runs first. + service._test_rule_engine.evaluate.assert_called_once() # type: ignore[attr-defined] + + async def test_human_approval_without_store_becomes_deny(self) -> None: + """Human-approval action without store → DENY.""" autonomy = EffectiveAutonomy( level=AutonomyLevel.SEMI, auto_approve_actions=frozenset({"code:read"}), @@ -521,10 +545,11 @@ async def test_human_approval_returns_escalate_as_deny(self) -> None: assert verdict.verdict == SecurityVerdictType.DENY assert "escalation unavailable" in verdict.reason.lower() - service._test_rule_engine.evaluate.assert_not_called() # type: ignore[attr-defined] + # Rule engine always runs first — even when escalation fails. + service._test_rule_engine.evaluate.assert_called_once() # type: ignore[attr-defined] - async def test_hard_deny_falls_through_to_rule_engine(self) -> None: - """When action is in hard_deny_action_types, autonomy is skipped.""" + async def test_rule_engine_deny_overrides_auto_approve(self) -> None: + """Rule engine DENY takes precedence over autonomy auto-approve.""" autonomy = EffectiveAutonomy( level=AutonomyLevel.SEMI, auto_approve_actions=frozenset({"deploy:production"}), @@ -536,16 +561,16 @@ async def test_hard_deny_falls_through_to_rule_engine(self) -> None: effective_autonomy=autonomy, engine_verdict=deny_verdict, ) - # deploy:production is in the default SecurityConfig.hard_deny_action_types ctx = _make_context(action_type="deploy:production") verdict = await service.evaluate_pre_tool(ctx) + # Security detectors take precedence over autonomy. assert verdict.verdict == SecurityVerdictType.DENY service._test_rule_engine.evaluate.assert_called_once() # type: ignore[attr-defined] async def test_unknown_action_falls_through(self) -> None: - """When action is not in any autonomy set, falls through to rule engine.""" + """When action is not in any autonomy set, rule engine verdict used.""" autonomy = EffectiveAutonomy( level=AutonomyLevel.SEMI, auto_approve_actions=frozenset({"code:read"}), diff --git a/tests/unit/security/timeout/test_config.py b/tests/unit/security/timeout/test_config.py index c7b830085b..34bc518ed5 100644 --- a/tests/unit/security/timeout/test_config.py +++ b/tests/unit/security/timeout/test_config.py @@ -94,10 +94,9 @@ class TestEscalationChainConfig: """EscalationChainConfig tests.""" @pytest.mark.unit - def test_empty_chain(self) -> None: - config = EscalationChainConfig() - assert config.chain == () - assert config.on_chain_exhausted == TimeoutActionType.DENY + def test_empty_chain_rejected(self) -> None: + with pytest.raises(ValidationError, match="at least one step"): + EscalationChainConfig() @pytest.mark.unit def test_chain_steps(self) -> None: @@ -126,6 +125,46 @@ def test_discriminator(self) -> None: assert len(result.chain) == 1 +class TestTierConfigValidation: + """TierConfig validator tests.""" + + @pytest.mark.unit + def test_escalate_on_timeout_rejected(self) -> None: + with pytest.raises(ValidationError, match="ESCALATE"): + TierConfig( + timeout_minutes=60, + on_timeout=TimeoutActionType.ESCALATE, + ) + + +class TestTieredTimeoutConfigValidation: + """TieredTimeoutConfig validator tests.""" + + @pytest.mark.unit + def test_invalid_tier_key_rejected(self) -> None: + tier = TierConfig(timeout_minutes=60, on_timeout=TimeoutActionType.DENY) + with pytest.raises(ValidationError, match="Invalid tier keys"): + TieredTimeoutConfig(tiers={"invalid_key": tier}) + + @pytest.mark.unit + def test_valid_tier_keys_accepted(self) -> None: + tier = TierConfig(timeout_minutes=60, on_timeout=TimeoutActionType.DENY) + config = TieredTimeoutConfig(tiers={"low": tier, "high": tier}) + assert len(config.tiers) == 2 + + +class TestEscalationChainConfigValidation: + """EscalationChainConfig validator tests.""" + + @pytest.mark.unit + def test_escalate_on_chain_exhausted_rejected(self) -> None: + with pytest.raises(ValidationError, match="ESCALATE"): + EscalationChainConfig( + chain=(EscalationStep(role="lead", timeout_minutes=30),), + on_chain_exhausted=TimeoutActionType.ESCALATE, + ) + + class TestTimeoutAction: """TimeoutAction escalate_to validator tests.""" diff --git a/tests/unit/security/timeout/test_parked_context.py b/tests/unit/security/timeout/test_parked_context.py index 2df5917057..c1aa4b020d 100644 --- a/tests/unit/security/timeout/test_parked_context.py +++ b/tests/unit/security/timeout/test_parked_context.py @@ -81,3 +81,22 @@ def test_blank_agent_id_rejected(self) -> None: with pytest.raises(ValidationError): _make_parked_context(agent_id=" ") + + def test_metadata_deep_copied(self) -> None: + """Metadata dict is deep-copied — mutations don't affect the model.""" + original = {"key": "value"} + parked = _make_parked_context(metadata=original) + original["key"] = "mutated" + assert parked.metadata["key"] == "value" + + def test_metadata_view_read_only(self) -> None: + """metadata_view returns a read-only MappingProxyType.""" + parked = _make_parked_context(metadata={"key": "value"}) + view = parked.metadata_view + with pytest.raises(TypeError): + view["new_key"] = "fail" # type: ignore[index] + + def test_invalid_context_json_rejected(self) -> None: + """context_json must be valid JSON.""" + with pytest.raises(ValidationError, match="valid JSON"): + _make_parked_context(context_json="not-valid-json{{") diff --git a/tests/unit/security/timeout/test_policies.py b/tests/unit/security/timeout/test_policies.py index 926674eb4b..2fd637671c 100644 --- a/tests/unit/security/timeout/test_policies.py +++ b/tests/unit/security/timeout/test_policies.py @@ -140,12 +140,64 @@ async def test_no_tier_config_waits(self) -> None: result = await policy.determine_action(item, 999999.0) assert result.action == TimeoutActionType.WAIT + @pytest.mark.unit + async def test_high_risk_auto_approve_blocked(self) -> None: + """HIGH risk tier with on_timeout=APPROVE should be overridden to DENY.""" + tiers = { + "high": TierConfig( + timeout_minutes=60, on_timeout=TimeoutActionType.APPROVE + ), + } + policy = TieredTimeoutPolicy( + tiers=tiers, + classifier=DefaultRiskTierClassifier(), + ) + item = _make_item(action_type="deploy:staging") # HIGH risk + result = await policy.determine_action(item, 3601.0) + assert result.action == TimeoutActionType.DENY + + @pytest.mark.unit + async def test_critical_risk_auto_approve_blocked(self) -> None: + """CRITICAL risk tier with on_timeout=APPROVE should be overridden to DENY.""" + tiers = { + "critical": TierConfig( + timeout_minutes=60, on_timeout=TimeoutActionType.APPROVE + ), + } + policy = TieredTimeoutPolicy( + tiers=tiers, + classifier=DefaultRiskTierClassifier(), + ) + item = _make_item(action_type="deploy:production") # CRITICAL risk + result = await policy.determine_action(item, 3601.0) + assert result.action == TimeoutActionType.DENY + + @pytest.mark.unit + async def test_action_type_based_tier_lookup(self) -> None: + """TierConfig.actions tuple overrides risk-level-based lookup.""" + tiers = { + "low": TierConfig( + timeout_minutes=10, + on_timeout=TimeoutActionType.APPROVE, + actions=("deploy:staging",), # normally HIGH + ), + } + policy = TieredTimeoutPolicy( + tiers=tiers, + classifier=DefaultRiskTierClassifier(), + ) + item = _make_item(action_type="deploy:staging") + result = await policy.determine_action(item, 601.0) # > 10 min + # Despite deploy:staging being HIGH risk, the actions tuple + # places it in the LOW tier, so APPROVE is allowed. + assert result.action == TimeoutActionType.APPROVE + class TestEscalationChainPolicy: """EscalationChainPolicy: chain of escalation steps.""" @pytest.mark.unit - async def test_first_step_escalation(self) -> None: + async def test_first_step_waits(self) -> None: chain = ( EscalationStep(role="lead", timeout_minutes=30), EscalationStep(role="director", timeout_minutes=60), @@ -155,9 +207,9 @@ async def test_first_step_escalation(self) -> None: on_chain_exhausted=TimeoutActionType.DENY, ) item = _make_item() - result = await policy.determine_action(item, 600.0) # 10 min - assert result.action == TimeoutActionType.ESCALATE - assert result.escalate_to == "lead" + # 10 min — still within first step, should WAIT (not ESCALATE) + result = await policy.determine_action(item, 600.0) + assert result.action == TimeoutActionType.WAIT @pytest.mark.unit async def test_second_step_escalation(self) -> None: @@ -203,6 +255,7 @@ async def test_chain_exhausted_approve(self) -> None: @pytest.mark.unit async def test_empty_chain_exhausted_immediately(self) -> None: + # Bypass config validation to test policy behavior directly. policy = EscalationChainPolicy( chain=(), on_chain_exhausted=TimeoutActionType.DENY, diff --git a/tests/unit/security/timeout/test_timeout_checker.py b/tests/unit/security/timeout/test_timeout_checker.py index 8cf811892a..c83ae224ec 100644 --- a/tests/unit/security/timeout/test_timeout_checker.py +++ b/tests/unit/security/timeout/test_timeout_checker.py @@ -130,3 +130,26 @@ async def test_check_and_resolve_escalate(self) -> None: assert action.action == TimeoutActionType.ESCALATE assert updated_item.status == ApprovalStatus.PENDING + + async def test_non_pending_item_raises(self) -> None: + """Checking a non-PENDING item raises ValueError.""" + mock_policy = _make_mock_policy() + checker = TimeoutChecker(policy=mock_policy) + item = _make_approval_item( + status=ApprovalStatus.APPROVED, + decided_at=datetime.now(UTC), + decided_by="human-1", + ) + + with pytest.raises(ValueError, match="non-PENDING"): + await checker.check(item) + + async def test_policy_error_defaults_to_wait(self) -> None: + """When policy.determine_action raises, checker defaults to WAIT.""" + mock_policy = AsyncMock() + mock_policy.determine_action.side_effect = RuntimeError("boom") + checker = TimeoutChecker(policy=mock_policy) + item = _make_approval_item() + + result = await checker.check(item) + assert result.action == TimeoutActionType.WAIT