Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ src/ai_company/
- **Docstrings**: Google style, required on public classes/functions (enforced by ruff D rules)
- **Immutability**: create new objects, never mutate existing ones. For non-Pydantic internal collections (registries, `BaseTool`), use `copy.deepcopy()` at construction + `MappingProxyType` wrapping for read-only enforcement. For `dict`/`list` fields in frozen Pydantic models, rely on `frozen=True` for field reassignment prevention and `copy.deepcopy()` at system boundaries (tool execution, LLM provider serialization, inter-agent delegation, serializing for persistence).
- **Config vs runtime state**: frozen Pydantic models for config/identity; separate mutable-via-copy models (using `model_copy(update=...)`) for runtime state that evolves (e.g. agent execution state, task progress). Never mix static config fields with mutable runtime fields in one model.
- **Models**: Pydantic v2 (`BaseModel`, `model_validator`, `ConfigDict`). Planned conventions for new code: use `@computed_field` for derived values instead of storing + validating redundant fields; use `NotBlankStr` (from `core.types`) for non-optional identifier/name fields instead of manual whitespace validators. Existing models are being migrated incrementally.
- **Models**: Pydantic v2 (`BaseModel`, `model_validator`, `computed_field`, `ConfigDict`). Adopted conventions: use `@computed_field` for derived values instead of storing + validating redundant fields (e.g. `TokenUsage.total_tokens`). Planned conventions for new code: use `NotBlankStr` (from `core.types`) for non-optional identifier/name fields instead of manual whitespace validators. Existing models are being migrated incrementally.
- **Async concurrency**: prefer `asyncio.TaskGroup` for fan-out/fan-in parallel operations in new code (e.g. multiple tool invocations, parallel agent calls). Prefer structured concurrency over bare `create_task`. Existing code is being migrated incrementally.
- **Line length**: 88 characters (ruff)
- **Functions**: < 50 lines, files < 800 lines
Expand Down
6 changes: 3 additions & 3 deletions DESIGN_SPEC.md
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ Every API call is tracked (illustrative schema):
}
```

> **Implementation note:** `total_tokens` is a stored `Field` validated by `@model_validator` to equal `input_tokens + output_tokens`. Spending summary models (`AgentSpending`, `DepartmentSpending`, `PeriodSpending`) each independently define `total_cost_usd`, `total_input_tokens`, `total_output_tokens`, and `record_count` fields. Extracting a shared `_SpendingTotals` base and migrating `total_tokens` to `@computed_field` are planned conventions (see §15.5).
> **Implementation note:** `total_tokens` is a `@computed_field` property that returns `input_tokens + output_tokens` — no stored field or validator needed. Spending summary models (`AgentSpending`, `DepartmentSpending`, `PeriodSpending`) each independently define `total_cost_usd`, `total_input_tokens`, `total_output_tokens`, and `record_count` fields. Extracting a shared `_SpendingTotals` base is a planned convention (see §15.5).

### 10.3 CFO Agent Responsibilities

Expand Down Expand Up @@ -1335,7 +1335,7 @@ ai-company/
│ ├── providers/ # LLM provider abstraction
│ │ ├── base.py # BaseCompletionProvider (retry + rate limiting)
│ │ ├── protocol.py # Provider protocol (abstract interface)
│ │ ├── models.py # CompletionRequest/Response, TokenUsage, ToolCall/Result
│ │ ├── models.py # CompletionConfig/Response, TokenUsage, ToolCall/Result
│ │ ├── capabilities.py # Provider capability registry
│ │ ├── registry.py # Provider registry
│ │ ├── enums.py # Provider enumerations
Expand Down Expand Up @@ -1438,7 +1438,7 @@ These conventions were established during the M0–M2 review cycle. **Adopted**
|------------|--------|----------|-----------|
| **Immutability strategy** | Adopted | `copy.deepcopy()` at construction + `MappingProxyType` wrapping for non-Pydantic internal collections (registries, `BaseTool`). For Pydantic frozen models: `frozen=True` prevents field reassignment; `copy.deepcopy()` at system boundaries (tool execution, LLM provider serialization) prevents nested mutation. No MappingProxyType inside Pydantic models (serialization friction). | Deep-copy at construction fully isolates nested structures; `MappingProxyType` enforces read-only access. Boundary-copy for Pydantic models is simple, centralized, and Pydantic-native. A future CPython built-in immutable mapping type (e.g. `frozendict`) would provide zero-friction field-level immutability when available. |
| **Config vs runtime split** | Adopted (M3) | Frozen models for config/identity; `model_copy(update=...)` for runtime state transitions | `TaskExecution` and `AgentContext` (in `engine/`) are frozen Pydantic models that use `model_copy(update=...)` for copy-on-write state transitions without re-running validators (per Pydantic `model_copy` semantics). Config layer (`AgentIdentity`, `Task`) remains unchanged. |
| **Derived fields** | Planned | `@computed_field` instead of stored + validated | Eliminates redundant storage and impossible-to-fail validators (e.g. `total_tokens = input + output`). Currently `total_tokens` uses stored `Field` + `@model_validator`. |
| **Derived fields** | Adopted | `@computed_field` instead of stored + validated | Eliminates redundant storage and impossible-to-fail validators. `TokenUsage.total_tokens` migrated from stored `Field` + `@model_validator` to `@computed_field` property. |
| **String validation** | Planned | `NotBlankStr` type from `core.types` for all identifiers | Eliminates per-model `@model_validator` boilerplate for whitespace checks. `NotBlankStr` is defined but models still use `Field(min_length=1)` + manual validators. |
| **Shared field groups** | Planned | Extract common field sets into base models (e.g. `_SpendingTotals`) | Prevents field duplication across spending summary models. Not yet implemented — each model independently defines fields. |
| **Event constants** | Adopted (flat) | Single `events.py` module with domain-scoped naming (e.g. `PROVIDER_CALL_START`, `BUDGET_RECORD_ADDED`) | Current approach uses a single module. Splitting into per-domain submodules may be revisited when the file exceeds ~200 constants. |
Expand Down
1 change: 0 additions & 1 deletion src/ai_company/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,6 @@ def compute_cost(
return TokenUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
cost_usd=round(cost, BUDGET_ROUNDING_PRECISION),
)

Expand Down
35 changes: 11 additions & 24 deletions src/ai_company/providers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Self

from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator

from ai_company.core.types import NotBlankStr # noqa: TC001

Expand All @@ -18,34 +18,26 @@ class TokenUsage(BaseModel):
Attributes:
input_tokens: Number of input (prompt) tokens.
output_tokens: Number of output (completion) tokens.
total_tokens: Sum of input and output tokens.
total_tokens: Sum of input and output tokens (computed).
cost_usd: Estimated cost in USD for this call.
"""

model_config = ConfigDict(frozen=True)

input_tokens: int = Field(ge=0, description="Input token count")
output_tokens: int = Field(ge=0, description="Output token count")
total_tokens: int = Field(ge=0, description="Total token count")
cost_usd: float = Field(ge=0.0, description="Estimated cost in USD")

@model_validator(mode="after")
def _validate_total(self) -> Self:
"""Ensure total_tokens equals the sum of input and output tokens."""
expected = self.input_tokens + self.output_tokens
if self.total_tokens != expected:
msg = (
f"total_tokens ({self.total_tokens}) must equal "
f"input_tokens + output_tokens ({expected})"
)
raise ValueError(msg)
return self
@computed_field(description="Total token count") # type: ignore[prop-decorator] # mypy doesn't support stacked decorators on @property
@property
def total_tokens(self) -> int:
"""Sum of input and output tokens."""
return self.input_tokens + self.output_tokens


ZERO_TOKEN_USAGE = TokenUsage(
input_tokens=0,
output_tokens=0,
total_tokens=0,
cost_usd=0.0,
)
"""Additive identity for ``TokenUsage``."""
Expand All @@ -54,22 +46,17 @@ def _validate_total(self) -> Self:
def add_token_usage(a: TokenUsage, b: TokenUsage) -> TokenUsage:
"""Create a new ``TokenUsage`` with summed token counts and cost.

Computes ``total_tokens`` from the summed parts to maintain the
``total_tokens == input_tokens + output_tokens`` invariant.

Args:
a: First usage record.
b: Second usage record.

Returns:
New ``TokenUsage`` with summed token counts and cost.
New ``TokenUsage`` with summed token counts and cost
(``total_tokens`` is computed automatically).
"""
input_tokens = a.input_tokens + b.input_tokens
output_tokens = a.output_tokens + b.output_tokens
return TokenUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
input_tokens=a.input_tokens + b.input_tokens,
output_tokens=a.output_tokens + b.output_tokens,
cost_usd=a.cost_usd + b.cost_usd,
)
Comment on lines 46 to 61
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While this function works correctly, consider implementing operator overloading by adding an __add__ method to the TokenUsage class. This would allow for more idiomatic and readable code, like usage_a + usage_b, instead of add_token_usage(usage_a, usage_b). This would make TokenUsage behave more like a numeric type where addition is a natural operation.

Example implementation inside TokenUsage class:

from typing import Self

# ...

    def __add__(self, other: Self) -> Self:
        if not isinstance(other, TokenUsage):
            return NotImplemented
        return TokenUsage(
            input_tokens=self.input_tokens + other.input_tokens,
            output_tokens=self.output_tokens + other.output_tokens,
            cost_usd=self.cost_usd + other.cost_usd,
        )

With this change, this add_token_usage function would no longer be necessary and could be removed, and call sites could be updated to use the + operator.


Expand Down
1 change: 0 additions & 1 deletion tests/unit/engine/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def sample_token_usage() -> TokenUsage:
return TokenUsage(
input_tokens=100,
output_tokens=50,
total_tokens=150,
cost_usd=0.01,
)

Expand Down
5 changes: 0 additions & 5 deletions tests/unit/engine/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def test_has_turns_remaining_boundary(
usage = TokenUsage(
input_tokens=1,
output_tokens=1,
total_tokens=2,
cost_usd=0.0,
)
msg = _make_assistant_msg()
Expand All @@ -174,7 +173,6 @@ def test_max_turns_exceeded_raises(
usage = TokenUsage(
input_tokens=1,
output_tokens=1,
total_tokens=2,
cost_usd=0.0,
)
msg = _make_assistant_msg()
Expand Down Expand Up @@ -267,7 +265,6 @@ def test_snapshot_task_id_without_status_rejected(self) -> None:
accumulated_cost=TokenUsage(
input_tokens=0,
output_tokens=0,
total_tokens=0,
cost_usd=0.0,
),
started_at=datetime.now(UTC),
Expand All @@ -286,7 +283,6 @@ def test_snapshot_task_status_without_id_rejected(self) -> None:
accumulated_cost=TokenUsage(
input_tokens=0,
output_tokens=0,
total_tokens=0,
cost_usd=0.0,
),
started_at=datetime.now(UTC),
Expand Down Expand Up @@ -370,7 +366,6 @@ def test_max_turns_exceeded_logs_error(
usage = TokenUsage(
input_tokens=1,
output_tokens=1,
total_tokens=2,
cost_usd=0.0,
)
msg = _make_assistant_msg()
Expand Down
7 changes: 1 addition & 6 deletions tests/unit/engine/test_task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,11 @@ def test_sums_correctly(self) -> None:
a = TokenUsage(
input_tokens=10,
output_tokens=5,
total_tokens=15,
cost_usd=0.01,
)
b = TokenUsage(
input_tokens=20,
output_tokens=10,
total_tokens=30,
cost_usd=0.02,
)
result = add_token_usage(a, b)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If you implement __add__ on TokenUsage as suggested in src/ai_company/providers/models.py, this call can be simplified.

Suggested change
result = add_token_usage(a, b)
result = a + b

Expand All @@ -267,17 +265,15 @@ def test_sums_correctly(self) -> None:
assert result.total_tokens == 45
assert result.cost_usd == pytest.approx(0.03)

def test_maintains_total_invariant(self) -> None:
def test_total_tokens_is_sum_of_parts(self) -> None:
a = TokenUsage(
input_tokens=7,
output_tokens=3,
total_tokens=10,
cost_usd=0.0,
)
b = TokenUsage(
input_tokens=13,
output_tokens=7,
total_tokens=20,
cost_usd=0.0,
)
result = add_token_usage(a, b)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If you implement __add__ on TokenUsage as suggested in src/ai_company/providers/models.py, this call can be simplified.

Suggested change
result = add_token_usage(a, b)
result = a + b

Expand All @@ -287,7 +283,6 @@ def test_with_zero_usage(self) -> None:
usage = TokenUsage(
input_tokens=50,
output_tokens=25,
total_tokens=75,
cost_usd=0.05,
)
result = add_token_usage(ZERO_TOKEN_USAGE, usage)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If you implement __add__ on TokenUsage as suggested in src/ai_company/providers/models.py, this call can be simplified.

Suggested change
result = add_token_usage(ZERO_TOKEN_USAGE, usage)
result = ZERO_TOKEN_USAGE + usage

Expand Down
2 changes: 0 additions & 2 deletions tests/unit/providers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class TokenUsageFactory(ModelFactory[TokenUsage]):
__model__ = TokenUsage
input_tokens = 100
output_tokens = 50
total_tokens = 150
cost_usd = 0.001


Expand Down Expand Up @@ -174,7 +173,6 @@ def sample_token_usage() -> TokenUsage:
return TokenUsage(
input_tokens=4500,
output_tokens=1200,
total_tokens=5700,
cost_usd=0.0315,
)

Expand Down
1 change: 0 additions & 1 deletion tests/unit/providers/test_base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ async def _do_complete(
usage=TokenUsage(
input_tokens=10,
output_tokens=5,
total_tokens=15,
cost_usd=0.0,
),
model=model,
Expand Down
47 changes: 35 additions & 12 deletions tests/unit/providers/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,11 @@ def test_valid(self, sample_token_usage: TokenUsage) -> None:
assert sample_token_usage.total_tokens == 5700
assert sample_token_usage.cost_usd == 0.0315

def test_total_must_equal_sum(self) -> None:
with pytest.raises(ValidationError, match="total_tokens"):
TokenUsage(
input_tokens=100,
output_tokens=50,
total_tokens=200,
cost_usd=0.001,
)

def test_negative_input_rejected(self) -> None:
with pytest.raises(ValidationError):
TokenUsage(
input_tokens=-1,
output_tokens=0,
total_tokens=-1,
cost_usd=0.0,
)

Expand All @@ -65,15 +55,13 @@ def test_negative_cost_rejected(self) -> None:
TokenUsage(
input_tokens=100,
output_tokens=0,
total_tokens=100,
cost_usd=-0.01,
)

def test_zero_tokens_valid(self) -> None:
usage = TokenUsage(
input_tokens=0,
output_tokens=0,
total_tokens=0,
cost_usd=0.0,
)
assert usage.total_tokens == 0
Expand All @@ -82,6 +70,41 @@ def test_frozen(self, sample_token_usage: TokenUsage) -> None:
with pytest.raises(ValidationError):
sample_token_usage.cost_usd = 999.0 # type: ignore[misc]

def test_total_tokens_is_always_computed(self) -> None:
usage = TokenUsage(input_tokens=10, output_tokens=5, cost_usd=0.0)
assert usage.total_tokens == 15

def test_total_tokens_in_serialization(self) -> None:
usage = TokenUsage(input_tokens=100, output_tokens=50, cost_usd=0.01)
dumped = usage.model_dump()
assert dumped["total_tokens"] == 150

def test_total_tokens_roundtrip(self) -> None:
"""Stale total_tokens in serialized data is ignored on load."""
payload = TokenUsage(
input_tokens=10,
output_tokens=5,
cost_usd=0.0,
).model_dump()
payload["total_tokens"] = 999
usage = TokenUsage.model_validate(payload)
assert usage.total_tokens == 15
Comment thread
coderabbitai[bot] marked this conversation as resolved.

# JSON roundtrip also recomputes correctly
json_str = TokenUsage(
input_tokens=20,
output_tokens=10,
cost_usd=0.01,
).model_dump_json()
restored = TokenUsage.model_validate_json(json_str)
assert restored.total_tokens == 30

def test_total_tokens_not_assignable(self) -> None:
"""Computed property rejects direct assignment."""
usage = TokenUsage(input_tokens=10, output_tokens=5, cost_usd=0.0)
with pytest.raises((ValidationError, AttributeError)):
usage.total_tokens = 999 # type: ignore[misc]

def test_factory(self) -> None:
usage = TokenUsageFactory.build()
assert isinstance(usage, TokenUsage)
Expand Down