From 6690ea047189c7c2b4888897bed101691165c69c Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Wed, 8 Apr 2026 23:41:55 +0200 Subject: [PATCH 1/4] feat: persistent cost aggregation for project-lifetime budgets --- CLAUDE.md | 2 +- docs/design/operations.md | 4 +- src/synthorg/budget/__init__.py | 8 + src/synthorg/budget/enforcer.py | 107 +++++++---- src/synthorg/budget/project_cost_aggregate.py | 95 +++++++++ src/synthorg/budget/tracker.py | 49 +++++ src/synthorg/observability/events/budget.py | 7 + .../observability/events/persistence.py | 18 ++ src/synthorg/persistence/sqlite/backend.py | 24 +++ .../sqlite/project_cost_aggregate_repo.py | 180 ++++++++++++++++++ src/synthorg/persistence/sqlite/schema.sql | 10 + .../budget/test_enforcer_project_durable.py | 161 ++++++++++++++++ .../budget/test_project_cost_aggregate.py | 101 ++++++++++ .../budget/test_tracker_project_aggregate.py | 90 +++++++++ .../persistence/sqlite/test_migrations.py | 1 + .../test_project_cost_aggregate_repo.py | 106 +++++++++++ 16 files changed, 926 insertions(+), 37 deletions(-) create mode 100644 src/synthorg/budget/project_cost_aggregate.py create mode 100644 src/synthorg/persistence/sqlite/project_cost_aggregate_repo.py create mode 100644 tests/unit/budget/test_enforcer_project_durable.py create mode 100644 tests/unit/budget/test_project_cost_aggregate.py create mode 100644 tests/unit/budget/test_tracker_project_aggregate.py create mode 100644 tests/unit/persistence/sqlite/test_project_cost_aggregate_repo.py diff --git a/CLAUDE.md b/CLAUDE.md index 97ac3bf04c..f29629b95c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -90,7 +90,7 @@ See `web/CLAUDE.md` for the full component inventory, design token rules, and po - **Every module** with business logic MUST have: `from synthorg.observability import get_logger` then `logger = get_logger(__name__)` - **Never** use `import logging` / `logging.getLogger()` / `print()` in application code (exception: `observability/setup.py`, `observability/sinks.py`, `observability/syslog_handler.py`, `observability/http_handler.py`, and `observability/otlp_handler.py` may use stdlib `logging` and `print(..., file=sys.stderr)` for handler construction, bootstrap, and error reporting code that runs before or during logging system configuration) - **Variable name**: always `logger` (not `_logger`, not `log`) -- **Event names**: always use constants from the domain-specific module under `synthorg.observability.events` (e.g., `API_REQUEST_STARTED` from `events.api`, `TOOL_INVOKE_START` from `events.tool`, `GIT_COMMAND_START` from `events.git`, `CONTEXT_BUDGET_FILL_UPDATED`, `CONTEXT_BUDGET_COMPACTION_STARTED`, `CONTEXT_BUDGET_COMPACTION_COMPLETED`, `CONTEXT_BUDGET_COMPACTION_FAILED`, `CONTEXT_BUDGET_COMPACTION_SKIPPED`, `CONTEXT_BUDGET_COMPACTION_FALLBACK`, `CONTEXT_BUDGET_INDICATOR_INJECTED`, `CONTEXT_BUDGET_AGENT_COMPACTION_REQUESTED`, `CONTEXT_BUDGET_EPISTEMIC_MARKERS_PRESERVED` from `events.context_budget`, `BACKUP_STARTED` from `events.backup`, `SETUP_COMPLETED` from `events.setup`, `ROUTING_CANDIDATE_SELECTED` from `events.routing`, `SHIPPING_HTTP_BATCH_SENT` from `events.shipping`, `EVAL_REPORT_COMPUTED` from `events.evaluation`, `PROMPT_PROFILE_SELECTED` from `events.prompt`, `PROCEDURAL_MEMORY_START` from `events.procedural_memory`, `PERF_LLM_JUDGE_STARTED` from `events.performance`, `TASK_ENGINE_OBSERVER_FAILED` from `events.task_engine`, `TASK_ASSIGNMENT_PROJECT_FILTERED` and `TASK_ASSIGNMENT_PROJECT_NO_ELIGIBLE` from `events.task_assignment`, `EXECUTION_SHUTDOWN_IMMEDIATE_CANCEL`, `EXECUTION_SHUTDOWN_TOOL_WAIT`, `EXECUTION_SHUTDOWN_CHECKPOINT_SAVE`, `EXECUTION_SHUTDOWN_CHECKPOINT_FAILED`, and `EXECUTION_PROJECT_VALIDATION_FAILED` from `events.execution`, `WORKFLOW_EXEC_COMPLETED` from `events.workflow_execution`, `BLUEPRINT_INSTANTIATE_START` from `events.blueprint`, `WORKFLOW_DEF_ROLLED_BACK` from `events.workflow_definition`, `WORKFLOW_VERSION_SAVED` from `events.workflow_version`, `MEMORY_FINE_TUNE_STARTED`, `MEMORY_SELF_EDIT_TOOL_EXECUTE`, `MEMORY_SELF_EDIT_CORE_READ`, `MEMORY_SELF_EDIT_CORE_WRITE`, `MEMORY_SELF_EDIT_CORE_WRITE_REJECTED`, `MEMORY_SELF_EDIT_ARCHIVAL_SEARCH`, `MEMORY_SELF_EDIT_ARCHIVAL_WRITE`, `MEMORY_SELF_EDIT_RECALL_READ`, `MEMORY_SELF_EDIT_RECALL_WRITE`, `MEMORY_SELF_EDIT_WRITE_FAILED` from `events.memory`, `REPORTING_GENERATION_STARTED` from `events.reporting`, `RISK_BUDGET_SCORE_COMPUTED` from `events.risk_budget`, `BUDGET_PROJECT_COST_QUERIED`, `BUDGET_PROJECT_RECORDS_QUERIED`, `BUDGET_PROJECT_BUDGET_EXCEEDED`, and `BUDGET_PROJECT_ENFORCEMENT_CHECK` from `events.budget`, `LLM_STRATEGY_SYNTHESIZED` and `DISTILLATION_CAPTURED` from `events.consolidation`, `MEMORY_DIVERSITY_RERANKED`, `MEMORY_DIVERSITY_RERANK_FAILED`, and `MEMORY_REFORMULATION_ROUND` from `events.memory`, `NOTIFICATION_DISPATCHED` and `NOTIFICATION_DISPATCH_FAILED` from `events.notification`, `QUALITY_STEP_CLASSIFIED` from `events.quality`, `HEALTH_TICKET_EMITTED` from `events.health`, `TRAJECTORY_SCORING_START` from `events.trajectory`, `COORD_METRICS_AMDAHL_COMPUTED` from `events.coordination_metrics`, `COORDINATION_STARTED`, `COORDINATION_COMPLETED`, `COORDINATION_FAILED`, `COORDINATION_PHASE_STARTED`, `COORDINATION_PHASE_COMPLETED`, `COORDINATION_PHASE_FAILED`, `COORDINATION_WAVE_STARTED`, `COORDINATION_WAVE_COMPLETED`, `COORDINATION_TOPOLOGY_RESOLVED`, `COORDINATION_CLEANUP_STARTED`, `COORDINATION_CLEANUP_COMPLETED`, `COORDINATION_CLEANUP_FAILED`, `COORDINATION_WAVE_BUILT`, `COORDINATION_FACTORY_BUILT`, and `COORDINATION_ATTRIBUTION_BUILT` from `events.coordination`, `WEB_REQUEST_START` and `WEB_SSRF_BLOCKED` from `events.web`, `DB_QUERY_START` and `DB_WRITE_BLOCKED` from `events.database`, `TERMINAL_COMMAND_START` and `TERMINAL_COMMAND_BLOCKED` from `events.terminal`, `SUB_CONSTRAINT_RESOLVED` and `SUB_CONSTRAINT_DENIED` from `events.sub_constraint`, `VERSION_SAVED` and `VERSION_SNAPSHOT_FAILED` from `events.versioning`, `ANALYTICS_AGGREGATION_COMPUTED` and `ANALYTICS_RETRY_RATE_ALERT` from `events.analytics`, `CALL_CLASSIFICATION_COMPUTED` from `events.call_classification`, `QUOTA_THRESHOLD_ALERT` and `QUOTA_POLL_FAILED` from `events.quota`, `CONFLICT_DEBATE_EVALUATOR_FAILED` from `events.conflict`, `DELEGATION_LOOP_CIRCUIT_BACKOFF` and `DELEGATION_LOOP_CIRCUIT_PERSIST_FAILED` from `events.delegation`, `MEETING_EVENT_COOLDOWN_SKIPPED` and `MEETING_TASKS_CAPPED` from `events.meeting`, `PERSISTENCE_CIRCUIT_BREAKER_SAVED`, `PERSISTENCE_CIRCUIT_BREAKER_SAVE_FAILED`, `PERSISTENCE_CIRCUIT_BREAKER_LOADED`, `PERSISTENCE_CIRCUIT_BREAKER_LOAD_FAILED`, `PERSISTENCE_CIRCUIT_BREAKER_DELETED`, and `PERSISTENCE_CIRCUIT_BREAKER_DELETE_FAILED` from `events.persistence`, `METRICS_SCRAPE_COMPLETED`, `METRICS_SCRAPE_FAILED`, `METRICS_COLLECTOR_INITIALIZED`, `METRICS_COORDINATION_RECORDED`, `METRICS_OTLP_EXPORT_COMPLETED` and `METRICS_OTLP_FLUSHER_STOPPED` from `events.metrics`, `ORG_MEMORY_QUERY_START`, `ORG_MEMORY_QUERY_COMPLETE`, `ORG_MEMORY_QUERY_FAILED`, `ORG_MEMORY_WRITE_START`, `ORG_MEMORY_WRITE_COMPLETE`, `ORG_MEMORY_WRITE_DENIED`, `ORG_MEMORY_WRITE_FAILED`, `ORG_MEMORY_POLICIES_LISTED`, `ORG_MEMORY_BACKEND_CREATED`, `ORG_MEMORY_CONNECT_FAILED`, `ORG_MEMORY_DISCONNECT_FAILED`, `ORG_MEMORY_NOT_CONNECTED`, `ORG_MEMORY_ROW_PARSE_FAILED`, `ORG_MEMORY_CONFIG_INVALID`, `ORG_MEMORY_MODEL_INVALID`, `ORG_MEMORY_MVCC_PUBLISH_APPENDED`, `ORG_MEMORY_MVCC_RETRACT_APPENDED`, `ORG_MEMORY_MVCC_SNAPSHOT_AT_QUERIED`, and `ORG_MEMORY_MVCC_LOG_QUERIED` from `events.org_memory`). Each domain has its own module -- see `src/synthorg/observability/events/` for the full inventory of constants. Import directly: `from synthorg.observability.events. import EVENT_CONSTANT` +- **Event names**: always use constants from the domain-specific module under `synthorg.observability.events` (e.g., `API_REQUEST_STARTED` from `events.api`, `TOOL_INVOKE_START` from `events.tool`, `GIT_COMMAND_START` from `events.git`, `CONTEXT_BUDGET_FILL_UPDATED`, `CONTEXT_BUDGET_COMPACTION_STARTED`, `CONTEXT_BUDGET_COMPACTION_COMPLETED`, `CONTEXT_BUDGET_COMPACTION_FAILED`, `CONTEXT_BUDGET_COMPACTION_SKIPPED`, `CONTEXT_BUDGET_COMPACTION_FALLBACK`, `CONTEXT_BUDGET_INDICATOR_INJECTED`, `CONTEXT_BUDGET_AGENT_COMPACTION_REQUESTED`, `CONTEXT_BUDGET_EPISTEMIC_MARKERS_PRESERVED` from `events.context_budget`, `BACKUP_STARTED` from `events.backup`, `SETUP_COMPLETED` from `events.setup`, `ROUTING_CANDIDATE_SELECTED` from `events.routing`, `SHIPPING_HTTP_BATCH_SENT` from `events.shipping`, `EVAL_REPORT_COMPUTED` from `events.evaluation`, `PROMPT_PROFILE_SELECTED` from `events.prompt`, `PROCEDURAL_MEMORY_START` from `events.procedural_memory`, `PERF_LLM_JUDGE_STARTED` from `events.performance`, `TASK_ENGINE_OBSERVER_FAILED` from `events.task_engine`, `TASK_ASSIGNMENT_PROJECT_FILTERED` and `TASK_ASSIGNMENT_PROJECT_NO_ELIGIBLE` from `events.task_assignment`, `EXECUTION_SHUTDOWN_IMMEDIATE_CANCEL`, `EXECUTION_SHUTDOWN_TOOL_WAIT`, `EXECUTION_SHUTDOWN_CHECKPOINT_SAVE`, `EXECUTION_SHUTDOWN_CHECKPOINT_FAILED`, and `EXECUTION_PROJECT_VALIDATION_FAILED` from `events.execution`, `WORKFLOW_EXEC_COMPLETED` from `events.workflow_execution`, `BLUEPRINT_INSTANTIATE_START` from `events.blueprint`, `WORKFLOW_DEF_ROLLED_BACK` from `events.workflow_definition`, `WORKFLOW_VERSION_SAVED` from `events.workflow_version`, `MEMORY_FINE_TUNE_STARTED`, `MEMORY_SELF_EDIT_TOOL_EXECUTE`, `MEMORY_SELF_EDIT_CORE_READ`, `MEMORY_SELF_EDIT_CORE_WRITE`, `MEMORY_SELF_EDIT_CORE_WRITE_REJECTED`, `MEMORY_SELF_EDIT_ARCHIVAL_SEARCH`, `MEMORY_SELF_EDIT_ARCHIVAL_WRITE`, `MEMORY_SELF_EDIT_RECALL_READ`, `MEMORY_SELF_EDIT_RECALL_WRITE`, `MEMORY_SELF_EDIT_WRITE_FAILED` from `events.memory`, `REPORTING_GENERATION_STARTED` from `events.reporting`, `RISK_BUDGET_SCORE_COMPUTED` from `events.risk_budget`, `BUDGET_PROJECT_COST_QUERIED`, `BUDGET_PROJECT_RECORDS_QUERIED`, `BUDGET_PROJECT_BUDGET_EXCEEDED`, `BUDGET_PROJECT_ENFORCEMENT_CHECK`, `BUDGET_PROJECT_COST_AGGREGATED`, `BUDGET_PROJECT_COST_AGGREGATION_FAILED`, and `BUDGET_PROJECT_BASELINE_SOURCE` from `events.budget`, `LLM_STRATEGY_SYNTHESIZED` and `DISTILLATION_CAPTURED` from `events.consolidation`, `MEMORY_DIVERSITY_RERANKED`, `MEMORY_DIVERSITY_RERANK_FAILED`, and `MEMORY_REFORMULATION_ROUND` from `events.memory`, `NOTIFICATION_DISPATCHED` and `NOTIFICATION_DISPATCH_FAILED` from `events.notification`, `QUALITY_STEP_CLASSIFIED` from `events.quality`, `HEALTH_TICKET_EMITTED` from `events.health`, `TRAJECTORY_SCORING_START` from `events.trajectory`, `COORD_METRICS_AMDAHL_COMPUTED` from `events.coordination_metrics`, `COORDINATION_STARTED`, `COORDINATION_COMPLETED`, `COORDINATION_FAILED`, `COORDINATION_PHASE_STARTED`, `COORDINATION_PHASE_COMPLETED`, `COORDINATION_PHASE_FAILED`, `COORDINATION_WAVE_STARTED`, `COORDINATION_WAVE_COMPLETED`, `COORDINATION_TOPOLOGY_RESOLVED`, `COORDINATION_CLEANUP_STARTED`, `COORDINATION_CLEANUP_COMPLETED`, `COORDINATION_CLEANUP_FAILED`, `COORDINATION_WAVE_BUILT`, `COORDINATION_FACTORY_BUILT`, and `COORDINATION_ATTRIBUTION_BUILT` from `events.coordination`, `WEB_REQUEST_START` and `WEB_SSRF_BLOCKED` from `events.web`, `DB_QUERY_START` and `DB_WRITE_BLOCKED` from `events.database`, `TERMINAL_COMMAND_START` and `TERMINAL_COMMAND_BLOCKED` from `events.terminal`, `SUB_CONSTRAINT_RESOLVED` and `SUB_CONSTRAINT_DENIED` from `events.sub_constraint`, `VERSION_SAVED` and `VERSION_SNAPSHOT_FAILED` from `events.versioning`, `ANALYTICS_AGGREGATION_COMPUTED` and `ANALYTICS_RETRY_RATE_ALERT` from `events.analytics`, `CALL_CLASSIFICATION_COMPUTED` from `events.call_classification`, `QUOTA_THRESHOLD_ALERT` and `QUOTA_POLL_FAILED` from `events.quota`, `CONFLICT_DEBATE_EVALUATOR_FAILED` from `events.conflict`, `DELEGATION_LOOP_CIRCUIT_BACKOFF` and `DELEGATION_LOOP_CIRCUIT_PERSIST_FAILED` from `events.delegation`, `MEETING_EVENT_COOLDOWN_SKIPPED` and `MEETING_TASKS_CAPPED` from `events.meeting`, `PERSISTENCE_CIRCUIT_BREAKER_SAVED`, `PERSISTENCE_CIRCUIT_BREAKER_SAVE_FAILED`, `PERSISTENCE_CIRCUIT_BREAKER_LOADED`, `PERSISTENCE_CIRCUIT_BREAKER_LOAD_FAILED`, `PERSISTENCE_CIRCUIT_BREAKER_DELETED`, `PERSISTENCE_CIRCUIT_BREAKER_DELETE_FAILED`, `PERSISTENCE_PROJECT_COST_AGG_INCREMENTED`, `PERSISTENCE_PROJECT_COST_AGG_INCREMENT_FAILED`, `PERSISTENCE_PROJECT_COST_AGG_FETCHED`, `PERSISTENCE_PROJECT_COST_AGG_FETCH_FAILED`, and `PERSISTENCE_PROJECT_COST_AGG_DESERIALIZE_FAILED` from `events.persistence`, `METRICS_SCRAPE_COMPLETED`, `METRICS_SCRAPE_FAILED`, `METRICS_COLLECTOR_INITIALIZED`, `METRICS_COORDINATION_RECORDED`, `METRICS_OTLP_EXPORT_COMPLETED` and `METRICS_OTLP_FLUSHER_STOPPED` from `events.metrics`, `ORG_MEMORY_QUERY_START`, `ORG_MEMORY_QUERY_COMPLETE`, `ORG_MEMORY_QUERY_FAILED`, `ORG_MEMORY_WRITE_START`, `ORG_MEMORY_WRITE_COMPLETE`, `ORG_MEMORY_WRITE_DENIED`, `ORG_MEMORY_WRITE_FAILED`, `ORG_MEMORY_POLICIES_LISTED`, `ORG_MEMORY_BACKEND_CREATED`, `ORG_MEMORY_CONNECT_FAILED`, `ORG_MEMORY_DISCONNECT_FAILED`, `ORG_MEMORY_NOT_CONNECTED`, `ORG_MEMORY_ROW_PARSE_FAILED`, `ORG_MEMORY_CONFIG_INVALID`, `ORG_MEMORY_MODEL_INVALID`, `ORG_MEMORY_MVCC_PUBLISH_APPENDED`, `ORG_MEMORY_MVCC_RETRACT_APPENDED`, `ORG_MEMORY_MVCC_SNAPSHOT_AT_QUERIED`, and `ORG_MEMORY_MVCC_LOG_QUERIED` from `events.org_memory`). Each domain has its own module -- see `src/synthorg/observability/events/` for the full inventory of constants. Import directly: `from synthorg.observability.events. import EVENT_CONSTANT` - **Structured kwargs**: always `logger.info(EVENT, key=value)` -- never `logger.info("msg %s", val)` - **All error paths** must log at WARNING or ERROR with context before raising - **All state transitions** must log at INFO diff --git a/docs/design/operations.md b/docs/design/operations.md index 94b1dd97d9..ee1b520c1d 100644 --- a/docs/design/operations.md +++ b/docs/design/operations.md @@ -1854,9 +1854,9 @@ them is required to support the full control-plane positioning claim. | G3 | ~~No policy-as-code export/import~~ | ~~Medium~~ | **Implemented** -- `GET /settings/security/export` and `POST /settings/security/import` (persists registered settings; code-defined policies require matching Python code). | | G4 | ~~No coordination metrics API~~ | ~~Medium~~ | **Implemented** -- `GET /coordination/metrics` exposes the 9 Kim et al. metrics with filtering. | | G5 | ~~No audit log query API~~ | ~~Medium~~ | **Implemented** -- `GET /security/audit` with agent_id, tool_name, verdict, action_type, and time-range filters. | -| G6 | Budget history granularity | Low | `CostTracker` is in-memory with TTL eviction. Multi-dimensional queries (provider X, agent Y, period Z) require persistence layer investigation. | +| G6 | Budget history granularity | Low | `CostTracker` is in-memory with TTL eviction. Project-level lifetime budgets are now backed by a durable `project_cost_aggregates` table (#1156). Multi-dimensional queries (provider X, agent Y, period Z) still require full persistence layer investigation. | -All gaps G1-G5 are now closed. G6 (budget history granularity) remains low-priority. +All gaps G1-G5 are now closed. G6 (budget history granularity) is partially addressed: project-level budgets are durable; broader multi-dimensional queries remain low-priority. ### Recommended Framing diff --git a/src/synthorg/budget/__init__.py b/src/synthorg/budget/__init__.py index 4597830704..ab51be1979 100644 --- a/src/synthorg/budget/__init__.py +++ b/src/synthorg/budget/__init__.py @@ -69,6 +69,7 @@ from synthorg.budget.errors import ( BudgetExhaustedError, DailyLimitExceededError, + ProjectBudgetExhaustedError, QuotaExhaustedError, RiskBudgetExhaustedError, ) @@ -93,6 +94,10 @@ RoutingSuggestion, SpendingAnomaly, ) +from synthorg.budget.project_cost_aggregate import ( + ProjectCostAggregate, + ProjectCostAggregateRepository, +) from synthorg.budget.quota import ( DegradationAction, DegradationConfig, @@ -212,6 +217,9 @@ "PeriodComparison", "PeriodSpending", "PreFlightResult", + "ProjectBudgetExhaustedError", + "ProjectCostAggregate", + "ProjectCostAggregateRepository", "ProviderCostModel", "ProviderDistribution", "QuotaAlertThresholds", diff --git a/src/synthorg/budget/enforcer.py b/src/synthorg/budget/enforcer.py index 821a394328..64abc229f0 100644 --- a/src/synthorg/budget/enforcer.py +++ b/src/synthorg/budget/enforcer.py @@ -47,6 +47,7 @@ BUDGET_HARD_STOP_EXCEEDED, BUDGET_NOTIFICATION_FAILED, BUDGET_PREFLIGHT_ERROR, + BUDGET_PROJECT_BASELINE_SOURCE, BUDGET_PROJECT_BUDGET_EXCEEDED, BUDGET_PROJECT_ENFORCEMENT_CHECK, BUDGET_RESOLVE_MODEL_ERROR, @@ -71,6 +72,9 @@ from synthorg.budget.config import BudgetConfig from synthorg.budget.degradation import DegradationResult + from synthorg.budget.project_cost_aggregate import ( + ProjectCostAggregateRepository, + ) from synthorg.budget.quota import QuotaCheckResult from synthorg.budget.quota_tracker import QuotaTracker from synthorg.budget.risk_record import RiskRecord @@ -113,12 +117,14 @@ def __init__( # noqa: PLR0913 risk_tracker: RiskTracker | None = None, risk_scorer: RiskScorer | None = None, notification_dispatcher: NotificationDispatcher | None = None, + project_cost_repo: ProjectCostAggregateRepository | None = None, ) -> None: self._budget_config = budget_config self._cost_tracker = cost_tracker self._model_resolver = model_resolver self._quota_tracker = quota_tracker self._notification_dispatcher = notification_dispatcher + self._project_cost_repo = project_cost_repo self._degradation_configs: MappingProxyType[str, DegradationConfig] | None = ( MappingProxyType(copy.deepcopy(dict(degradation_configs))) if degradation_configs is not None @@ -253,15 +259,11 @@ async def check_project_budget( ) -> None: """Check project-level budget and raise if exceeded. - .. warning:: - - The current in-memory tracker applies retention-based - pruning (168 h). For projects whose lifetime exceeds - the retention window, older cost records are pruned and - the aggregate returned by ``get_project_cost`` under- - reports actual spend. A persistent cost-tracking backend - (planned) will resolve this; until then, project budgets - are accurate only within the retention window. + Uses the durable project cost aggregate when available, + providing accurate lifetime totals that survive the + in-memory tracker's 168-hour retention window. Falls + back to in-memory tracking when no aggregate repository + is configured or when the aggregate query fails. Args: project_id: Project identifier for cost lookup. @@ -273,18 +275,8 @@ async def check_project_budget( if project_budget <= 0: return - try: - project_cost = await self._cost_tracker.get_project_cost( - project_id, - ) - except MemoryError, RecursionError: - raise - except Exception: - logger.exception( - BUDGET_PREFLIGHT_ERROR, - project_id=project_id, - reason="project_cost_query_failed", - ) + project_cost = await self._get_project_cost(project_id) + if project_cost is None: return logger.debug( @@ -636,19 +628,9 @@ async def make_budget_checker( project_baseline = 0.0 if project_id is not None and project_budget > 0: - try: - project_baseline = await self._cost_tracker.get_project_cost( - project_id, - ) - except MemoryError, RecursionError: - raise - except Exception: - logger.exception( - BUDGET_BASELINE_ERROR, - agent_id=agent_id, - project_id=project_id, - reason="project_baseline_query_failed", - ) + baseline = await self._get_project_cost(project_id) + if baseline is not None: + project_baseline = baseline thresholds = _compute_thresholds(cfg, monthly_budget) @@ -884,3 +866,60 @@ async def _compute_baselines( ) return monthly_baseline, daily_baseline + + async def _get_project_cost( + self, + project_id: str, + ) -> float | None: + """Query project cost from durable aggregate or in-memory tracker. + + Returns the total cost, or ``None`` when both sources fail + (caller should skip enforcement on ``None``). + """ + # Try durable aggregate first. + if self._project_cost_repo is not None: + try: + aggregate = await self._project_cost_repo.get( + project_id, + ) + except MemoryError, RecursionError: + raise + except Exception: + logger.exception( + BUDGET_PREFLIGHT_ERROR, + project_id=project_id, + reason="project_cost_aggregate_query_failed", + ) + # Fall through to in-memory. + else: + cost = aggregate.total_cost if aggregate else 0.0 + logger.debug( + BUDGET_PROJECT_BASELINE_SOURCE, + project_id=project_id, + source="aggregate", + cost=cost, + ) + return cost + + # Fallback to in-memory tracker. + try: + cost = await self._cost_tracker.get_project_cost( + project_id, + ) + except MemoryError, RecursionError: + raise + except Exception: + logger.exception( + BUDGET_PREFLIGHT_ERROR, + project_id=project_id, + reason="project_cost_query_failed", + ) + return None + else: + logger.debug( + BUDGET_PROJECT_BASELINE_SOURCE, + project_id=project_id, + source="in_memory", + cost=cost, + ) + return cost diff --git a/src/synthorg/budget/project_cost_aggregate.py b/src/synthorg/budget/project_cost_aggregate.py new file mode 100644 index 0000000000..63fbd78d56 --- /dev/null +++ b/src/synthorg/budget/project_cost_aggregate.py @@ -0,0 +1,95 @@ +"""Durable per-project cost aggregate model and repository protocol. + +Stores lifetime cost totals per project, surviving the in-memory +CostTracker's 168-hour retention window. Updated atomically on +each cost recording; queried by BudgetEnforcer for project-level +budget enforcement. +""" + +from typing import Protocol, runtime_checkable + +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field + +from synthorg.core.types import NotBlankStr # noqa: TC001 + + +class ProjectCostAggregate(BaseModel): + """Immutable snapshot of a project's lifetime cost totals. + + One row per project in the ``project_cost_aggregates`` table. + Totals are monotonically increasing (never pruned). + + Attributes: + project_id: Unique project identifier (primary key). + total_cost: Accumulated cost in base currency. + total_input_tokens: Accumulated input token count. + total_output_tokens: Accumulated output token count. + record_count: Number of cost records aggregated. + last_updated: Timestamp of the most recent increment. + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + project_id: NotBlankStr = Field(description="Project identifier") + total_cost: float = Field(ge=0.0, description="Accumulated cost") + total_input_tokens: int = Field( + ge=0, + description="Accumulated input tokens", + ) + total_output_tokens: int = Field( + ge=0, + description="Accumulated output tokens", + ) + record_count: int = Field( + ge=0, + description="Number of cost records aggregated", + ) + last_updated: AwareDatetime = Field( + description="Timestamp of last increment", + ) + + +@runtime_checkable +class ProjectCostAggregateRepository(Protocol): + """Repository for durable per-project cost aggregates. + + Implementations must provide atomic increment semantics so + concurrent cost recordings do not lose updates. + """ + + async def get( + self, + project_id: NotBlankStr, + ) -> ProjectCostAggregate | None: + """Retrieve the aggregate for a project. + + Args: + project_id: Project identifier. + + Returns: + The aggregate, or ``None`` if no costs have been recorded. + """ + ... + + async def increment( + self, + project_id: NotBlankStr, + cost: float, + input_tokens: int, + output_tokens: int, + ) -> ProjectCostAggregate: + """Atomically increment the project's cost aggregate. + + Creates a new aggregate row on the first call for a project. + Subsequent calls increment the existing totals. + + Args: + project_id: Project identifier. + cost: Cost delta to add. + input_tokens: Input token delta to add. + output_tokens: Output token delta to add. + + Returns: + The updated aggregate after the increment. + """ + ... diff --git a/src/synthorg/budget/tracker.py b/src/synthorg/budget/tracker.py index 13f3221bdd..fd595b54d8 100644 --- a/src/synthorg/budget/tracker.py +++ b/src/synthorg/budget/tracker.py @@ -37,6 +37,8 @@ BUDGET_DEPARTMENT_RESOLVE_FAILED, BUDGET_ORCHESTRATION_RATIO_ALERT, BUDGET_ORCHESTRATION_RATIO_QUERIED, + BUDGET_PROJECT_COST_AGGREGATED, + BUDGET_PROJECT_COST_AGGREGATION_FAILED, BUDGET_PROJECT_COST_QUERIED, BUDGET_PROJECT_RECORDS_QUERIED, BUDGET_PROVIDER_USAGE_QUERIED, @@ -59,6 +61,9 @@ OrchestrationAlertThresholds, ) from synthorg.budget.cost_record import CostRecord + from synthorg.budget.project_cost_aggregate import ( + ProjectCostAggregateRepository, + ) from synthorg.core.types import NotBlankStr # noqa: TC001 @@ -113,6 +118,7 @@ def __init__( budget_config: BudgetConfig | None = None, department_resolver: Callable[[str], str | None] | None = None, auto_prune_threshold: int = _AUTO_PRUNE_THRESHOLD, + project_cost_repo: ProjectCostAggregateRepository | None = None, ) -> None: if auto_prune_threshold < 1: msg = f"auto_prune_threshold must be >= 1, got {auto_prune_threshold}" @@ -122,10 +128,12 @@ def __init__( self._budget_config = budget_config self._department_resolver = department_resolver self._auto_prune_threshold = auto_prune_threshold + self._project_cost_repo = project_cost_repo logger.debug( BUDGET_TRACKER_CREATED, has_budget_config=budget_config is not None, has_department_resolver=department_resolver is not None, + has_project_cost_repo=project_cost_repo is not None, ) @property @@ -140,6 +148,11 @@ def budget_config(self) -> BudgetConfig | None: async def record(self, cost_record: CostRecord) -> None: """Append a cost record. + Also updates the durable project cost aggregate when the + record has a ``project_id`` and a repository is configured. + Aggregate writes are best-effort: failures are logged but + do not affect the in-memory recording. + Args: cost_record: Immutable cost record to store. """ @@ -152,6 +165,8 @@ async def record(self, cost_record: CostRecord) -> None: cost_usd=cost_record.cost_usd, ) + await self._update_project_aggregate(cost_record) + async def prune_expired(self, *, now: datetime | None = None) -> int: """Remove records older than the 168-hour (7-day) cost window. @@ -568,6 +583,40 @@ async def get_orchestration_ratio( # ── Private helpers ────────────────────────────────────────────── + async def _update_project_aggregate( + self, + cost_record: CostRecord, + ) -> None: + """Best-effort update of the durable project cost aggregate. + + No-op when the record has no ``project_id`` or no repository + is configured. Failures are logged at WARNING and swallowed. + """ + if self._project_cost_repo is None or cost_record.project_id is None: + return + + try: + await self._project_cost_repo.increment( + cost_record.project_id, + cost_record.cost_usd, + cost_record.input_tokens, + cost_record.output_tokens, + ) + logger.debug( + BUDGET_PROJECT_COST_AGGREGATED, + project_id=cost_record.project_id, + cost_usd=cost_record.cost_usd, + ) + except MemoryError, RecursionError: + raise + except Exception: + logger.warning( + BUDGET_PROJECT_COST_AGGREGATION_FAILED, + project_id=cost_record.project_id, + cost_usd=cost_record.cost_usd, + exc_info=True, + ) + async def _snapshot( self, *, diff --git a/src/synthorg/observability/events/budget.py b/src/synthorg/observability/events/budget.py index 51517bbdb3..1d7a59d3dd 100644 --- a/src/synthorg/observability/events/budget.py +++ b/src/synthorg/observability/events/budget.py @@ -54,3 +54,10 @@ BUDGET_PROJECT_RECORDS_QUERIED: Final[str] = "budget.project_records.queried" BUDGET_PROJECT_BUDGET_EXCEEDED: Final[str] = "budget.project_budget.exceeded" BUDGET_PROJECT_ENFORCEMENT_CHECK: Final[str] = "budget.project.enforcement_check" + +# -- Durable project cost aggregate events -- +BUDGET_PROJECT_COST_AGGREGATED: Final[str] = "budget.project_cost.aggregated" +BUDGET_PROJECT_COST_AGGREGATION_FAILED: Final[str] = ( + "budget.project_cost.aggregation_failed" +) +BUDGET_PROJECT_BASELINE_SOURCE: Final[str] = "budget.project_baseline.source" diff --git a/src/synthorg/observability/events/persistence.py b/src/synthorg/observability/events/persistence.py index ac4c16b93c..546d6baf66 100644 --- a/src/synthorg/observability/events/persistence.py +++ b/src/synthorg/observability/events/persistence.py @@ -261,6 +261,24 @@ "persistence.project.deserialize_failed" ) +# -- Project cost aggregate events -------------------------------------------- + +PERSISTENCE_PROJECT_COST_AGG_INCREMENTED: Final[str] = ( + "persistence.project_cost_agg.incremented" +) +PERSISTENCE_PROJECT_COST_AGG_INCREMENT_FAILED: Final[str] = ( + "persistence.project_cost_agg.increment_failed" +) +PERSISTENCE_PROJECT_COST_AGG_FETCHED: Final[str] = ( + "persistence.project_cost_agg.fetched" +) +PERSISTENCE_PROJECT_COST_AGG_FETCH_FAILED: Final[str] = ( + "persistence.project_cost_agg.fetch_failed" +) +PERSISTENCE_PROJECT_COST_AGG_DESERIALIZE_FAILED: Final[str] = ( + "persistence.project_cost_agg.deserialize_failed" +) + # -- Workflow definition events ----------------------------------------------- PERSISTENCE_WORKFLOW_DEF_SAVED: Final[str] = "persistence.workflow_def.saved" diff --git a/src/synthorg/persistence/sqlite/backend.py b/src/synthorg/persistence/sqlite/backend.py index 924d1b2329..955d05e9d2 100644 --- a/src/synthorg/persistence/sqlite/backend.py +++ b/src/synthorg/persistence/sqlite/backend.py @@ -57,6 +57,9 @@ from synthorg.persistence.sqlite.preset_repo import ( SQLitePersonalityPresetRepository, ) +from synthorg.persistence.sqlite.project_cost_aggregate_repo import ( + SQLiteProjectCostAggregateRepository, +) from synthorg.persistence.sqlite.project_repo import ( SQLiteProjectRepository, ) @@ -144,6 +147,9 @@ def __init__(self, config: SQLiteConfig) -> None: self._risk_overrides: SQLiteRiskOverrideRepository | None = None self._ssrf_violations: SQLiteSsrfViolationRepository | None = None self._circuit_breaker_state: SQLiteCircuitBreakerStateRepository | None = None + self._project_cost_aggregates: SQLiteProjectCostAggregateRepository | None = ( + None + ) def _clear_state(self) -> None: """Reset connection and repository references to ``None``.""" @@ -173,6 +179,7 @@ def _clear_state(self) -> None: self._risk_overrides = None self._ssrf_violations = None self._circuit_breaker_state = None + self._project_cost_aggregates = None async def connect(self) -> None: """Open the SQLite database and configure WAL mode.""" @@ -279,6 +286,9 @@ def _create_repositories(self) -> None: self._db, write_lock=self._shared_write_lock, ) + self._project_cost_aggregates = SQLiteProjectCostAggregateRepository( + self._db, + ) async def _cleanup_failed_connect(self, exc: sqlite3.Error | OSError) -> None: """Log failure, close partial connection, and raise. @@ -542,6 +552,20 @@ def projects(self) -> SQLiteProjectRepository: """ return self._require_connected(self._projects, "projects") + @property + def project_cost_aggregates( + self, + ) -> SQLiteProjectCostAggregateRepository: + """Repository for durable project cost aggregates. + + Raises: + PersistenceConnectionError: If not connected. + """ + return self._require_connected( + self._project_cost_aggregates, + "project_cost_aggregates", + ) + @property def custom_presets(self) -> SQLitePersonalityPresetRepository: """Repository for custom personality preset persistence. diff --git a/src/synthorg/persistence/sqlite/project_cost_aggregate_repo.py b/src/synthorg/persistence/sqlite/project_cost_aggregate_repo.py new file mode 100644 index 0000000000..17a717deaf --- /dev/null +++ b/src/synthorg/persistence/sqlite/project_cost_aggregate_repo.py @@ -0,0 +1,180 @@ +"""SQLite repository for durable project cost aggregates.""" + +import sqlite3 +from datetime import UTC, datetime + +import aiosqlite +from pydantic import ValidationError + +from synthorg.budget.project_cost_aggregate import ProjectCostAggregate +from synthorg.core.types import NotBlankStr # noqa: TC001 +from synthorg.observability import get_logger +from synthorg.observability.events.persistence import ( + PERSISTENCE_PROJECT_COST_AGG_DESERIALIZE_FAILED, + PERSISTENCE_PROJECT_COST_AGG_FETCH_FAILED, + PERSISTENCE_PROJECT_COST_AGG_FETCHED, + PERSISTENCE_PROJECT_COST_AGG_INCREMENT_FAILED, + PERSISTENCE_PROJECT_COST_AGG_INCREMENTED, +) +from synthorg.persistence.errors import QueryError + +logger = get_logger(__name__) + +_UPSERT_SQL = """\ +INSERT INTO project_cost_aggregates + (project_id, total_cost, total_input_tokens, + total_output_tokens, record_count, last_updated) +VALUES (?, ?, ?, ?, 1, ?) +ON CONFLICT(project_id) DO UPDATE SET + total_cost = total_cost + excluded.total_cost, + total_input_tokens = total_input_tokens + excluded.total_input_tokens, + total_output_tokens = total_output_tokens + excluded.total_output_tokens, + record_count = record_count + 1, + last_updated = excluded.last_updated +""" + +_SELECT_SQL = """\ +SELECT project_id, total_cost, total_input_tokens, + total_output_tokens, record_count, last_updated +FROM project_cost_aggregates +WHERE project_id = ? +""" + + +def _row_to_aggregate(row: aiosqlite.Row) -> ProjectCostAggregate: + """Reconstruct a ``ProjectCostAggregate`` from a database row. + + Args: + row: A single database row. + + Returns: + Validated model instance. + + Raises: + ValidationError: If the row data fails Pydantic validation. + """ + data = dict(row) + return ProjectCostAggregate.model_validate(data) + + +class SQLiteProjectCostAggregateRepository: + """SQLite-backed project cost aggregate repository. + + Provides atomic increment and lookup for per-project lifetime + cost totals. Uses ``INSERT ... ON CONFLICT DO UPDATE`` for + atomic upsert semantics. + + Args: + db: An open aiosqlite connection with ``row_factory`` + set to ``aiosqlite.Row``. + """ + + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + async def get( + self, + project_id: NotBlankStr, + ) -> ProjectCostAggregate | None: + """Retrieve the aggregate for a project. + + Args: + project_id: Project identifier. + + Returns: + The aggregate, or ``None`` if no costs recorded. + + Raises: + QueryError: If the database operation fails. + """ + try: + cursor = await self._db.execute(_SELECT_SQL, (project_id,)) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + logger.exception( + PERSISTENCE_PROJECT_COST_AGG_FETCH_FAILED, + project_id=project_id, + error=str(exc), + ) + raise QueryError(str(exc)) from exc + + if row is None: + logger.debug( + PERSISTENCE_PROJECT_COST_AGG_FETCHED, + project_id=project_id, + found=False, + ) + return None + + try: + aggregate = _row_to_aggregate(row) + except ValidationError as exc: + logger.exception( + PERSISTENCE_PROJECT_COST_AGG_DESERIALIZE_FAILED, + project_id=project_id, + error=str(exc), + ) + raise QueryError(str(exc)) from exc + + logger.debug( + PERSISTENCE_PROJECT_COST_AGG_FETCHED, + project_id=project_id, + found=True, + total_cost=aggregate.total_cost, + record_count=aggregate.record_count, + ) + return aggregate + + async def increment( + self, + project_id: NotBlankStr, + cost: float, + input_tokens: int, + output_tokens: int, + ) -> ProjectCostAggregate: + """Atomically increment the project's cost aggregate. + + Creates a new row on first call; increments on subsequent. + + Args: + project_id: Project identifier. + cost: Cost delta to add. + input_tokens: Input token delta. + output_tokens: Output token delta. + + Returns: + The updated aggregate after the increment. + + Raises: + QueryError: If the database operation fails. + """ + now = datetime.now(UTC).isoformat() + try: + await self._db.execute( + _UPSERT_SQL, + (project_id, cost, input_tokens, output_tokens, now), + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + logger.exception( + PERSISTENCE_PROJECT_COST_AGG_INCREMENT_FAILED, + project_id=project_id, + cost=cost, + error=str(exc), + ) + raise QueryError(str(exc)) from exc + + # Read back the updated aggregate. + aggregate = await self.get(project_id) + if aggregate is None: # pragma: no cover -- defensive + msg = f"Aggregate for {project_id!r} missing after upsert" + raise QueryError(msg) + + logger.debug( + PERSISTENCE_PROJECT_COST_AGG_INCREMENTED, + project_id=project_id, + cost_delta=cost, + total_cost=aggregate.total_cost, + record_count=aggregate.record_count, + ) + return aggregate diff --git a/src/synthorg/persistence/sqlite/schema.sql b/src/synthorg/persistence/sqlite/schema.sql index 59ef513b85..ca8b60a09a 100644 --- a/src/synthorg/persistence/sqlite/schema.sql +++ b/src/synthorg/persistence/sqlite/schema.sql @@ -317,6 +317,16 @@ CREATE TABLE IF NOT EXISTS projects ( CREATE INDEX IF NOT EXISTS idx_projects_status ON projects(status); CREATE INDEX IF NOT EXISTS idx_projects_lead ON projects(lead); +-- ── Project-lifetime cost aggregates ───────────────────────── +CREATE TABLE IF NOT EXISTS project_cost_aggregates ( + project_id TEXT PRIMARY KEY, + total_cost REAL NOT NULL DEFAULT 0.0, + total_input_tokens INTEGER NOT NULL DEFAULT 0, + total_output_tokens INTEGER NOT NULL DEFAULT 0, + record_count INTEGER NOT NULL DEFAULT 0, + last_updated TEXT NOT NULL +); + -- ── Custom personality presets (user-defined) ──────────────── CREATE TABLE IF NOT EXISTS custom_presets ( name TEXT PRIMARY KEY CHECK(length(name) > 0), diff --git a/tests/unit/budget/test_enforcer_project_durable.py b/tests/unit/budget/test_enforcer_project_durable.py new file mode 100644 index 0000000000..55da73edeb --- /dev/null +++ b/tests/unit/budget/test_enforcer_project_durable.py @@ -0,0 +1,161 @@ +"""Unit tests for BudgetEnforcer with durable project cost aggregate.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock + +import pytest + +from synthorg.budget.config import BudgetConfig +from synthorg.budget.enforcer import BudgetEnforcer +from synthorg.budget.errors import ProjectBudgetExhaustedError +from synthorg.budget.project_cost_aggregate import ProjectCostAggregate +from synthorg.budget.tracker import CostTracker +from synthorg.core.task import Task + +from .conftest import make_cost_record + + +def _make_task() -> Task: + return Task( + id="t-1", + title="Test task", + description="A test task", + type="development", + priority="medium", + project="proj-1", + created_by="alice", + ) + + +def _make_aggregate( + project_id: str = "proj-1", + total_cost: float = 0.0, +) -> ProjectCostAggregate: + return ProjectCostAggregate( + project_id=project_id, + total_cost=total_cost, + total_input_tokens=0, + total_output_tokens=0, + record_count=1, + last_updated=datetime.now(UTC), + ) + + +def _make_repo( + get_return: ProjectCostAggregate | None = None, +) -> AsyncMock: + repo = AsyncMock() + repo.get = AsyncMock(return_value=get_return) + return repo + + +def _make_enforcer( + *, + tracker: CostTracker | None = None, + project_cost_repo: AsyncMock | None = None, +) -> BudgetEnforcer: + config = BudgetConfig(total_monthly=100.0) + t = tracker or CostTracker(budget_config=config) + return BudgetEnforcer( + budget_config=config, + cost_tracker=t, + project_cost_repo=project_cost_repo, + ) + + +@pytest.mark.unit +class TestCheckProjectBudgetDurable: + """Tests for check_project_budget() with durable aggregate.""" + + async def test_uses_aggregate_when_available(self) -> None: + repo = _make_repo(_make_aggregate(total_cost=8.0)) + enforcer = _make_enforcer(project_cost_repo=repo) + + # Should pass: 8.0 < 10.0 + await enforcer.check_project_budget("proj-1", project_budget=10.0) + + repo.get.assert_awaited_once_with("proj-1") + + async def test_raises_from_aggregate_data(self) -> None: + repo = _make_repo(_make_aggregate(total_cost=15.0)) + enforcer = _make_enforcer(project_cost_repo=repo) + + with pytest.raises(ProjectBudgetExhaustedError) as exc_info: + await enforcer.check_project_budget("proj-1", project_budget=10.0) + + assert exc_info.value.project_spent >= 15.0 + + async def test_falls_back_to_in_memory_on_repo_error(self) -> None: + repo = _make_repo() + repo.get.side_effect = RuntimeError("DB error") + + tracker = CostTracker() + await tracker.record(make_cost_record(project_id="proj-1", cost_usd=2.0)) + enforcer = _make_enforcer(tracker=tracker, project_cost_repo=repo) + + # Falls back to in-memory (2.0 < 10.0), should pass + await enforcer.check_project_budget("proj-1", project_budget=10.0) + + async def test_uses_in_memory_when_no_repo(self) -> None: + tracker = CostTracker() + await tracker.record(make_cost_record(project_id="proj-1", cost_usd=5.0)) + enforcer = _make_enforcer(tracker=tracker) + + # No repo -- uses in-memory tracker + with pytest.raises(ProjectBudgetExhaustedError): + await enforcer.check_project_budget("proj-1", project_budget=5.0) + + async def test_aggregate_none_treated_as_zero(self) -> None: + repo = _make_repo(get_return=None) + enforcer = _make_enforcer(project_cost_repo=repo) + + # No aggregate record -> 0.0 cost, passes any budget + await enforcer.check_project_budget("proj-1", project_budget=10.0) + + async def test_zero_budget_skips_regardless_of_repo(self) -> None: + repo = _make_repo(_make_aggregate(total_cost=999.0)) + enforcer = _make_enforcer(project_cost_repo=repo) + + # Zero budget means no enforcement + await enforcer.check_project_budget("proj-1", project_budget=0.0) + repo.get.assert_not_awaited() + + +@pytest.mark.unit +class TestMakeBudgetCheckerDurable: + """Tests for make_budget_checker() with durable aggregate baseline.""" + + async def test_uses_aggregate_baseline(self) -> None: + repo = _make_repo(_make_aggregate(total_cost=7.0)) + enforcer = _make_enforcer(project_cost_repo=repo) + + task = _make_task() + + checker = await enforcer.make_budget_checker( + task, + "alice", + project_id="proj-1", + project_budget=10.0, + ) + + assert checker is not None + repo.get.assert_awaited_once_with("proj-1") + + async def test_falls_back_to_in_memory_on_error(self) -> None: + repo = _make_repo() + repo.get.side_effect = RuntimeError("DB error") + + tracker = CostTracker() + await tracker.record(make_cost_record(project_id="proj-1", cost_usd=3.0)) + enforcer = _make_enforcer(tracker=tracker, project_cost_repo=repo) + + task = _make_task() + + # Should not raise -- falls back to in-memory baseline + checker = await enforcer.make_budget_checker( + task, + "alice", + project_id="proj-1", + project_budget=10.0, + ) + assert checker is not None diff --git a/tests/unit/budget/test_project_cost_aggregate.py b/tests/unit/budget/test_project_cost_aggregate.py new file mode 100644 index 0000000000..3dbe63ce6c --- /dev/null +++ b/tests/unit/budget/test_project_cost_aggregate.py @@ -0,0 +1,101 @@ +"""Unit tests for the ProjectCostAggregate model.""" + +from datetime import UTC, datetime + +import pytest +from pydantic import ValidationError + +from synthorg.budget.project_cost_aggregate import ( + ProjectCostAggregate, + ProjectCostAggregateRepository, +) + + +@pytest.mark.unit +class TestProjectCostAggregate: + """Tests for the ProjectCostAggregate frozen model.""" + + def test_valid_construction(self) -> None: + agg = ProjectCostAggregate( + project_id="proj-1", + total_cost=10.5, + total_input_tokens=1000, + total_output_tokens=500, + record_count=3, + last_updated=datetime.now(UTC), + ) + assert agg.project_id == "proj-1" + assert agg.total_cost == 10.5 + assert agg.total_input_tokens == 1000 + assert agg.total_output_tokens == 500 + assert agg.record_count == 3 + + def test_frozen(self) -> None: + agg = ProjectCostAggregate( + project_id="proj-1", + total_cost=1.0, + total_input_tokens=100, + total_output_tokens=50, + record_count=1, + last_updated=datetime.now(UTC), + ) + with pytest.raises(ValidationError): + agg.total_cost = 999.0 # type: ignore[misc] + + def test_rejects_blank_project_id(self) -> None: + with pytest.raises(ValidationError): + ProjectCostAggregate( + project_id=" ", + total_cost=0.0, + total_input_tokens=0, + total_output_tokens=0, + record_count=0, + last_updated=datetime.now(UTC), + ) + + def test_rejects_negative_cost(self) -> None: + with pytest.raises(ValidationError): + ProjectCostAggregate( + project_id="proj-1", + total_cost=-1.0, + total_input_tokens=0, + total_output_tokens=0, + record_count=0, + last_updated=datetime.now(UTC), + ) + + def test_rejects_negative_tokens(self) -> None: + with pytest.raises(ValidationError): + ProjectCostAggregate( + project_id="proj-1", + total_cost=0.0, + total_input_tokens=-1, + total_output_tokens=0, + record_count=0, + last_updated=datetime.now(UTC), + ) + + def test_rejects_nan(self) -> None: + with pytest.raises(ValidationError): + ProjectCostAggregate( + project_id="proj-1", + total_cost=float("nan"), + total_input_tokens=0, + total_output_tokens=0, + record_count=0, + last_updated=datetime.now(UTC), + ) + + def test_rejects_inf(self) -> None: + with pytest.raises(ValidationError): + ProjectCostAggregate( + project_id="proj-1", + total_cost=float("inf"), + total_input_tokens=0, + total_output_tokens=0, + record_count=0, + last_updated=datetime.now(UTC), + ) + + def test_protocol_is_runtime_checkable(self) -> None: + assert hasattr(ProjectCostAggregateRepository, "__protocol_attrs__") diff --git a/tests/unit/budget/test_tracker_project_aggregate.py b/tests/unit/budget/test_tracker_project_aggregate.py new file mode 100644 index 0000000000..339df0f0bc --- /dev/null +++ b/tests/unit/budget/test_tracker_project_aggregate.py @@ -0,0 +1,90 @@ +"""Unit tests for CostTracker project aggregate write path.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock + +import pytest + +from synthorg.budget.project_cost_aggregate import ( + ProjectCostAggregate, +) +from synthorg.budget.tracker import CostTracker + +from .conftest import make_cost_record + + +def _make_mock_repo() -> AsyncMock: + """Build a mock ProjectCostAggregateRepository.""" + repo = AsyncMock() + repo.increment = AsyncMock( + return_value=ProjectCostAggregate( + project_id="proj-1", + total_cost=1.0, + total_input_tokens=100, + total_output_tokens=50, + record_count=1, + last_updated=datetime.now(UTC), + ), + ) + return repo + + +@pytest.mark.unit +class TestTrackerProjectAggregate: + """Tests for CostTracker aggregate write path.""" + + async def test_record_calls_repo_increment_for_project(self) -> None: + repo = _make_mock_repo() + tracker = CostTracker(project_cost_repo=repo) + record = make_cost_record(project_id="proj-1", cost_usd=1.0) + + await tracker.record(record) + + repo.increment.assert_awaited_once_with( + "proj-1", + 1.0, + record.input_tokens, + record.output_tokens, + ) + + async def test_record_skips_repo_when_no_project_id(self) -> None: + repo = _make_mock_repo() + tracker = CostTracker(project_cost_repo=repo) + record = make_cost_record(project_id=None) + + await tracker.record(record) + + repo.increment.assert_not_awaited() + + async def test_record_succeeds_when_repo_raises(self) -> None: + repo = _make_mock_repo() + repo.increment.side_effect = RuntimeError("DB down") + tracker = CostTracker(project_cost_repo=repo) + record = make_cost_record(project_id="proj-1") + + # Should not raise -- aggregate write is best-effort + await tracker.record(record) + + # In-memory record still present + count = await tracker.get_record_count() + assert count == 1 + + async def test_record_works_without_repo(self) -> None: + tracker = CostTracker() + record = make_cost_record(project_id="proj-1") + + await tracker.record(record) + + count = await tracker.get_record_count() + assert count == 1 + + async def test_in_memory_still_works_alongside_repo(self) -> None: + repo = _make_mock_repo() + tracker = CostTracker(project_cost_repo=repo) + + await tracker.record(make_cost_record(project_id="proj-1", cost_usd=2.0)) + await tracker.record(make_cost_record(project_id="proj-1", cost_usd=3.0)) + + # In-memory queries still work + cost = await tracker.get_project_cost("proj-1") + assert cost == pytest.approx(5.0) diff --git a/tests/unit/persistence/sqlite/test_migrations.py b/tests/unit/persistence/sqlite/test_migrations.py index 54acaa1f5a..fc02b6618f 100644 --- a/tests/unit/persistence/sqlite/test_migrations.py +++ b/tests/unit/persistence/sqlite/test_migrations.py @@ -43,6 +43,7 @@ "ssrf_violations", "agent_identity_versions", "circuit_breaker_state", + "project_cost_aggregates", } _EXPECTED_INDEXES = { diff --git a/tests/unit/persistence/sqlite/test_project_cost_aggregate_repo.py b/tests/unit/persistence/sqlite/test_project_cost_aggregate_repo.py new file mode 100644 index 0000000000..fbe47a8c36 --- /dev/null +++ b/tests/unit/persistence/sqlite/test_project_cost_aggregate_repo.py @@ -0,0 +1,106 @@ +"""Unit tests for SQLiteProjectCostAggregateRepository.""" + +from typing import TYPE_CHECKING + +import pytest + +from synthorg.persistence.sqlite.project_cost_aggregate_repo import ( + SQLiteProjectCostAggregateRepository, +) + +if TYPE_CHECKING: + import aiosqlite + + +@pytest.mark.unit +class TestSQLiteProjectCostAggregateRepository: + """Tests for the durable project cost aggregate repo.""" + + async def test_get_returns_none_when_not_found( + self, + migrated_db: aiosqlite.Connection, + ) -> None: + repo = SQLiteProjectCostAggregateRepository(migrated_db) + result = await repo.get("proj-nonexistent") + assert result is None + + async def test_increment_creates_new_aggregate( + self, + migrated_db: aiosqlite.Connection, + ) -> None: + repo = SQLiteProjectCostAggregateRepository(migrated_db) + agg = await repo.increment("proj-1", 1.5, 100, 50) + + assert agg.project_id == "proj-1" + assert agg.total_cost == 1.5 + assert agg.total_input_tokens == 100 + assert agg.total_output_tokens == 50 + assert agg.record_count == 1 + + async def test_increment_updates_existing( + self, + migrated_db: aiosqlite.Connection, + ) -> None: + repo = SQLiteProjectCostAggregateRepository(migrated_db) + await repo.increment("proj-1", 1.0, 100, 50) + agg = await repo.increment("proj-1", 2.0, 200, 100) + + assert agg.total_cost == pytest.approx(3.0) + assert agg.total_input_tokens == 300 + assert agg.total_output_tokens == 150 + assert agg.record_count == 2 + + async def test_multiple_increments_accumulate( + self, + migrated_db: aiosqlite.Connection, + ) -> None: + repo = SQLiteProjectCostAggregateRepository(migrated_db) + for _ in range(5): + await repo.increment("proj-1", 0.1, 10, 5) + + agg = await repo.get("proj-1") + assert agg is not None + assert agg.total_cost == pytest.approx(0.5) + assert agg.total_input_tokens == 50 + assert agg.total_output_tokens == 25 + assert agg.record_count == 5 + + async def test_get_after_increment( + self, + migrated_db: aiosqlite.Connection, + ) -> None: + repo = SQLiteProjectCostAggregateRepository(migrated_db) + await repo.increment("proj-1", 3.0, 500, 200) + + agg = await repo.get("proj-1") + assert agg is not None + assert agg.total_cost == 3.0 + assert agg.total_input_tokens == 500 + assert agg.total_output_tokens == 200 + assert agg.record_count == 1 + + async def test_isolation_between_projects( + self, + migrated_db: aiosqlite.Connection, + ) -> None: + repo = SQLiteProjectCostAggregateRepository(migrated_db) + await repo.increment("proj-a", 10.0, 1000, 500) + await repo.increment("proj-b", 5.0, 200, 100) + + agg_a = await repo.get("proj-a") + agg_b = await repo.get("proj-b") + + assert agg_a is not None + assert agg_b is not None + assert agg_a.total_cost == 10.0 + assert agg_b.total_cost == 5.0 + + async def test_last_updated_changes( + self, + migrated_db: aiosqlite.Connection, + ) -> None: + repo = SQLiteProjectCostAggregateRepository(migrated_db) + agg1 = await repo.increment("proj-1", 1.0, 10, 5) + agg2 = await repo.increment("proj-1", 1.0, 10, 5) + + assert agg2.last_updated >= agg1.last_updated From 625ef2b84d190bc23fad13a79a347b79715342fc Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Wed, 8 Apr 2026 23:43:11 +0200 Subject: [PATCH 2/4] fix: use enum types in test Task construction for mypy --- tests/unit/budget/test_enforcer_project_durable.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/budget/test_enforcer_project_durable.py b/tests/unit/budget/test_enforcer_project_durable.py index 55da73edeb..03aa86dd50 100644 --- a/tests/unit/budget/test_enforcer_project_durable.py +++ b/tests/unit/budget/test_enforcer_project_durable.py @@ -10,6 +10,7 @@ from synthorg.budget.errors import ProjectBudgetExhaustedError from synthorg.budget.project_cost_aggregate import ProjectCostAggregate from synthorg.budget.tracker import CostTracker +from synthorg.core.enums import Priority, TaskType from synthorg.core.task import Task from .conftest import make_cost_record @@ -20,8 +21,8 @@ def _make_task() -> Task: id="t-1", title="Test task", description="A test task", - type="development", - priority="medium", + type=TaskType.DEVELOPMENT, + priority=Priority.MEDIUM, project="proj-1", created_by="alice", ) From 2f634d63845dea80e303dad7af3819e7e29df587 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:13:08 +0200 Subject: [PATCH 3/4] fix: address pre-PR review findings for persistent cost aggregation --- docs/reference/claude-reference.md | 4 +- src/synthorg/budget/enforcer.py | 14 +++-- src/synthorg/budget/tracker.py | 12 ++-- src/synthorg/persistence/sqlite/backend.py | 1 + .../sqlite/project_cost_aggregate_repo.py | 37 +++++++++--- src/synthorg/persistence/sqlite/schema.sql | 10 ++-- .../budget/test_enforcer_project_durable.py | 27 +++++++++ .../test_project_cost_aggregate_repo.py | 56 +++++++++++++++++++ 8 files changed, 136 insertions(+), 25 deletions(-) diff --git a/docs/reference/claude-reference.md b/docs/reference/claude-reference.md index d64550af62..d72032f13b 100644 --- a/docs/reference/claude-reference.md +++ b/docs/reference/claude-reference.md @@ -39,7 +39,7 @@ curl http://localhost:3000/api/v1/health # backend (via web proxy) src/synthorg/ api/ # Litestar REST + WebSocket API, RFC 9457 errors, setup wizard, personality presets, auth/ (role-based access control, HttpOnly cookie sessions, CSRF double-submit, lockout_store, refresh_store, concurrent session enforcement, session store, user presence, OrgRole enum for org config permissions), guards (HumanRole-based + OrgRole-based with department scoping via require_org_mutation), user management (CRUD + org-role grant/revoke), dto_org (request DTOs for company/department/agent mutations), services/org_mutations (read-modify-write config mutation service), auto-wiring, lifecycle (auto-promote first owner), bootstrap (agent registry init from config), template packs (list + live-apply), memory admin (fine-tuning pipeline with orchestrator, checkpoint management, preflight checks, run history, embedder queries), optimistic concurrency (ETag/If-Match), TLS config, tiered rate limiting (unauth by IP, auth by user ID), workflows (visual workflow definition CRUD, validation, YAML export, blueprint listing, blueprint instantiation, version history, diff, rollback), workflow executions (activate, list, get, cancel), ceremony policy (project + per-department query/override, resolved policy with field origins), quality overrides (per-agent quality score override CRUD), reports (on-demand report generation, period listing), notification_dispatcher (fan-out notification sink) backup/ # Backup/restore orchestrator, scheduler, retention, handlers/ - budget/ # Cost tracking, budget enforcement, quota degradation (including synchronous peek for routing-time selector hints), CFO optimization, trend analysis, budget forecasting, configurable currency formatting, risk budget (cumulative risk-unit tracking, risk scoring integration, risk check, risk records), automated reporting (periodic comprehensive reports, spending/performance/task-completion/risk-trends templates, report scheduling config), coordination metrics (9 empirical metrics: efficiency, overhead, error amplification, message density, redundancy, Amdahl ceiling, straggler gap, token/speedup ratio, message overhead) + budget/ # Cost tracking, budget enforcement, quota degradation (including synchronous peek for routing-time selector hints), CFO optimization, trend analysis, budget forecasting, configurable currency formatting, risk budget (cumulative risk-unit tracking, risk scoring integration, risk check, risk records), automated reporting (periodic comprehensive reports, spending/performance/task-completion/risk-trends templates, report scheduling config), coordination metrics (9 empirical metrics: efficiency, overhead, error amplification, message density, redundancy, Amdahl ceiling, straggler gap, token/speedup ratio, message overhead), project cost aggregates (durable per-project lifetime cost totals surviving retention pruning) cli/ # Python CLI module (superseded by top-level cli/ Go binary) communication/ # Message bus, dispatcher, channels, delegation, conflict resolution, meeting/ config/ # YAML company config loading and validation @@ -48,7 +48,7 @@ src/synthorg/ hr/ # Hiring, firing, onboarding, agent registry, performance tracking, activity timeline, activity event types, cost event redaction, career history, promotion/demotion, evaluation/ (five-pillar evaluation framework, pluggable pillar scoring strategies, EvaluationConfig), quality scoring (layered composite: CI signal + LLM judge + human override, QualityOverrideStore) notifications/ # NotificationSink protocol, NotificationDispatcher fan-out, Notification model (category taxonomy: approval/budget/security/stagnation/system/agent/health + severity taxonomy), adapters/ (console, ntfy, slack, email), config memory/ # Pluggable MemoryBackend, retrieval pipeline (hybrid dense+BM25 sparse with RRF fusion, MMR diversity re-ranking via apply_diversity_penalty with pre-computed bigram cache), tool-based injection strategy with iterative Search-and-Ask reformulation loop (fail-safe reformulator/sufficiency_checker), ToolRegistry memory tool wrappers (SearchMemoryTool, RecallMemoryTool), fail-closed memory filter, agentic query reformulation, org memory, backends/ (composite namespace-based routing, inmemory session-scoped, mem0 Qdrant+SQLite, EmbeddingCostConfig embedding cost tracking), consolidation/ (SimpleConsolidationStrategy, DualModeConsolidationStrategy density-aware, LLMConsolidationStrategy with parallel TaskGroup per-category processing + trajectory-context injection from distillation entries, LLMConsolidationConfig, DistillationRequest capture helper tagged "distillation" EPISODIC, retention, archival), embedding/ (LMEB-ranked model selection, embedder config resolution, fine-tuning pipeline with orchestrator, cancellation, checkpoint management), procedural/ (failure-driven auto-generation, proposer LLM pipeline, SKILL.md materialization, ProceduralMemoryConfig) - persistence/ # Pluggable PersistenceBackend, SQLite, settings + user + artifact + project + preset + workflow definition + workflow execution + workflow version + agent identity versions + fine-tune + decision record (append-only audit drop-box) + risk override + SSRF violation repositories, artifact content storage (pluggable ArtifactStorageBackend, filesystem impl) + persistence/ # Pluggable PersistenceBackend, SQLite, settings + user + artifact + project + preset + workflow definition + workflow execution + workflow version + agent identity versions + fine-tune + decision record (append-only audit drop-box) + risk override + SSRF violation + project cost aggregate repositories, artifact content storage (pluggable ArtifactStorageBackend, filesystem impl) versioning/ # Generic versioning infrastructure: VersionSnapshot[T] model, VersioningService[T] (content-addressable deduplication via SHA-256 hash, INSERT OR IGNORE concurrent-write safety), compute_content_hash observability/ # Structured logging, correlation tracking, redaction, third-party logger taming, log shipping (syslog, HTTP), compressed archival, events/ providers/ # LLM provider abstraction, presets, model auto-discovery, capabilities, runtime CRUD (management/), local model management (pull/delete/config via LocalModelManager protocol), provider families, discovery SSRF allowlist, health tracking, active health probing, routing/ (strategy-based model routing, multi-provider resolution with ModelCandidateSelector protocol, QuotaAwareSelector, CheapestSelector) diff --git a/src/synthorg/budget/enforcer.py b/src/synthorg/budget/enforcer.py index 64abc229f0..a261dfbb83 100644 --- a/src/synthorg/budget/enforcer.py +++ b/src/synthorg/budget/enforcer.py @@ -104,6 +104,9 @@ class BudgetEnforcer: degradation_configs: Per-provider degradation strategies. risk_tracker: Optional risk tracking service. risk_scorer: Optional risk scoring implementation. + notification_dispatcher: Optional notification dispatcher. + project_cost_repo: Optional durable project cost aggregate + repository for lifetime budget enforcement. """ def __init__( # noqa: PLR0913 @@ -259,11 +262,12 @@ async def check_project_budget( ) -> None: """Check project-level budget and raise if exceeded. - Uses the durable project cost aggregate when available, - providing accurate lifetime totals that survive the - in-memory tracker's 168-hour retention window. Falls - back to in-memory tracking when no aggregate repository - is configured or when the aggregate query fails. + Returns immediately when ``project_budget <= 0`` (enforcement + disabled). Otherwise uses the durable project cost aggregate + when available, providing accurate lifetime totals that survive + the in-memory tracker's 168-hour retention window. Falls back + to in-memory tracking when no aggregate repository is configured + or when the aggregate query fails. Args: project_id: Project identifier for cost lookup. diff --git a/src/synthorg/budget/tracker.py b/src/synthorg/budget/tracker.py index fd595b54d8..0ce7def2e1 100644 --- a/src/synthorg/budget/tracker.py +++ b/src/synthorg/budget/tracker.py @@ -148,14 +148,18 @@ def budget_config(self) -> BudgetConfig | None: async def record(self, cost_record: CostRecord) -> None: """Append a cost record. - Also updates the durable project cost aggregate when the - record has a ``project_id`` and a repository is configured. - Aggregate writes are best-effort: failures are logged but - do not affect the in-memory recording. + After the in-memory write, updates the durable project cost + aggregate when the record has a ``project_id`` and a repository + is configured. Aggregate updates are best-effort: failures are + logged at WARNING by ``_update_project_aggregate`` but do not + affect the in-memory recording. Args: cost_record: Immutable cost record to store. """ + # Lock protects in-memory list only. DB aggregate update is + # best-effort and runs outside the lock to avoid blocking other + # callers on I/O. async with self._lock: self._records.append(cost_record) logger.info( diff --git a/src/synthorg/persistence/sqlite/backend.py b/src/synthorg/persistence/sqlite/backend.py index 955d05e9d2..b2bad00b88 100644 --- a/src/synthorg/persistence/sqlite/backend.py +++ b/src/synthorg/persistence/sqlite/backend.py @@ -288,6 +288,7 @@ def _create_repositories(self) -> None: ) self._project_cost_aggregates = SQLiteProjectCostAggregateRepository( self._db, + write_lock=self._shared_write_lock, ) async def _cleanup_failed_connect(self, exc: sqlite3.Error | OSError) -> None: diff --git a/src/synthorg/persistence/sqlite/project_cost_aggregate_repo.py b/src/synthorg/persistence/sqlite/project_cost_aggregate_repo.py index 17a717deaf..22353f1d01 100644 --- a/src/synthorg/persistence/sqlite/project_cost_aggregate_repo.py +++ b/src/synthorg/persistence/sqlite/project_cost_aggregate_repo.py @@ -1,5 +1,6 @@ """SQLite repository for durable project cost aggregates.""" +import asyncio import sqlite3 from datetime import UTC, datetime @@ -67,10 +68,18 @@ class SQLiteProjectCostAggregateRepository: Args: db: An open aiosqlite connection with ``row_factory`` set to ``aiosqlite.Row``. + write_lock: Optional shared write lock for serialising + multi-statement write operations. """ - def __init__(self, db: aiosqlite.Connection) -> None: + def __init__( + self, + db: aiosqlite.Connection, + *, + write_lock: asyncio.Lock | None = None, + ) -> None: self._db = db + self._write_lock = write_lock or asyncio.Lock() async def get( self, @@ -138,23 +147,33 @@ async def increment( Args: project_id: Project identifier. - cost: Cost delta to add. - input_tokens: Input token delta. - output_tokens: Output token delta. + cost: Cost delta to add (must be >= 0). + input_tokens: Input token delta (must be >= 0). + output_tokens: Output token delta (must be >= 0). Returns: The updated aggregate after the increment. Raises: QueryError: If the database operation fails. + ValueError: If any delta is negative. """ + if cost < 0 or input_tokens < 0 or output_tokens < 0: + msg = ( + f"Deltas must be non-negative: " + f"cost={cost}, input_tokens={input_tokens}, " + f"output_tokens={output_tokens}" + ) + raise ValueError(msg) + now = datetime.now(UTC).isoformat() try: - await self._db.execute( - _UPSERT_SQL, - (project_id, cost, input_tokens, output_tokens, now), - ) - await self._db.commit() + async with self._write_lock: + await self._db.execute( + _UPSERT_SQL, + (project_id, cost, input_tokens, output_tokens, now), + ) + await self._db.commit() except (sqlite3.Error, aiosqlite.Error) as exc: logger.exception( PERSISTENCE_PROJECT_COST_AGG_INCREMENT_FAILED, diff --git a/src/synthorg/persistence/sqlite/schema.sql b/src/synthorg/persistence/sqlite/schema.sql index ca8b60a09a..a70d78d038 100644 --- a/src/synthorg/persistence/sqlite/schema.sql +++ b/src/synthorg/persistence/sqlite/schema.sql @@ -319,11 +319,11 @@ CREATE INDEX IF NOT EXISTS idx_projects_lead ON projects(lead); -- ── Project-lifetime cost aggregates ───────────────────────── CREATE TABLE IF NOT EXISTS project_cost_aggregates ( - project_id TEXT PRIMARY KEY, - total_cost REAL NOT NULL DEFAULT 0.0, - total_input_tokens INTEGER NOT NULL DEFAULT 0, - total_output_tokens INTEGER NOT NULL DEFAULT 0, - record_count INTEGER NOT NULL DEFAULT 0, + project_id TEXT PRIMARY KEY CHECK(length(project_id) > 0), + total_cost REAL NOT NULL DEFAULT 0.0 CHECK(total_cost >= 0.0), + total_input_tokens INTEGER NOT NULL DEFAULT 0 CHECK(total_input_tokens >= 0), + total_output_tokens INTEGER NOT NULL DEFAULT 0 CHECK(total_output_tokens >= 0), + record_count INTEGER NOT NULL DEFAULT 0 CHECK(record_count >= 0), last_updated TEXT NOT NULL ); diff --git a/tests/unit/budget/test_enforcer_project_durable.py b/tests/unit/budget/test_enforcer_project_durable.py index 03aa86dd50..cb9a281d16 100644 --- a/tests/unit/budget/test_enforcer_project_durable.py +++ b/tests/unit/budget/test_enforcer_project_durable.py @@ -113,6 +113,13 @@ async def test_aggregate_none_treated_as_zero(self) -> None: # No aggregate record -> 0.0 cost, passes any budget await enforcer.check_project_budget("proj-1", project_budget=10.0) + async def test_raises_at_exact_boundary(self) -> None: + repo = _make_repo(_make_aggregate(total_cost=10.0)) + enforcer = _make_enforcer(project_cost_repo=repo) + + with pytest.raises(ProjectBudgetExhaustedError): + await enforcer.check_project_budget("proj-1", project_budget=10.0) + async def test_zero_budget_skips_regardless_of_repo(self) -> None: repo = _make_repo(_make_aggregate(total_cost=999.0)) enforcer = _make_enforcer(project_cost_repo=repo) @@ -160,3 +167,23 @@ async def test_falls_back_to_in_memory_on_error(self) -> None: project_budget=10.0, ) assert checker is not None + + async def test_both_sources_fail_uses_zero_baseline(self) -> None: + repo = _make_repo() + repo.get.side_effect = RuntimeError("DB error") + + tracker = CostTracker() + enforcer = _make_enforcer(tracker=tracker, project_cost_repo=repo) + + task = _make_task() + + # Both repo and in-memory fail for a fresh tracker with + # no records -- _get_project_cost returns None, baseline + # defaults to 0.0, checker is still created. + checker = await enforcer.make_budget_checker( + task, + "alice", + project_id="proj-1", + project_budget=10.0, + ) + assert checker is not None diff --git a/tests/unit/persistence/sqlite/test_project_cost_aggregate_repo.py b/tests/unit/persistence/sqlite/test_project_cost_aggregate_repo.py index fbe47a8c36..fef9efa8ae 100644 --- a/tests/unit/persistence/sqlite/test_project_cost_aggregate_repo.py +++ b/tests/unit/persistence/sqlite/test_project_cost_aggregate_repo.py @@ -1,9 +1,12 @@ """Unit tests for SQLiteProjectCostAggregateRepository.""" +import sqlite3 from typing import TYPE_CHECKING +from unittest.mock import AsyncMock, patch import pytest +from synthorg.persistence.errors import QueryError from synthorg.persistence.sqlite.project_cost_aggregate_repo import ( SQLiteProjectCostAggregateRepository, ) @@ -104,3 +107,56 @@ async def test_last_updated_changes( agg2 = await repo.increment("proj-1", 1.0, 10, 5) assert agg2.last_updated >= agg1.last_updated + + async def test_zero_cost_increment( + self, + migrated_db: aiosqlite.Connection, + ) -> None: + repo = SQLiteProjectCostAggregateRepository(migrated_db) + agg = await repo.increment("proj-1", 0.0, 0, 0) + + assert agg.total_cost == 0.0 + assert agg.record_count == 1 + + agg2 = await repo.increment("proj-1", 0.0, 0, 0) + assert agg2.record_count == 2 + + async def test_get_raises_query_error_on_db_failure( + self, + migrated_db: aiosqlite.Connection, + ) -> None: + repo = SQLiteProjectCostAggregateRepository(migrated_db) + with ( + patch.object( + migrated_db, + "execute", + new_callable=AsyncMock, + side_effect=sqlite3.OperationalError("disk I/O error"), + ), + pytest.raises(QueryError), + ): + await repo.get("proj-1") + + async def test_increment_raises_query_error_on_db_failure( + self, + migrated_db: aiosqlite.Connection, + ) -> None: + repo = SQLiteProjectCostAggregateRepository(migrated_db) + with ( + patch.object( + migrated_db, + "execute", + new_callable=AsyncMock, + side_effect=sqlite3.OperationalError("disk I/O error"), + ), + pytest.raises(QueryError), + ): + await repo.increment("proj-1", 1.0, 100, 50) + + async def test_increment_rejects_negative_deltas( + self, + migrated_db: aiosqlite.Connection, + ) -> None: + repo = SQLiteProjectCostAggregateRepository(migrated_db) + with pytest.raises(ValueError, match="non-negative"): + await repo.increment("proj-1", -1.0, 100, 50) From 13c8af6340a272752bfd605b75867c0167457099 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:46:27 +0200 Subject: [PATCH 4/4] fix: address 17 PR review items from local agents, Gemini, Copilot, CodeRabbit --- src/synthorg/budget/enforcer.py | 38 +++++++++--- src/synthorg/budget/tracker.py | 11 ++-- .../sqlite/project_cost_aggregate_repo.py | 58 +++++++++++++++---- src/synthorg/persistence/sqlite/schema.sql | 4 +- .../budget/test_enforcer_project_durable.py | 6 +- .../budget/test_project_cost_aggregate.py | 18 +++++- .../test_project_cost_aggregate_repo.py | 24 +++++++- 7 files changed, 128 insertions(+), 31 deletions(-) diff --git a/src/synthorg/budget/enforcer.py b/src/synthorg/budget/enforcer.py index a261dfbb83..2d27c3e98b 100644 --- a/src/synthorg/budget/enforcer.py +++ b/src/synthorg/budget/enforcer.py @@ -82,6 +82,7 @@ from synthorg.budget.tracker import CostTracker from synthorg.core.agent import AgentIdentity from synthorg.core.task import Task + from synthorg.core.types import NotBlankStr from synthorg.engine.loop_protocol import BudgetChecker from synthorg.providers.routing.resolver import ModelResolver from synthorg.security.risk_scorer import RiskScorer @@ -257,7 +258,7 @@ async def check_can_execute( async def check_project_budget( self, - project_id: str, + project_id: NotBlankStr, project_budget: float, ) -> None: """Check project-level budget and raise if exceeded. @@ -275,6 +276,8 @@ async def check_project_budget( Raises: ProjectBudgetExhaustedError: When project spend >= budget. + MemoryError: Re-raised unconditionally. + RecursionError: Re-raised unconditionally. """ if project_budget <= 0: return @@ -593,7 +596,7 @@ async def make_budget_checker( task: Task, agent_id: str, *, - project_id: str | None = None, + project_id: NotBlankStr | None = None, project_budget: float = 0.0, ) -> BudgetChecker | None: """Create a sync BudgetChecker with pre-computed baselines. @@ -632,7 +635,10 @@ async def make_budget_checker( project_baseline = 0.0 if project_id is not None and project_budget > 0: - baseline = await self._get_project_cost(project_id) + baseline = await self._get_project_cost( + project_id, + error_event=BUDGET_BASELINE_ERROR, + ) if baseline is not None: project_baseline = baseline @@ -873,12 +879,25 @@ async def _compute_baselines( async def _get_project_cost( self, - project_id: str, + project_id: NotBlankStr, + *, + error_event: str = BUDGET_PREFLIGHT_ERROR, ) -> float | None: """Query project cost from durable aggregate or in-memory tracker. - Returns the total cost, or ``None`` when both sources fail - (caller should skip enforcement on ``None``). + Returns the total cost (rounded to + ``BUDGET_ROUNDING_PRECISION``), or ``None`` when both + sources fail (caller should skip enforcement on ``None``). + + Args: + project_id: Project identifier. + error_event: Event constant to log on failure. Allows + callers to preserve distinct monitoring semantics + (e.g. preflight vs baseline). + + Raises: + MemoryError: Re-raised unconditionally. + RecursionError: Re-raised unconditionally. """ # Try durable aggregate first. if self._project_cost_repo is not None: @@ -890,13 +909,14 @@ async def _get_project_cost( raise except Exception: logger.exception( - BUDGET_PREFLIGHT_ERROR, + error_event, project_id=project_id, reason="project_cost_aggregate_query_failed", ) # Fall through to in-memory. else: - cost = aggregate.total_cost if aggregate else 0.0 + raw = aggregate.total_cost if aggregate else 0.0 + cost = round(raw, BUDGET_ROUNDING_PRECISION) logger.debug( BUDGET_PROJECT_BASELINE_SOURCE, project_id=project_id, @@ -914,7 +934,7 @@ async def _get_project_cost( raise except Exception: logger.exception( - BUDGET_PREFLIGHT_ERROR, + error_event, project_id=project_id, reason="project_cost_query_failed", ) diff --git a/src/synthorg/budget/tracker.py b/src/synthorg/budget/tracker.py index 0ce7def2e1..75914834e3 100644 --- a/src/synthorg/budget/tracker.py +++ b/src/synthorg/budget/tracker.py @@ -148,11 +148,12 @@ def budget_config(self) -> BudgetConfig | None: async def record(self, cost_record: CostRecord) -> None: """Append a cost record. - After the in-memory write, updates the durable project cost - aggregate when the record has a ``project_id`` and a repository - is configured. Aggregate updates are best-effort: failures are - logged at WARNING by ``_update_project_aggregate`` but do not - affect the in-memory recording. + The in-memory append runs under ``_lock``. After the lock + is released, ``_update_project_aggregate`` is awaited to + update the durable project cost aggregate when the record + has a ``project_id`` and a repository is configured. + Aggregate updates are best-effort: failures are logged at + WARNING but do not affect the in-memory recording. Args: cost_record: Immutable cost record to store. diff --git a/src/synthorg/persistence/sqlite/project_cost_aggregate_repo.py b/src/synthorg/persistence/sqlite/project_cost_aggregate_repo.py index 22353f1d01..44ac44b572 100644 --- a/src/synthorg/persistence/sqlite/project_cost_aggregate_repo.py +++ b/src/synthorg/persistence/sqlite/project_cost_aggregate_repo.py @@ -1,6 +1,7 @@ """SQLite repository for durable project cost aggregates.""" import asyncio +import math import sqlite3 from datetime import UTC, datetime @@ -32,6 +33,8 @@ total_output_tokens = total_output_tokens + excluded.total_output_tokens, record_count = record_count + 1, last_updated = excluded.last_updated +RETURNING project_id, total_cost, total_input_tokens, + total_output_tokens, record_count, last_updated """ _SELECT_SQL = """\ @@ -79,7 +82,7 @@ def __init__( write_lock: asyncio.Lock | None = None, ) -> None: self._db = db - self._write_lock = write_lock or asyncio.Lock() + self._write_lock = write_lock if write_lock is not None else asyncio.Lock() async def get( self, @@ -105,7 +108,8 @@ async def get( project_id=project_id, error=str(exc), ) - raise QueryError(str(exc)) from exc + msg = f"Failed to fetch project cost aggregate for {project_id!r}: {exc}" + raise QueryError(msg) from exc if row is None: logger.debug( @@ -123,7 +127,11 @@ async def get( project_id=project_id, error=str(exc), ) - raise QueryError(str(exc)) from exc + msg = ( + f"Failed to deserialize project cost aggregate" + f" for {project_id!r}: {exc}" + ) + raise QueryError(msg) from exc logger.debug( PERSISTENCE_PROJECT_COST_AGG_FETCHED, @@ -144,10 +152,13 @@ async def increment( """Atomically increment the project's cost aggregate. Creates a new row on first call; increments on subsequent. + Uses ``RETURNING`` to read back the updated row inside the + same locked section, avoiding race conditions with concurrent + increments. Args: project_id: Project identifier. - cost: Cost delta to add (must be >= 0). + cost: Cost delta to add (must be finite and >= 0). input_tokens: Input token delta (must be >= 0). output_tokens: Output token delta (must be >= 0). @@ -156,23 +167,31 @@ async def increment( Raises: QueryError: If the database operation fails. - ValueError: If any delta is negative. + ValueError: If any delta is negative or cost is + non-finite (NaN/Inf). """ - if cost < 0 or input_tokens < 0 or output_tokens < 0: + if not math.isfinite(cost) or cost < 0 or input_tokens < 0 or output_tokens < 0: msg = ( - f"Deltas must be non-negative: " + "Deltas must be finite and non-negative: " f"cost={cost}, input_tokens={input_tokens}, " f"output_tokens={output_tokens}" ) + logger.warning( + PERSISTENCE_PROJECT_COST_AGG_INCREMENT_FAILED, + project_id=project_id, + cost=cost, + error=msg, + ) raise ValueError(msg) now = datetime.now(UTC).isoformat() try: async with self._write_lock: - await self._db.execute( + cursor = await self._db.execute( _UPSERT_SQL, (project_id, cost, input_tokens, output_tokens, now), ) + row = await cursor.fetchone() await self._db.commit() except (sqlite3.Error, aiosqlite.Error) as exc: logger.exception( @@ -181,14 +200,29 @@ async def increment( cost=cost, error=str(exc), ) - raise QueryError(str(exc)) from exc + msg = ( + f"Failed to increment project cost aggregate for {project_id!r}: {exc}" + ) + raise QueryError(msg) from exc - # Read back the updated aggregate. - aggregate = await self.get(project_id) - if aggregate is None: # pragma: no cover -- defensive + if row is None: # pragma: no cover -- defensive msg = f"Aggregate for {project_id!r} missing after upsert" raise QueryError(msg) + try: + aggregate = _row_to_aggregate(row) + except ValidationError as exc: + logger.exception( + PERSISTENCE_PROJECT_COST_AGG_DESERIALIZE_FAILED, + project_id=project_id, + error=str(exc), + ) + msg = ( + f"Failed to deserialize project cost aggregate" + f" for {project_id!r} after increment: {exc}" + ) + raise QueryError(msg) from exc + logger.debug( PERSISTENCE_PROJECT_COST_AGG_INCREMENTED, project_id=project_id, diff --git a/src/synthorg/persistence/sqlite/schema.sql b/src/synthorg/persistence/sqlite/schema.sql index a70d78d038..8d95a4d296 100644 --- a/src/synthorg/persistence/sqlite/schema.sql +++ b/src/synthorg/persistence/sqlite/schema.sql @@ -324,7 +324,9 @@ CREATE TABLE IF NOT EXISTS project_cost_aggregates ( total_input_tokens INTEGER NOT NULL DEFAULT 0 CHECK(total_input_tokens >= 0), total_output_tokens INTEGER NOT NULL DEFAULT 0 CHECK(total_output_tokens >= 0), record_count INTEGER NOT NULL DEFAULT 0 CHECK(record_count >= 0), - last_updated TEXT NOT NULL + last_updated TEXT NOT NULL CHECK( + last_updated LIKE '%+00:00' OR last_updated LIKE '%Z' + ) ); -- ── Custom personality presets (user-defined) ──────────────── diff --git a/tests/unit/budget/test_enforcer_project_durable.py b/tests/unit/budget/test_enforcer_project_durable.py index cb9a281d16..5cba268f1f 100644 --- a/tests/unit/budget/test_enforcer_project_durable.py +++ b/tests/unit/budget/test_enforcer_project_durable.py @@ -103,9 +103,13 @@ async def test_uses_in_memory_when_no_repo(self) -> None: enforcer = _make_enforcer(tracker=tracker) # No repo -- uses in-memory tracker - with pytest.raises(ProjectBudgetExhaustedError): + with pytest.raises(ProjectBudgetExhaustedError) as exc_info: await enforcer.check_project_budget("proj-1", project_budget=5.0) + assert exc_info.value.project_id == "proj-1" + assert exc_info.value.project_spent >= 5.0 + assert exc_info.value.project_budget == 5.0 + async def test_aggregate_none_treated_as_zero(self) -> None: repo = _make_repo(get_return=None) enforcer = _make_enforcer(project_cost_repo=repo) diff --git a/tests/unit/budget/test_project_cost_aggregate.py b/tests/unit/budget/test_project_cost_aggregate.py index 3dbe63ce6c..59b797dda4 100644 --- a/tests/unit/budget/test_project_cost_aggregate.py +++ b/tests/unit/budget/test_project_cost_aggregate.py @@ -98,4 +98,20 @@ def test_rejects_inf(self) -> None: ) def test_protocol_is_runtime_checkable(self) -> None: - assert hasattr(ProjectCostAggregateRepository, "__protocol_attrs__") + class _RepoStub: + async def get( + self, + project_id: str, + ) -> None: + return None + + async def increment( + self, + project_id: str, + cost: float, + input_tokens: int, + output_tokens: int, + ) -> None: + return None + + assert isinstance(_RepoStub(), ProjectCostAggregateRepository) diff --git a/tests/unit/persistence/sqlite/test_project_cost_aggregate_repo.py b/tests/unit/persistence/sqlite/test_project_cost_aggregate_repo.py index fef9efa8ae..30b84a78a5 100644 --- a/tests/unit/persistence/sqlite/test_project_cost_aggregate_repo.py +++ b/tests/unit/persistence/sqlite/test_project_cost_aggregate_repo.py @@ -153,10 +153,30 @@ async def test_increment_raises_query_error_on_db_failure( ): await repo.increment("proj-1", 1.0, 100, 50) - async def test_increment_rejects_negative_deltas( + @pytest.mark.parametrize( + ("cost", "input_tokens", "output_tokens"), + [ + (-1.0, 100, 50), + (1.0, -1, 50), + (1.0, 100, -1), + (float("nan"), 100, 50), + (float("inf"), 100, 50), + ], + ids=[ + "negative_cost", + "negative_input_tokens", + "negative_output_tokens", + "nan_cost", + "inf_cost", + ], + ) + async def test_increment_rejects_invalid_deltas( self, migrated_db: aiosqlite.Connection, + cost: float, + input_tokens: int, + output_tokens: int, ) -> None: repo = SQLiteProjectCostAggregateRepository(migrated_db) with pytest.raises(ValueError, match="non-negative"): - await repo.increment("proj-1", -1.0, 100, 50) + await repo.increment("proj-1", cost, input_tokens, output_tokens)