diff --git a/CLAUDE.md b/CLAUDE.md index 77aa783ee5..36e4833348 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -68,7 +68,7 @@ src/ai_company/ - **Docstrings**: Google style, required on public classes/functions (enforced by ruff D rules) - **Immutability**: create new objects, never mutate existing ones. For non-Pydantic internal collections (registries, `BaseTool`), use `copy.deepcopy()` at construction + `MappingProxyType` wrapping for read-only enforcement. For `dict`/`list` fields in frozen Pydantic models, rely on `frozen=True` for field reassignment prevention and `copy.deepcopy()` at system boundaries (tool execution, LLM provider serialization, inter-agent delegation, serializing for persistence). - **Config vs runtime state**: frozen Pydantic models for config/identity; separate mutable-via-copy models (using `model_copy(update=...)`) for runtime state that evolves (e.g. agent execution state, task progress). Never mix static config fields with mutable runtime fields in one model. -- **Models**: Pydantic v2 (`BaseModel`, `model_validator`, `ConfigDict`). Planned conventions for new code: use `@computed_field` for derived values instead of storing + validating redundant fields; use `NotBlankStr` (from `core.types`) for non-optional identifier/name fields instead of manual whitespace validators. Existing models are being migrated incrementally. +- **Models**: Pydantic v2 (`BaseModel`, `model_validator`, `computed_field`, `ConfigDict`). Adopted conventions: use `@computed_field` for derived values instead of storing + validating redundant fields (e.g. `TokenUsage.total_tokens`). Planned conventions for new code: use `NotBlankStr` (from `core.types`) for non-optional identifier/name fields instead of manual whitespace validators. Existing models are being migrated incrementally. - **Async concurrency**: prefer `asyncio.TaskGroup` for fan-out/fan-in parallel operations in new code (e.g. multiple tool invocations, parallel agent calls). Prefer structured concurrency over bare `create_task`. Existing code is being migrated incrementally. - **Line length**: 88 characters (ruff) - **Functions**: < 50 lines, files < 800 lines diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index 2e050fcf1e..c7ece84bd0 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -857,7 +857,7 @@ Every API call is tracked (illustrative schema): } ``` -> **Implementation note:** `total_tokens` is a stored `Field` validated by `@model_validator` to equal `input_tokens + output_tokens`. Spending summary models (`AgentSpending`, `DepartmentSpending`, `PeriodSpending`) each independently define `total_cost_usd`, `total_input_tokens`, `total_output_tokens`, and `record_count` fields. Extracting a shared `_SpendingTotals` base and migrating `total_tokens` to `@computed_field` are planned conventions (see §15.5). +> **Implementation note:** `total_tokens` is a `@computed_field` property that returns `input_tokens + output_tokens` — no stored field or validator needed. Spending summary models (`AgentSpending`, `DepartmentSpending`, `PeriodSpending`) each independently define `total_cost_usd`, `total_input_tokens`, `total_output_tokens`, and `record_count` fields. Extracting a shared `_SpendingTotals` base is a planned convention (see §15.5). ### 10.3 CFO Agent Responsibilities @@ -1335,7 +1335,7 @@ ai-company/ │ ├── providers/ # LLM provider abstraction │ │ ├── base.py # BaseCompletionProvider (retry + rate limiting) │ │ ├── protocol.py # Provider protocol (abstract interface) -│ │ ├── models.py # CompletionRequest/Response, TokenUsage, ToolCall/Result +│ │ ├── models.py # CompletionConfig/Response, TokenUsage, ToolCall/Result │ │ ├── capabilities.py # Provider capability registry │ │ ├── registry.py # Provider registry │ │ ├── enums.py # Provider enumerations @@ -1438,7 +1438,7 @@ These conventions were established during the M0–M2 review cycle. **Adopted** |------------|--------|----------|-----------| | **Immutability strategy** | Adopted | `copy.deepcopy()` at construction + `MappingProxyType` wrapping for non-Pydantic internal collections (registries, `BaseTool`). For Pydantic frozen models: `frozen=True` prevents field reassignment; `copy.deepcopy()` at system boundaries (tool execution, LLM provider serialization) prevents nested mutation. No MappingProxyType inside Pydantic models (serialization friction). | Deep-copy at construction fully isolates nested structures; `MappingProxyType` enforces read-only access. Boundary-copy for Pydantic models is simple, centralized, and Pydantic-native. A future CPython built-in immutable mapping type (e.g. `frozendict`) would provide zero-friction field-level immutability when available. | | **Config vs runtime split** | Adopted (M3) | Frozen models for config/identity; `model_copy(update=...)` for runtime state transitions | `TaskExecution` and `AgentContext` (in `engine/`) are frozen Pydantic models that use `model_copy(update=...)` for copy-on-write state transitions without re-running validators (per Pydantic `model_copy` semantics). Config layer (`AgentIdentity`, `Task`) remains unchanged. | -| **Derived fields** | Planned | `@computed_field` instead of stored + validated | Eliminates redundant storage and impossible-to-fail validators (e.g. `total_tokens = input + output`). Currently `total_tokens` uses stored `Field` + `@model_validator`. | +| **Derived fields** | Adopted | `@computed_field` instead of stored + validated | Eliminates redundant storage and impossible-to-fail validators. `TokenUsage.total_tokens` migrated from stored `Field` + `@model_validator` to `@computed_field` property. | | **String validation** | Planned | `NotBlankStr` type from `core.types` for all identifiers | Eliminates per-model `@model_validator` boilerplate for whitespace checks. `NotBlankStr` is defined but models still use `Field(min_length=1)` + manual validators. | | **Shared field groups** | Planned | Extract common field sets into base models (e.g. `_SpendingTotals`) | Prevents field duplication across spending summary models. Not yet implemented — each model independently defines fields. | | **Event constants** | Adopted (flat) | Single `events.py` module with domain-scoped naming (e.g. `PROVIDER_CALL_START`, `BUDGET_RECORD_ADDED`) | Current approach uses a single module. Splitting into per-domain submodules may be revisited when the file exceeds ~200 constants. | diff --git a/src/ai_company/providers/base.py b/src/ai_company/providers/base.py index 18f72792aa..f6c6698496 100644 --- a/src/ai_company/providers/base.py +++ b/src/ai_company/providers/base.py @@ -391,7 +391,6 @@ def compute_cost( return TokenUsage( input_tokens=input_tokens, output_tokens=output_tokens, - total_tokens=input_tokens + output_tokens, cost_usd=round(cost, BUDGET_ROUNDING_PRECISION), ) diff --git a/src/ai_company/providers/models.py b/src/ai_company/providers/models.py index a267769974..73e3610e96 100644 --- a/src/ai_company/providers/models.py +++ b/src/ai_company/providers/models.py @@ -2,7 +2,7 @@ from typing import Any, Self -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator from ai_company.core.types import NotBlankStr # noqa: TC001 @@ -18,7 +18,7 @@ class TokenUsage(BaseModel): Attributes: input_tokens: Number of input (prompt) tokens. output_tokens: Number of output (completion) tokens. - total_tokens: Sum of input and output tokens. + total_tokens: Sum of input and output tokens (computed). cost_usd: Estimated cost in USD for this call. """ @@ -26,26 +26,18 @@ class TokenUsage(BaseModel): input_tokens: int = Field(ge=0, description="Input token count") output_tokens: int = Field(ge=0, description="Output token count") - total_tokens: int = Field(ge=0, description="Total token count") cost_usd: float = Field(ge=0.0, description="Estimated cost in USD") - @model_validator(mode="after") - def _validate_total(self) -> Self: - """Ensure total_tokens equals the sum of input and output tokens.""" - expected = self.input_tokens + self.output_tokens - if self.total_tokens != expected: - msg = ( - f"total_tokens ({self.total_tokens}) must equal " - f"input_tokens + output_tokens ({expected})" - ) - raise ValueError(msg) - return self + @computed_field(description="Total token count") # type: ignore[prop-decorator] # mypy doesn't support stacked decorators on @property + @property + def total_tokens(self) -> int: + """Sum of input and output tokens.""" + return self.input_tokens + self.output_tokens ZERO_TOKEN_USAGE = TokenUsage( input_tokens=0, output_tokens=0, - total_tokens=0, cost_usd=0.0, ) """Additive identity for ``TokenUsage``.""" @@ -54,22 +46,17 @@ def _validate_total(self) -> Self: def add_token_usage(a: TokenUsage, b: TokenUsage) -> TokenUsage: """Create a new ``TokenUsage`` with summed token counts and cost. - Computes ``total_tokens`` from the summed parts to maintain the - ``total_tokens == input_tokens + output_tokens`` invariant. - Args: a: First usage record. b: Second usage record. Returns: - New ``TokenUsage`` with summed token counts and cost. + New ``TokenUsage`` with summed token counts and cost + (``total_tokens`` is computed automatically). """ - input_tokens = a.input_tokens + b.input_tokens - output_tokens = a.output_tokens + b.output_tokens return TokenUsage( - input_tokens=input_tokens, - output_tokens=output_tokens, - total_tokens=input_tokens + output_tokens, + input_tokens=a.input_tokens + b.input_tokens, + output_tokens=a.output_tokens + b.output_tokens, cost_usd=a.cost_usd + b.cost_usd, ) diff --git a/tests/unit/engine/conftest.py b/tests/unit/engine/conftest.py index 051d76d3e4..4dc53742e3 100644 --- a/tests/unit/engine/conftest.py +++ b/tests/unit/engine/conftest.py @@ -126,7 +126,6 @@ def sample_token_usage() -> TokenUsage: return TokenUsage( input_tokens=100, output_tokens=50, - total_tokens=150, cost_usd=0.01, ) diff --git a/tests/unit/engine/test_context.py b/tests/unit/engine/test_context.py index 2fa0824a0d..cf139644d3 100644 --- a/tests/unit/engine/test_context.py +++ b/tests/unit/engine/test_context.py @@ -157,7 +157,6 @@ def test_has_turns_remaining_boundary( usage = TokenUsage( input_tokens=1, output_tokens=1, - total_tokens=2, cost_usd=0.0, ) msg = _make_assistant_msg() @@ -174,7 +173,6 @@ def test_max_turns_exceeded_raises( usage = TokenUsage( input_tokens=1, output_tokens=1, - total_tokens=2, cost_usd=0.0, ) msg = _make_assistant_msg() @@ -267,7 +265,6 @@ def test_snapshot_task_id_without_status_rejected(self) -> None: accumulated_cost=TokenUsage( input_tokens=0, output_tokens=0, - total_tokens=0, cost_usd=0.0, ), started_at=datetime.now(UTC), @@ -286,7 +283,6 @@ def test_snapshot_task_status_without_id_rejected(self) -> None: accumulated_cost=TokenUsage( input_tokens=0, output_tokens=0, - total_tokens=0, cost_usd=0.0, ), started_at=datetime.now(UTC), @@ -370,7 +366,6 @@ def test_max_turns_exceeded_logs_error( usage = TokenUsage( input_tokens=1, output_tokens=1, - total_tokens=2, cost_usd=0.0, ) msg = _make_assistant_msg() diff --git a/tests/unit/engine/test_task_execution.py b/tests/unit/engine/test_task_execution.py index 91ca3c224d..5f49aea8d9 100644 --- a/tests/unit/engine/test_task_execution.py +++ b/tests/unit/engine/test_task_execution.py @@ -252,13 +252,11 @@ def test_sums_correctly(self) -> None: a = TokenUsage( input_tokens=10, output_tokens=5, - total_tokens=15, cost_usd=0.01, ) b = TokenUsage( input_tokens=20, output_tokens=10, - total_tokens=30, cost_usd=0.02, ) result = add_token_usage(a, b) @@ -267,17 +265,15 @@ def test_sums_correctly(self) -> None: assert result.total_tokens == 45 assert result.cost_usd == pytest.approx(0.03) - def test_maintains_total_invariant(self) -> None: + def test_total_tokens_is_sum_of_parts(self) -> None: a = TokenUsage( input_tokens=7, output_tokens=3, - total_tokens=10, cost_usd=0.0, ) b = TokenUsage( input_tokens=13, output_tokens=7, - total_tokens=20, cost_usd=0.0, ) result = add_token_usage(a, b) @@ -287,7 +283,6 @@ def test_with_zero_usage(self) -> None: usage = TokenUsage( input_tokens=50, output_tokens=25, - total_tokens=75, cost_usd=0.05, ) result = add_token_usage(ZERO_TOKEN_USAGE, usage) diff --git a/tests/unit/providers/conftest.py b/tests/unit/providers/conftest.py index f63b2f6c8d..abcfe0eb63 100644 --- a/tests/unit/providers/conftest.py +++ b/tests/unit/providers/conftest.py @@ -25,7 +25,6 @@ class TokenUsageFactory(ModelFactory[TokenUsage]): __model__ = TokenUsage input_tokens = 100 output_tokens = 50 - total_tokens = 150 cost_usd = 0.001 @@ -174,7 +173,6 @@ def sample_token_usage() -> TokenUsage: return TokenUsage( input_tokens=4500, output_tokens=1200, - total_tokens=5700, cost_usd=0.0315, ) diff --git a/tests/unit/providers/test_base_provider.py b/tests/unit/providers/test_base_provider.py index 2164e57b64..8258b9af36 100644 --- a/tests/unit/providers/test_base_provider.py +++ b/tests/unit/providers/test_base_provider.py @@ -54,7 +54,6 @@ async def _do_complete( usage=TokenUsage( input_tokens=10, output_tokens=5, - total_tokens=15, cost_usd=0.0, ), model=model, diff --git a/tests/unit/providers/test_models.py b/tests/unit/providers/test_models.py index 6d2ab36252..f86bffad29 100644 --- a/tests/unit/providers/test_models.py +++ b/tests/unit/providers/test_models.py @@ -42,21 +42,11 @@ def test_valid(self, sample_token_usage: TokenUsage) -> None: assert sample_token_usage.total_tokens == 5700 assert sample_token_usage.cost_usd == 0.0315 - def test_total_must_equal_sum(self) -> None: - with pytest.raises(ValidationError, match="total_tokens"): - TokenUsage( - input_tokens=100, - output_tokens=50, - total_tokens=200, - cost_usd=0.001, - ) - def test_negative_input_rejected(self) -> None: with pytest.raises(ValidationError): TokenUsage( input_tokens=-1, output_tokens=0, - total_tokens=-1, cost_usd=0.0, ) @@ -65,7 +55,6 @@ def test_negative_cost_rejected(self) -> None: TokenUsage( input_tokens=100, output_tokens=0, - total_tokens=100, cost_usd=-0.01, ) @@ -73,7 +62,6 @@ def test_zero_tokens_valid(self) -> None: usage = TokenUsage( input_tokens=0, output_tokens=0, - total_tokens=0, cost_usd=0.0, ) assert usage.total_tokens == 0 @@ -82,6 +70,41 @@ def test_frozen(self, sample_token_usage: TokenUsage) -> None: with pytest.raises(ValidationError): sample_token_usage.cost_usd = 999.0 # type: ignore[misc] + def test_total_tokens_is_always_computed(self) -> None: + usage = TokenUsage(input_tokens=10, output_tokens=5, cost_usd=0.0) + assert usage.total_tokens == 15 + + def test_total_tokens_in_serialization(self) -> None: + usage = TokenUsage(input_tokens=100, output_tokens=50, cost_usd=0.01) + dumped = usage.model_dump() + assert dumped["total_tokens"] == 150 + + def test_total_tokens_roundtrip(self) -> None: + """Stale total_tokens in serialized data is ignored on load.""" + payload = TokenUsage( + input_tokens=10, + output_tokens=5, + cost_usd=0.0, + ).model_dump() + payload["total_tokens"] = 999 + usage = TokenUsage.model_validate(payload) + assert usage.total_tokens == 15 + + # JSON roundtrip also recomputes correctly + json_str = TokenUsage( + input_tokens=20, + output_tokens=10, + cost_usd=0.01, + ).model_dump_json() + restored = TokenUsage.model_validate_json(json_str) + assert restored.total_tokens == 30 + + def test_total_tokens_not_assignable(self) -> None: + """Computed property rejects direct assignment.""" + usage = TokenUsage(input_tokens=10, output_tokens=5, cost_usd=0.0) + with pytest.raises((ValidationError, AttributeError)): + usage.total_tokens = 999 # type: ignore[misc] + def test_factory(self) -> None: usage = TokenUsageFactory.build() assert isinstance(usage, TokenUsage)