diff --git a/CLAUDE.md b/CLAUDE.md index 364a1f704d..9fd7c547ff 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -44,7 +44,7 @@ uv run pre-commit run --all-files # all pre-commit hooks ```text src/ai_company/ api/ # FastAPI REST + WebSocket routes - budget/ # Cost tracking, budget enforcement (pre-flight/in-flight checks, auto-downgrade), billing periods + budget/ # Cost tracking, budget enforcement (pre-flight/in-flight checks, auto-downgrade), billing periods, cost tiers, quota/subscription tracking cli/ # Typer CLI commands communication/ # Message bus, dispatcher, messenger, channels, delegation, loop prevention, conflict resolution, meeting protocol config/ # YAML company config loading and validation diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index 79b1954fcf..3d5be123cb 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -81,7 +81,7 @@ The MVP validates the core hypothesis: **a single agent can complete a real task > **Implementation snapshot (2026-03-09):** > - **Done:** M0–M4 (tooling, config/core, providers, single-agent engine, multi-agent orchestration). Memory layer backend selected ([ADR-001](docs/decisions/ADR-001-memory-layer.md)). Persistence backend (§7.6) completed. -> - **In progress:** M5 — memory interface protocol complete (MemoryBackend, MemoryCapabilities, SharedKnowledgeStore protocols, models, config, factory), Mem0 adapter (#41) and budget enforcement pending. +> - **In progress:** M5 — memory interface protocol complete (MemoryBackend, MemoryCapabilities, SharedKnowledgeStore protocols, models, config, factory), budget enforcement complete (BudgetEnforcer + configurable cost tiers + quota/subscription tracking), Mem0 adapter (#41) pending. > - **Not started (mostly placeholders):** M6 API/CLI surface, M7 security + approval system. ### 1.5 Configuration Philosophy @@ -1694,6 +1694,8 @@ providers: cost_per_1k_output: 0.0 ``` +> **Implementation note (M5):** `ProviderConfig` now includes `subscription: SubscriptionConfig` and `degradation: DegradationConfig` fields for per-provider quota limits and subscription-aware degradation behavior. The default degradation strategy is `ALERT` (raise `QuotaExhaustedError`). `FALLBACK` (route to fallback providers) and `QUEUE` (delay and retry) strategies are defined in the model but **not yet implemented** — the engine currently always raises on quota exhaustion regardless of strategy. Regular quota polling / proactive alerting before quotas are hit is deferred to a follow-up issue. + ### 9.3 LiteLLM Integration Use **LiteLLM** as the provider abstraction layer: @@ -1825,12 +1827,14 @@ budget: > **Auto-downgrade boundary:** Model downgrades apply only at **task assignment time**, never mid-execution. An agent halfway through an architecture review cannot be switched to a cheaper model — the task completes on its assigned model. The next task assignment respects the downgrade threshold. This prevents quality degradation from mid-thought model switches. > **Implementation note (M5):** `BudgetEnforcer` composes `CostTracker` + -> `BudgetConfig` to provide three enforcement layers: (1) pre-flight checks -> via `check_can_execute` (monthly hard stop + per-agent daily limit), (2) -> in-flight budget checking via a sync `BudgetChecker` closure with -> pre-computed baselines (task + monthly + daily limits, alert deduplication), -> and (3) task-boundary auto-downgrade via `resolve_model`. Billing periods -> are scoped by `billing_period_start(reset_day)`. `DailyLimitExceededError` +> `BudgetConfig` + optional `QuotaTracker` + optional `ModelResolver` to +> provide three enforcement layers: (1) pre-flight checks via +> `check_can_execute` (monthly hard stop + per-agent daily limit + provider +> quota enforcement when `QuotaTracker` is present), (2) in-flight budget +> checking via a sync `BudgetChecker` closure with pre-computed baselines +> (task + monthly + daily limits, alert deduplication), and (3) +> task-boundary auto-downgrade via `resolve_model`. Billing periods are +> scoped by `billing_period_start(reset_day)`. `DailyLimitExceededError` > is a subclass of `BudgetExhaustedError` for granular error handling. ### 10.5 LLM Call Analytics @@ -2794,6 +2798,7 @@ ai-company/ │ │ │ ├── persistence.py # PERSISTENCE_* constants │ │ │ ├── personality.py # PERSONALITY_* constants │ │ │ ├── prompt.py # PROMPT_* constants +│ │ │ ├── quota.py # QUOTA_* event constants │ │ │ ├── provider.py # PROVIDER_* constants │ │ │ ├── role.py # ROLE_* constants │ │ │ ├── routing.py # ROUTING_* constants @@ -2869,6 +2874,7 @@ ai-company/ │ ├── budget/ # Cost management │ │ ├── config.py # Budget configuration models │ │ ├── cost_record.py # CostRecord model (frozen) +│ │ ├── cost_tiers.py # Cost tier definitions, classification, and built-in tiers │ │ ├── call_category.py # LLM call category enums (productive, coordination, system) │ │ ├── category_analytics.py # Per-category cost breakdown + orchestration ratio │ │ ├── coordination_config.py # Coordination metrics config models @@ -2880,6 +2886,8 @@ ai-company/ │ │ ├── billing.py # Billing period computation utilities │ │ ├── enforcer.py # BudgetEnforcer service (pre-flight, in-flight, auto-downgrade) │ │ ├── optimizer.py # Cost optimization / CFO logic (M5) +│ │ ├── quota.py # Quota/subscription models, degradation config, quota snapshots +│ │ ├── quota_tracker.py # QuotaTracker service: per-provider request/token quota enforcement │ │ └── reports.py # Spending reports (M5) │ ├── api/ # REST + WebSocket API (M6, stubs only) │ │ ├── app.py # FastAPI application (M6) @@ -2954,6 +2962,7 @@ These conventions were established during the M0–M2+ review cycle. **Adopted** | **Personality compatibility scoring** | Adopted (M3) | Weighted composite: 60% Big Five similarity (openness, conscientiousness, agreeableness, stress_response → 1−\|diff\|; extraversion → tent-function peaking at 0.3 diff), 20% collaboration alignment (ordinal adjacency: INDEPENDENT↔PAIR↔TEAM), 20% conflict approach (constructive pairs score 1.0, destructive pairs 0.2, mixed 0.4–0.6). `itertools.combinations` for team-level averaging. Result clamped to [0, 1]. | Covers behavioral diversity (extraversion complement), task alignment (conscientiousness similarity), and interpersonal friction (conflict approach). Weights are configurable module constants. | | **Agent behavior testing** | Planned (M3) | Scripted `FakeProvider` for unit tests (deterministic turn sequences); behavioral outcome assertions for integration tests (task completed, tools called, cost within budget). | Leverages existing `FakeProvider` and `CompletionResponseFactory` fixtures. Precise engine testing without brittle response-matching at integration level. | | **LLM call analytics** | Adopted (incremental) | M3: proxy metrics (`turns_per_task`, `tokens_per_task`) — adopted. M4 data models: call categorization (`productive`, `coordination`, `system`), category analytics, coordination metrics, orchestration ratio — adopted. M4 runtime collection pipeline and M5+ full analytics: planned. | Append-only, never blocks execution. Builds on existing `CostRecord` infrastructure. Detects orchestration overhead early. See §10.5. | +| **Cost tiers & quota tracking** | Adopted (M5) | Configurable `CostTierDefinition` definitions with merge/override semantics via `resolve_tiers(config: CostTiersConfig)`. `SubscriptionConfig` + `QuotaLimit` model per-provider subscription plans. `QuotaTracker` enforces per-provider request/token quotas with window-based rotation. `DegradationConfig` controls behavior when quotas are exhausted (default: `ALERT` — raise error; `FALLBACK` and `QUEUE` strategies defined but not yet implemented). | Enables cost classification without hardcoding vendor tiers. Quota tracking prevents surprise overages at the provider level. Window-based rotation aligns quota resets with billing periods. See §10.4. | | **State coordination** | Planned (M4) | Centralized single-writer: `TaskEngine` owns all task/project mutations via `asyncio.Queue`. Agents submit requests, engine applies `model_copy(update=...)` sequentially and publishes snapshots. `version: int` field on state models for future optimistic concurrency if multi-process scaling is needed. | Prevents lost updates by design. Trivial in single-threaded asyncio (no locks). Perfect audit trail. Industry consensus: MetaGPT, CrewAI, AutoGen all use prevention-by-design, not conflict resolution. See §6.8 State Coordination table. | | **Workspace isolation** | Adopted (M4 core) | Pluggable `WorkspaceIsolationStrategy` protocol. Default: planner + git worktrees. Each agent works in an isolated worktree; sequential merge on completion. Textual conflicts detected by git; semantic conflicts reviewed by agent or human. Runtime multi-agent coordination wiring remains M4 hardening work. | Industry standard (Codex, Cursor, Claude Code, VS Code). Maximum parallelism. Leverages mature git infrastructure. See §6.8. | | **Graceful shutdown** | Adopted (M3) | Pluggable `ShutdownStrategy` protocol. Default: cooperative with 30s timeout. Agents check shutdown event at turn boundaries. Force-cancel after timeout. `INTERRUPTED` status for force-cancelled tasks. M4/M5: upgrade to checkpoint-and-stop. | Cross-platform (Windows `signal.signal()` fallback). Bounded shutdown time. Mirrors cooperative shutdown in §6.7. | diff --git a/README.md b/README.md index 20321cc4eb..ba383b57f9 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ AI Company lets you spin up a virtual organization staffed entirely by AI agents - **Persistence Layer (M5)** - Pluggable `PersistenceBackend` protocol with SQLite backend (aiosqlite), repository protocols, schema migrations - **Memory Interface (M5)** - Pluggable `MemoryBackend` protocol with capability discovery, shared knowledge protocol, domain models, config, and factory - **Coordination Error Taxonomy (M5)** - Post-execution classification pipeline detecting logical contradictions, numerical drift, context omissions, and coordination failures -- **Budget Enforcement (M5)** - `BudgetEnforcer` service with pre-flight checks, in-flight budget checking, and auto-downgrade; CFO agent and advanced reporting pending +- **Budget Enforcement (M5)** - `BudgetEnforcer` service with pre-flight checks, in-flight budget checking, auto-downgrade, configurable cost tiers, and quota/subscription tracking; CFO agent and advanced reporting pending ### Not implemented yet (planned milestones) diff --git a/src/ai_company/budget/__init__.py b/src/ai_company/budget/__init__.py index f077c7e697..c8f7e006e1 100644 --- a/src/ai_company/budget/__init__.py +++ b/src/ai_company/budget/__init__.py @@ -29,6 +29,13 @@ RedundancyRate, ) from ai_company.budget.cost_record import CostRecord +from ai_company.budget.cost_tiers import ( + BUILTIN_TIERS, + CostTierDefinition, + CostTiersConfig, + classify_model_tier, + resolve_tiers, +) from ai_company.budget.enforcer import BudgetEnforcer from ai_company.budget.enums import BudgetAlertLevel from ai_company.budget.hierarchy import ( @@ -36,6 +43,18 @@ DepartmentBudget, TeamBudget, ) +from ai_company.budget.quota import ( + DegradationAction, + DegradationConfig, + ProviderCostModel, + QuotaCheckResult, + QuotaLimit, + QuotaSnapshot, + QuotaWindow, + SubscriptionConfig, + effective_cost_per_1k, +) +from ai_company.budget.quota_tracker import QuotaTracker from ai_company.budget.spending_summary import ( AgentSpending, DepartmentSpending, @@ -45,6 +64,7 @@ from ai_company.budget.tracker import CostTracker __all__ = [ + "BUILTIN_TIERS", "AgentSpending", "AutoDowngradeConfig", "BudgetAlertConfig", @@ -59,7 +79,11 @@ "CoordinationMetricsConfig", "CoordinationOverhead", "CostRecord", + "CostTierDefinition", + "CostTiersConfig", "CostTracker", + "DegradationAction", + "DegradationConfig", "DepartmentBudget", "DepartmentSpending", "ErrorAmplification", @@ -71,9 +95,19 @@ "OrchestrationAlertThresholds", "OrchestrationRatio", "PeriodSpending", + "ProviderCostModel", + "QuotaCheckResult", + "QuotaLimit", + "QuotaSnapshot", + "QuotaTracker", + "QuotaWindow", "RedundancyRate", "SpendingSummary", + "SubscriptionConfig", "TeamBudget", "billing_period_start", + "classify_model_tier", "daily_period_start", + "effective_cost_per_1k", + "resolve_tiers", ] diff --git a/src/ai_company/budget/cost_tiers.py b/src/ai_company/budget/cost_tiers.py new file mode 100644 index 0000000000..6ee66f48d5 --- /dev/null +++ b/src/ai_company/budget/cost_tiers.py @@ -0,0 +1,255 @@ +"""Cost tier definitions and classification. + +Provides configurable metadata for cost tiers: price ranges, display +properties, and model-to-tier classification. The built-in ``CostTier`` +enum (``ai_company.core.enums``) remains for backward compatibility; +this module adds a configurable layer on top. +""" + +from typing import Self + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.observability import get_logger +from ai_company.observability.events.budget import ( + BUDGET_TIER_CLASSIFY_MISS, + BUDGET_TIER_RESOLVED, +) + +logger = get_logger(__name__) + + +class CostTierDefinition(BaseModel): + """Metadata for a single cost tier. + + Attributes: + id: Unique tier identifier (e.g. ``"low"``, ``"custom-budget"``). + display_name: Human-readable name. + description: What this tier represents. + price_range_min: Minimum cost_per_1k_total for models in this + tier (USD). + price_range_max: Maximum cost_per_1k_total; ``None`` means + unbounded above. + color: Hex color for UI rendering. + icon: Icon identifier for UI rendering. + sort_order: Display ordering (lower = cheaper). + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + id: NotBlankStr = Field(description="Unique tier identifier") + display_name: NotBlankStr = Field(description="Human-readable name") + description: str = Field( + default="", + description="What this tier represents", + ) + price_range_min: float = Field( + default=0.0, + ge=0.0, + description="Minimum cost_per_1k_total (USD)", + ) + price_range_max: float | None = Field( + default=None, + ge=0.0, + description="Maximum cost_per_1k_total (USD); None = unbounded", + ) + color: str = Field( + default="#6b7280", + description="Hex color for UI", + ) + icon: str = Field( + default="circle", + description="Icon identifier for UI", + ) + sort_order: int = Field( + default=0, + description="Display ordering (lower = cheaper)", + ) + + @model_validator(mode="after") + def _validate_price_range(self) -> Self: + """Ensure max > min when both are set. + + A zero-width range (min == max) with a finite max can never + match any cost because classification uses ``[min, max)`` + semantics. + """ + if self.price_range_max is not None: + if self.price_range_max < self.price_range_min: + msg = ( + f"price_range_max ({self.price_range_max}) must be " + f"> price_range_min ({self.price_range_min})" + ) + raise ValueError(msg) + if self.price_range_max == self.price_range_min: + msg = ( + f"price_range_max ({self.price_range_max}) must be " + f"> price_range_min ({self.price_range_min}); " + f"zero-width range can never match with [min, max) " + f"semantics" + ) + raise ValueError(msg) + return self + + +class CostTiersConfig(BaseModel): + """Configuration for cost tier definitions. + + Attributes: + tiers: User-defined tier overrides/additions. + include_builtin: Whether to merge built-in default tiers. + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + tiers: tuple[CostTierDefinition, ...] = Field( + default=(), + description="User-defined tier overrides/additions", + ) + include_builtin: bool = Field( + default=True, + description="Whether to merge built-in default tiers", + ) + + @model_validator(mode="after") + def _validate_unique_ids(self) -> Self: + """Ensure tier IDs are unique within user-defined tiers.""" + seen: set[str] = set() + dupes: set[str] = set() + for t in self.tiers: + if t.id in seen: + dupes.add(t.id) + seen.add(t.id) + if dupes: + msg = f"Duplicate tier IDs: {sorted(dupes)}" + raise ValueError(msg) + return self + + +BUILTIN_TIERS: tuple[CostTierDefinition, ...] = ( + CostTierDefinition( + id="low", + display_name="Low", + description="Budget-friendly models for simple tasks", + price_range_min=0.0, + price_range_max=0.002, + color="#22c55e", + icon="circle", + sort_order=0, + ), + CostTierDefinition( + id="medium", + display_name="Medium", + description="Balanced cost-performance models", + price_range_min=0.002, + price_range_max=0.01, + color="#eab308", + icon="circle", + sort_order=1, + ), + CostTierDefinition( + id="high", + display_name="High", + description="High-capability models for complex tasks", + price_range_min=0.01, + price_range_max=0.03, + color="#f97316", + icon="circle", + sort_order=2, + ), + CostTierDefinition( + id="premium", + display_name="Premium", + description="Top-tier models for maximum capability", + price_range_min=0.03, + price_range_max=None, + color="#ef4444", + icon="circle", + sort_order=3, + ), +) + + +def resolve_tiers( + config: CostTiersConfig, +) -> tuple[CostTierDefinition, ...]: + """Merge built-in and user-defined tiers, sorted by sort_order. + + User tiers override built-in tiers with the same ID. + + Args: + config: Cost tiers configuration. + + Returns: + Merged and sorted tuple of tier definitions. + """ + if not config.include_builtin: + result = sorted(config.tiers, key=lambda t: t.sort_order) + logger.debug( + BUDGET_TIER_RESOLVED, + tier_count=len(result), + include_builtin=False, + ) + return tuple(result) + + # User tiers override built-in by ID + user_ids = {t.id for t in config.tiers} + merged: list[CostTierDefinition] = [ + t for t in BUILTIN_TIERS if t.id not in user_ids + ] + merged.extend(config.tiers) + merged.sort(key=lambda t: t.sort_order) + + logger.debug( + BUDGET_TIER_RESOLVED, + tier_count=len(merged), + include_builtin=True, + overridden_ids=sorted(user_ids & {t.id for t in BUILTIN_TIERS}), + ) + return tuple(merged) + + +def classify_model_tier( + cost_per_1k_total: float, + tiers: tuple[CostTierDefinition, ...], +) -> str | None: + """Classify a model into a cost tier based on total cost per 1k tokens. + + Matches the first tier whose price range contains the given cost. + Range check: ``min <= cost < max`` (or ``min <= cost`` if max is + ``None``). If tiers have overlapping ranges, the first match in + iteration order wins — callers should ensure tiers are sorted by + ``sort_order``. + + Args: + cost_per_1k_total: Combined ``cost_per_1k_input + + cost_per_1k_output``. + tiers: Resolved tier definitions (should be sorted by + sort_order). + + Returns: + Tier ID of the matching tier, or ``None`` if no tier matches. + """ + if cost_per_1k_total < 0: + logger.warning( + BUDGET_TIER_CLASSIFY_MISS, + cost_per_1k_total=cost_per_1k_total, + tier_count=len(tiers), + reason="negative_cost", + ) + return None + + for tier in tiers: + if tier.price_range_max is None: + if cost_per_1k_total >= tier.price_range_min: + return tier.id + elif tier.price_range_min <= cost_per_1k_total < tier.price_range_max: + return tier.id + + logger.debug( + BUDGET_TIER_CLASSIFY_MISS, + cost_per_1k_total=cost_per_1k_total, + tier_count=len(tiers), + ) + return None diff --git a/src/ai_company/budget/enforcer.py b/src/ai_company/budget/enforcer.py index cbbef1c8dc..fe7cabe93c 100644 --- a/src/ai_company/budget/enforcer.py +++ b/src/ai_company/budget/enforcer.py @@ -10,8 +10,13 @@ from ai_company.budget.billing import billing_period_start, daily_period_start from ai_company.budget.enums import BudgetAlertLevel +from ai_company.budget.quota import QuotaCheckResult from ai_company.constants import BUDGET_ROUNDING_PRECISION -from ai_company.engine.errors import BudgetExhaustedError, DailyLimitExceededError +from ai_company.engine.errors import ( + BudgetExhaustedError, + DailyLimitExceededError, + QuotaExhaustedError, +) from ai_company.observability import get_logger from ai_company.observability.events.budget import ( BUDGET_ALERT_THRESHOLD_CROSSED, @@ -27,9 +32,14 @@ BUDGET_RESOLVE_MODEL_ERROR, BUDGET_TASK_LIMIT_HIT, ) +from ai_company.observability.events.quota import ( + QUOTA_CHECK_ALLOWED, + QUOTA_CHECK_DENIED, +) if TYPE_CHECKING: from ai_company.budget.config import BudgetConfig + from ai_company.budget.quota_tracker import QuotaTracker from ai_company.budget.tracker import CostTracker from ai_company.core.agent import AgentIdentity, ModelConfig from ai_company.core.task import Task @@ -59,6 +69,8 @@ class BudgetEnforcer: cost_tracker: Cost tracking service for querying spend. model_resolver: Optional model resolver for auto-downgrade alias lookup. + quota_tracker: Optional quota tracker for provider-level + quota enforcement. """ def __init__( @@ -67,23 +79,41 @@ def __init__( budget_config: BudgetConfig, cost_tracker: CostTracker, model_resolver: ModelResolver | None = None, + quota_tracker: QuotaTracker | None = None, ) -> None: self._budget_config = budget_config self._cost_tracker = cost_tracker self._model_resolver = model_resolver + self._quota_tracker = quota_tracker @property def cost_tracker(self) -> CostTracker: """The underlying cost tracker.""" return self._cost_tracker - async def check_can_execute(self, agent_id: str) -> None: - """Pre-flight: verify monthly + daily limits allow execution. + async def check_can_execute( + self, + agent_id: str, + *, + provider_name: str | None = None, + estimated_tokens: int = 0, + ) -> None: + """Pre-flight: verify monthly + daily + quota limits allow execution. + + Args: + agent_id: Agent requesting execution. + provider_name: Optional provider name for quota checks. + When ``None``, quota check is skipped. + estimated_tokens: Estimated tokens for the upcoming request. + Forwarded to the quota tracker for token-based checks. Raises: BudgetExhaustedError: Monthly hard stop exceeded. DailyLimitExceededError: Agent daily limit exceeded (subclass of ``BudgetExhaustedError``). + QuotaExhaustedError: Provider quota exhausted. + Degradation routing (FALLBACK/QUEUE) is not yet + implemented; currently always raises. """ cfg = self._budget_config @@ -91,6 +121,13 @@ async def check_can_execute(self, agent_id: str) -> None: if cfg.total_monthly > 0: await self._check_monthly_hard_stop(cfg, agent_id) await self._check_daily_limit(cfg, agent_id) + + if provider_name is not None: + await self._check_provider_quota( + agent_id, + provider_name, + estimated_tokens=estimated_tokens, + ) except BudgetExhaustedError: raise except MemoryError, RecursionError: # builtin MemoryError (OOM) @@ -109,6 +146,66 @@ async def check_can_execute(self, agent_id: str) -> None: result="pass", ) + async def check_quota( + self, + provider_name: str, + *, + estimated_tokens: int = 0, + ) -> QuotaCheckResult: + """Check provider quota, delegating to QuotaTracker. + + Returns always-allowed when no quota tracker is configured. + + Note: + Unlike ``check_can_execute``, this method does **not** catch + unexpected exceptions from the underlying ``QuotaTracker``. + ``check_can_execute`` wraps quota checks in a try/except + that falls back to allowing execution on unexpected errors + (graceful degradation), but direct callers of ``check_quota`` + are responsible for their own error handling. + + Args: + provider_name: Provider to check. + estimated_tokens: Estimated tokens for the request. + + Returns: + Quota check result. + """ + if self._quota_tracker is None: + logger.debug( + QUOTA_CHECK_ALLOWED, + provider=provider_name, + reason="no_quota_tracker", + ) + return _always_allowed_result(provider_name) + + return await self._quota_tracker.check_quota( + provider_name, + estimated_tokens=estimated_tokens, + ) + + async def _check_provider_quota( + self, + agent_id: str, + provider_name: str, + *, + estimated_tokens: int = 0, + ) -> None: + """Check provider quota, raising on exhaustion.""" + quota_result = await self.check_quota( + provider_name, + estimated_tokens=estimated_tokens, + ) + if not quota_result.allowed: + logger.warning( + QUOTA_CHECK_DENIED, + agent_id=agent_id, + provider=provider_name, + reason=quota_result.reason, + ) + msg = f"Provider {provider_name!r} quota exhausted: {quota_result.reason}" + raise QuotaExhaustedError(msg) + async def _check_monthly_hard_stop( self, cfg: BudgetConfig, @@ -348,6 +445,14 @@ async def _compute_baselines( # ── Module-level pure helpers ──────────────────────────────────── +def _always_allowed_result(provider_name: str) -> QuotaCheckResult: + """Build an always-allowed QuotaCheckResult.""" + return QuotaCheckResult( + allowed=True, + provider_name=provider_name, + ) + + def _apply_downgrade( identity: AgentIdentity, resolver: ModelResolver, diff --git a/src/ai_company/budget/quota.py b/src/ai_company/budget/quota.py new file mode 100644 index 0000000000..8d9305fddb --- /dev/null +++ b/src/ai_company/budget/quota.py @@ -0,0 +1,397 @@ +"""Quota and subscription models for provider cost tracking. + +Defines quota windows, subscription configurations, degradation +strategies, and quota check result models for providers that operate +under subscription plans, local deployments, or pay-as-you-go billing. +""" + +from datetime import UTC, datetime +from enum import StrEnum +from typing import Self + +from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator + +from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.observability import get_logger +from ai_company.observability.events.config import CONFIG_VALIDATION_FAILED + +logger = get_logger(__name__) + + +class QuotaWindow(StrEnum): + """Time window for quota enforcement.""" + + PER_MINUTE = "per_minute" + PER_HOUR = "per_hour" + PER_DAY = "per_day" + PER_MONTH = "per_month" + + +class QuotaLimit(BaseModel): + """A single quota limit for a time window. + + Attributes: + window: Time window for this limit. + max_requests: Maximum requests in the window (0 = unlimited). + max_tokens: Maximum tokens in the window (0 = unlimited). + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + window: QuotaWindow = Field(description="Time window for this limit") + max_requests: int = Field( + default=0, + ge=0, + description="Maximum requests in the window (0 = unlimited)", + ) + max_tokens: int = Field( + default=0, + ge=0, + description="Maximum tokens in the window (0 = unlimited)", + ) + + @model_validator(mode="after") + def _at_least_one_limit(self) -> Self: + """Ensure at least one of max_requests or max_tokens is set.""" + if self.max_requests == 0 and self.max_tokens == 0: + msg = "At least one of max_requests or max_tokens must be > 0" + logger.warning( + CONFIG_VALIDATION_FAILED, + model="QuotaLimit", + field="max_requests/max_tokens", + reason=msg, + ) + raise ValueError(msg) + return self + + +class ProviderCostModel(StrEnum): + """How a provider charges for usage. + + Members: + PER_TOKEN: Standard pay-as-you-go; cost computed from + cost_per_1k_input/output. + SUBSCRIPTION: Monthly flat fee; individual calls are pre-paid. + LOCAL: Zero monetary cost; only hardware constraints. + """ + + PER_TOKEN = "per_token" # noqa: S105 — billing concept, not a secret + SUBSCRIPTION = "subscription" + LOCAL = "local" + + +class SubscriptionConfig(BaseModel): + """Subscription and quota configuration for a provider. + + Attributes: + plan_name: Name of the subscription plan. + cost_model: How the provider charges for usage. + monthly_cost: Fixed monthly subscription fee in USD. + quotas: Rate/token/request limits per time window. + hardware_limits: Free-text hardware constraints for local models. + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + plan_name: NotBlankStr = Field( + default="pay_as_you_go", + description="Subscription plan name", + ) + cost_model: ProviderCostModel = Field( + default=ProviderCostModel.PER_TOKEN, + description="How the provider charges for usage", + ) + monthly_cost: float = Field( + default=0.0, + ge=0.0, + description="Fixed monthly subscription fee in USD", + ) + quotas: tuple[QuotaLimit, ...] = Field( + default=(), + description="Rate/token/request limits per time window", + ) + hardware_limits: str | None = Field( + default=None, + description="Free-text hardware constraints for local models", + ) + + @model_validator(mode="after") + def _validate_quotas_unique_windows(self) -> Self: + """Ensure quota windows are unique.""" + seen: set[QuotaWindow] = set() + dupes: set[str] = set() + for q in self.quotas: + if q.window in seen: + dupes.add(q.window.value) + seen.add(q.window) + if dupes: + msg = f"Duplicate quota windows: {sorted(dupes)}" + logger.warning( + CONFIG_VALIDATION_FAILED, + model="SubscriptionConfig", + field="quotas", + reason=msg, + ) + raise ValueError(msg) + return self + + @model_validator(mode="after") + def _validate_cost_model_constraints(self) -> Self: + """Validate cost_model-specific constraints.""" + if self.cost_model == ProviderCostModel.LOCAL and self.monthly_cost > 0: + msg = ( + f"LOCAL cost_model must have monthly_cost=0.0, got {self.monthly_cost}" + ) + raise ValueError(msg) + + if self.cost_model == ProviderCostModel.SUBSCRIPTION and self.monthly_cost <= 0: + logger.warning( + CONFIG_VALIDATION_FAILED, + model="SubscriptionConfig", + field="monthly_cost", + reason=( + "SUBSCRIPTION cost_model typically has monthly_cost > 0; " + f"got {self.monthly_cost}" + ), + ) + + return self + + +class DegradationAction(StrEnum): + """Action to take when a provider's quota is exhausted. + + Members: + FALLBACK: Route to a fallback provider. + QUEUE: Queue for later (not implemented in M5). + ALERT: Raise error and alert user. + """ + + FALLBACK = "fallback" + QUEUE = "queue" + ALERT = "alert" + + +class DegradationConfig(BaseModel): + """Configuration for graceful degradation when quota is exhausted. + + Attributes: + strategy: What to do when quota is exhausted. + fallback_providers: Ordered fallback provider names. + queue_max_wait_seconds: Max seconds to wait when queueing. + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + strategy: DegradationAction = Field( + default=DegradationAction.ALERT, + description="Degradation strategy when quota exhausted", + ) + fallback_providers: tuple[NotBlankStr, ...] = Field( + default=(), + description="Ordered fallback provider names", + ) + queue_max_wait_seconds: int = Field( + default=300, + ge=0, + le=3600, + description="Max wait seconds when queueing", + ) + + @model_validator(mode="after") + def _validate_fallback_providers(self) -> Self: + """Warn if FALLBACK strategy has no fallback providers.""" + if self.strategy == DegradationAction.FALLBACK and not self.fallback_providers: + logger.warning( + CONFIG_VALIDATION_FAILED, + model="DegradationConfig", + field="fallback_providers", + reason=( + "FALLBACK strategy specified but no fallback_providers configured" + ), + ) + return self + + +class QuotaSnapshot(BaseModel): + """Point-in-time snapshot of quota usage for a provider window. + + Attributes: + provider_name: Provider this snapshot belongs to. + window: Time window for this snapshot. + requests_used: Requests consumed in this window. + requests_limit: Maximum requests allowed (0 = unlimited). + tokens_used: Tokens consumed in this window. + tokens_limit: Maximum tokens allowed (0 = unlimited). + window_resets_at: When the current window resets. + captured_at: When this snapshot was captured. + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + provider_name: NotBlankStr = Field(description="Provider name") + window: QuotaWindow = Field(description="Time window") + requests_used: int = Field(default=0, ge=0, description="Requests used") + requests_limit: int = Field( + default=0, + ge=0, + description="Requests limit (0 = unlimited)", + ) + tokens_used: int = Field(default=0, ge=0, description="Tokens used") + tokens_limit: int = Field( + default=0, + ge=0, + description="Tokens limit (0 = unlimited)", + ) + window_resets_at: datetime | None = Field( + default=None, + description="When the current window resets", + ) + captured_at: datetime = Field(description="When snapshot was captured") + + @computed_field # type: ignore[prop-decorator] + @property + def requests_remaining(self) -> int | None: + """Remaining requests in this window. + + Returns ``None`` when the limit is not enforced (unlimited). + Returns 0 when fully consumed. + """ + if self.requests_limit == 0: + return None + return max(0, self.requests_limit - self.requests_used) + + @computed_field # type: ignore[prop-decorator] + @property + def tokens_remaining(self) -> int | None: + """Remaining tokens in this window. + + Returns ``None`` when the limit is not enforced (unlimited). + Returns 0 when fully consumed. + """ + if self.tokens_limit == 0: + return None + return max(0, self.tokens_limit - self.tokens_used) + + @computed_field # type: ignore[prop-decorator] + @property + def is_exhausted(self) -> bool: + """Whether any enforced limit in this window is exhausted.""" + if self.requests_limit > 0 and self.requests_used >= self.requests_limit: + return True + return self.tokens_limit > 0 and self.tokens_used >= self.tokens_limit + + +class QuotaCheckResult(BaseModel): + """Result of a pre-flight quota check. + + Attributes: + allowed: Whether the request is allowed. + provider_name: Provider that was checked. + reason: Human-readable reason (set when denied). + exhausted_windows: Which windows are exhausted (if any). + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + allowed: bool = Field(description="Whether the request is allowed") + provider_name: NotBlankStr = Field(description="Provider checked") + reason: str = Field(default="", description="Reason (set when denied)") + exhausted_windows: tuple[QuotaWindow, ...] = Field( + default=(), + description="Exhausted windows", + ) + + @model_validator(mode="after") + def _validate_denied_has_reason(self) -> Self: + """Ensure denied results have a non-empty reason.""" + if not self.allowed and not self.reason: + msg = "Denied QuotaCheckResult must have a non-empty reason" + raise ValueError(msg) + if self.allowed and self.exhausted_windows: + msg = "Allowed QuotaCheckResult must not have exhausted_windows" + raise ValueError(msg) + return self + + +def window_start( + window: QuotaWindow, + *, + now: datetime | None = None, +) -> datetime: + """Compute the UTC-aware start of the current quota window. + + Args: + window: Which time window to compute. + now: Reference timestamp. Defaults to ``datetime.now(UTC)``. + Must be timezone-aware; naive datetimes are rejected. + + Returns: + UTC-aware datetime at the start of the current window. + + Raises: + ValueError: If *now* is a naive (timezone-unaware) datetime. + """ + if now is None: + now = datetime.now(UTC) + elif now.tzinfo is None: + msg = "now must be timezone-aware, got naive datetime" + logger.warning( + CONFIG_VALIDATION_FAILED, + model="window_start", + field="now", + reason=msg, + ) + raise ValueError(msg) + else: + now = now.astimezone(UTC) + + if window == QuotaWindow.PER_MINUTE: + return datetime( + now.year, + now.month, + now.day, + now.hour, + now.minute, + tzinfo=UTC, + ) + if window == QuotaWindow.PER_HOUR: + return datetime( + now.year, + now.month, + now.day, + now.hour, + tzinfo=UTC, + ) + if window == QuotaWindow.PER_DAY: + return datetime( + now.year, + now.month, + now.day, + tzinfo=UTC, + ) + # PER_MONTH — first day of the month + return datetime(now.year, now.month, 1, tzinfo=UTC) + + +def effective_cost_per_1k( + cost_per_1k_input: float, + cost_per_1k_output: float, + cost_model: ProviderCostModel, +) -> float: + """Compute effective cost per 1k tokens based on cost model. + + Returns 0.0 for SUBSCRIPTION and LOCAL models (pre-paid / free). + Returns ``cost_per_1k_input + cost_per_1k_output`` for PER_TOKEN. + + Args: + cost_per_1k_input: Cost per 1k input tokens. + cost_per_1k_output: Cost per 1k output tokens. + cost_model: The provider's cost model. + + Returns: + Effective cost per 1k tokens. + """ + if cost_model in (ProviderCostModel.SUBSCRIPTION, ProviderCostModel.LOCAL): + return 0.0 + return cost_per_1k_input + cost_per_1k_output diff --git a/src/ai_company/budget/quota_tracker.py b/src/ai_company/budget/quota_tracker.py new file mode 100644 index 0000000000..06fbe2da07 --- /dev/null +++ b/src/ai_company/budget/quota_tracker.py @@ -0,0 +1,436 @@ +"""Quota tracking service. + +Tracks per-provider request and token usage against configured quota +windows. Window-based counters are rotated automatically when a window +boundary is crossed. + +Concurrency-safe via ``asyncio.Lock`` (same pattern as +:class:`~ai_company.budget.tracker.CostTracker`). +""" + +import asyncio +import copy +from datetime import UTC, datetime, timedelta +from typing import TYPE_CHECKING, NamedTuple + +if TYPE_CHECKING: + from collections.abc import Mapping + +from ai_company.budget.quota import ( + QuotaCheckResult, + QuotaLimit, + QuotaSnapshot, + QuotaWindow, + SubscriptionConfig, + window_start, +) +from ai_company.observability import get_logger +from ai_company.observability.events.quota import ( + QUOTA_CHECK_ALLOWED, + QUOTA_CHECK_DENIED, + QUOTA_SNAPSHOT_QUERIED, + QUOTA_TRACKER_CREATED, + QUOTA_USAGE_RECORDED, + QUOTA_USAGE_SKIPPED, + QUOTA_WINDOW_ROTATED, +) + +logger = get_logger(__name__) + + +class _WindowUsage(NamedTuple): + """Immutable usage counters for a single window (replaced on update).""" + + requests: int + tokens: int + window_start: datetime + + +class QuotaTracker: + """Tracks per-provider quota usage across configured time windows. + + Providers without a subscription config are silently ignored (no-op + on record, always allowed on check). + + Note: + ``check_quota`` followed by ``record_usage`` is subject to a + TOCTOU gap: another coroutine could record usage between the + check and the record, pushing the counter past the limit. + This is by-design for the single-loop ``asyncio`` concurrency + model — the gap only exists across ``await`` points and is + acceptable for quota enforcement (the in-flight budget checker + is the true safety net). + + Args: + subscriptions: Mapping of provider name to subscription config. + """ + + def __init__( + self, + *, + subscriptions: Mapping[str, SubscriptionConfig], + ) -> None: + self._subscriptions: dict[str, SubscriptionConfig] = copy.deepcopy( + dict(subscriptions), + ) + self._lock = asyncio.Lock() + self._usage: dict[str, dict[QuotaWindow, _WindowUsage]] = {} + + # Initialize usage tracking for providers with quotas + for provider_name, sub_config in self._subscriptions.items(): + if sub_config.quotas: + self._usage[provider_name] = {} + for quota in sub_config.quotas: + ws = window_start(quota.window) + self._usage[provider_name][quota.window] = _WindowUsage( + requests=0, + tokens=0, + window_start=ws, + ) + + logger.debug( + QUOTA_TRACKER_CREATED, + provider_count=len(self._subscriptions), + tracked_providers=sorted(self._usage), + ) + + async def record_usage( + self, + provider_name: str, + *, + requests: int = 1, + tokens: int = 0, + ) -> None: + """Record usage against all configured windows for a provider. + + Rotates window counters if a window boundary has been crossed. + Providers with no subscription config are skipped with a DEBUG log. + + Args: + provider_name: Provider to record usage for. + requests: Number of requests to record (must be >= 0). + tokens: Number of tokens to record (must be >= 0). + + Raises: + ValueError: If requests or tokens is negative. + """ + if requests < 0: + msg = f"requests must be non-negative, got {requests}" + logger.warning( + QUOTA_USAGE_SKIPPED, + provider=provider_name, + reason=msg, + ) + raise ValueError(msg) + if tokens < 0: + msg = f"tokens must be non-negative, got {tokens}" + logger.warning( + QUOTA_USAGE_SKIPPED, + provider=provider_name, + reason=msg, + ) + raise ValueError(msg) + + if provider_name not in self._usage: + reason = ( + "no_quotas_configured" + if provider_name in self._subscriptions + else "unknown_provider" + ) + logger.debug( + QUOTA_USAGE_SKIPPED, + provider=provider_name, + reason=reason, + ) + return + + async with self._lock: + now = datetime.now(UTC) + provider_usage = self._usage[provider_name] + + for window_type in list(provider_usage): + current = provider_usage[window_type] + expected_start = window_start(window_type, now=now) + + if expected_start != current.window_start: + # Window boundary crossed — rotate + provider_usage[window_type] = _WindowUsage( + requests=requests, + tokens=tokens, + window_start=expected_start, + ) + logger.debug( + QUOTA_WINDOW_ROTATED, + provider=provider_name, + window=window_type.value, + old_start=str(current.window_start), + new_start=str(expected_start), + ) + else: + provider_usage[window_type] = _WindowUsage( + requests=current.requests + requests, + tokens=current.tokens + tokens, + window_start=current.window_start, + ) + + logger.debug( + QUOTA_USAGE_RECORDED, + provider=provider_name, + requests=requests, + tokens=tokens, + ) + + async def check_quota( + self, + provider_name: str, + *, + estimated_tokens: int = 0, + ) -> QuotaCheckResult: + """Pre-flight check: can this provider handle a request? + + Providers with no subscription config always return allowed. + + Args: + provider_name: Provider to check. + estimated_tokens: Estimated tokens for the request (must be >= 0). + + Returns: + Check result with allowed status and reason. + + Raises: + ValueError: If estimated_tokens is negative. + """ + if estimated_tokens < 0: + msg = f"estimated_tokens must be non-negative, got {estimated_tokens}" + logger.warning( + QUOTA_CHECK_DENIED, + provider=provider_name, + reason=msg, + ) + raise ValueError(msg) + + if provider_name not in self._usage: + reason = ( + "no_quotas_configured" + if provider_name in self._subscriptions + else "unknown_provider" + ) + logger.debug( + QUOTA_CHECK_ALLOWED, + provider=provider_name, + reason=reason, + ) + return QuotaCheckResult( + allowed=True, + provider_name=provider_name, + ) + + sub_config = self._subscriptions[provider_name] + quota_map = {q.window: q for q in sub_config.quotas} + + async with self._lock: + now = datetime.now(UTC) + provider_usage = self._usage[provider_name] + exhausted: list[QuotaWindow] = [] + reasons: list[str] = [] + + for window_type, usage in provider_usage.items(): + expected_start = window_start(window_type, now=now) + + # If window has rotated, counters would be zero + if expected_start != usage.window_start: + continue + + quota = quota_map.get(window_type) + if quota is None: + continue + + if _is_window_exhausted( + usage, + quota, + estimated_tokens, + ): + exhausted.append(window_type) + reasons.append( + _build_exhaustion_reason( + provider_name, + window_type, + usage, + quota, + estimated_tokens, + ), + ) + + if exhausted: + result = QuotaCheckResult( + allowed=False, + provider_name=provider_name, + reason="; ".join(reasons), + exhausted_windows=tuple(exhausted), + ) + logger.info( + QUOTA_CHECK_DENIED, + provider=provider_name, + exhausted_windows=[w.value for w in exhausted], + reason=result.reason, + ) + return result + + logger.debug( + QUOTA_CHECK_ALLOWED, + provider=provider_name, + ) + return QuotaCheckResult( + allowed=True, + provider_name=provider_name, + ) + + async def get_snapshot( + self, + provider_name: str, + window: QuotaWindow | None = None, + ) -> tuple[QuotaSnapshot, ...]: + """Get current usage snapshots for a provider. + + Args: + provider_name: Provider to query. + window: Optional specific window to query. If ``None``, + returns all windows. + + Returns: + Tuple of quota snapshots. + """ + if provider_name not in self._usage: + reason = ( + "no_quotas_configured" + if provider_name in self._subscriptions + else "unknown_provider" + ) + logger.debug( + QUOTA_SNAPSHOT_QUERIED, + provider=provider_name, + snapshot_count=0, + reason=reason, + ) + return () + + sub_config = self._subscriptions[provider_name] + quota_map = {q.window: q for q in sub_config.quotas} + + async with self._lock: + now = datetime.now(UTC) + snapshots: list[QuotaSnapshot] = [] + provider_usage = self._usage[provider_name] + + for window_type, usage in provider_usage.items(): + if window is not None and window_type != window: + continue + + quota = quota_map.get(window_type) + if quota is None: + continue + + expected_start = window_start(window_type, now=now) + # If window has rotated, show zero usage + if expected_start != usage.window_start: + req_used = 0 + tok_used = 0 + else: + req_used = usage.requests + tok_used = usage.tokens + + snapshots.append( + QuotaSnapshot( + provider_name=provider_name, + window=window_type, + requests_used=req_used, + requests_limit=quota.max_requests, + tokens_used=tok_used, + tokens_limit=quota.max_tokens, + window_resets_at=_window_end( + window_type, + expected_start, + ), + captured_at=now, + ), + ) + + logger.debug( + QUOTA_SNAPSHOT_QUERIED, + provider=provider_name, + snapshot_count=len(snapshots), + ) + return tuple(snapshots) + + async def get_all_snapshots( + self, + ) -> dict[str, tuple[QuotaSnapshot, ...]]: + """Get usage snapshots for all tracked providers. + + Note: + Snapshots are collected per-provider with separate lock + acquisitions, so cross-provider consistency is not guaranteed + under concurrent writes. + + Returns: + Dict mapping provider name to tuple of snapshots. + """ + result: dict[str, tuple[QuotaSnapshot, ...]] = {} + for provider_name in self._usage: + result[provider_name] = await self.get_snapshot(provider_name) + return result + + +def _is_window_exhausted( + usage: _WindowUsage, + quota: QuotaLimit, + estimated_tokens: int, +) -> bool: + """Check if a window's quota is exhausted. + + Request check uses ``>=`` (hard limit — *at* the limit means + exhausted because the next request would exceed it). Token check + uses ``>`` for projected tokens (``usage + estimated``), allowing + exact-fill: a request whose projected total exactly matches the + limit is still permitted. + """ + if quota.max_requests > 0 and usage.requests >= quota.max_requests: + return True + return quota.max_tokens > 0 and usage.tokens + estimated_tokens > quota.max_tokens + + +def _build_exhaustion_reason( + provider_name: str, + window: QuotaWindow, + usage: _WindowUsage, + quota: QuotaLimit, + estimated_tokens: int = 0, +) -> str: + """Build a human-readable exhaustion reason.""" + parts: list[str] = [f"{provider_name} {window.value}:"] + if quota.max_requests > 0 and usage.requests >= quota.max_requests: + parts.append(f"requests {usage.requests}/{quota.max_requests}") + projected = usage.tokens + estimated_tokens + if quota.max_tokens > 0 and projected > quota.max_tokens: + parts.append(f"tokens {projected}/{quota.max_tokens}") + return " ".join(parts) + + +_MONTHS_PER_YEAR = 12 + +_WINDOW_DELTAS: dict[QuotaWindow, timedelta] = { + QuotaWindow.PER_MINUTE: timedelta(minutes=1), + QuotaWindow.PER_HOUR: timedelta(hours=1), + QuotaWindow.PER_DAY: timedelta(days=1), +} + + +def _window_end(window: QuotaWindow, start: datetime) -> datetime: + """Compute end of a quota window from its start.""" + delta = _WINDOW_DELTAS.get(window) + if delta is not None: + return start + delta + # PER_MONTH — advance to first of next month + month = start.month % _MONTHS_PER_YEAR + 1 + year = start.year + (1 if start.month == _MONTHS_PER_YEAR else 0) + return start.replace(year=year, month=month) diff --git a/src/ai_company/config/defaults.py b/src/ai_company/config/defaults.py index 279b2aea6d..a00d74dcf8 100644 --- a/src/ai_company/config/defaults.py +++ b/src/ai_company/config/defaults.py @@ -32,4 +32,5 @@ def default_config_dict() -> dict[str, Any]: "task_assignment": {}, "memory": {}, "persistence": {}, + "cost_tiers": {}, } diff --git a/src/ai_company/config/schema.py b/src/ai_company/config/schema.py index c2c2f8a5b5..89bb5873e0 100644 --- a/src/ai_company/config/schema.py +++ b/src/ai_company/config/schema.py @@ -7,6 +7,8 @@ from ai_company.budget.config import BudgetConfig from ai_company.budget.coordination_config import CoordinationMetricsConfig +from ai_company.budget.cost_tiers import CostTiersConfig +from ai_company.budget.quota import DegradationConfig, SubscriptionConfig from ai_company.communication.config import CommunicationConfig from ai_company.core.company import ( CompanyConfig, @@ -166,6 +168,8 @@ class ProviderConfig(BaseModel): models: Available models for this provider. retry: Retry configuration for transient errors. rate_limiter: Client-side rate limiting configuration. + subscription: Subscription and quota configuration. + degradation: Degradation strategy when quota exhausted. """ model_config = ConfigDict(frozen=True) @@ -195,6 +199,14 @@ class ProviderConfig(BaseModel): default_factory=RateLimiterConfig, description="Client-side rate limiting configuration", ) + subscription: SubscriptionConfig = Field( + default_factory=SubscriptionConfig, + description="Subscription and quota configuration", + ) + degradation: DegradationConfig = Field( + default_factory=DegradationConfig, + description="Degradation strategy when quota exhausted", + ) @model_validator(mode="after") def _validate_unique_model_identifiers(self) -> Self: @@ -470,6 +482,7 @@ class RootConfig(BaseModel): task_assignment: Task assignment configuration. memory: Memory backend configuration. persistence: Persistence backend configuration. + cost_tiers: Cost tier definitions. """ model_config = ConfigDict(frozen=True) @@ -545,6 +558,10 @@ class RootConfig(BaseModel): default_factory=PersistenceConfig, description="Persistence backend configuration", ) + cost_tiers: CostTiersConfig = Field( + default_factory=CostTiersConfig, + description="Cost tier definitions", + ) @model_validator(mode="after") def _validate_unique_agent_names(self) -> Self: @@ -635,3 +652,23 @@ def _validate_routing_references(self) -> Self: ) raise ValueError(msg) return self + + @model_validator(mode="after") + def _validate_degradation_fallback_providers(self) -> Self: + """Ensure degradation fallback_providers reference known providers.""" + known_providers = set(self.providers) + for prov_name, prov_config in self.providers.items(): + for fb in prov_config.degradation.fallback_providers: + if fb not in known_providers: + msg = ( + f"Provider {prov_name!r} degradation " + f"fallback_providers references unknown " + f"provider: {fb!r}" + ) + logger.warning( + CONFIG_VALIDATION_FAILED, + model="RootConfig", + error=msg, + ) + raise ValueError(msg) + return self diff --git a/src/ai_company/engine/__init__.py b/src/ai_company/engine/__init__.py index 326af58790..138137888c 100644 --- a/src/ai_company/engine/__init__.py +++ b/src/ai_company/engine/__init__.py @@ -67,6 +67,7 @@ NoEligibleAgentError, ParallelExecutionError, PromptBuildError, + QuotaExhaustedError, ResourceConflictError, TaskAssignmentError, TaskRoutingError, @@ -220,6 +221,7 @@ "ProgressCallback", "PromptBuildError", "PromptTokenEstimator", + "QuotaExhaustedError", "ReactLoop", "RecoveryResult", "RecoveryStrategy", diff --git a/src/ai_company/engine/errors.py b/src/ai_company/engine/errors.py index b1a2efeaa3..42a1c4a636 100644 --- a/src/ai_company/engine/errors.py +++ b/src/ai_company/engine/errors.py @@ -38,6 +38,14 @@ class DailyLimitExceededError(BudgetExhaustedError): """Per-agent daily spending limit exceeded.""" +class QuotaExhaustedError(BudgetExhaustedError): + """Raised when provider quota is exhausted. + + Currently raised for all degradation strategies. Degradation routing + (FALLBACK/QUEUE) is planned for a future milestone. + """ + + class LoopExecutionError(EngineError): """Non-recoverable execution loop error for the engine layer. diff --git a/src/ai_company/observability/events/budget.py b/src/ai_company/observability/events/budget.py index e0157d79f6..389e87bc81 100644 --- a/src/ai_company/observability/events/budget.py +++ b/src/ai_company/observability/events/budget.py @@ -26,3 +26,6 @@ BUDGET_BASELINE_ERROR: Final[str] = "budget.baseline.error" BUDGET_PREFLIGHT_ERROR: Final[str] = "budget.preflight.error" BUDGET_RESOLVE_MODEL_ERROR: Final[str] = "budget.resolve_model.error" + +BUDGET_TIER_RESOLVED: Final[str] = "budget.tier.resolved" +BUDGET_TIER_CLASSIFY_MISS: Final[str] = "budget.tier.classify_miss" diff --git a/src/ai_company/observability/events/quota.py b/src/ai_company/observability/events/quota.py new file mode 100644 index 0000000000..877a0e092b --- /dev/null +++ b/src/ai_company/observability/events/quota.py @@ -0,0 +1,11 @@ +"""Quota tracking event constants.""" + +from typing import Final + +QUOTA_TRACKER_CREATED: Final[str] = "quota.tracker.created" +QUOTA_USAGE_RECORDED: Final[str] = "quota.usage.recorded" +QUOTA_CHECK_ALLOWED: Final[str] = "quota.check.allowed" +QUOTA_CHECK_DENIED: Final[str] = "quota.check.denied" +QUOTA_WINDOW_ROTATED: Final[str] = "quota.window.rotated" +QUOTA_SNAPSHOT_QUERIED: Final[str] = "quota.snapshot.queried" +QUOTA_USAGE_SKIPPED: Final[str] = "quota.usage.skipped" diff --git a/tests/unit/budget/conftest.py b/tests/unit/budget/conftest.py index 42e4e99163..e8fe67f2fd 100644 --- a/tests/unit/budget/conftest.py +++ b/tests/unit/budget/conftest.py @@ -12,12 +12,19 @@ BudgetConfig, ) from ai_company.budget.cost_record import CostRecord +from ai_company.budget.cost_tiers import CostTierDefinition, CostTiersConfig from ai_company.budget.enums import BudgetAlertLevel from ai_company.budget.hierarchy import ( BudgetHierarchy, DepartmentBudget, TeamBudget, ) +from ai_company.budget.quota import ( + QuotaLimit, + QuotaWindow, + SubscriptionConfig, +) +from ai_company.budget.quota_tracker import QuotaTracker from ai_company.budget.spending_summary import ( AgentSpending, DepartmentSpending, @@ -98,6 +105,27 @@ class SpendingSummaryFactory(ModelFactory[SpendingSummary]): by_department = () +class CostTierDefinitionFactory(ModelFactory[CostTierDefinition]): + __model__ = CostTierDefinition + sort_order = 0 + + +class CostTiersConfigFactory(ModelFactory[CostTiersConfig]): + __model__ = CostTiersConfig + tiers = () + include_builtin = True + + +class QuotaLimitFactory(ModelFactory[QuotaLimit]): + __model__ = QuotaLimit + max_requests = 60 + + +class SubscriptionConfigFactory(ModelFactory[SubscriptionConfig]): + __model__ = SubscriptionConfig + quotas = () + + # ── Sample Fixtures ──────────────────────────────────────────────── @@ -235,6 +263,19 @@ def cost_tracker( ) +def make_quota_tracker( + *, + provider: str = "test-provider", + max_requests: int = 60, + window: QuotaWindow = QuotaWindow.PER_MINUTE, +) -> QuotaTracker: + """Build a QuotaTracker with a single provider and quota.""" + sub = SubscriptionConfig( + quotas=(QuotaLimit(window=window, max_requests=max_requests),), + ) + return QuotaTracker(subscriptions={provider: sub}) + + def make_cost_record( # noqa: PLR0913 *, agent_id: str = "alice", diff --git a/tests/unit/budget/test_cost_tiers.py b/tests/unit/budget/test_cost_tiers.py new file mode 100644 index 0000000000..4caf954ab4 --- /dev/null +++ b/tests/unit/budget/test_cost_tiers.py @@ -0,0 +1,364 @@ +"""Tests for cost tier definitions and classification.""" + +import pytest +from pydantic import ValidationError + +from ai_company.budget.cost_tiers import ( + BUILTIN_TIERS, + CostTierDefinition, + CostTiersConfig, + classify_model_tier, + resolve_tiers, +) + +pytestmark = pytest.mark.timeout(30) + + +# ── CostTierDefinition ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestCostTierDefinition: + """Tests for CostTierDefinition model.""" + + def test_valid_minimal(self) -> None: + """Minimal valid tier with just id and display_name.""" + tier = CostTierDefinition(id="low", display_name="Low") + assert tier.id == "low" + assert tier.display_name == "Low" + assert tier.description == "" + assert tier.price_range_min == 0.0 + assert tier.price_range_max is None + assert tier.color == "#6b7280" + assert tier.icon == "circle" + assert tier.sort_order == 0 + + def test_valid_full(self) -> None: + """Full tier with all fields set.""" + tier = CostTierDefinition( + id="custom", + display_name="Custom Tier", + description="A custom cost tier", + price_range_min=0.01, + price_range_max=0.05, + color="#ff0000", + icon="star", + sort_order=5, + ) + assert tier.price_range_min == 0.01 + assert tier.price_range_max == 0.05 + assert tier.color == "#ff0000" + assert tier.icon == "star" + assert tier.sort_order == 5 + + def test_unbounded_max_allowed(self) -> None: + """price_range_max=None means unbounded above.""" + tier = CostTierDefinition( + id="premium", + display_name="Premium", + price_range_min=0.03, + price_range_max=None, + ) + assert tier.price_range_max is None + + def test_equal_min_max_rejected(self) -> None: + """price_range_max == price_range_min is rejected (zero-width).""" + with pytest.raises(ValidationError, match="zero-width"): + CostTierDefinition( + id="exact", + display_name="Exact", + price_range_min=0.01, + price_range_max=0.01, + ) + + def test_max_less_than_min_rejected(self) -> None: + """price_range_max < price_range_min raises ValueError.""" + with pytest.raises(ValidationError, match="price_range_max"): + CostTierDefinition( + id="bad", + display_name="Bad", + price_range_min=0.05, + price_range_max=0.01, + ) + + def test_negative_min_rejected(self) -> None: + """Negative price_range_min is rejected.""" + with pytest.raises(ValidationError): + CostTierDefinition( + id="neg", + display_name="Neg", + price_range_min=-0.01, + ) + + def test_negative_max_rejected(self) -> None: + """Negative price_range_max is rejected.""" + with pytest.raises(ValidationError): + CostTierDefinition( + id="neg", + display_name="Neg", + price_range_max=-0.01, + ) + + def test_blank_id_rejected(self) -> None: + """Empty id is rejected.""" + with pytest.raises(ValidationError): + CostTierDefinition(id="", display_name="X") + + def test_whitespace_id_rejected(self) -> None: + """Whitespace-only id is rejected.""" + with pytest.raises(ValidationError, match="whitespace-only"): + CostTierDefinition(id=" ", display_name="X") + + def test_blank_display_name_rejected(self) -> None: + """Empty display_name is rejected.""" + with pytest.raises(ValidationError): + CostTierDefinition(id="x", display_name="") + + def test_frozen(self) -> None: + """Model is frozen (immutable).""" + tier = CostTierDefinition(id="x", display_name="X") + with pytest.raises(ValidationError): + tier.id = "y" # type: ignore[misc] + + def test_nan_rejected(self) -> None: + """NaN values are rejected by allow_inf_nan=False.""" + with pytest.raises(ValidationError): + CostTierDefinition( + id="x", + display_name="X", + price_range_min=float("nan"), + ) + + def test_inf_rejected(self) -> None: + """Inf values are rejected by allow_inf_nan=False.""" + with pytest.raises(ValidationError): + CostTierDefinition( + id="x", + display_name="X", + price_range_min=float("inf"), + ) + + +# ── CostTiersConfig ──────────────────────────────────────────────── + + +@pytest.mark.unit +class TestCostTiersConfig: + """Tests for CostTiersConfig model.""" + + def test_defaults(self) -> None: + """Default config has no user tiers and includes builtins.""" + cfg = CostTiersConfig() + assert cfg.tiers == () + assert cfg.include_builtin is True + + def test_with_user_tiers(self) -> None: + """User-defined tiers are accepted.""" + tier = CostTierDefinition(id="custom", display_name="Custom") + cfg = CostTiersConfig(tiers=(tier,)) + assert len(cfg.tiers) == 1 + assert cfg.tiers[0].id == "custom" + + def test_duplicate_tier_ids_rejected(self) -> None: + """Duplicate tier IDs within user tiers are rejected.""" + tier = CostTierDefinition(id="dup", display_name="Dup") + with pytest.raises(ValidationError, match="Duplicate tier IDs"): + CostTiersConfig(tiers=(tier, tier)) + + def test_include_builtin_false(self) -> None: + """include_builtin=False disables built-in tiers.""" + cfg = CostTiersConfig(include_builtin=False) + assert cfg.include_builtin is False + + def test_frozen(self) -> None: + """Model is frozen.""" + cfg = CostTiersConfig() + with pytest.raises(ValidationError): + cfg.include_builtin = False # type: ignore[misc] + + +# ── BUILTIN_TIERS ────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestBuiltinTiers: + """Tests for built-in tier constants.""" + + def test_four_builtin_tiers(self) -> None: + """There are exactly 4 built-in tiers.""" + assert len(BUILTIN_TIERS) == 4 + + def test_builtin_ids(self) -> None: + """Built-in tiers have expected IDs.""" + ids = {t.id for t in BUILTIN_TIERS} + assert ids == {"low", "medium", "high", "premium"} + + def test_builtin_sort_order(self) -> None: + """Built-in tiers are ordered by sort_order.""" + orders = [t.sort_order for t in BUILTIN_TIERS] + assert orders == sorted(orders) + + def test_premium_unbounded(self) -> None: + """Premium tier has unbounded max.""" + premium = next(t for t in BUILTIN_TIERS if t.id == "premium") + assert premium.price_range_max is None + + def test_low_starts_at_zero(self) -> None: + """Low tier starts at 0.0.""" + low = next(t for t in BUILTIN_TIERS if t.id == "low") + assert low.price_range_min == 0.0 + + def test_no_gaps_in_ranges(self) -> None: + """Adjacent tiers share boundary values (max of one == min of next).""" + sorted_tiers = sorted(BUILTIN_TIERS, key=lambda t: t.sort_order) + for i in range(len(sorted_tiers) - 1): + current = sorted_tiers[i] + next_tier = sorted_tiers[i + 1] + assert current.price_range_max == next_tier.price_range_min + + +# ── resolve_tiers ────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestResolveTiers: + """Tests for resolve_tiers function.""" + + def test_default_config_returns_builtins(self) -> None: + """Default config resolves to just built-in tiers.""" + result = resolve_tiers(CostTiersConfig()) + assert len(result) == 4 + assert result[0].id == "low" + assert result[-1].id == "premium" + + def test_include_builtin_false_returns_user_only(self) -> None: + """include_builtin=False returns only user-defined tiers.""" + tier = CostTierDefinition(id="custom", display_name="Custom") + cfg = CostTiersConfig(tiers=(tier,), include_builtin=False) + result = resolve_tiers(cfg) + assert len(result) == 1 + assert result[0].id == "custom" + + def test_user_override_replaces_builtin(self) -> None: + """User tier with same ID as built-in replaces it.""" + override = CostTierDefinition( + id="premium", + display_name="Premium+", + price_range_min=0.03, + color="#a855f7", + sort_order=3, + ) + cfg = CostTiersConfig(tiers=(override,)) + result = resolve_tiers(cfg) + premium = next(t for t in result if t.id == "premium") + assert premium.display_name == "Premium+" + assert premium.color == "#a855f7" + # Still 4 total (override replaces, doesn't add) + assert len(result) == 4 + + def test_user_addition_adds_to_builtins(self) -> None: + """User tier with unique ID adds to built-in tiers.""" + extra = CostTierDefinition( + id="budget", + display_name="Budget", + sort_order=-1, + ) + cfg = CostTiersConfig(tiers=(extra,)) + result = resolve_tiers(cfg) + assert len(result) == 5 + # Budget should be first (sort_order=-1) + assert result[0].id == "budget" + + def test_sorted_by_sort_order(self) -> None: + """Result is always sorted by sort_order.""" + tiers = ( + CostTierDefinition(id="z", display_name="Z", sort_order=10), + CostTierDefinition(id="a", display_name="A", sort_order=-5), + ) + cfg = CostTiersConfig(tiers=tiers, include_builtin=False) + result = resolve_tiers(cfg) + assert result[0].id == "a" + assert result[1].id == "z" + + def test_empty_config_no_builtins(self) -> None: + """Empty tiers with no builtins returns empty tuple.""" + cfg = CostTiersConfig(include_builtin=False) + result = resolve_tiers(cfg) + assert result == () + + +# ── classify_model_tier ──────────────────────────────────────────── + + +@pytest.mark.unit +class TestClassifyModelTier: + """Tests for classify_model_tier function.""" + + @pytest.fixture + def default_tiers(self) -> tuple[CostTierDefinition, ...]: + """Resolved default tiers.""" + return resolve_tiers(CostTiersConfig()) + + def test_no_matching_tier_returns_none(self) -> None: + """Returns None when no tier matches.""" + # Tier with range [0.01, 0.02) — cost of 0.0 won't match + tiers = ( + CostTierDefinition( + id="narrow", + display_name="Narrow", + price_range_min=0.01, + price_range_max=0.02, + ), + ) + assert classify_model_tier(0.0, tiers) is None + + def test_empty_tiers_returns_none(self) -> None: + """Empty tiers always returns None.""" + assert classify_model_tier(0.01, ()) is None + + def test_negative_cost_returns_none( + self, + default_tiers: tuple[CostTierDefinition, ...], + ) -> None: + """Negative cost returns None (logged as warning).""" + assert classify_model_tier(-0.01, default_tiers) is None + + @pytest.mark.parametrize( + ("cost", "expected"), + [ + (0.0, "low"), + (0.001, "low"), + (0.0019, "low"), + (0.002, "medium"), + (0.005, "medium"), + (0.009, "medium"), + (0.01, "high"), + (0.02, "high"), + (0.029, "high"), + (0.03, "premium"), + (0.1, "premium"), + (1.0, "premium"), + ], + ids=[ + "0.000_low", + "0.001_low", + "0.0019_low", + "0.002_medium_boundary", + "0.005_medium_mid", + "0.009_medium_near_top", + "0.01_high_boundary", + "0.02_high_mid", + "0.029_high_near_top", + "0.03_premium_boundary", + "0.1_premium_mid", + "1.0_premium_very_high", + ], + ) + def test_classification_boundaries( + self, + cost: float, + expected: str, + default_tiers: tuple[CostTierDefinition, ...], + ) -> None: + """Parametrized boundary tests.""" + assert classify_model_tier(cost, default_tiers) == expected diff --git a/tests/unit/budget/test_enforcer_quota.py b/tests/unit/budget/test_enforcer_quota.py new file mode 100644 index 0000000000..1df53c447a --- /dev/null +++ b/tests/unit/budget/test_enforcer_quota.py @@ -0,0 +1,283 @@ +"""Tests for BudgetEnforcer quota integration.""" + +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, patch + +if TYPE_CHECKING: + from contextlib import AbstractContextManager + +import pytest + +from ai_company.budget.config import BudgetAlertConfig, BudgetConfig +from ai_company.budget.enforcer import BudgetEnforcer +from ai_company.budget.quota import ( + QuotaCheckResult, + QuotaLimit, + QuotaWindow, + SubscriptionConfig, +) +from ai_company.budget.quota_tracker import QuotaTracker +from ai_company.budget.tracker import CostTracker +from ai_company.engine.errors import QuotaExhaustedError + +pytestmark = pytest.mark.timeout(30) + +_BILLING_START = datetime(2026, 3, 1, tzinfo=UTC) +_DAY_START = datetime(2026, 3, 15, tzinfo=UTC) + + +# ── Helpers ──────────────────────────────────────────────────────── + + +def _make_budget_config( + *, + total_monthly: float = 100.0, + per_agent_daily_limit: float = 10.0, +) -> BudgetConfig: + return BudgetConfig( + total_monthly=total_monthly, + alerts=BudgetAlertConfig(warn_at=75, critical_at=90, hard_stop_at=100), + per_agent_daily_limit=per_agent_daily_limit, + ) + + +def _make_quota_tracker( + *, + provider: str = "test-provider", + max_requests: int = 60, +) -> QuotaTracker: + sub = SubscriptionConfig( + quotas=( + QuotaLimit( + window=QuotaWindow.PER_HOUR, + max_requests=max_requests, + ), + ), + ) + return QuotaTracker(subscriptions={provider: sub}) + + +def _patch_periods() -> tuple[ + AbstractContextManager[Any], + AbstractContextManager[Any], +]: + """Patch billing and daily period starts.""" + return ( + patch( + "ai_company.budget.enforcer.billing_period_start", + return_value=_BILLING_START, + ), + patch( + "ai_company.budget.enforcer.daily_period_start", + return_value=_DAY_START, + ), + ) + + +# ── check_can_execute with quota ─────────────────────────────────── + + +@pytest.mark.unit +class TestCheckCanExecuteWithQuota: + """Tests for quota-aware pre-flight check.""" + + async def test_passes_when_quota_allowed(self) -> None: + """Pre-flight passes when quota is not exhausted.""" + cfg = _make_budget_config() + tracker = CostTracker(budget_config=cfg) + quota_tracker = _make_quota_tracker() + + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + quota_tracker=quota_tracker, + ) + + billing_patch, daily_patch = _patch_periods() + with billing_patch, daily_patch: + await enforcer.check_can_execute( + "alice", + provider_name="test-provider", + ) + + async def test_raises_when_quota_exhausted(self) -> None: + """Pre-flight raises QuotaExhaustedError when exhausted.""" + cfg = _make_budget_config() + tracker = CostTracker(budget_config=cfg) + quota_tracker = _make_quota_tracker(max_requests=5) + + # Exhaust quota + for _ in range(5): + await quota_tracker.record_usage("test-provider") + + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + quota_tracker=quota_tracker, + ) + + billing_patch, daily_patch = _patch_periods() + with ( + billing_patch, + daily_patch, + pytest.raises( + QuotaExhaustedError, + match="quota exhausted", + ), + ): + await enforcer.check_can_execute( + "alice", + provider_name="test-provider", + ) + + async def test_skips_quota_when_no_provider_name(self) -> None: + """Quota check is skipped when provider_name is None.""" + cfg = _make_budget_config() + tracker = CostTracker(budget_config=cfg) + quota_tracker = _make_quota_tracker(max_requests=5) + + # Exhaust quota + for _ in range(5): + await quota_tracker.record_usage("test-provider") + + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + quota_tracker=quota_tracker, + ) + + billing_patch, daily_patch = _patch_periods() + with billing_patch, daily_patch: + # No provider_name → skip quota check, even though exhausted + await enforcer.check_can_execute("alice") + + async def test_skips_quota_when_no_quota_tracker(self) -> None: + """Quota check is skipped when no quota_tracker is set.""" + cfg = _make_budget_config() + tracker = CostTracker(budget_config=cfg) + + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + quota_tracker=None, + ) + + billing_patch, daily_patch = _patch_periods() + with billing_patch, daily_patch: + await enforcer.check_can_execute( + "alice", + provider_name="test-provider", + ) + + +# ── check_quota ──────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestCheckQuota: + """Tests for BudgetEnforcer.check_quota().""" + + async def test_delegates_to_quota_tracker(self) -> None: + """Delegates to QuotaTracker when set.""" + cfg = _make_budget_config() + tracker = CostTracker(budget_config=cfg) + quota_tracker = _make_quota_tracker() + + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + quota_tracker=quota_tracker, + ) + + result = await enforcer.check_quota("test-provider") + assert result.allowed is True + assert result.provider_name == "test-provider" + + async def test_returns_allowed_without_quota_tracker(self) -> None: + """Returns always-allowed when no quota tracker.""" + cfg = _make_budget_config() + tracker = CostTracker(budget_config=cfg) + + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + quota_tracker=None, + ) + + result = await enforcer.check_quota("test-provider") + assert result.allowed is True + + async def test_passes_estimated_tokens(self) -> None: + """Estimated tokens are forwarded to QuotaTracker.""" + cfg = _make_budget_config() + tracker = CostTracker(budget_config=cfg) + quota_tracker = _make_quota_tracker() + + mock_check = AsyncMock( + return_value=QuotaCheckResult( + allowed=True, + provider_name="test-provider", + ), + ) + quota_tracker.check_quota = mock_check # type: ignore[method-assign] + + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + quota_tracker=quota_tracker, + ) + + await enforcer.check_quota( + "test-provider", + estimated_tokens=5000, + ) + + mock_check.assert_awaited_once_with( + "test-provider", + estimated_tokens=5000, + ) + + async def test_quota_exhausted_returns_denied(self) -> None: + """Returns denied result when quota is exhausted.""" + cfg = _make_budget_config() + tracker = CostTracker(budget_config=cfg) + quota_tracker = _make_quota_tracker(max_requests=2) + + # Exhaust quota + await quota_tracker.record_usage("test-provider", requests=2) + + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + quota_tracker=quota_tracker, + ) + + result = await enforcer.check_quota("test-provider") + assert result.allowed is False + assert result.provider_name == "test-provider" + + async def test_graceful_degradation_on_generic_exception(self) -> None: + """Falls back to allow when quota_tracker raises unexpectedly.""" + cfg = _make_budget_config() + tracker = CostTracker(budget_config=cfg) + quota_tracker = _make_quota_tracker() + + # Mock check_quota to raise a generic error + quota_tracker.check_quota = AsyncMock( # type: ignore[method-assign] + side_effect=RuntimeError("unexpected"), + ) + + enforcer = BudgetEnforcer( + budget_config=cfg, + cost_tracker=tracker, + quota_tracker=quota_tracker, + ) + + billing_patch, daily_patch = _patch_periods() + with billing_patch, daily_patch: + # Should not raise — graceful degradation + await enforcer.check_can_execute( + "alice", + provider_name="test-provider", + ) diff --git a/tests/unit/budget/test_quota.py b/tests/unit/budget/test_quota.py new file mode 100644 index 0000000000..61f730feb3 --- /dev/null +++ b/tests/unit/budget/test_quota.py @@ -0,0 +1,545 @@ +"""Tests for quota and subscription domain models.""" + +from datetime import UTC, datetime + +import pytest +from pydantic import ValidationError + +from ai_company.budget.quota import ( + DegradationAction, + DegradationConfig, + ProviderCostModel, + QuotaCheckResult, + QuotaLimit, + QuotaSnapshot, + QuotaWindow, + SubscriptionConfig, + effective_cost_per_1k, + window_start, +) + +pytestmark = pytest.mark.timeout(30) + + +# ── QuotaWindow ──────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestQuotaWindow: + """Tests for QuotaWindow enum.""" + + def test_values(self) -> None: + """All expected windows exist.""" + assert QuotaWindow.PER_MINUTE.value == "per_minute" + assert QuotaWindow.PER_HOUR.value == "per_hour" + assert QuotaWindow.PER_DAY.value == "per_day" + assert QuotaWindow.PER_MONTH.value == "per_month" + + def test_member_count(self) -> None: + """Exactly 4 windows.""" + assert len(QuotaWindow) == 4 + + +# ── QuotaLimit ───────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestQuotaLimit: + """Tests for QuotaLimit model.""" + + def test_valid_requests_only(self) -> None: + """Limit with only max_requests is valid.""" + ql = QuotaLimit(window=QuotaWindow.PER_MINUTE, max_requests=60) + assert ql.max_requests == 60 + assert ql.max_tokens == 0 + + def test_valid_tokens_only(self) -> None: + """Limit with only max_tokens is valid.""" + ql = QuotaLimit(window=QuotaWindow.PER_DAY, max_tokens=1_000_000) + assert ql.max_tokens == 1_000_000 + assert ql.max_requests == 0 + + def test_both_set(self) -> None: + """Limit with both fields set is valid.""" + ql = QuotaLimit( + window=QuotaWindow.PER_HOUR, + max_requests=100, + max_tokens=500_000, + ) + assert ql.max_requests == 100 + assert ql.max_tokens == 500_000 + + def test_both_zero_rejected(self) -> None: + """Both at zero is rejected.""" + with pytest.raises(ValidationError, match="At least one"): + QuotaLimit(window=QuotaWindow.PER_MINUTE) + + def test_negative_requests_rejected(self) -> None: + """Negative max_requests is rejected.""" + with pytest.raises(ValidationError): + QuotaLimit( + window=QuotaWindow.PER_MINUTE, + max_requests=-1, + ) + + def test_negative_tokens_rejected(self) -> None: + """Negative max_tokens is rejected.""" + with pytest.raises(ValidationError): + QuotaLimit( + window=QuotaWindow.PER_MINUTE, + max_tokens=-1, + ) + + def test_frozen(self) -> None: + """Model is frozen.""" + ql = QuotaLimit(window=QuotaWindow.PER_MINUTE, max_requests=10) + with pytest.raises(ValidationError): + ql.max_requests = 20 # type: ignore[misc] + + +# ── ProviderCostModel ────────────────────────────────────────────── + + +@pytest.mark.unit +class TestProviderCostModel: + """Tests for ProviderCostModel enum.""" + + def test_values(self) -> None: + """All expected cost models exist.""" + assert ProviderCostModel.PER_TOKEN.value == "per_token" + assert ProviderCostModel.SUBSCRIPTION.value == "subscription" + assert ProviderCostModel.LOCAL.value == "local" + + +# ── SubscriptionConfig ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestSubscriptionConfig: + """Tests for SubscriptionConfig model.""" + + def test_defaults(self) -> None: + """Default config is pay-as-you-go.""" + sc = SubscriptionConfig() + assert sc.plan_name == "pay_as_you_go" + assert sc.cost_model == ProviderCostModel.PER_TOKEN + assert sc.monthly_cost == 0.0 + assert sc.quotas == () + assert sc.hardware_limits is None + + def test_subscription_with_monthly_cost(self) -> None: + """Subscription model with monthly cost.""" + sc = SubscriptionConfig( + plan_name="pro", + cost_model=ProviderCostModel.SUBSCRIPTION, + monthly_cost=20.0, + ) + assert sc.monthly_cost == 20.0 + + def test_local_with_hardware_limits(self) -> None: + """Local model with hardware limits.""" + sc = SubscriptionConfig( + plan_name="local", + cost_model=ProviderCostModel.LOCAL, + hardware_limits="RTX 4090, ~30 tok/s", + ) + assert sc.hardware_limits == "RTX 4090, ~30 tok/s" + + def test_local_with_monthly_cost_rejected(self) -> None: + """LOCAL cost_model with monthly_cost > 0 is rejected.""" + with pytest.raises(ValidationError, match="LOCAL cost_model"): + SubscriptionConfig( + cost_model=ProviderCostModel.LOCAL, + monthly_cost=10.0, + ) + + def test_subscription_zero_monthly_cost_warns(self) -> None: + """SUBSCRIPTION with monthly_cost=0 logs warning but is accepted.""" + # Should not raise — just warns + sc = SubscriptionConfig( + cost_model=ProviderCostModel.SUBSCRIPTION, + monthly_cost=0.0, + ) + assert sc.monthly_cost == 0.0 + + def test_duplicate_quota_windows_rejected(self) -> None: + """Duplicate quota windows are rejected.""" + with pytest.raises(ValidationError, match="Duplicate quota windows"): + SubscriptionConfig( + quotas=( + QuotaLimit( + window=QuotaWindow.PER_MINUTE, + max_requests=60, + ), + QuotaLimit( + window=QuotaWindow.PER_MINUTE, + max_requests=30, + ), + ), + ) + + def test_multiple_unique_windows_accepted(self) -> None: + """Multiple unique windows are accepted.""" + sc = SubscriptionConfig( + quotas=( + QuotaLimit(window=QuotaWindow.PER_MINUTE, max_requests=60), + QuotaLimit(window=QuotaWindow.PER_DAY, max_tokens=1_000_000), + ), + ) + assert len(sc.quotas) == 2 + + def test_negative_monthly_cost_rejected(self) -> None: + """Negative monthly_cost is rejected.""" + with pytest.raises(ValidationError): + SubscriptionConfig(monthly_cost=-10.0) + + def test_frozen(self) -> None: + """Model is frozen.""" + sc = SubscriptionConfig() + with pytest.raises(ValidationError): + sc.plan_name = "other" # type: ignore[misc] + + def test_blank_plan_name_rejected(self) -> None: + """Blank plan_name is rejected.""" + with pytest.raises(ValidationError): + SubscriptionConfig(plan_name="") + + +# ── DegradationAction ────────────────────────────────────────────── + + +@pytest.mark.unit +class TestDegradationAction: + """Tests for DegradationAction enum.""" + + def test_values(self) -> None: + """All expected actions exist.""" + assert DegradationAction.FALLBACK.value == "fallback" + assert DegradationAction.QUEUE.value == "queue" + assert DegradationAction.ALERT.value == "alert" + + +# ── DegradationConfig ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestDegradationConfig: + """Tests for DegradationConfig model.""" + + def test_defaults(self) -> None: + """Default is ALERT with no fallback providers.""" + dc = DegradationConfig() + assert dc.strategy == DegradationAction.ALERT + assert dc.fallback_providers == () + assert dc.queue_max_wait_seconds == 300 + + def test_alert_strategy(self) -> None: + """ALERT strategy is accepted.""" + dc = DegradationConfig(strategy=DegradationAction.ALERT) + assert dc.strategy == DegradationAction.ALERT + + def test_fallback_with_providers(self) -> None: + """FALLBACK with providers is accepted.""" + dc = DegradationConfig( + strategy=DegradationAction.FALLBACK, + fallback_providers=("provider-a", "provider-b"), + ) + assert len(dc.fallback_providers) == 2 + + def test_fallback_without_providers_warns(self) -> None: + """FALLBACK with empty providers logs warning but is accepted.""" + dc = DegradationConfig(strategy=DegradationAction.FALLBACK) + assert dc.fallback_providers == () + + def test_queue_max_wait_bounds(self) -> None: + """queue_max_wait_seconds must be in [0, 3600].""" + with pytest.raises(ValidationError): + DegradationConfig(queue_max_wait_seconds=-1) + with pytest.raises(ValidationError): + DegradationConfig(queue_max_wait_seconds=3601) + + def test_frozen(self) -> None: + """Model is frozen.""" + dc = DegradationConfig() + with pytest.raises(ValidationError): + dc.strategy = DegradationAction.ALERT # type: ignore[misc] + + +# ── QuotaSnapshot ────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestQuotaSnapshot: + """Tests for QuotaSnapshot model.""" + + def _make_snapshot( + self, + *, + requests_used: int = 0, + requests_limit: int = 100, + tokens_used: int = 0, + tokens_limit: int = 0, + ) -> QuotaSnapshot: + return QuotaSnapshot( + provider_name="test-provider", + window=QuotaWindow.PER_MINUTE, + requests_used=requests_used, + requests_limit=requests_limit, + tokens_used=tokens_used, + tokens_limit=tokens_limit, + captured_at=datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC), + ) + + def test_requests_remaining(self) -> None: + """Computes remaining requests correctly.""" + snap = self._make_snapshot( + requests_used=30, + requests_limit=100, + ) + assert snap.requests_remaining == 70 + + def test_requests_remaining_at_limit(self) -> None: + """Remaining is 0 when at limit.""" + snap = self._make_snapshot( + requests_used=100, + requests_limit=100, + ) + assert snap.requests_remaining == 0 + + def test_requests_remaining_unlimited(self) -> None: + """Remaining is None when limit is unlimited (0).""" + snap = self._make_snapshot( + requests_used=50, + requests_limit=0, + ) + assert snap.requests_remaining is None + + def test_tokens_remaining(self) -> None: + """Computes remaining tokens correctly.""" + snap = self._make_snapshot( + tokens_used=500, + tokens_limit=1000, + ) + assert snap.tokens_remaining == 500 + + def test_tokens_remaining_unlimited(self) -> None: + """Remaining is None when tokens unlimited.""" + snap = self._make_snapshot( + tokens_used=500, + tokens_limit=0, + ) + assert snap.tokens_remaining is None + + def test_is_exhausted_requests(self) -> None: + """Exhausted when requests at limit.""" + snap = self._make_snapshot( + requests_used=100, + requests_limit=100, + ) + assert snap.is_exhausted is True + + def test_is_exhausted_tokens(self) -> None: + """Exhausted when tokens at limit.""" + snap = self._make_snapshot( + requests_limit=0, + tokens_used=1000, + tokens_limit=1000, + ) + assert snap.is_exhausted is True + + def test_not_exhausted(self) -> None: + """Not exhausted when under both limits.""" + snap = self._make_snapshot( + requests_used=50, + requests_limit=100, + tokens_used=500, + tokens_limit=1000, + ) + assert snap.is_exhausted is False + + def test_not_exhausted_unlimited(self) -> None: + """Not exhausted when both limits are unlimited.""" + snap = self._make_snapshot( + requests_used=1000, + requests_limit=0, + tokens_used=1000, + tokens_limit=0, + ) + assert snap.is_exhausted is False + + +# ── QuotaCheckResult ─────────────────────────────────────────────── + + +@pytest.mark.unit +class TestQuotaCheckResult: + """Tests for QuotaCheckResult model.""" + + def test_allowed(self) -> None: + """Allowed result.""" + result = QuotaCheckResult( + allowed=True, + provider_name="test-provider", + ) + assert result.allowed is True + assert result.reason == "" + assert result.exhausted_windows == () + + def test_denied(self) -> None: + """Denied result with reason and exhausted windows.""" + result = QuotaCheckResult( + allowed=False, + provider_name="test-provider", + reason="per_minute requests exhausted", + exhausted_windows=(QuotaWindow.PER_MINUTE,), + ) + assert result.allowed is False + assert "per_minute" in result.reason + + +# ── window_start ─────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestWindowStart: + """Tests for window_start function.""" + + def test_per_minute(self) -> None: + """PER_MINUTE truncates to minute start.""" + now = datetime(2026, 3, 15, 14, 35, 42, tzinfo=UTC) + result = window_start(QuotaWindow.PER_MINUTE, now=now) + assert result == datetime(2026, 3, 15, 14, 35, tzinfo=UTC) + + def test_per_hour(self) -> None: + """PER_HOUR truncates to hour start.""" + now = datetime(2026, 3, 15, 14, 35, 42, tzinfo=UTC) + result = window_start(QuotaWindow.PER_HOUR, now=now) + assert result == datetime(2026, 3, 15, 14, 0, tzinfo=UTC) + + def test_per_day(self) -> None: + """PER_DAY truncates to day start.""" + now = datetime(2026, 3, 15, 14, 35, 42, tzinfo=UTC) + result = window_start(QuotaWindow.PER_DAY, now=now) + assert result == datetime(2026, 3, 15, tzinfo=UTC) + + def test_per_month(self) -> None: + """PER_MONTH truncates to first of month.""" + now = datetime(2026, 3, 15, 14, 35, 42, tzinfo=UTC) + result = window_start(QuotaWindow.PER_MONTH, now=now) + assert result == datetime(2026, 3, 1, tzinfo=UTC) + + def test_naive_datetime_rejected(self) -> None: + """Naive datetime raises ValueError.""" + naive = datetime(2026, 3, 15, 14, 30, 0) # noqa: DTZ001 + with pytest.raises(ValueError, match="timezone-aware"): + window_start(QuotaWindow.PER_HOUR, now=naive) + + def test_defaults_to_now(self) -> None: + """Uses current time when now is not provided.""" + before = datetime.now(UTC) + result = window_start(QuotaWindow.PER_MONTH) + after = datetime.now(UTC) + # PER_MONTH truncates to 1st of month — stable within a month + assert result.year == before.year + assert result.month == before.month + assert result.day == 1 + assert before <= after # sanity + + +# ── effective_cost_per_1k ────────────────────────────────────────── + + +@pytest.mark.unit +class TestEffectiveCostPer1k: + """Tests for effective_cost_per_1k function.""" + + def test_per_token(self) -> None: + """PER_TOKEN returns sum of input + output costs.""" + result = effective_cost_per_1k(0.003, 0.015, ProviderCostModel.PER_TOKEN) + assert result == 0.018 + + def test_subscription_returns_zero(self) -> None: + """SUBSCRIPTION returns 0.0 (pre-paid).""" + result = effective_cost_per_1k(0.003, 0.015, ProviderCostModel.SUBSCRIPTION) + assert result == 0.0 + + def test_local_returns_zero(self) -> None: + """LOCAL returns 0.0 (free).""" + result = effective_cost_per_1k(0.003, 0.015, ProviderCostModel.LOCAL) + assert result == 0.0 + + def test_per_token_zero_costs(self) -> None: + """PER_TOKEN with zero costs returns 0.0.""" + result = effective_cost_per_1k(0.0, 0.0, ProviderCostModel.PER_TOKEN) + assert result == 0.0 + + def test_per_token_negative_inputs(self) -> None: + """PER_TOKEN with negative inputs returns the sum as-is.""" + result = effective_cost_per_1k(-0.001, 0.005, ProviderCostModel.PER_TOKEN) + assert result == pytest.approx(0.004) + + +# ── SubscriptionConfig nan/inf rejection ────────────────────────── + + +@pytest.mark.unit +class TestSubscriptionConfigNanInf: + """Tests that SubscriptionConfig rejects nan/inf values.""" + + def test_rejects_nan_monthly_cost(self) -> None: + """NaN monthly_cost is rejected.""" + with pytest.raises(ValidationError): + SubscriptionConfig(monthly_cost=float("nan")) + + def test_rejects_inf_monthly_cost(self) -> None: + """Inf monthly_cost is rejected.""" + with pytest.raises(ValidationError): + SubscriptionConfig(monthly_cost=float("inf")) + + +# ── QuotaSnapshot over-limit ───────────────────────────────────── + + +@pytest.mark.unit +class TestQuotaSnapshotOverLimit: + """Tests for QuotaSnapshot with usage exceeding limits.""" + + def test_is_exhausted_over_limit(self) -> None: + """is_exhausted returns True when usage exceeds limit.""" + snap = QuotaSnapshot( + provider_name="test-provider", + window=QuotaWindow.PER_MINUTE, + requests_used=150, + requests_limit=100, + captured_at=datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC), + ) + assert snap.is_exhausted is True + assert snap.requests_remaining == 0 + + +# ── QuotaCheckResult cross-field validation ─────────────────────── + + +@pytest.mark.unit +class TestQuotaCheckResultValidation: + """Tests for QuotaCheckResult cross-field validation.""" + + def test_denied_without_reason_rejected(self) -> None: + """Denied result with empty reason is rejected.""" + with pytest.raises(ValidationError, match="non-empty reason"): + QuotaCheckResult( + allowed=False, + provider_name="test-provider", + ) + + def test_allowed_with_exhausted_windows_rejected(self) -> None: + """Allowed result with exhausted_windows is rejected.""" + with pytest.raises( + ValidationError, + match="must not have exhausted_windows", + ): + QuotaCheckResult( + allowed=True, + provider_name="test-provider", + exhausted_windows=(QuotaWindow.PER_MINUTE,), + ) diff --git a/tests/unit/budget/test_quota_tracker.py b/tests/unit/budget/test_quota_tracker.py new file mode 100644 index 0000000000..9a962bd547 --- /dev/null +++ b/tests/unit/budget/test_quota_tracker.py @@ -0,0 +1,535 @@ +"""Tests for QuotaTracker service.""" + +from datetime import UTC, datetime +from unittest.mock import patch + +import pytest + +from ai_company.budget.quota import ( + QuotaLimit, + QuotaWindow, + SubscriptionConfig, +) +from ai_company.budget.quota_tracker import QuotaTracker + +pytestmark = pytest.mark.timeout(30) + +# Fixed timestamps for deterministic tests +_NOW = datetime(2026, 3, 15, 14, 30, 0, tzinfo=UTC) +_MINUTE_START = datetime(2026, 3, 15, 14, 30, tzinfo=UTC) +_HOUR_START = datetime(2026, 3, 15, 14, 0, tzinfo=UTC) +_DAY_START = datetime(2026, 3, 15, tzinfo=UTC) +_MONTH_START = datetime(2026, 3, 1, tzinfo=UTC) + + +# ── Helpers ──────────────────────────────────────────────────────── + + +def _make_tracker( + *, + provider: str = "test-provider", + quotas: tuple[QuotaLimit, ...] = (), + **kwargs: object, +) -> QuotaTracker: + """Create a QuotaTracker with a single provider.""" + sub = SubscriptionConfig(quotas=quotas, **kwargs) # type: ignore[arg-type] + return QuotaTracker(subscriptions={provider: sub}) + + +def _hour_quota(max_requests: int = 60) -> QuotaLimit: + return QuotaLimit(window=QuotaWindow.PER_HOUR, max_requests=max_requests) + + +def _day_token_quota(max_tokens: int = 1_000_000) -> QuotaLimit: + return QuotaLimit(window=QuotaWindow.PER_DAY, max_tokens=max_tokens) + + +# ── Construction ─────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestQuotaTrackerConstruction: + """Tests for QuotaTracker initialization.""" + + def test_creates_with_empty_subscriptions(self) -> None: + """Tracker with no subscriptions creates successfully.""" + tracker = QuotaTracker(subscriptions={}) + assert tracker is not None + + def test_creates_with_provider(self) -> None: + """Tracker with provider subscription creates successfully.""" + tracker = _make_tracker( + quotas=(_hour_quota(),), + ) + assert tracker is not None + + async def test_provider_without_quotas_not_tracked(self) -> None: + """Provider with no quotas is not actively tracked.""" + sub = SubscriptionConfig() # No quotas + tracker = QuotaTracker(subscriptions={"test-provider": sub}) + + # Provider is known but has no quotas — should still be allowed + result = await tracker.check_quota("test-provider") + assert result.allowed is True + + # Recording usage is a no-op (no crash) + await tracker.record_usage("test-provider") + + # Snapshot returns empty (no windows tracked) + snapshots = await tracker.get_snapshot("test-provider") + assert snapshots == () + + +# ── record_usage ─────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestRecordUsage: + """Tests for QuotaTracker.record_usage().""" + + async def test_records_request(self) -> None: + """Records a single request.""" + tracker = _make_tracker(quotas=(_hour_quota(60),)) + + await tracker.record_usage("test-provider") + + snapshots = await tracker.get_snapshot("test-provider") + assert len(snapshots) == 1 + assert snapshots[0].requests_used == 1 + + async def test_records_tokens(self) -> None: + """Records token usage.""" + tracker = _make_tracker(quotas=(_day_token_quota(1_000_000),)) + + await tracker.record_usage("test-provider", tokens=5000) + + snapshots = await tracker.get_snapshot("test-provider") + assert len(snapshots) == 1 + assert snapshots[0].tokens_used == 5000 + + async def test_accumulates_usage(self) -> None: + """Multiple records accumulate within same window.""" + tracker = _make_tracker(quotas=(_hour_quota(60),)) + + await tracker.record_usage("test-provider", requests=3) + await tracker.record_usage("test-provider", requests=2) + + snapshots = await tracker.get_snapshot("test-provider") + assert snapshots[0].requests_used == 5 + + async def test_unknown_provider_is_noop(self) -> None: + """Recording for unknown provider does nothing.""" + tracker = _make_tracker(quotas=(_hour_quota(),)) + + # Should not raise + await tracker.record_usage("unknown-provider") + + async def test_window_rotation(self) -> None: + """Counters reset when window boundary is crossed.""" + # Use per_hour to avoid minute-boundary flakiness + hour_quota = QuotaLimit( + window=QuotaWindow.PER_HOUR, + max_requests=60, + ) + tracker = _make_tracker(quotas=(hour_quota,)) + + # Record 10 requests (same hour as tracker creation) + await tracker.record_usage("test-provider", requests=10) + + # Verify initial count + snapshots = await tracker.get_snapshot( + "test-provider", + window=QuotaWindow.PER_HOUR, + ) + assert snapshots[0].requests_used == 10 + + # Force window rotation by mocking time to next hour + next_hour = datetime(2099, 1, 1, 1, 0, 0, tzinfo=UTC) + with patch( + "ai_company.budget.quota_tracker.datetime", + ) as mock_dt: + mock_dt.now.return_value = next_hour + mock_dt.side_effect = datetime + await tracker.record_usage("test-provider", requests=1) + + # Query in same mocked time so window matches + snapshots = await tracker.get_snapshot( + "test-provider", + window=QuotaWindow.PER_HOUR, + ) + + # After rotation, only the new request counts + assert snapshots[0].requests_used == 1 + + +# ── check_quota ──────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestCheckQuota: + """Tests for QuotaTracker.check_quota().""" + + async def test_allowed_when_under_limit(self) -> None: + """Check passes when under quota limit.""" + tracker = _make_tracker(quotas=(_hour_quota(60),)) + + await tracker.record_usage("test-provider", requests=30) + + result = await tracker.check_quota("test-provider") + assert result.allowed is True + + async def test_denied_when_at_limit(self) -> None: + """Check denied when at quota limit.""" + tracker = _make_tracker(quotas=(_hour_quota(10),)) + + await tracker.record_usage("test-provider", requests=10) + + result = await tracker.check_quota("test-provider") + assert result.allowed is False + assert QuotaWindow.PER_HOUR in result.exhausted_windows + assert "requests" in result.reason + + async def test_denied_when_over_limit(self) -> None: + """Check denied when over quota limit.""" + tracker = _make_tracker(quotas=(_hour_quota(10),)) + + await tracker.record_usage("test-provider", requests=15) + + result = await tracker.check_quota("test-provider") + assert result.allowed is False + + async def test_unknown_provider_always_allowed(self) -> None: + """Unknown providers are always allowed.""" + tracker = _make_tracker(quotas=(_hour_quota(),)) + + result = await tracker.check_quota("unknown-provider") + assert result.allowed is True + + async def test_estimated_tokens_checked(self) -> None: + """Estimated tokens are considered in quota check.""" + tracker = _make_tracker(quotas=(_day_token_quota(1000),)) + + await tracker.record_usage("test-provider", tokens=800) + + # 800 used + 300 estimated = 1100 > 1000 + result = await tracker.check_quota( + "test-provider", + estimated_tokens=300, + ) + assert result.allowed is False + + async def test_estimated_tokens_under_limit_allowed(self) -> None: + """Estimated tokens under limit passes.""" + tracker = _make_tracker(quotas=(_day_token_quota(1000),)) + + await tracker.record_usage("test-provider", tokens=500) + + # 500 used + 100 estimated = 600 < 1000 + result = await tracker.check_quota( + "test-provider", + estimated_tokens=100, + ) + assert result.allowed is True + + async def test_multiple_windows_checked(self) -> None: + """All configured windows are checked.""" + tracker = _make_tracker( + quotas=( + _hour_quota(100), + _day_token_quota(10_000), + ), + ) + + # Exhaust daily tokens (exceed limit — tokens use > semantics) + await tracker.record_usage("test-provider", requests=5, tokens=10_001) + + result = await tracker.check_quota("test-provider") + assert result.allowed is False + assert QuotaWindow.PER_DAY in result.exhausted_windows + + async def test_rotated_window_resets_check(self) -> None: + """Rotated window allows requests again.""" + # Use per_hour to avoid minute-boundary flakiness + hour_quota = QuotaLimit( + window=QuotaWindow.PER_HOUR, + max_requests=10, + ) + tracker = _make_tracker(quotas=(hour_quota,)) + + # Exhaust quota + await tracker.record_usage("test-provider", requests=10) + + # Verify exhausted + result = await tracker.check_quota("test-provider") + assert result.allowed is False + + # Check in next hour (window rotated — check sees fresh window) + next_hour = datetime(2099, 1, 1, 1, 0, 0, tzinfo=UTC) + with patch( + "ai_company.budget.quota_tracker.datetime", + ) as mock_dt: + mock_dt.now.return_value = next_hour + mock_dt.side_effect = datetime + result = await tracker.check_quota("test-provider") + + assert result.allowed is True + + +# ── get_snapshot ─────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestGetSnapshot: + """Tests for QuotaTracker.get_snapshot().""" + + async def test_returns_snapshots_for_tracked_provider(self) -> None: + """Returns snapshots for tracked provider.""" + tracker = _make_tracker(quotas=(_hour_quota(60),)) + + await tracker.record_usage("test-provider", requests=5) + + snapshots = await tracker.get_snapshot("test-provider") + assert len(snapshots) == 1 + assert snapshots[0].provider_name == "test-provider" + assert snapshots[0].window == QuotaWindow.PER_HOUR + assert snapshots[0].requests_used == 5 + assert snapshots[0].requests_limit == 60 + + async def test_returns_empty_for_unknown_provider(self) -> None: + """Returns empty tuple for unknown provider.""" + tracker = _make_tracker(quotas=(_hour_quota(),)) + + snapshots = await tracker.get_snapshot("unknown") + assert snapshots == () + + async def test_filter_by_window(self) -> None: + """Can filter snapshots by specific window.""" + tracker = _make_tracker( + quotas=( + _hour_quota(60), + _day_token_quota(1_000_000), + ), + ) + + snapshots = await tracker.get_snapshot( + "test-provider", + window=QuotaWindow.PER_HOUR, + ) + assert len(snapshots) == 1 + assert snapshots[0].window == QuotaWindow.PER_HOUR + + async def test_rotated_window_shows_zero(self) -> None: + """Rotated window shows zero usage in snapshot.""" + # Use per_hour to avoid minute-boundary flakiness + hour_quota = QuotaLimit( + window=QuotaWindow.PER_HOUR, + max_requests=60, + ) + tracker = _make_tracker(quotas=(hour_quota,)) + + # Record requests + await tracker.record_usage("test-provider", requests=10) + + # Query in next hour + next_hour = datetime(2099, 1, 1, 1, 0, 0, tzinfo=UTC) + with patch( + "ai_company.budget.quota_tracker.datetime", + ) as mock_dt: + mock_dt.now.return_value = next_hour + mock_dt.side_effect = datetime + snapshots = await tracker.get_snapshot("test-provider") + + assert snapshots[0].requests_used == 0 + + +# ── get_all_snapshots ────────────────────────────────────────────── + + +@pytest.mark.unit +class TestGetAllSnapshots: + """Tests for QuotaTracker.get_all_snapshots().""" + + async def test_returns_all_providers(self) -> None: + """Returns snapshots for all tracked providers.""" + sub_a = SubscriptionConfig( + quotas=(_hour_quota(60),), + ) + sub_b = SubscriptionConfig( + quotas=(_day_token_quota(1_000_000),), + ) + tracker = QuotaTracker( + subscriptions={"provider-a": sub_a, "provider-b": sub_b}, + ) + + all_snapshots = await tracker.get_all_snapshots() + assert "provider-a" in all_snapshots + assert "provider-b" in all_snapshots + + async def test_empty_when_no_subscriptions(self) -> None: + """Returns empty dict when no subscriptions.""" + tracker = QuotaTracker(subscriptions={}) + all_snapshots = await tracker.get_all_snapshots() + assert all_snapshots == {} + + +# ── Deep copy isolation ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestDeepCopyIsolation: + """Tests that QuotaTracker defensively copies subscriptions.""" + + async def test_external_mutation_does_not_affect_tracker(self) -> None: + """Mutating the original dict after construction has no effect.""" + sub = SubscriptionConfig( + quotas=( + QuotaLimit( + window=QuotaWindow.PER_HOUR, + max_requests=10, + ), + ), + ) + subs: dict[str, SubscriptionConfig] = {"test-provider": sub} + tracker = QuotaTracker(subscriptions=subs) + + # Mutate original dict + subs["new-provider"] = sub + del subs["test-provider"] + + # Tracker should still work with original provider + result = await tracker.check_quota("test-provider") + assert result.allowed is True + + # New provider should not be tracked + result = await tracker.check_quota("new-provider") + assert result.allowed is True # unknown = always allowed + + +# ── Exhaustion reason with estimated_tokens ─────────────────────── + + +@pytest.mark.unit +class TestExhaustionReasonWithEstimatedTokens: + """Tests for exhaustion reason when triggered by projected tokens.""" + + async def test_reason_includes_projected_tokens(self) -> None: + """Reason string mentions projected tokens when denial is + triggered by estimated_tokens projection.""" + sub = SubscriptionConfig( + quotas=( + QuotaLimit( + window=QuotaWindow.PER_DAY, + max_tokens=1000, + ), + ), + ) + tracker = QuotaTracker(subscriptions={"test-provider": sub}) + + # Record 800 tokens (under limit) + await tracker.record_usage("test-provider", requests=0, tokens=800) + + # Check with estimated_tokens=300 → projected 1100 > 1000 + result = await tracker.check_quota( + "test-provider", + estimated_tokens=300, + ) + assert result.allowed is False + assert "tokens" in result.reason + assert "1100" in result.reason # projected total + assert "1000" in result.reason # limit + + +# ── Multiple exhausted windows ──────────────────────────────────── + + +@pytest.mark.unit +class TestMultipleExhaustedWindows: + """Tests for simultaneous exhaustion across multiple windows.""" + + async def test_both_windows_exhausted(self) -> None: + """Both windows appear in result when both are exhausted.""" + sub = SubscriptionConfig( + quotas=( + QuotaLimit( + window=QuotaWindow.PER_HOUR, + max_requests=5, + ), + QuotaLimit( + window=QuotaWindow.PER_DAY, + max_tokens=100, + ), + ), + ) + tracker = QuotaTracker(subscriptions={"test-provider": sub}) + + # Exhaust both (tokens=101 to exceed > threshold) + await tracker.record_usage( + "test-provider", + requests=5, + tokens=101, + ) + + result = await tracker.check_quota("test-provider") + assert result.allowed is False + assert len(result.exhausted_windows) == 2 + assert QuotaWindow.PER_HOUR in result.exhausted_windows + assert QuotaWindow.PER_DAY in result.exhausted_windows + # Reason should have both, joined by "; " + assert "; " in result.reason + + async def test_record_usage_updates_all_windows(self) -> None: + """Recording usage updates counters for all configured windows.""" + sub = SubscriptionConfig( + quotas=( + QuotaLimit( + window=QuotaWindow.PER_HOUR, + max_requests=100, + ), + QuotaLimit( + window=QuotaWindow.PER_DAY, + max_requests=1000, + max_tokens=50_000, + ), + ), + ) + tracker = QuotaTracker(subscriptions={"test-provider": sub}) + + await tracker.record_usage( + "test-provider", + requests=3, + tokens=500, + ) + + snapshots = await tracker.get_snapshot("test-provider") + assert len(snapshots) == 2 + + by_window = {s.window: s for s in snapshots} + assert by_window[QuotaWindow.PER_HOUR].requests_used == 3 + assert by_window[QuotaWindow.PER_DAY].requests_used == 3 + assert by_window[QuotaWindow.PER_DAY].tokens_used == 500 + + +# ── Input validation ────────────────────────────────────────────── + + +@pytest.mark.unit +class TestInputValidation: + """Tests for negative input rejection.""" + + async def test_record_usage_rejects_negative_requests(self) -> None: + """Negative requests raise ValueError.""" + tracker = QuotaTracker(subscriptions={}) + with pytest.raises(ValueError, match="non-negative"): + await tracker.record_usage("test-provider", requests=-1) + + async def test_record_usage_rejects_negative_tokens(self) -> None: + """Negative tokens raise ValueError.""" + tracker = QuotaTracker(subscriptions={}) + with pytest.raises(ValueError, match="non-negative"): + await tracker.record_usage("test-provider", tokens=-1) + + async def test_check_quota_rejects_negative_estimated(self) -> None: + """Negative estimated_tokens raises ValueError.""" + tracker = QuotaTracker(subscriptions={}) + with pytest.raises(ValueError, match="non-negative"): + await tracker.check_quota("p", estimated_tokens=-1) diff --git a/tests/unit/config/conftest.py b/tests/unit/config/conftest.py index ab242e0ff3..207e4fa36b 100644 --- a/tests/unit/config/conftest.py +++ b/tests/unit/config/conftest.py @@ -7,6 +7,8 @@ from ai_company.budget.config import BudgetConfig from ai_company.budget.coordination_config import CoordinationMetricsConfig +from ai_company.budget.cost_tiers import CostTiersConfig +from ai_company.budget.quota import DegradationConfig, SubscriptionConfig from ai_company.communication.config import CommunicationConfig from ai_company.config.schema import ( AgentConfig, @@ -45,6 +47,8 @@ class ProviderConfigFactory(ModelFactory[ProviderConfig]): models = () retry = RetryConfig() rate_limiter = RateLimiterConfig() + subscription = SubscriptionConfig() + degradation = DegradationConfig() class RoutingRuleConfigFactory(ModelFactory[RoutingRuleConfig]): @@ -77,6 +81,7 @@ class RootConfigFactory(ModelFactory[RootConfig]): task_assignment = TaskAssignmentConfig() memory = CompanyMemoryConfig() persistence = PersistenceConfig() + cost_tiers = CostTiersConfig() # ── Sample YAML strings ────────────────────────────────────────── diff --git a/tests/unit/observability/test_events.py b/tests/unit/observability/test_events.py index 76ae29dbd5..dde625f0b9 100644 --- a/tests/unit/observability/test_events.py +++ b/tests/unit/observability/test_events.py @@ -192,6 +192,7 @@ def test_all_domain_modules_discovered(self) -> None: "personality", "prompt", "provider", + "quota", "role", "routing", "sandbox",