diff --git a/CLAUDE.md b/CLAUDE.md index 565c572f4a..2751138ff5 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -44,7 +44,7 @@ uv run pre-commit run --all-files # all pre-commit hooks ```text src/ai_company/ api/ # FastAPI REST + WebSocket routes - budget/ # Per-agent cost tracking and spending controls + budget/ # Cost tracking, budget enforcement (pre-flight/in-flight checks, auto-downgrade), billing periods cli/ # Typer CLI commands communication/ # Message bus, dispatcher, messenger, channels, delegation, loop prevention, conflict resolution, meeting protocol config/ # YAML company config loading and validation diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index 389c5a2179..f2f1010302 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -971,21 +971,22 @@ hybrid: Pipeline steps: 1. **Validate inputs** — agent must be `ACTIVE`, task must be `ASSIGNED` or `IN_PROGRESS`. Raises `ExecutionStateError` on violation. -2. **Build system prompt** — calls `build_system_prompt()` with agent identity, task, and available tool definitions. -3. **Create context** — `AgentContext.from_identity()` with the configured `max_turns`. -4. **Seed conversation** — injects system prompt, optional memory messages, and formatted task instruction as initial messages. -5. **Transition task** — `ASSIGNED` → `IN_PROGRESS` (pass-through if already `IN_PROGRESS`). -6. **Prepare tools and budget** — creates `ToolInvoker` from registry and `BudgetChecker` from task budget limit. -7. **Delegate to loop** — calls `ExecutionLoop.execute()` with context, provider, tool invoker, budget checker, and completion config. If `timeout_seconds` is set, wraps the call in `asyncio.wait_for`; on expiry the run returns with `TerminationReason.ERROR` but cost recording and post-execution processing still occur. -8. **Record costs** — records accumulated `TokenUsage` to `CostTracker` (if available). Cost recording failures are logged but do not affect the result. -9. **Apply post-execution transitions** — on `COMPLETED` termination: IN_PROGRESS → IN_REVIEW → COMPLETED (two-hop auto-complete in M3; reviewers deferred to M4+). On `SHUTDOWN` termination: current status → INTERRUPTED (see §6.7). On `ERROR` termination: recovery strategy is applied (default `FailAndReassignStrategy` transitions to FAILED; see §6.6). All other termination reasons (`MAX_TURNS`, `BUDGET_EXHAUSTED`) leave the task in its current state. Transition failures are logged but do not discard the successful execution result. -10. **Return result** — wraps `ExecutionResult` in `AgentRunResult` with engine-level metadata. +2. **Pre-flight budget enforcement** — if `BudgetEnforcer` is provided, check monthly hard stop and daily limit via `check_can_execute()`, then apply auto-downgrade via `resolve_model()`. Raises `BudgetExhaustedError` or `DailyLimitExceededError` on violation. +3. **Build system prompt** — calls `build_system_prompt()` with agent identity, task, and available tool definitions. +4. **Create context** — `AgentContext.from_identity()` with the configured `max_turns`. +5. **Seed conversation** — injects system prompt, optional memory messages, and formatted task instruction as initial messages. +6. **Transition task** — `ASSIGNED` → `IN_PROGRESS` (pass-through if already `IN_PROGRESS`). +7. **Prepare tools and budget** — creates `ToolInvoker` from registry and `BudgetChecker` from `BudgetEnforcer` (task + monthly + daily limits with pre-computed baselines and alert deduplication) or from task budget limit alone when no enforcer is configured. +8. **Delegate to loop** — calls `ExecutionLoop.execute()` with context, provider, tool invoker, budget checker, and completion config. If `timeout_seconds` is set, wraps the call in `asyncio.wait_for`; on expiry the run returns with `TerminationReason.ERROR` but cost recording and post-execution processing still occur. +9. **Record costs** — records accumulated `TokenUsage` to `CostTracker` (if available). Cost recording failures are logged but do not affect the result. +10. **Apply post-execution transitions** — on `COMPLETED` termination: IN_PROGRESS → IN_REVIEW → COMPLETED (two-hop auto-complete in M3; reviewers deferred to M4+). On `SHUTDOWN` termination: current status → INTERRUPTED (see §6.7). On `ERROR` termination: recovery strategy is applied (default `FailAndReassignStrategy` transitions to FAILED; see §6.6). All other termination reasons (`MAX_TURNS`, `BUDGET_EXHAUSTED`) leave the task in its current state. Transition failures are logged but do not discard the successful execution result. +11. **Return result** — wraps `ExecutionResult` in `AgentRunResult` with engine-level metadata. -Error handling: `MemoryError` and `RecursionError` propagate unconditionally. All other exceptions are caught and wrapped in an `AgentRunResult` with `TerminationReason.ERROR`. +Error handling: `MemoryError` and `RecursionError` propagate unconditionally. `BudgetExhaustedError` (including `DailyLimitExceededError`) returns `TerminationReason.BUDGET_EXHAUSTED` without recovery — budget exhaustion is a controlled stop, not a crash. All other exceptions are caught and wrapped in an `AgentRunResult` with `TerminationReason.ERROR`. -Constructor accepts: `provider` (required), `execution_loop` (defaults to `ReactLoop`), `tool_registry`, `cost_tracker`. The `run()` method also accepts `memory_messages` — optional working memory to inject between the system prompt and task instruction (memory retrieval is M5; the engine provides the injection hook). +Constructor accepts: `provider` (required), `execution_loop` (defaults to `ReactLoop`), `tool_registry`, `cost_tracker`, `recovery_strategy` (defaults to `FailAndReassignStrategy`), `shutdown_checker`, `budget_enforcer`. The `run()` method also accepts `memory_messages` — optional working memory to inject between the system prompt and task instruction (memory retrieval is M5; the engine provides the injection hook). -Logs structured events under the `execution.engine.*` namespace (12 constants in `events/execution.py`): creation, start, prompt built, completion, errors, invalid input, task transitions, cost recording outcomes, task metrics, and timeout. +Logs structured events under the `execution.engine.*` namespace (13 constants in `events/execution.py`): creation, start, prompt built, completion, errors, budget stopped, invalid input, task transitions, cost recording outcomes, task metrics, and timeout. **`AgentRunResult`** — frozen Pydantic model wrapping `ExecutionResult` with engine metadata: @@ -1778,7 +1779,7 @@ Every API call is tracked (illustrative schema): ### 10.3 CFO Agent Responsibilities -> **MVP: Not in M3.** Budget tracking and per-task cost recording exist (M2), but the CFO agent is M5+. Cost controls (§10.4) are enforced by the engine, not by an agent. +> **MVP: Not in M3.** Budget tracking and per-task cost recording exist (M2); cost controls (§10.4) are now enforced by `BudgetEnforcer` (a service the engine composes, not an agent — M5). The CFO agent is M5+. The CFO agent (when enabled) acts as a cost management system: @@ -1804,6 +1805,7 @@ The CFO agent (when enabled) acts as a cost management system: ```yaml budget: total_monthly: 100.00 + reset_day: 1 alerts: warn_at: 75 # percent critical_at: 90 @@ -1822,6 +1824,15 @@ budget: > **Auto-downgrade boundary:** Model downgrades apply only at **task assignment time**, never mid-execution. An agent halfway through an architecture review cannot be switched to a cheaper model — the task completes on its assigned model. The next task assignment respects the downgrade threshold. This prevents quality degradation from mid-thought model switches. +> **Implementation note (M5):** `BudgetEnforcer` composes `CostTracker` + +> `BudgetConfig` to provide three enforcement layers: (1) pre-flight checks +> via `check_can_execute` (monthly hard stop + per-agent daily limit), (2) +> in-flight budget checking via a sync `BudgetChecker` closure with +> pre-computed baselines (task + monthly + daily limits, alert deduplication), +> and (3) task-boundary auto-downgrade via `resolve_model`. Billing periods +> are scoped by `billing_period_start(reset_day)`. `DailyLimitExceededError` +> is a subclass of `BudgetExhaustedError` for granular error handling. + ### 10.5 LLM Call Analytics > **Current state:** Proxy metrics (M3), call categorization + coordination metric data models (M4 models, brought forward), and error taxonomy classification pipeline (M5) are implemented. Runtime collection pipeline for coordination metrics and full analytics layer are M5+. @@ -2637,6 +2648,7 @@ ai-company/ │ │ ├── recovery.py # Crash recovery strategies (RecoveryStrategy protocol) │ │ ├── cost_recording.py # Per-turn cost recording helpers │ │ ├── run_result.py # AgentRunResult outcome model +│ │ ├── _validation.py # Input validation helpers for AgentEngine │ │ ├── agent_engine.py # Agent execution engine │ │ ├── parallel.py # Parallel agent executor (TaskGroup + Semaphore) │ │ ├── parallel_models.py # AgentAssignment, ParallelExecutionGroup, AgentOutcome, ParallelExecutionResult, ParallelProgress @@ -2864,7 +2876,8 @@ ai-company/ │ │ ├── spending_summary.py # _SpendingTotals base + spending summary models │ │ ├── hierarchy.py # BudgetHierarchy, BudgetConfig │ │ ├── enums.py # Budget-related enums -│ │ ├── limits.py # Budget enforcement (M5) +│ │ ├── billing.py # Billing period computation utilities +│ │ ├── enforcer.py # BudgetEnforcer service (pre-flight, in-flight, auto-downgrade) │ │ ├── optimizer.py # Cost optimization / CFO logic (M5) │ │ └── reports.py # Spending reports (M5) │ ├── api/ # REST + WebSocket API (M6, stubs only) diff --git a/README.md b/README.md index 75ef23eacd..c1ef236027 100644 --- a/README.md +++ b/README.md @@ -23,11 +23,11 @@ AI Company lets you spin up a virtual organization staffed entirely by AI agents - **Persistence Layer (M5)** - Pluggable `PersistenceBackend` protocol with SQLite backend (aiosqlite), repository protocols, schema migrations - **Memory Interface (M5)** - Pluggable `MemoryBackend` protocol with capability discovery, shared knowledge protocol, domain models, config, and factory - **Coordination Error Taxonomy (M5)** - Post-execution classification pipeline detecting logical contradictions, numerical drift, context omissions, and coordination failures +- **Budget Enforcement (M5)** - `BudgetEnforcer` service with pre-flight checks, in-flight budget checking, and auto-downgrade; CFO agent and advanced reporting pending ### Not implemented yet (planned milestones) - **Memory Backends (M5)** - Mem0 adapter ([ADR-001](docs/decisions/ADR-001-memory-layer.md), #41) pending; shared knowledge store backends planned -- **Budget Controls (M5)** - Per-agent spending limits, budget hierarchy enforcement - **API Layer (M6)** - `api/` package and route modules are placeholders - **CLI Surface (M6)** - `cli/` package is placeholder-only - **Security/Approval System (M7)** - `security/` package is placeholder-only diff --git a/src/ai_company/budget/__init__.py b/src/ai_company/budget/__init__.py index b5d049cab9..f077c7e697 100644 --- a/src/ai_company/budget/__init__.py +++ b/src/ai_company/budget/__init__.py @@ -5,6 +5,7 @@ DESIGN_SPEC Section 10. """ +from ai_company.budget.billing import billing_period_start, daily_period_start from ai_company.budget.call_category import LLMCallCategory, OrchestrationAlertLevel from ai_company.budget.category_analytics import CategoryBreakdown, OrchestrationRatio from ai_company.budget.config import ( @@ -28,6 +29,7 @@ RedundancyRate, ) from ai_company.budget.cost_record import CostRecord +from ai_company.budget.enforcer import BudgetEnforcer from ai_company.budget.enums import BudgetAlertLevel from ai_company.budget.hierarchy import ( BudgetHierarchy, @@ -48,6 +50,7 @@ "BudgetAlertConfig", "BudgetAlertLevel", "BudgetConfig", + "BudgetEnforcer", "BudgetHierarchy", "CategoryBreakdown", "CoordinationEfficiency", @@ -71,4 +74,6 @@ "RedundancyRate", "SpendingSummary", "TeamBudget", + "billing_period_start", + "daily_period_start", ] diff --git a/src/ai_company/budget/billing.py b/src/ai_company/budget/billing.py new file mode 100644 index 0000000000..77c586033c --- /dev/null +++ b/src/ai_company/budget/billing.py @@ -0,0 +1,59 @@ +"""Billing period computation utilities. + +Pure functions for determining billing period boundaries based on a +configurable reset day. Used by :class:`~ai_company.budget.enforcer.BudgetEnforcer` +to scope cost queries to the current billing cycle. +""" + +from datetime import UTC, datetime + + +def billing_period_start( + reset_day: int, + *, + now: datetime | None = None, +) -> datetime: + """Compute the UTC-aware start of the current billing period. + + If ``now.day >= reset_day``, returns current month's ``reset_day`` + at 00:00 UTC. Otherwise, returns previous month's ``reset_day`` + at 00:00 UTC. + + Args: + reset_day: Day of month when the billing period resets (1-28). + now: Reference timestamp. Defaults to ``datetime.now(UTC)``. + + Returns: + UTC-aware datetime at midnight on the billing period start day. + + Raises: + ValueError: If ``reset_day`` is not in ``[1, 28]``. + """ + if not 1 <= reset_day <= 28: # noqa: PLR2004 + msg = f"reset_day must be 1-28, got {reset_day}" + raise ValueError(msg) + + if now is None: + now = datetime.now(UTC) + + if now.day >= reset_day: + return datetime(now.year, now.month, reset_day, tzinfo=UTC) + + # Roll back to previous month + if now.month == 1: + return datetime(now.year - 1, 12, reset_day, tzinfo=UTC) + return datetime(now.year, now.month - 1, reset_day, tzinfo=UTC) + + +def daily_period_start(*, now: datetime | None = None) -> datetime: + """Compute the UTC-aware start of today (midnight UTC). + + Args: + now: Reference timestamp. Defaults to ``datetime.now(UTC)``. + + Returns: + UTC-aware datetime at midnight of the current day. + """ + if now is None: + now = datetime.now(UTC) + return datetime(now.year, now.month, now.day, tzinfo=UTC) diff --git a/src/ai_company/budget/config.py b/src/ai_company/budget/config.py index 68327c5c57..b345770a10 100644 --- a/src/ai_company/budget/config.py +++ b/src/ai_company/budget/config.py @@ -5,7 +5,7 @@ """ from collections import Counter -from typing import Any, Self +from typing import Any, Literal, Self from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -74,6 +74,8 @@ class AutoDowngradeConfig(BaseModel): enabled: Whether auto-downgrade is active. threshold: Budget percent that triggers downgrade. downgrade_map: Ordered pairs of (from_alias, to_alias). + boundary: When to apply downgrade (task_assignment only, + never mid-execution per DESIGN_SPEC §10.4). """ model_config = ConfigDict(frozen=True) @@ -93,6 +95,12 @@ class AutoDowngradeConfig(BaseModel): default=(), description="Ordered pairs of (from_alias, to_alias)", ) + boundary: Literal["task_assignment"] = Field( + default="task_assignment", + description=( + "When to apply downgrade (task_assignment only, never mid-execution)" + ), + ) @model_validator(mode="before") @classmethod @@ -152,9 +160,11 @@ class BudgetConfig(BaseModel): per_task_limit: Maximum USD per task. per_agent_daily_limit: Maximum USD per agent per day. auto_downgrade: Automatic model downgrade configuration. + reset_day: Day of month when budget resets (1-28, avoids + month-length issues). """ - model_config = ConfigDict(frozen=True) + model_config = ConfigDict(frozen=True, allow_inf_nan=False) total_monthly: float = Field( default=100.0, @@ -179,6 +189,15 @@ class BudgetConfig(BaseModel): default_factory=AutoDowngradeConfig, description="Automatic model downgrade configuration", ) + reset_day: int = Field( + default=1, + ge=1, + le=28, + strict=True, + description=( + "Day of month when budget resets (1-28, avoids month-length issues)" + ), + ) @model_validator(mode="after") def _validate_per_task_limit_within_monthly(self) -> Self: diff --git a/src/ai_company/budget/enforcer.py b/src/ai_company/budget/enforcer.py new file mode 100644 index 0000000000..8bcdb53355 --- /dev/null +++ b/src/ai_company/budget/enforcer.py @@ -0,0 +1,665 @@ +"""Budget enforcement service. + +Composes :class:`~ai_company.budget.tracker.CostTracker` and +:class:`~ai_company.budget.config.BudgetConfig` to provide pre-flight +checks, in-flight budget checking, and task-boundary auto-downgrade +as described in DESIGN_SPEC Section 10.4. +""" + +from typing import TYPE_CHECKING, NamedTuple + +from ai_company.budget.billing import billing_period_start, daily_period_start +from ai_company.budget.enums import BudgetAlertLevel +from ai_company.constants import BUDGET_ROUNDING_PRECISION +from ai_company.engine.errors import BudgetExhaustedError, DailyLimitExceededError +from ai_company.observability import get_logger +from ai_company.observability.events.budget import ( + BUDGET_ALERT_THRESHOLD_CROSSED, + BUDGET_BASELINE_ERROR, + BUDGET_DAILY_LIMIT_EXCEEDED, + BUDGET_DAILY_LIMIT_HIT, + BUDGET_DOWNGRADE_APPLIED, + BUDGET_DOWNGRADE_SKIPPED, + BUDGET_ENFORCEMENT_CHECK, + BUDGET_HARD_STOP_EXCEEDED, + BUDGET_HARD_STOP_TRIGGERED, + BUDGET_PREFLIGHT_ERROR, + BUDGET_RESOLVE_MODEL_ERROR, + BUDGET_TASK_LIMIT_HIT, +) + +if TYPE_CHECKING: + from ai_company.budget.config import BudgetConfig + from ai_company.budget.tracker import CostTracker + from ai_company.core.agent import AgentIdentity, ModelConfig + from ai_company.core.task import Task + from ai_company.engine.context import AgentContext + from ai_company.engine.loop_protocol import BudgetChecker + from ai_company.providers.routing.models import ResolvedModel + from ai_company.providers.routing.resolver import ModelResolver + +logger = get_logger(__name__) + + +class BudgetEnforcer: + """Budget enforcement service composing CostTracker + BudgetConfig. + + Provides pre-flight checks (can this agent start?), in-flight budget + checking (monthly + daily + task limits with alert emission), and + task-boundary auto-downgrade. Concurrency-safe via CostTracker's + asyncio.Lock. + + Note: Pre-flight checks are best-effort under concurrency (TOCTOU). + The in-flight checker is the true safety net, though it also uses + pre-computed baselines that are snapshot-in-time and will not + reflect concurrent spend by other agents. + + Args: + budget_config: Budget configuration for limits and thresholds. + cost_tracker: Cost tracking service for querying spend. + model_resolver: Optional model resolver for auto-downgrade + alias lookup. + """ + + def __init__( + self, + *, + budget_config: BudgetConfig, + cost_tracker: CostTracker, + model_resolver: ModelResolver | None = None, + ) -> None: + self._budget_config = budget_config + self._cost_tracker = cost_tracker + self._model_resolver = model_resolver + + @property + def cost_tracker(self) -> CostTracker: + """The underlying cost tracker.""" + return self._cost_tracker + + async def check_can_execute(self, agent_id: str) -> None: + """Pre-flight: verify monthly + daily limits allow execution. + + Raises: + BudgetExhaustedError: Monthly hard stop exceeded. + DailyLimitExceededError: Agent daily limit exceeded. + """ + cfg = self._budget_config + + # Skip if enforcement disabled (total_monthly <= 0) + if cfg.total_monthly <= 0: + logger.debug( + BUDGET_ENFORCEMENT_CHECK, + agent_id=agent_id, + result="pass", + reason="enforcement_disabled", + ) + return + + try: + await self._check_monthly_hard_stop(cfg, agent_id) + await self._check_daily_limit(cfg, agent_id) + except BudgetExhaustedError, DailyLimitExceededError: + raise + except MemoryError, RecursionError: + raise + except Exception: + logger.exception( + BUDGET_PREFLIGHT_ERROR, + agent_id=agent_id, + reason="falling_back_to_allow_execution", + ) + return + + logger.debug( + BUDGET_ENFORCEMENT_CHECK, + agent_id=agent_id, + result="pass", + ) + + async def _check_monthly_hard_stop( + self, + cfg: BudgetConfig, + agent_id: str, + ) -> None: + """Check monthly hard stop and raise if exceeded.""" + period_start = billing_period_start(cfg.reset_day) + monthly_cost = await self._cost_tracker.get_total_cost( + start=period_start, + ) + hard_stop_limit = round( + cfg.total_monthly * cfg.alerts.hard_stop_at / 100, + BUDGET_ROUNDING_PRECISION, + ) + + if monthly_cost >= hard_stop_limit: + logger.warning( + BUDGET_HARD_STOP_EXCEEDED, + agent_id=agent_id, + total_cost=monthly_cost, + monthly_budget=cfg.total_monthly, + hard_stop_limit=hard_stop_limit, + ) + msg = ( + f"Monthly budget exhausted: ${monthly_cost:.2f} >= " + f"${hard_stop_limit:.2f} " + f"({cfg.alerts.hard_stop_at}% of " + f"${cfg.total_monthly:.2f})" + ) + raise BudgetExhaustedError(msg) + + async def _check_daily_limit( + self, + cfg: BudgetConfig, + agent_id: str, + ) -> None: + """Check per-agent daily limit and raise if exceeded.""" + if cfg.per_agent_daily_limit <= 0: + return + + day_start = daily_period_start() + daily_cost = await self._cost_tracker.get_agent_cost( + agent_id, + start=day_start, + ) + if daily_cost >= cfg.per_agent_daily_limit: + logger.warning( + BUDGET_DAILY_LIMIT_EXCEEDED, + agent_id=agent_id, + daily_cost=daily_cost, + daily_limit=cfg.per_agent_daily_limit, + ) + msg = ( + f"Agent {agent_id!r} daily limit exceeded: " + f"${daily_cost:.2f} >= " + f"${cfg.per_agent_daily_limit:.2f}" + ) + raise DailyLimitExceededError(msg) + + async def resolve_model( + self, + identity: AgentIdentity, + ) -> AgentIdentity: + """Apply auto-downgrade at task boundary if threshold exceeded. + + Returns identity unchanged when: + - ``auto_downgrade.enabled`` is ``False`` + - ``total_monthly`` is zero or negative (enforcement disabled) + - no ``model_resolver`` provided + - budget usage below threshold + - ``model_id`` not found in resolver + - model alias not in ``downgrade_map`` + - target alias not resolvable + - CostTracker query fails (graceful degradation) + + Returns new ``AgentIdentity`` with downgraded ``ModelConfig`` + otherwise. + """ + cfg = self._budget_config + downgrade = cfg.auto_downgrade + + if ( + not downgrade.enabled + or cfg.total_monthly <= 0 + or self._model_resolver is None + ): + return identity + + try: + period_start = billing_period_start(cfg.reset_day) + monthly_cost = await self._cost_tracker.get_total_cost( + start=period_start, + ) + except MemoryError, RecursionError: + raise + except Exception: + logger.exception( + BUDGET_RESOLVE_MODEL_ERROR, + agent_id=str(identity.id), + reason="cost_tracker_query_failed", + ) + return identity + + used_pct = round( + monthly_cost / cfg.total_monthly * 100, + BUDGET_ROUNDING_PRECISION, + ) + + if used_pct < downgrade.threshold: + return identity + + return _apply_downgrade( + identity, + self._model_resolver, + downgrade.downgrade_map, + used_pct, + downgrade.threshold, + ) + + async def make_budget_checker( + self, + task: Task, + agent_id: str, + ) -> BudgetChecker | None: + """Create a sync BudgetChecker with pre-computed baselines. + + Queries CostTracker once for monthly and daily baselines, then + returns a sync closure that checks: + + 1. Task budget limit + (``ctx.accumulated_cost.cost_usd >= task.budget_limit``) + 2. Monthly total + (baseline + running cost >= hard_stop threshold) + 3. Agent daily + (baseline + running cost >= per_agent_daily_limit) + + Baselines are snapshot-in-time: they will not reflect + concurrent spend by other agents during this task's execution. + The pre-flight check (``check_can_execute``) is the + authoritative gate. + + Alert deduplication: the closure tracks the last emitted alert + level and only logs upward transitions + (NORMAL -> WARNING -> CRITICAL -> HARD_STOP). + + Returns ``None`` when all limits are disabled (monthly, + task, and daily all off). + """ + cfg = self._budget_config + task_limit = task.budget_limit + monthly_budget = cfg.total_monthly + daily_limit = cfg.per_agent_daily_limit + + # All enforcement disabled — monthly, task, and daily all off. + # Note: total_monthly=0 disables monthly/daily checks but task + # limits are independent (set on the Task, not the budget). + if monthly_budget <= 0 and task_limit <= 0 and daily_limit <= 0: + return None + + monthly_baseline, daily_baseline = await self._compute_baselines_safe( + cfg, + monthly_budget, + daily_limit, + agent_id, + ) + + thresholds = _compute_thresholds(cfg, monthly_budget) + + return _build_checker_closure( + task_limit=task_limit, + monthly_budget=monthly_budget, + daily_limit=daily_limit, + monthly_baseline=monthly_baseline, + daily_baseline=daily_baseline, + thresholds=thresholds, + agent_id=agent_id, + ) + + # ── Private helpers ────────────────────────────────────────── + + async def _compute_baselines_safe( + self, + cfg: BudgetConfig, + monthly_budget: float, + daily_limit: float, + agent_id: str, + ) -> tuple[float, float]: + """Compute baselines, falling back to zero baselines on error. + + When CostTracker queries fail, returns ``(0.0, 0.0)`` so the + caller can still build a checker that enforces task-level + limits. Monthly/daily enforcement uses zero baselines, meaning + only the running task cost is tracked (no historical context). + """ + try: + return await self._compute_baselines( + cfg, + monthly_budget, + daily_limit, + agent_id, + ) + except MemoryError, RecursionError: + raise + except Exception: + logger.exception( + BUDGET_BASELINE_ERROR, + agent_id=agent_id, + reason="falling_back_to_zero_baselines", + ) + return 0.0, 0.0 + + async def _compute_baselines( + self, + cfg: BudgetConfig, + monthly_budget: float, + daily_limit: float, + agent_id: str, + ) -> tuple[float, float]: + """Compute monthly and daily cost baselines.""" + monthly_baseline = 0.0 + daily_baseline = 0.0 + + if monthly_budget > 0: + period_start = billing_period_start(cfg.reset_day) + monthly_baseline = await self._cost_tracker.get_total_cost( + start=period_start, + ) + + if daily_limit > 0: + day_start = daily_period_start() + daily_baseline = await self._cost_tracker.get_agent_cost( + agent_id, + start=day_start, + ) + + return monthly_baseline, daily_baseline + + +# ── Module-level pure helpers ──────────────────────────────────── + + +def _apply_downgrade( + identity: AgentIdentity, + resolver: ModelResolver, + downgrade_map: tuple[tuple[str, str], ...], + used_pct: float, + threshold: int, +) -> AgentIdentity: + """Attempt model downgrade, returning identity unchanged on skip.""" + current_model_id = identity.model.model_id + agent_id_str = str(identity.id) + + resolved = resolver.resolve_safe(current_model_id) + if resolved is None: + logger.debug( + BUDGET_DOWNGRADE_SKIPPED, + agent_id=agent_id_str, + model_id=current_model_id, + reason="model_not_in_resolver", + ) + return identity + + source_alias = resolved.alias + if source_alias is None: + logger.debug( + BUDGET_DOWNGRADE_SKIPPED, + agent_id=agent_id_str, + model_id=current_model_id, + reason="no_alias", + ) + return identity + + target_alias = _find_downgrade_target(source_alias, downgrade_map) + if target_alias is None: + logger.debug( + BUDGET_DOWNGRADE_SKIPPED, + agent_id=agent_id_str, + model_id=current_model_id, + source_alias=source_alias, + reason="no_mapping", + ) + return identity + + target_resolved = resolver.resolve_safe(target_alias) + if target_resolved is None: + logger.warning( + BUDGET_DOWNGRADE_SKIPPED, + agent_id=agent_id_str, + source_alias=source_alias, + target_alias=target_alias, + reason="target_not_resolvable", + ) + return identity + + new_model = _build_downgraded_model_config( + identity.model, + target_resolved, + ) + + logger.info( + BUDGET_DOWNGRADE_APPLIED, + agent_id=agent_id_str, + from_model=current_model_id, + from_alias=source_alias, + to_model=target_resolved.model_id, + to_alias=target_alias, + used_pct=used_pct, + threshold=threshold, + ) + + return identity.model_copy(update={"model": new_model}) + + +def _find_downgrade_target( + source_alias: str, + downgrade_map: tuple[tuple[str, str], ...], +) -> str | None: + """Find the target alias for a source in the downgrade map.""" + for src, tgt in downgrade_map: + if src == source_alias: + return tgt + return None + + +def _build_downgraded_model_config( + current: ModelConfig, + target: ResolvedModel, +) -> ModelConfig: + """Build a new ModelConfig with the downgraded model and provider.""" + return current.model_copy( + update={ + "provider": target.provider_name, + "model_id": target.model_id, + }, + ) + + +_ALERT_LEVEL_ORDER: dict[BudgetAlertLevel, int] = { + BudgetAlertLevel.NORMAL: 0, + BudgetAlertLevel.WARNING: 1, + BudgetAlertLevel.CRITICAL: 2, + BudgetAlertLevel.HARD_STOP: 3, +} + +assert set(_ALERT_LEVEL_ORDER) == set(BudgetAlertLevel), ( # noqa: S101 + f"_ALERT_LEVEL_ORDER keys {set(_ALERT_LEVEL_ORDER)} do not match " + f"BudgetAlertLevel members {set(BudgetAlertLevel)}" +) + + +def _emit_alert( + level: BudgetAlertLevel, + last_alert: list[BudgetAlertLevel], + agent_id: str, + total_cost: float, + monthly_budget: float, +) -> None: + """Log an alert if the level is higher than the last emitted. + + ``last_alert`` is a single-element list used as a mutable cell + to track state across closure invocations. + """ + if _ALERT_LEVEL_ORDER[level] <= _ALERT_LEVEL_ORDER[last_alert[0]]: + return + + last_alert[0] = level + + if level in (BudgetAlertLevel.WARNING, BudgetAlertLevel.CRITICAL): + logger.warning( + BUDGET_ALERT_THRESHOLD_CROSSED, + agent_id=agent_id, + alert_level=level.value, + total_cost=total_cost, + monthly_budget=monthly_budget, + ) + elif level == BudgetAlertLevel.HARD_STOP: + logger.error( + BUDGET_HARD_STOP_TRIGGERED, + agent_id=agent_id, + total_cost=total_cost, + monthly_budget=monthly_budget, + ) + + +class _AlertThresholds(NamedTuple): + """Pre-computed alert thresholds in ascending order.""" + + warn: float + critical: float + hard_stop: float + + +def _compute_thresholds( + cfg: BudgetConfig, + monthly_budget: float, +) -> _AlertThresholds: + """Pre-compute warn, critical, and hard_stop limits.""" + if monthly_budget <= 0: + return _AlertThresholds(0.0, 0.0, 0.0) + return _AlertThresholds( + warn=round( + monthly_budget * cfg.alerts.warn_at / 100, + BUDGET_ROUNDING_PRECISION, + ), + critical=round( + monthly_budget * cfg.alerts.critical_at / 100, + BUDGET_ROUNDING_PRECISION, + ), + hard_stop=round( + monthly_budget * cfg.alerts.hard_stop_at / 100, + BUDGET_ROUNDING_PRECISION, + ), + ) + + +def _build_checker_closure( # noqa: PLR0913 + *, + task_limit: float, + monthly_budget: float, + daily_limit: float, + monthly_baseline: float, + daily_baseline: float, + thresholds: _AlertThresholds, + agent_id: str, +) -> BudgetChecker: + """Build the sync budget checker closure. + + Args: + task_limit: Per-task cost limit (0 = disabled). + monthly_budget: Total monthly budget (0 = disabled). + daily_limit: Per-agent daily limit (0 = disabled). + monthly_baseline: Pre-computed monthly spend at task start. + daily_baseline: Pre-computed daily spend at task start. + thresholds: Pre-computed alert thresholds. + agent_id: Agent identifier for logging. + + Returns: + Sync callable returning ``True`` when budget is exhausted. + """ + last_alert: list[BudgetAlertLevel] = [BudgetAlertLevel.NORMAL] + + def _check(ctx: AgentContext) -> bool: + running_cost = ctx.accumulated_cost.cost_usd + return ( + _check_task_limit(running_cost, task_limit, agent_id) + or _check_monthly_limit( + running_cost, + monthly_budget, + monthly_baseline, + thresholds, + last_alert, + agent_id, + ) + or _check_daily_limit( + running_cost, + daily_limit, + daily_baseline, + agent_id, + ) + ) + + return _check + + +def _check_task_limit( + running_cost: float, + task_limit: float, + agent_id: str, +) -> bool: + """Return True if task budget limit is exhausted.""" + if task_limit > 0 and running_cost >= task_limit: + logger.warning( + BUDGET_TASK_LIMIT_HIT, + agent_id=agent_id, + running_cost=running_cost, + task_limit=task_limit, + ) + return True + return False + + +def _check_monthly_limit( # noqa: PLR0913 + running_cost: float, + monthly_budget: float, + monthly_baseline: float, + thresholds: _AlertThresholds, + last_alert: list[BudgetAlertLevel], + agent_id: str, +) -> bool: + """Return True if monthly hard stop is hit; emit alerts.""" + if monthly_budget <= 0: + return False + total_monthly = round( + monthly_baseline + running_cost, + BUDGET_ROUNDING_PRECISION, + ) + if total_monthly >= thresholds.hard_stop: + _emit_alert( + BudgetAlertLevel.HARD_STOP, + last_alert, + agent_id, + total_monthly, + monthly_budget, + ) + return True + if total_monthly >= thresholds.critical: + _emit_alert( + BudgetAlertLevel.CRITICAL, + last_alert, + agent_id, + total_monthly, + monthly_budget, + ) + elif total_monthly >= thresholds.warn: + _emit_alert( + BudgetAlertLevel.WARNING, + last_alert, + agent_id, + total_monthly, + monthly_budget, + ) + return False + + +def _check_daily_limit( + running_cost: float, + daily_limit: float, + daily_baseline: float, + agent_id: str, +) -> bool: + """Return True if daily limit is exhausted.""" + if daily_limit <= 0: + return False + total_daily = round( + daily_baseline + running_cost, + BUDGET_ROUNDING_PRECISION, + ) + if total_daily >= daily_limit: + logger.warning( + BUDGET_DAILY_LIMIT_HIT, + agent_id=agent_id, + total_daily=total_daily, + daily_limit=daily_limit, + ) + return True + return False diff --git a/src/ai_company/engine/__init__.py b/src/ai_company/engine/__init__.py index eb13904cc5..326af58790 100644 --- a/src/ai_company/engine/__init__.py +++ b/src/ai_company/engine/__init__.py @@ -56,6 +56,7 @@ ) from ai_company.engine.errors import ( BudgetExhaustedError, + DailyLimitExceededError, DecompositionCycleError, DecompositionDepthError, DecompositionError, @@ -174,6 +175,7 @@ "CleanupCallback", "CooperativeTimeoutStrategy", "CostOptimizedAssignmentStrategy", + "DailyLimitExceededError", "DecompositionContext", "DecompositionCycleError", "DecompositionDepthError", diff --git a/src/ai_company/engine/_validation.py b/src/ai_company/engine/_validation.py new file mode 100644 index 0000000000..9009d5fc7b --- /dev/null +++ b/src/ai_company/engine/_validation.py @@ -0,0 +1,105 @@ +"""Input validation helpers for AgentEngine. + +Pure validation functions extracted from :mod:`agent_engine` to keep +the main orchestrator under the 800-line limit. +""" + +from typing import TYPE_CHECKING + +from ai_company.core.enums import AgentStatus, TaskStatus +from ai_company.engine.errors import ExecutionStateError +from ai_company.observability import get_logger +from ai_company.observability.events.execution import ( + EXECUTION_ENGINE_INVALID_INPUT, +) + +if TYPE_CHECKING: + from ai_company.core.agent import AgentIdentity + from ai_company.core.task import Task + +logger = get_logger(__name__) + +_EXECUTABLE_STATUSES = frozenset( + {TaskStatus.ASSIGNED, TaskStatus.IN_PROGRESS}, +) +"""Task statuses the engine will accept for execution. + +CREATED tasks lack an assignee; terminal statuses (COMPLETED, CANCELLED), +BLOCKED, IN_REVIEW, FAILED, and INTERRUPTED are not executable. FAILED +and INTERRUPTED tasks must be reassigned (-> ASSIGNED) before re-execution. +""" + + +def validate_run_inputs( + *, + agent_id: str, + task_id: str, + max_turns: int, + timeout_seconds: float | None, +) -> None: + """Validate scalar ``run()`` arguments before execution.""" + if max_turns < 1: + msg = f"max_turns must be >= 1, got {max_turns}" + logger.warning( + EXECUTION_ENGINE_INVALID_INPUT, + agent_id=agent_id, + task_id=task_id, + reason=msg, + ) + raise ValueError(msg) + if timeout_seconds is not None and timeout_seconds <= 0: + msg = f"timeout_seconds must be > 0, got {timeout_seconds}" + logger.warning( + EXECUTION_ENGINE_INVALID_INPUT, + agent_id=agent_id, + task_id=task_id, + reason=msg, + ) + raise ValueError(msg) + + +def validate_agent(identity: AgentIdentity, agent_id: str) -> None: + """Raise if agent is not ACTIVE.""" + if identity.status != AgentStatus.ACTIVE: + msg = ( + f"Agent {agent_id} has status {identity.status.value!r}; " + f"only 'active' agents can run tasks" + ) + logger.warning( + EXECUTION_ENGINE_INVALID_INPUT, + agent_id=agent_id, + reason=msg, + ) + raise ExecutionStateError(msg) + + +def validate_task( + task: Task, + agent_id: str, + task_id: str, +) -> None: + """Raise if task is not executable or not assigned to this agent.""" + if task.status not in _EXECUTABLE_STATUSES: + msg = ( + f"Task {task_id!r} has status {task.status.value!r}; " + f"only 'assigned' or 'in_progress' tasks can be executed" + ) + logger.warning( + EXECUTION_ENGINE_INVALID_INPUT, + agent_id=agent_id, + task_id=task_id, + reason=msg, + ) + raise ExecutionStateError(msg) + if task.assigned_to is not None and task.assigned_to != agent_id: + msg = ( + f"Task {task_id!r} is assigned to {task.assigned_to!r}, " + f"not to agent {agent_id!r}" + ) + logger.warning( + EXECUTION_ENGINE_INVALID_INPUT, + agent_id=agent_id, + task_id=task_id, + reason=msg, + ) + raise ExecutionStateError(msg) diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index 71b372fb1a..bcdc0756ea 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -10,10 +10,18 @@ from typing import TYPE_CHECKING from ai_company.core.enums import TaskStatus +from ai_company.engine._validation import ( + validate_agent, + validate_run_inputs, + validate_task, +) from ai_company.engine.classification.pipeline import classify_execution_errors from ai_company.engine.context import DEFAULT_MAX_TURNS, AgentContext from ai_company.engine.cost_recording import record_execution_costs -from ai_company.engine.errors import ExecutionStateError +from ai_company.engine.errors import ( + BudgetExhaustedError, + ExecutionStateError, +) from ai_company.engine.loop_protocol import ( ExecutionResult, TerminationReason, @@ -29,13 +37,9 @@ from ai_company.engine.react_loop import ReactLoop from ai_company.engine.recovery import FailAndReassignStrategy, RecoveryStrategy from ai_company.engine.run_result import AgentRunResult -from ai_company.engine.validation import ( - validate_agent, - validate_run_inputs, - validate_task, -) from ai_company.observability import get_logger from ai_company.observability.events.execution import ( + EXECUTION_ENGINE_BUDGET_STOPPED, EXECUTION_ENGINE_COMPLETE, EXECUTION_ENGINE_CREATED, EXECUTION_ENGINE_ERROR, @@ -53,6 +57,7 @@ if TYPE_CHECKING: from ai_company.budget.coordination_config import ErrorTaxonomyConfig + from ai_company.budget.enforcer import BudgetEnforcer from ai_company.budget.tracker import CostTracker from ai_company.core.agent import AgentIdentity from ai_company.core.task import Task @@ -61,7 +66,7 @@ ExecutionLoop, ShutdownChecker, ) - from ai_company.providers.models import CompletionConfig, ToolDefinition + from ai_company.providers.models import CompletionConfig from ai_company.providers.protocol import CompletionProvider from ai_company.tools.registry import ToolRegistry @@ -80,19 +85,16 @@ class AgentEngine: Args: provider: LLM completion provider (required). - execution_loop: Loop implementation. Defaults to ``ReactLoop()``. + execution_loop: Defaults to ``ReactLoop()``. tool_registry: Optional tools available to the agent. - cost_tracker: Optional cost recording service. When ``None``, - cost recording is skipped silently. - recovery_strategy: Crash recovery strategy. Defaults to a - shared ``FailAndReassignStrategy`` instance. Pass ``None`` - to disable. - shutdown_checker: Optional callback; returns ``True`` when a - graceful shutdown has been requested. Passed through to - the execution loop. - error_taxonomy_config: Optional error taxonomy configuration. - When provided and enabled, runs post-execution - classification of coordination errors. + cost_tracker: Falls back to ``budget_enforcer.cost_tracker`` + when ``None`` and ``budget_enforcer`` is provided. Must + match ``budget_enforcer.cost_tracker`` if both supplied. + recovery_strategy: Defaults to ``FailAndReassignStrategy``. + shutdown_checker: Returns ``True`` for graceful shutdown. + error_taxonomy_config: Post-execution error classification. + budget_enforcer: Pre-flight checks, auto-downgrade, and + enhanced in-flight budget checking. """ def __init__( # noqa: PLR0913 @@ -105,11 +107,26 @@ def __init__( # noqa: PLR0913 recovery_strategy: RecoveryStrategy | None = _DEFAULT_RECOVERY_STRATEGY, shutdown_checker: ShutdownChecker | None = None, error_taxonomy_config: ErrorTaxonomyConfig | None = None, + budget_enforcer: BudgetEnforcer | None = None, ) -> None: self._provider = provider self._loop: ExecutionLoop = execution_loop or ReactLoop() self._tool_registry = tool_registry - self._cost_tracker = cost_tracker + self._budget_enforcer = budget_enforcer + self._cost_tracker: CostTracker | None + if budget_enforcer is not None: + if ( + cost_tracker is not None + and cost_tracker is not budget_enforcer.cost_tracker + ): + msg = ( + "cost_tracker must match budget_enforcer.cost_tracker " + "when budget_enforcer is provided" + ) + raise ValueError(msg) + self._cost_tracker = budget_enforcer.cost_tracker + else: + self._cost_tracker = cost_tracker self._recovery_strategy = recovery_strategy self._shutdown_checker = shutdown_checker self._error_taxonomy_config = error_taxonomy_config @@ -118,6 +135,7 @@ def __init__( # noqa: PLR0913 loop_type=self._loop.get_loop_type(), has_tool_registry=self._tool_registry is not None, has_cost_tracker=self._cost_tracker is not None, + has_budget_enforcer=self._budget_enforcer is not None, ) async def run( # noqa: PLR0913 @@ -162,6 +180,11 @@ async def run( # noqa: PLR0913 ctx: AgentContext | None = None system_prompt: SystemPrompt | None = None try: + # Pre-flight budget enforcement + if self._budget_enforcer: + await self._budget_enforcer.check_can_execute(agent_id) + identity = await self._budget_enforcer.resolve_model(identity) + tool_invoker = self._make_tool_invoker(identity) ctx, system_prompt = self._prepare_context( identity=identity, @@ -185,14 +208,24 @@ async def run( # noqa: PLR0913 tool_invoker=tool_invoker, ) except MemoryError, RecursionError: - logger.error( + logger.exception( EXECUTION_ENGINE_ERROR, agent_id=agent_id, task_id=task_id, error="non-recoverable error in run()", - exc_info=True, ) raise + except BudgetExhaustedError as exc: + return self._handle_budget_error( + exc=exc, + identity=identity, + task=task, + agent_id=agent_id, + task_id=task_id, + duration_seconds=time.monotonic() - start, + ctx=ctx, + system_prompt=system_prompt, + ) except Exception as exc: return await self._handle_fatal_error( exc=exc, @@ -220,7 +253,14 @@ async def _execute( # noqa: PLR0913 tool_invoker: ToolInvoker | None = None, ) -> AgentRunResult: """Run execution loop, record costs, apply transitions, and build result.""" - budget_checker = make_budget_checker(task) + budget_checker: BudgetChecker | None + if self._budget_enforcer: + budget_checker = await self._budget_enforcer.make_budget_checker( + task, + agent_id, + ) + else: + budget_checker = make_budget_checker(task) logger.debug( EXECUTION_ENGINE_PROMPT_BUILT, @@ -402,7 +442,7 @@ def _prepare_context( # noqa: PLR0913 tool_invoker: ToolInvoker | None = None, ) -> tuple[AgentContext, SystemPrompt]: """Build system prompt and prepare execution context.""" - tool_defs = self._get_tool_definitions(tool_invoker) + tool_defs = tool_invoker.get_permitted_definitions() if tool_invoker else () system_prompt = build_system_prompt( agent=identity, task=task, @@ -431,15 +471,6 @@ def _prepare_context( # noqa: PLR0913 # ── Helpers ────────────────────────────────────────────────── - def _get_tool_definitions( - self, - tool_invoker: ToolInvoker | None, - ) -> tuple[ToolDefinition, ...]: - """Extract permitted tool definitions for prompt building.""" - if tool_invoker is None: - return () - return tool_invoker.get_permitted_definitions() - def _transition_task_if_needed( self, ctx: AgentContext, @@ -653,6 +684,54 @@ def _log_completion( duration_seconds=metrics.duration_seconds, ) + def _handle_budget_error( # noqa: PLR0913 + self, + *, + exc: BudgetExhaustedError, + identity: AgentIdentity, + task: Task, + agent_id: str, + task_id: str, + duration_seconds: float, + ctx: AgentContext | None = None, + system_prompt: SystemPrompt | None = None, + ) -> AgentRunResult: + """Build a BUDGET_EXHAUSTED result (no recovery — controlled stop).""" + logger.warning( + EXECUTION_ENGINE_BUDGET_STOPPED, + agent_id=agent_id, + task_id=task_id, + error=f"{type(exc).__name__}: {exc}", + ) + try: + error_ctx = ctx or AgentContext.from_identity(identity, task=task) + budget_result = ExecutionResult( + context=error_ctx, + termination_reason=TerminationReason.BUDGET_EXHAUSTED, + ) + error_prompt = build_error_prompt( + identity, + agent_id, + system_prompt, + ) + return AgentRunResult( + execution_result=budget_result, + system_prompt=error_prompt, + duration_seconds=duration_seconds, + agent_id=agent_id, + task_id=task_id, + ) + except MemoryError, RecursionError: + raise + except Exception as build_exc: + logger.exception( + EXECUTION_ENGINE_ERROR, + agent_id=agent_id, + task_id=task_id, + error=f"Failed to build budget-exhausted result: {build_exc}", + ) + raise exc from None + async def _handle_fatal_error( # noqa: PLR0913 self, *, @@ -700,12 +779,11 @@ async def _handle_fatal_error( # noqa: PLR0913 task_id=task_id, ) except MemoryError, RecursionError: - logger.error( + logger.exception( EXECUTION_ENGINE_ERROR, agent_id=agent_id, task_id=task_id, error="non-recoverable error while building error result", - exc_info=True, ) raise except Exception as build_exc: @@ -716,7 +794,7 @@ async def _handle_fatal_error( # noqa: PLR0913 error=f"Failed to build error result: {build_exc}", original_error=error_msg, ) - raise exc from build_exc + raise exc from None async def _build_error_execution( # noqa: PLR0913 self, diff --git a/src/ai_company/engine/errors.py b/src/ai_company/engine/errors.py index 32e793bff7..b1a2efeaa3 100644 --- a/src/ai_company/engine/errors.py +++ b/src/ai_company/engine/errors.py @@ -24,12 +24,20 @@ class MaxTurnsExceededError(EngineError): class BudgetExhaustedError(EngineError): """Budget exhaustion signal for the engine layer. - The execution loop returns ``TerminationReason.BUDGET_EXHAUSTED`` - internally. This exception is available for the engine layer above - the loop to convert that result into a raised error when appropriate. + Used in two contexts: + + 1. Raised directly by :meth:`BudgetEnforcer.check_can_execute` + when pre-flight budget checks fail (monthly hard stop or daily + limit exceeded). + 2. Available for converting ``TerminationReason.BUDGET_EXHAUSTED`` + loop results into a raised error at the engine layer. """ +class DailyLimitExceededError(BudgetExhaustedError): + """Per-agent daily spending limit exceeded.""" + + class LoopExecutionError(EngineError): """Non-recoverable execution loop error for the engine layer. diff --git a/src/ai_company/observability/events/budget.py b/src/ai_company/observability/events/budget.py index 19e2b1cd8b..e0157d79f6 100644 --- a/src/ai_company/observability/events/budget.py +++ b/src/ai_company/observability/events/budget.py @@ -13,3 +13,16 @@ BUDGET_CATEGORY_BREAKDOWN_QUERIED: Final[str] = "budget.category_breakdown.queried" BUDGET_ORCHESTRATION_RATIO_QUERIED: Final[str] = "budget.orchestration_ratio.queried" BUDGET_ORCHESTRATION_RATIO_ALERT: Final[str] = "budget.orchestration_ratio.alert" + +BUDGET_ALERT_THRESHOLD_CROSSED: Final[str] = "budget.alert.threshold_crossed" +BUDGET_HARD_STOP_EXCEEDED: Final[str] = "budget.hard_stop.exceeded" +BUDGET_HARD_STOP_TRIGGERED: Final[str] = "budget.hard_stop.triggered" +BUDGET_DAILY_LIMIT_EXCEEDED: Final[str] = "budget.daily_limit.exceeded" +BUDGET_DOWNGRADE_APPLIED: Final[str] = "budget.downgrade.applied" +BUDGET_DOWNGRADE_SKIPPED: Final[str] = "budget.downgrade.skipped" +BUDGET_ENFORCEMENT_CHECK: Final[str] = "budget.enforcement.check" +BUDGET_TASK_LIMIT_HIT: Final[str] = "budget.task_limit.hit" +BUDGET_DAILY_LIMIT_HIT: Final[str] = "budget.daily_limit.hit" +BUDGET_BASELINE_ERROR: Final[str] = "budget.baseline.error" +BUDGET_PREFLIGHT_ERROR: Final[str] = "budget.preflight.error" +BUDGET_RESOLVE_MODEL_ERROR: Final[str] = "budget.resolve_model.error" diff --git a/src/ai_company/observability/events/execution.py b/src/ai_company/observability/events/execution.py index 8c5562ac4d..700224a46e 100644 --- a/src/ai_company/observability/events/execution.py +++ b/src/ai_company/observability/events/execution.py @@ -34,6 +34,7 @@ EXECUTION_ENGINE_COST_FAILED: Final[str] = "execution.engine.cost_failed" EXECUTION_ENGINE_TASK_METRICS: Final[str] = "execution.engine.task_metrics" EXECUTION_ENGINE_TIMEOUT: Final[str] = "execution.engine.timeout" +EXECUTION_ENGINE_BUDGET_STOPPED: Final[str] = "execution.engine.budget_stopped" EXECUTION_SHUTDOWN_SIGNAL: Final[str] = "execution.shutdown.signal" EXECUTION_SHUTDOWN_MANAGER_CREATED: Final[str] = "execution.shutdown.manager_created" diff --git a/tests/unit/budget/test_billing.py b/tests/unit/budget/test_billing.py new file mode 100644 index 0000000000..79e9ec7597 --- /dev/null +++ b/tests/unit/budget/test_billing.py @@ -0,0 +1,142 @@ +"""Tests for billing period computation utilities.""" + +from datetime import UTC, datetime + +import pytest + +from ai_company.budget.billing import billing_period_start, daily_period_start + +pytestmark = pytest.mark.timeout(30) + + +@pytest.mark.unit +class TestBillingPeriodStart: + """Tests for billing_period_start().""" + + @pytest.mark.parametrize( + ("reset_day", "now", "expected"), + [ + # Same month: now.day >= reset_day + ( + 1, + datetime(2026, 3, 15, 10, 30, tzinfo=UTC), + datetime(2026, 3, 1, tzinfo=UTC), + ), + ( + 15, + datetime(2026, 3, 20, 8, 0, tzinfo=UTC), + datetime(2026, 3, 15, tzinfo=UTC), + ), + # Exact boundary: now.day == reset_day + ( + 1, + datetime(2026, 1, 1, 0, 0, tzinfo=UTC), + datetime(2026, 1, 1, tzinfo=UTC), + ), + ( + 15, + datetime(2026, 3, 15, 23, 59, tzinfo=UTC), + datetime(2026, 3, 15, tzinfo=UTC), + ), + # Previous month: now.day < reset_day + ( + 15, + datetime(2026, 3, 10, 12, 0, tzinfo=UTC), + datetime(2026, 2, 15, tzinfo=UTC), + ), + ( + 10, + datetime(2026, 3, 5, 0, 0, tzinfo=UTC), + datetime(2026, 2, 10, tzinfo=UTC), + ), + # Year boundary: January with rollback to December + ( + 10, + datetime(2026, 1, 5, 0, 0, tzinfo=UTC), + datetime(2025, 12, 10, tzinfo=UTC), + ), + # reset_day=28 (max allowed), February is safe + ( + 28, + datetime(2026, 3, 1, 0, 0, tzinfo=UTC), + datetime(2026, 2, 28, tzinfo=UTC), + ), + # reset_day=1 always stays in current month + ( + 1, + datetime(2026, 12, 31, 23, 59, tzinfo=UTC), + datetime(2026, 12, 1, tzinfo=UTC), + ), + ], + ids=[ + "same_month_day1", + "same_month_day15", + "exact_boundary_jan1", + "exact_boundary_day15", + "prev_month_day15", + "prev_month_day10", + "year_boundary", + "feb_28_safe", + "dec_31_day1", + ], + ) + def test_billing_period_start( + self, + reset_day: int, + now: datetime, + expected: datetime, + ) -> None: + """Verify billing period start for various date/reset_day combos.""" + result = billing_period_start(reset_day, now=now) + assert result == expected + + def test_defaults_to_utc_now(self) -> None: + """Verify billing_period_start works without explicit now.""" + result = billing_period_start(1) + assert result.tzinfo is UTC + assert result.hour == 0 + assert result.minute == 0 + assert result.second == 0 + + def test_result_is_utc_aware(self) -> None: + """Result always has UTC timezone.""" + result = billing_period_start( + 15, + now=datetime(2026, 3, 20, tzinfo=UTC), + ) + assert result.tzinfo is UTC + + @pytest.mark.parametrize( + "invalid_day", + [0, -1, 29, 31, 100], + ids=["zero", "negative", "29", "31", "100"], + ) + def test_invalid_reset_day_raises(self, invalid_day: int) -> None: + """Invalid reset_day raises ValueError.""" + with pytest.raises(ValueError, match="reset_day must be 1-28"): + billing_period_start(invalid_day) + + +@pytest.mark.unit +class TestDailyPeriodStart: + """Tests for daily_period_start().""" + + def test_returns_midnight_utc(self) -> None: + """Verify midnight UTC of the given day.""" + now = datetime(2026, 3, 15, 14, 30, 45, tzinfo=UTC) + result = daily_period_start(now=now) + assert result == datetime(2026, 3, 15, tzinfo=UTC) + + def test_already_at_midnight(self) -> None: + """When now is already midnight, return same instant.""" + now = datetime(2026, 3, 15, 0, 0, 0, tzinfo=UTC) + result = daily_period_start(now=now) + assert result == now + + def test_defaults_to_utc_now(self) -> None: + """Verify daily_period_start works without explicit now.""" + result = daily_period_start() + assert result.tzinfo is UTC + assert result.hour == 0 + assert result.minute == 0 + assert result.second == 0 diff --git a/tests/unit/budget/test_config.py b/tests/unit/budget/test_config.py index a83e5c3865..8a0e92405a 100644 --- a/tests/unit/budget/test_config.py +++ b/tests/unit/budget/test_config.py @@ -103,6 +103,7 @@ def test_defaults(self) -> None: assert cfg.enabled is False assert cfg.threshold == 85 assert cfg.downgrade_map == () + assert cfg.boundary == "task_assignment" def test_custom_values(self) -> None: """Accept valid custom configuration.""" @@ -183,6 +184,16 @@ def test_aliases_normalized(self) -> None: ) assert cfg.downgrade_map == (("large", "medium"),) + def test_boundary_default_is_task_assignment(self) -> None: + """Verify boundary default is 'task_assignment'.""" + cfg = AutoDowngradeConfig() + assert cfg.boundary == "task_assignment" + + def test_boundary_rejects_other_values(self) -> None: + """Reject boundary values other than 'task_assignment'.""" + with pytest.raises(ValidationError): + AutoDowngradeConfig(boundary="mid_execution") # type: ignore[arg-type] + def test_frozen(self) -> None: """Ensure AutoDowngradeConfig is immutable.""" cfg = AutoDowngradeConfig() @@ -210,6 +221,7 @@ def test_defaults(self) -> None: assert cfg.per_agent_daily_limit == 10.0 assert cfg.alerts.warn_at == 75 assert cfg.auto_downgrade.enabled is False + assert cfg.reset_day == 1 def test_custom_values(self, sample_budget_config: BudgetConfig) -> None: """Accept valid custom budget config.""" @@ -258,6 +270,38 @@ def test_zero_monthly_skips_limit_validation(self) -> None: assert cfg.per_task_limit == 100.0 assert cfg.per_agent_daily_limit == 100.0 + def test_reset_day_valid_range(self) -> None: + """Accept reset_day in valid range (1-28).""" + cfg_1 = BudgetConfig(reset_day=1) + assert cfg_1.reset_day == 1 + cfg_28 = BudgetConfig(reset_day=28) + assert cfg_28.reset_day == 28 + + def test_reset_day_zero_rejected(self) -> None: + """Reject reset_day of 0.""" + with pytest.raises(ValidationError): + BudgetConfig(reset_day=0) + + def test_reset_day_29_rejected(self) -> None: + """Reject reset_day of 29 (avoids month-length issues).""" + with pytest.raises(ValidationError): + BudgetConfig(reset_day=29) + + def test_reset_day_float_rejected(self) -> None: + """Reject float value for reset_day (strict int).""" + with pytest.raises(ValidationError): + BudgetConfig(reset_day=15.0) # type: ignore[arg-type] + + @pytest.mark.parametrize( + "value", + [float("inf"), float("-inf"), float("nan")], + ids=["inf", "neg_inf", "nan"], + ) + def test_inf_nan_rejected(self, value: float) -> None: + """Reject inf and NaN values for float fields.""" + with pytest.raises(ValidationError): + BudgetConfig(total_monthly=value) + def test_frozen(self) -> None: """Ensure BudgetConfig is immutable.""" cfg = BudgetConfig() diff --git a/tests/unit/budget/test_enforcer.py b/tests/unit/budget/test_enforcer.py new file mode 100644 index 0000000000..00124d3baa --- /dev/null +++ b/tests/unit/budget/test_enforcer.py @@ -0,0 +1,1067 @@ +"""Tests for BudgetEnforcer service.""" + +import contextlib +from datetime import UTC, date, datetime +from typing import TYPE_CHECKING +from unittest.mock import patch +from uuid import uuid4 + +if TYPE_CHECKING: + from collections.abc import Iterator + +import pytest + +from ai_company.budget.config import ( + AutoDowngradeConfig, + BudgetAlertConfig, + BudgetConfig, +) +from ai_company.budget.enforcer import BudgetEnforcer +from ai_company.budget.tracker import CostTracker +from ai_company.core.agent import AgentIdentity, ModelConfig +from ai_company.core.enums import TaskStatus, TaskType +from ai_company.core.task import Task +from ai_company.engine.context import AgentContext +from ai_company.engine.errors import BudgetExhaustedError, DailyLimitExceededError +from ai_company.observability.events.budget import BUDGET_ALERT_THRESHOLD_CROSSED +from ai_company.providers.models import TokenUsage +from ai_company.providers.routing.models import ResolvedModel +from ai_company.providers.routing.resolver import ModelResolver + +from .conftest import make_cost_record + +pytestmark = pytest.mark.timeout(30) + +# Timestamps within the test billing period (March 2026) +_BILLING_START = datetime(2026, 3, 1, tzinfo=UTC) +_DAY_START = datetime(2026, 3, 15, tzinfo=UTC) +_RECORD_TS = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + + +# ── Helpers ────────────────────────────────────────────────────────── + + +def _make_budget_config( # noqa: PLR0913 + *, + total_monthly: float = 100.0, + warn_at: int = 75, + critical_at: int = 90, + hard_stop_at: int = 100, + per_agent_daily_limit: float = 10.0, + per_task_limit: float = 5.0, + reset_day: int = 1, + auto_downgrade: AutoDowngradeConfig | None = None, +) -> BudgetConfig: + return BudgetConfig( + total_monthly=total_monthly, + alerts=BudgetAlertConfig( + warn_at=warn_at, + critical_at=critical_at, + hard_stop_at=hard_stop_at, + ), + per_agent_daily_limit=per_agent_daily_limit, + per_task_limit=per_task_limit, + reset_day=reset_day, + auto_downgrade=auto_downgrade or AutoDowngradeConfig(), + ) + + +def _make_identity( + *, + model_id: str = "test-large-001", + provider: str = "test-provider", +) -> AgentIdentity: + return AgentIdentity( + id=uuid4(), + name="Test Agent", + role="Developer", + department="Engineering", + model=ModelConfig(provider=provider, model_id=model_id), + hiring_date=date(2026, 1, 1), + ) + + +def _make_task( + *, + agent_id: str, + budget_limit: float = 0.0, +) -> Task: + return Task( + id="task-001", + title="Test task", + description="A test task", + type=TaskType.DEVELOPMENT, + project="proj-001", + created_by="manager", + assigned_to=agent_id, + status=TaskStatus.ASSIGNED, + budget_limit=budget_limit, + ) + + +def _make_resolver( + models: dict[str, ResolvedModel] | None = None, +) -> ModelResolver: + index = models or {} + return ModelResolver(index) + + +def _resolved( + *, + model_id: str, + provider: str = "test-provider", + alias: str | None = None, +) -> ResolvedModel: + return ResolvedModel( + provider_name=provider, + model_id=model_id, + alias=alias, + ) + + +def _ctx_with_cost( + identity: AgentIdentity, + task: Task, + cost_usd: float, +) -> AgentContext: + """Build an AgentContext with a specific accumulated cost.""" + ctx = AgentContext.from_identity(identity, task=task) + return ctx.model_copy( + update={ + "accumulated_cost": TokenUsage( + input_tokens=100, + output_tokens=50, + cost_usd=cost_usd, + ), + }, + ) + + +def _patch_periods() -> contextlib.AbstractContextManager[None]: + """Context manager that patches billing and daily period starts.""" + return _combined_patch(_BILLING_START, _DAY_START) + + +def _combined_patch( + billing_start: datetime, + day_start: datetime, +) -> contextlib.AbstractContextManager[None]: + """Return combined patch context manager for both period functions.""" + + @contextlib.contextmanager + def _ctx() -> Iterator[None]: + with ( + patch( + "ai_company.budget.enforcer.billing_period_start", + return_value=billing_start, + ), + patch( + "ai_company.budget.enforcer.daily_period_start", + return_value=day_start, + ), + ): + yield + + return _ctx() + + +# ── Pre-flight checks ─────────────────────────────────────────────── + + +@pytest.mark.unit +class TestCheckCanExecute: + """Tests for BudgetEnforcer.check_can_execute().""" + + async def test_passes_when_under_budget(self) -> None: + """Monthly budget not exceeded passes without exception.""" + cfg = _make_budget_config( + total_monthly=100.0, + per_agent_daily_limit=50.0, + ) + tracker = CostTracker(budget_config=cfg) + await tracker.record( + make_cost_record( + cost_usd=30.0, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + ) + with _patch_periods(): + await enforcer.check_can_execute("alice") + + async def test_raises_at_exactly_hard_stop(self) -> None: + """Monthly budget at exactly 100% raises BudgetExhaustedError.""" + cfg = _make_budget_config(total_monthly=100.0, hard_stop_at=100) + tracker = CostTracker(budget_config=cfg) + await tracker.record( + make_cost_record( + cost_usd=100.0, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + + with ( + _patch_periods(), + pytest.raises( + BudgetExhaustedError, + match="Monthly budget exhausted", + ), + ): + await enforcer.check_can_execute("alice") + + async def test_raises_over_hard_stop(self) -> None: + """Monthly budget over 100% raises BudgetExhaustedError.""" + cfg = _make_budget_config(total_monthly=100.0, hard_stop_at=100) + tracker = CostTracker(budget_config=cfg) + await tracker.record( + make_cost_record( + cost_usd=110.0, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + + with _patch_periods(), pytest.raises(BudgetExhaustedError): + await enforcer.check_can_execute("alice") + + async def test_daily_limit_at_exact_limit_raises(self) -> None: + """Daily limit at exactly the limit raises DailyLimitExceededError.""" + cfg = _make_budget_config( + total_monthly=100.0, + per_agent_daily_limit=10.0, + ) + tracker = CostTracker(budget_config=cfg) + await tracker.record( + make_cost_record( + agent_id="alice", + cost_usd=10.0, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + + with ( + _patch_periods(), + pytest.raises( + DailyLimitExceededError, + match="daily limit exceeded", + ), + ): + await enforcer.check_can_execute("alice") + + async def test_daily_limit_not_exceeded_passes(self) -> None: + """Daily limit not reached passes without exception.""" + cfg = _make_budget_config(per_agent_daily_limit=10.0) + tracker = CostTracker(budget_config=cfg) + await tracker.record( + make_cost_record( + agent_id="alice", + cost_usd=5.0, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + + with _patch_periods(): + await enforcer.check_can_execute("alice") + + async def test_enforcement_disabled_always_passes(self) -> None: + """Budget disabled (total_monthly=0) always passes.""" + cfg = _make_budget_config(total_monthly=0.0) + tracker = CostTracker(budget_config=cfg) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + await enforcer.check_can_execute("alice") + + async def test_daily_limit_disabled_skips_check(self) -> None: + """Daily limit of 0 skips the daily check entirely.""" + cfg = _make_budget_config(per_agent_daily_limit=0.0) + tracker = CostTracker(budget_config=cfg) + await tracker.record( + make_cost_record( + agent_id="alice", + cost_usd=50.0, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + + with _patch_periods(): + await enforcer.check_can_execute("alice") + + +# ── Auto-downgrade ─────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestResolveModel: + """Tests for BudgetEnforcer.resolve_model().""" + + async def test_below_threshold_returns_unchanged(self) -> None: + """Budget below downgrade threshold returns identity unchanged.""" + cfg = _make_budget_config( + auto_downgrade=AutoDowngradeConfig( + enabled=True, + threshold=85, + downgrade_map=(("large", "medium"),), + ), + ) + tracker = CostTracker(budget_config=cfg) + # 50% usage — below 85% threshold + await tracker.record( + make_cost_record( + cost_usd=50.0, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + resolver = _make_resolver( + {"test-large-001": _resolved(model_id="test-large-001", alias="large")}, + ) + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + model_resolver=resolver, + ) + identity = _make_identity() + + with _patch_periods(): + result = await enforcer.resolve_model(identity) + + assert result.model.model_id == "test-large-001" + + async def test_above_threshold_with_mapping_downgrades(self) -> None: + """Budget above threshold with matching alias downgrades the model.""" + cfg = _make_budget_config( + auto_downgrade=AutoDowngradeConfig( + enabled=True, + threshold=85, + downgrade_map=(("large", "medium"),), + ), + ) + tracker = CostTracker(budget_config=cfg) + # 90% usage — above 85% threshold + await tracker.record( + make_cost_record( + cost_usd=90.0, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + resolver = _make_resolver( + { + "test-large-001": _resolved( + model_id="test-large-001", + alias="large", + ), + "large": _resolved(model_id="test-large-001", alias="large"), + "medium": _resolved( + model_id="test-medium-001", + provider="test-provider", + alias="medium", + ), + } + ) + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + model_resolver=resolver, + ) + identity = _make_identity(model_id="test-large-001") + + with _patch_periods(): + result = await enforcer.resolve_model(identity) + + assert result.model.model_id == "test-medium-001" + assert result.model.provider == "test-provider" + + async def test_above_threshold_no_matching_alias_unchanged(self) -> None: + """Budget above threshold but no matching alias returns unchanged.""" + cfg = _make_budget_config( + auto_downgrade=AutoDowngradeConfig( + enabled=True, + threshold=85, + downgrade_map=(("small", "tiny"),), + ), + ) + tracker = CostTracker(budget_config=cfg) + await tracker.record( + make_cost_record( + cost_usd=90.0, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + resolver = _make_resolver( + { + "test-large-001": _resolved( + model_id="test-large-001", + alias="large", + ), + } + ) + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + model_resolver=resolver, + ) + identity = _make_identity() + + with _patch_periods(): + result = await enforcer.resolve_model(identity) + + assert result.model.model_id == "test-large-001" + + async def test_no_model_resolver_returns_unchanged(self) -> None: + """No model_resolver provided returns identity unchanged.""" + cfg = _make_budget_config( + auto_downgrade=AutoDowngradeConfig( + enabled=True, + threshold=85, + downgrade_map=(("large", "medium"),), + ), + ) + tracker = CostTracker(budget_config=cfg) + await tracker.record( + make_cost_record( + cost_usd=90.0, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + model_resolver=None, + ) + identity = _make_identity() + + result = await enforcer.resolve_model(identity) + assert result.model.model_id == "test-large-001" + + async def test_disabled_returns_unchanged(self) -> None: + """Auto-downgrade disabled returns identity unchanged.""" + cfg = _make_budget_config( + auto_downgrade=AutoDowngradeConfig(enabled=False), + ) + tracker = CostTracker(budget_config=cfg) + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + ) + identity = _make_identity() + + result = await enforcer.resolve_model(identity) + assert result.model.model_id == "test-large-001" + + async def test_at_exact_threshold_applies_downgrade(self) -> None: + """Budget at exactly the threshold applies downgrade.""" + cfg = _make_budget_config( + auto_downgrade=AutoDowngradeConfig( + enabled=True, + threshold=85, + downgrade_map=(("large", "medium"),), + ), + ) + tracker = CostTracker(budget_config=cfg) + # Exactly 85% usage + await tracker.record( + make_cost_record( + cost_usd=85.0, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + resolver = _make_resolver( + { + "test-large-001": _resolved( + model_id="test-large-001", + alias="large", + ), + "large": _resolved( + model_id="test-large-001", + alias="large", + ), + "medium": _resolved( + model_id="test-medium-001", + alias="medium", + ), + } + ) + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + model_resolver=resolver, + ) + identity = _make_identity(model_id="test-large-001") + + with _patch_periods(): + result = await enforcer.resolve_model(identity) + + # At exactly threshold → downgrade applies (< is strict) + assert result.model.model_id == "test-medium-001" + + async def test_resolved_model_has_no_alias_unchanged(self) -> None: + """Model in resolver but with no alias returns unchanged.""" + cfg = _make_budget_config( + auto_downgrade=AutoDowngradeConfig( + enabled=True, + threshold=85, + downgrade_map=(("large", "medium"),), + ), + ) + tracker = CostTracker(budget_config=cfg) + await tracker.record( + make_cost_record( + cost_usd=90.0, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + # Model registered without an alias + resolver = _make_resolver( + { + "test-large-001": _resolved( + model_id="test-large-001", + alias=None, + ), + } + ) + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + model_resolver=resolver, + ) + identity = _make_identity() + + with _patch_periods(): + result = await enforcer.resolve_model(identity) + + assert result.model.model_id == "test-large-001" + + async def test_target_alias_not_resolvable_unchanged(self) -> None: + """Target alias in downgrade map but not in resolver skips.""" + cfg = _make_budget_config( + auto_downgrade=AutoDowngradeConfig( + enabled=True, + threshold=85, + downgrade_map=(("large", "nonexistent"),), + ), + ) + tracker = CostTracker(budget_config=cfg) + await tracker.record( + make_cost_record( + cost_usd=90.0, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + resolver = _make_resolver( + { + "test-large-001": _resolved( + model_id="test-large-001", + alias="large", + ), + "large": _resolved( + model_id="test-large-001", + alias="large", + ), + # "nonexistent" is NOT registered + } + ) + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + model_resolver=resolver, + ) + identity = _make_identity() + + with _patch_periods(): + result = await enforcer.resolve_model(identity) + + assert result.model.model_id == "test-large-001" + + async def test_chain_downgrade_applies_first_match_only(self) -> None: + """Only the first matching downgrade_map entry applies.""" + cfg = _make_budget_config( + auto_downgrade=AutoDowngradeConfig( + enabled=True, + threshold=85, + downgrade_map=( + ("large", "medium"), + ("medium", "small"), + ), + ), + ) + tracker = CostTracker(budget_config=cfg) + await tracker.record( + make_cost_record( + cost_usd=90.0, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + resolver = _make_resolver( + { + "test-large-001": _resolved( + model_id="test-large-001", + alias="large", + ), + "large": _resolved(model_id="test-large-001", alias="large"), + "medium": _resolved( + model_id="test-medium-001", + alias="medium", + ), + "small": _resolved(model_id="test-small-001", alias="small"), + } + ) + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + model_resolver=resolver, + ) + identity = _make_identity(model_id="test-large-001") + + with _patch_periods(): + result = await enforcer.resolve_model(identity) + + # Should downgrade to medium, NOT to small + assert result.model.model_id == "test-medium-001" + + +# ── Budget checker factory ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestMakeBudgetChecker: + """Tests for BudgetEnforcer.make_budget_checker().""" + + async def test_returns_none_when_all_disabled(self) -> None: + """Returns None when all limits are disabled.""" + cfg = _make_budget_config( + total_monthly=0.0, + per_agent_daily_limit=0.0, + ) + tracker = CostTracker(budget_config=cfg) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + identity = _make_identity() + task = _make_task( + agent_id=str(identity.id), + budget_limit=0.0, + ) + + checker = await enforcer.make_budget_checker(task, str(identity.id)) + assert checker is None + + async def test_returns_checker_when_only_task_limit_active(self) -> None: + """Returns a checker (not None) when only task_limit is set.""" + cfg = _make_budget_config( + total_monthly=0.0, + per_agent_daily_limit=0.0, + ) + tracker = CostTracker(budget_config=cfg) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + identity = _make_identity() + task = _make_task( + agent_id=str(identity.id), + budget_limit=5.0, + ) + + checker = await enforcer.make_budget_checker(task, str(identity.id)) + assert checker is not None + + # Under limit → not exhausted + ctx_under = _ctx_with_cost(identity, task, 4.99) + assert checker(ctx_under) is False + + # At limit → exhausted + ctx_at = _ctx_with_cost(identity, task, 5.0) + assert checker(ctx_at) is True + + async def test_task_budget_exhaustion(self) -> None: + """Checker detects task budget exhaustion at exact limit.""" + cfg = _make_budget_config(total_monthly=100.0) + tracker = CostTracker(budget_config=cfg) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + identity = _make_identity() + task = _make_task( + agent_id=str(identity.id), + budget_limit=5.0, + ) + + with _patch_periods(): + checker = await enforcer.make_budget_checker(task, str(identity.id)) + + assert checker is not None + + # Under limit + ctx_under = _ctx_with_cost(identity, task, 4.99) + assert checker(ctx_under) is False + + # At limit + ctx_at = _ctx_with_cost(identity, task, 5.0) + assert checker(ctx_at) is True + + async def test_monthly_hard_stop(self) -> None: + """Checker detects monthly hard stop (baseline + running cost).""" + cfg = _make_budget_config( + total_monthly=100.0, + hard_stop_at=100, + per_agent_daily_limit=0.0, + ) + tracker = CostTracker(budget_config=cfg) + # Pre-existing monthly spend of 90 + await tracker.record( + make_cost_record( + cost_usd=90.0, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + identity = _make_identity() + task = _make_task(agent_id=str(identity.id)) + + with _patch_periods(): + checker = await enforcer.make_budget_checker(task, str(identity.id)) + + assert checker is not None + + # Running cost of 9 → total 99 → under 100 + ctx_under = _ctx_with_cost(identity, task, 9.0) + assert checker(ctx_under) is False + + # Running cost of 10 → total 100 → at hard stop + ctx_at = _ctx_with_cost(identity, task, 10.0) + assert checker(ctx_at) is True + + async def test_daily_limit_in_checker(self) -> None: + """Checker detects daily limit (baseline + running cost).""" + cfg = _make_budget_config( + total_monthly=100.0, + per_agent_daily_limit=10.0, + ) + tracker = CostTracker(budget_config=cfg) + # Pre-existing daily spend of 8 + await tracker.record( + make_cost_record( + agent_id="alice", + cost_usd=8.0, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + identity = _make_identity() + task = _make_task(agent_id=str(identity.id)) + + with _patch_periods(): + checker = await enforcer.make_budget_checker(task, "alice") + + assert checker is not None + + # Running cost of 1 → daily total 9 → under 10 + ctx_under = _ctx_with_cost(identity, task, 1.0) + assert checker(ctx_under) is False + + # Running cost of 2 → daily total 10 → at limit + ctx_at = _ctx_with_cost(identity, task, 2.0) + assert checker(ctx_at) is True + + @pytest.mark.parametrize( + ("baseline", "running", "expected_exhausted"), + [ + (74.0, 0.9, False), # 74.9% → NORMAL, not exhausted + (74.0, 1.0, False), # 75.0% → WARNING alert, not exhausted + (89.0, 1.0, False), # 90.0% → CRITICAL alert, not exhausted + (99.0, 1.0, True), # 100.0% → HARD_STOP, exhausted + ], + ids=["74.9_normal", "75.0_warning", "90.0_critical", "100.0_hard_stop"], + ) + async def test_alert_thresholds( + self, + baseline: float, + running: float, + expected_exhausted: bool, + ) -> None: + """Alert fires at exact threshold percentages.""" + cfg = _make_budget_config( + total_monthly=100.0, + warn_at=75, + critical_at=90, + hard_stop_at=100, + per_agent_daily_limit=0.0, + ) + tracker = CostTracker(budget_config=cfg) + await tracker.record( + make_cost_record( + cost_usd=baseline, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + identity = _make_identity() + task = _make_task(agent_id=str(identity.id)) + + with _patch_periods(): + checker = await enforcer.make_budget_checker(task, str(identity.id)) + + assert checker is not None + ctx = _ctx_with_cost(identity, task, running) + assert checker(ctx) is expected_exhausted + + async def test_alert_deduplication(self) -> None: + """Same alert level is not logged twice.""" + cfg = _make_budget_config( + total_monthly=100.0, + warn_at=75, + critical_at=90, + hard_stop_at=100, + per_agent_daily_limit=0.0, + ) + tracker = CostTracker(budget_config=cfg) + # Baseline of 70 + await tracker.record( + make_cost_record( + cost_usd=70.0, + input_tokens=100, + output_tokens=50, + timestamp=_RECORD_TS, + ), + ) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + identity = _make_identity() + task = _make_task(agent_id=str(identity.id)) + + with _patch_periods(): + checker = await enforcer.make_budget_checker(task, str(identity.id)) + + assert checker is not None + + # Context at 75% (baseline 70 + running 5) + ctx_warning = _ctx_with_cost(identity, task, 5.0) + + # First call at 75% → should emit WARNING + with patch( + "ai_company.budget.enforcer.logger", + ) as mock_logger: + checker(ctx_warning) + warn_calls = [ + c + for c in mock_logger.warning.call_args_list + if c[0][0] == BUDGET_ALERT_THRESHOLD_CROSSED + ] + assert len(warn_calls) == 1 + + # Second call at same level → should NOT emit again + with patch( + "ai_company.budget.enforcer.logger", + ) as mock_logger2: + checker(ctx_warning) + warn_calls2 = [ + c + for c in mock_logger2.warning.call_args_list + if c[0][0] == BUDGET_ALERT_THRESHOLD_CROSSED + ] + assert len(warn_calls2) == 0 + + +# ── Graceful degradation ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestGracefulDegradation: + """Tests for CostTracker failure fallback paths.""" + + async def test_resolve_model_returns_unchanged_on_tracker_error( + self, + ) -> None: + """CostTracker failure in resolve_model returns identity unchanged.""" + cfg = _make_budget_config( + auto_downgrade=AutoDowngradeConfig( + enabled=True, + threshold=85, + downgrade_map=(("large", "medium"),), + ), + ) + tracker = CostTracker(budget_config=cfg) + resolver = _make_resolver( + { + "test-large-001": _resolved( + model_id="test-large-001", + alias="large", + ), + } + ) + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + model_resolver=resolver, + ) + identity = _make_identity() + + with patch.object( + tracker, + "get_total_cost", + side_effect=RuntimeError("db connection failed"), + ): + result = await enforcer.resolve_model(identity) + + assert result.model.model_id == "test-large-001" + + async def test_resolve_model_propagates_memory_error(self) -> None: + """MemoryError from CostTracker in resolve_model is re-raised.""" + cfg = _make_budget_config( + auto_downgrade=AutoDowngradeConfig( + enabled=True, + threshold=85, + downgrade_map=(("large", "medium"),), + ), + ) + tracker = CostTracker(budget_config=cfg) + resolver = _make_resolver( + { + "test-large-001": _resolved( + model_id="test-large-001", + alias="large", + ), + } + ) + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + model_resolver=resolver, + ) + identity = _make_identity() + + with ( + patch.object( + tracker, + "get_total_cost", + side_effect=MemoryError("OOM"), + ), + pytest.raises(MemoryError, match="OOM"), + ): + await enforcer.resolve_model(identity) + + async def test_make_budget_checker_falls_back_on_tracker_error( + self, + ) -> None: + """CostTracker failure in make_budget_checker still returns a checker.""" + cfg = _make_budget_config( + total_monthly=100.0, + per_agent_daily_limit=10.0, + ) + tracker = CostTracker(budget_config=cfg) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + identity = _make_identity() + task = _make_task( + agent_id=str(identity.id), + budget_limit=5.0, + ) + + with patch.object( + tracker, + "get_total_cost", + side_effect=RuntimeError("db connection failed"), + ): + checker = await enforcer.make_budget_checker( + task, + str(identity.id), + ) + + # Checker should still be returned (not None) + assert checker is not None + + # Task limit should still be enforced + ctx_at = _ctx_with_cost(identity, task, 5.0) + assert checker(ctx_at) is True + + async def test_make_budget_checker_propagates_memory_error( + self, + ) -> None: + """MemoryError from CostTracker in make_budget_checker is re-raised.""" + cfg = _make_budget_config(total_monthly=100.0) + tracker = CostTracker(budget_config=cfg) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + identity = _make_identity() + task = _make_task(agent_id=str(identity.id)) + + with ( + patch.object( + tracker, + "get_total_cost", + side_effect=MemoryError("OOM"), + ), + pytest.raises(MemoryError, match="OOM"), + ): + await enforcer.make_budget_checker(task, str(identity.id)) + + async def test_checker_task_limit_zero_does_not_trigger(self) -> None: + """Checker with task_limit=0 but monthly active ignores task limit.""" + cfg = _make_budget_config( + total_monthly=100.0, + per_agent_daily_limit=0.0, + ) + tracker = CostTracker(budget_config=cfg) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + identity = _make_identity() + task = _make_task( + agent_id=str(identity.id), + budget_limit=0.0, + ) + + with _patch_periods(): + checker = await enforcer.make_budget_checker( + task, + str(identity.id), + ) + + assert checker is not None + + # High running cost should not trigger task limit (disabled) + # but should not hit monthly hard stop either (no baseline spend) + ctx = _ctx_with_cost(identity, task, 50.0) + assert checker(ctx) is False + + +# ── cost_tracker property ──────────────────────────────────────────── + + +@pytest.mark.unit +class TestCostTrackerProperty: + """Tests for BudgetEnforcer.cost_tracker property.""" + + def test_returns_injected_tracker(self) -> None: + """Property returns the same tracker injected at construction.""" + cfg = _make_budget_config() + tracker = CostTracker(budget_config=cfg) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + assert enforcer.cost_tracker is tracker diff --git a/tests/unit/engine/test_agent_engine_budget.py b/tests/unit/engine/test_agent_engine_budget.py new file mode 100644 index 0000000000..4c095a480c --- /dev/null +++ b/tests/unit/engine/test_agent_engine_budget.py @@ -0,0 +1,205 @@ +"""Tests for AgentEngine budget enforcer integration.""" + +from typing import TYPE_CHECKING +from unittest.mock import AsyncMock, patch + +import pytest + +from ai_company.budget.config import ( + BudgetAlertConfig, + BudgetConfig, +) +from ai_company.budget.enforcer import BudgetEnforcer +from ai_company.budget.tracker import CostTracker +from ai_company.engine.agent_engine import AgentEngine +from ai_company.engine.errors import BudgetExhaustedError, DailyLimitExceededError +from ai_company.engine.loop_protocol import TerminationReason + +if TYPE_CHECKING: + from ai_company.core.agent import AgentIdentity + from ai_company.core.task import Task + +from .conftest import ( + MockCompletionProvider, + make_completion_response, +) + +pytestmark = pytest.mark.timeout(30) + + +def _make_budget_config( + *, + total_monthly: float = 100.0, + hard_stop_at: int = 100, +) -> BudgetConfig: + return BudgetConfig( + total_monthly=total_monthly, + alerts=BudgetAlertConfig( + warn_at=75, + critical_at=90, + hard_stop_at=hard_stop_at, + ), + ) + + +@pytest.mark.unit +class TestEngineWithEnforcer: + """Tests for AgentEngine with budget_enforcer wired in.""" + + @pytest.mark.parametrize( + ("exc_cls", "msg"), + [ + (BudgetExhaustedError, "Monthly budget exhausted"), + (DailyLimitExceededError, "Daily limit exceeded"), + ], + ids=["monthly_exhausted", "daily_limit"], + ) + async def test_preflight_budget_stop_returns_budget_exhausted( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + exc_cls: type[BudgetExhaustedError], + msg: str, + ) -> None: + """Pre-flight budget errors propagate as BUDGET_EXHAUSTED result.""" + cfg = _make_budget_config(total_monthly=100.0) + tracker = CostTracker(budget_config=cfg) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + + provider = MockCompletionProvider( + [make_completion_response(content="Done.")], + ) + engine = AgentEngine( + provider=provider, + budget_enforcer=enforcer, + ) + + with patch.object( + enforcer, + "check_can_execute", + new=AsyncMock(side_effect=exc_cls(msg)), + ): + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + assert result.termination_reason == TerminationReason.BUDGET_EXHAUSTED + assert provider.call_count == 0 + + async def test_model_downgrade_applied( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Model downgrade at task boundary changes model used.""" + cfg = _make_budget_config() + tracker = CostTracker(budget_config=cfg) + + downgraded_identity = sample_agent_with_personality.model_copy( + update={ + "model": sample_agent_with_personality.model.model_copy( + update={"model_id": "test-small-001"}, + ), + }, + ) + + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + + provider = MockCompletionProvider( + [make_completion_response(content="Done.")], + ) + engine = AgentEngine( + provider=provider, + budget_enforcer=enforcer, + ) + + with ( + patch.object( + enforcer, + "check_can_execute", + new=AsyncMock(), + ), + patch.object( + enforcer, + "resolve_model", + new=AsyncMock(return_value=downgraded_identity), + ), + patch.object( + enforcer, + "make_budget_checker", + new=AsyncMock(return_value=None), + ), + ): + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + assert result.termination_reason == TerminationReason.COMPLETED + # Verify the downgraded model was used for the LLM call + assert provider.recorded_models[0] == "test-small-001" + + async def test_no_enforcer_uses_fallback_checker( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Without enforcer, uses existing make_budget_checker fallback.""" + provider = MockCompletionProvider( + [make_completion_response(content="Done.")], + ) + engine = AgentEngine(provider=provider) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + assert result.termination_reason == TerminationReason.COMPLETED + + async def test_enforcer_provides_cost_tracker( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """When no explicit cost_tracker, uses enforcer's tracker.""" + cfg = _make_budget_config() + tracker = CostTracker(budget_config=cfg) + enforcer = BudgetEnforcer(budget_config=cfg, cost_tracker=tracker) + + provider = MockCompletionProvider( + [make_completion_response(content="Done.")], + ) + engine = AgentEngine( + provider=provider, + budget_enforcer=enforcer, + ) + + # Run a task and verify costs were recorded to the enforcer's tracker + with ( + patch.object( + enforcer, + "check_can_execute", + new=AsyncMock(), + ), + patch.object( + enforcer, + "resolve_model", + new=AsyncMock(return_value=sample_agent_with_personality), + ), + patch.object( + enforcer, + "make_budget_checker", + new=AsyncMock(return_value=None), + ), + ): + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + assert result.termination_reason == TerminationReason.COMPLETED + # Verify cost was recorded to the enforcer's tracker + total = await tracker.get_total_cost() + assert total > 0 diff --git a/tests/unit/engine/test_agent_engine_errors.py b/tests/unit/engine/test_agent_engine_errors.py index b288128134..57c9d7ac35 100644 --- a/tests/unit/engine/test_agent_engine_errors.py +++ b/tests/unit/engine/test_agent_engine_errors.py @@ -27,7 +27,7 @@ from .conftest import MockCompletionProvider -from .conftest import make_completion_response as _make_completion_response +from .conftest import make_completion_response pytestmark = pytest.mark.timeout(30) @@ -160,6 +160,34 @@ async def test_negative_max_turns_raises( ) +@pytest.mark.unit +class TestAgentEngineTimeoutValidation: + """timeout_seconds <= 0 raises ValueError at the engine boundary.""" + + @pytest.mark.parametrize( + "timeout_val", + [0, -1.0, -0.001], + ids=["zero", "negative", "small_negative"], + ) + async def test_invalid_timeout_raises( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + timeout_val: float, + ) -> None: + """Invalid timeout_seconds raises ValueError.""" + provider = mock_provider_factory([]) + engine = AgentEngine(provider=provider) + + with pytest.raises(ValueError, match="timeout_seconds must be > 0"): + await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + timeout_seconds=timeout_val, + ) + + @pytest.mark.unit class TestAgentEngineCostRecordingNonRecoverable: """MemoryError/RecursionError in _record_costs propagate unconditionally.""" @@ -173,7 +201,7 @@ async def test_memory_error_in_cost_recording_propagates( """MemoryError from CostTracker.record() is not swallowed.""" tracker = MagicMock() tracker.record = AsyncMock(side_effect=MemoryError("OOM in tracker")) - response = _make_completion_response(cost_usd=0.05) + response = make_completion_response(cost_usd=0.05) provider = mock_provider_factory([response]) engine = AgentEngine(provider=provider, cost_tracker=tracker) @@ -194,7 +222,7 @@ async def test_recursion_error_in_cost_recording_propagates( tracker.record = AsyncMock( side_effect=RecursionError("infinite in tracker"), ) - response = _make_completion_response(cost_usd=0.05) + response = make_completion_response(cost_usd=0.05) provider = mock_provider_factory([response]) engine = AgentEngine(provider=provider, cost_tracker=tracker) @@ -261,8 +289,9 @@ async def test_handle_fatal_error_secondary_failure_raises_original( identity=sample_agent_with_personality, task=sample_task_with_criteria, ) - assert isinstance(exc_info.value.__cause__, ValueError) - assert "secondary failure" in str(exc_info.value.__cause__) + # raise exc from None suppresses the secondary error chain + # so the original exception propagates cleanly + assert exc_info.value.__cause__ is None @pytest.mark.unit