Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ src/ai_company/
## Code Conventions

- **No `from __future__ import annotations`** — Python 3.14 has PEP 649
- **PEP 758 except syntax**: use `except A, B:` (no parentheses) — ruff enforces this on Python 3.14

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The documented syntax except A, B: for multiple exceptions is from Python 2 and is invalid in Python 3. The correct syntax in Python 3 is to use a tuple: except (A, B):.

Additionally, I couldn't find a reference to PEP 758. This might be a typo.

To prevent confusion for developers, it would be good to update this convention.

Suggested change
- **PEP 758 except syntax**: use `except A, B:` (no parentheses) — ruff enforces this on Python 3.14
- **Multiple exceptions**: use `except (A, B):` to catch multiple exception types.

Copilot AI Mar 1, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CLAUDE.md guidance on exception syntax appears incorrect: except A, B: is not valid Python 3 syntax for catching multiple exceptions (it should use a tuple, e.g. except (A, B):). As written this will mislead contributors and may cause invalid code to be introduced.

Suggested change
- **PEP 758 except syntax**: use `except A, B:` (no parentheses) — ruff enforces this on Python 3.14
- **Exception syntax**: for multiple exceptions use `except (A, B):` (tuple) — Python 3 only supports the tuple form

Copilot uses AI. Check for mistakes.
- **Type hints**: all public functions, mypy strict mode
- **Docstrings**: Google style, required on public classes/functions (enforced by ruff D rules)
- **Immutability**: create new objects, never mutate existing ones
Expand Down
57 changes: 57 additions & 0 deletions src/ai_company/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Unified provider interface for LLM completion.

Exports protocols, base classes, domain models, enums, and errors
for the provider layer.
"""

from .base import BaseCompletionProvider
from .capabilities import ModelCapabilities
from .enums import FinishReason, MessageRole, StreamEventType
from .errors import (
AuthenticationError,
ContentFilterError,
InvalidRequestError,
ModelNotFoundError,
ProviderConnectionError,
ProviderError,
ProviderInternalError,
ProviderTimeoutError,
RateLimitError,
)
from .models import (
ChatMessage,
CompletionConfig,
CompletionResponse,
StreamChunk,
TokenUsage,
ToolCall,
ToolDefinition,
ToolResult,
)
from .protocol import CompletionProvider

__all__ = [
"AuthenticationError",
"BaseCompletionProvider",
"ChatMessage",
"CompletionConfig",
"CompletionProvider",
"CompletionResponse",
"ContentFilterError",
"FinishReason",
"InvalidRequestError",
"MessageRole",
"ModelCapabilities",
"ModelNotFoundError",
"ProviderConnectionError",
"ProviderError",
"ProviderInternalError",
"ProviderTimeoutError",
"RateLimitError",
"StreamChunk",
"StreamEventType",
"TokenUsage",
"ToolCall",
"ToolDefinition",
"ToolResult",
]
271 changes: 271 additions & 0 deletions src/ai_company/providers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
"""Abstract base class for completion providers.

Concrete adapters subclass ``BaseCompletionProvider`` and implement
the ``_do_*`` hooks. The base class handles input validation and
provides a cost-computation helper.
"""

import math
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator # noqa: TC003

from .capabilities import ModelCapabilities # noqa: TC001
from .errors import InvalidRequestError
from .models import (
ChatMessage,
CompletionConfig,
CompletionResponse,
StreamChunk,
TokenUsage,
ToolDefinition,
)

_COST_ROUNDING_PRECISION: int = 10
"""Decimal places for cost rounding to avoid floating-point dust."""

Copilot AI Mar 1, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This introduces a new rounding-precision constant with the same value/purpose as ai_company.constants.BUDGET_ROUNDING_PRECISION. To avoid duplicated “10” precision definitions drifting over time, consider reusing a shared constant (or promoting a generic rounding constant in ai_company.constants) instead of defining _COST_ROUNDING_PRECISION locally.

Copilot uses AI. Check for mistakes.


class BaseCompletionProvider(ABC):
"""Shared base for all completion provider adapters.

Subclasses implement three hooks:

* ``_do_complete`` — raw non-streaming call
* ``_do_stream`` — raw streaming call
* ``_do_get_model_capabilities`` — capability lookup

The public methods validate inputs before delegating to hooks.
A static ``compute_cost`` helper is available for subclasses to
build ``TokenUsage`` records from raw token counts.
"""

# -- Public API ---------------------------------------------------

async def complete(
self,
messages: list[ChatMessage],
model: str,
*,
tools: list[ToolDefinition] | None = None,
config: CompletionConfig | None = None,
) -> CompletionResponse:
"""Validate inputs, delegate to ``_do_complete``.

Args:
messages: Conversation history.
model: Model identifier to use.
tools: Available tools for function calling.
config: Optional completion parameters.

Returns:
The completion response returned by the subclass
``_do_complete`` hook, unmodified.

Raises:
InvalidRequestError: If messages are empty or model is blank.
"""
self._validate_messages(messages)
self._validate_model(model)
return await self._do_complete(
messages,
model,
tools=tools,
config=config,
)

async def stream(
self,
messages: list[ChatMessage],
model: str,
*,
tools: list[ToolDefinition] | None = None,
config: CompletionConfig | None = None,
) -> AsyncIterator[StreamChunk]:
"""Validate inputs, delegate to ``_do_stream``.

Args:
messages: Conversation history.
model: Model identifier to use.
tools: Available tools for function calling.
config: Optional completion parameters.

Returns:
Async iterator of stream chunks returned by the subclass
``_do_stream`` hook, unmodified.

Raises:
InvalidRequestError: If messages are empty or model is blank.
"""
self._validate_messages(messages)
self._validate_model(model)
return await self._do_stream(
messages,
model,
tools=tools,
config=config,
)

async def get_model_capabilities(self, model: str) -> ModelCapabilities:
"""Validate model identifier, delegate to ``_do_get_model_capabilities``.

Args:
model: Model identifier.

Returns:
Static capability and cost information.

Raises:
InvalidRequestError: If model is blank.
"""
self._validate_model(model)
return await self._do_get_model_capabilities(model)

# -- Hooks (subclasses implement) ---------------------------------

@abstractmethod
async def _do_complete(
self,
messages: list[ChatMessage],
model: str,
*,
tools: list[ToolDefinition] | None = None,
config: CompletionConfig | None = None,
) -> CompletionResponse:
"""Provider-specific non-streaming completion.

Subclasses **must** catch all provider-specific exceptions and
re-raise them as appropriate ``ProviderError`` subclasses.
Exceptions that escape without wrapping will bypass the error
hierarchy.

Raises:
ProviderError: All errors must use the provider error hierarchy.
"""
...

@abstractmethod
async def _do_stream(
self,
messages: list[ChatMessage],
model: str,
*,
tools: list[ToolDefinition] | None = None,
config: CompletionConfig | None = None,
) -> AsyncIterator[StreamChunk]:
r"""Provider-specific streaming completion.

Implementations must *return* an ``AsyncIterator`` (not ``yield``
directly), since the caller ``await``\s this coroutine to obtain

Copilot AI Mar 1, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring contains the text await\s, which renders literally as "await\s" and looks like an escaped typo. Replace it with normal wording (e.g., "awaits") to avoid confusing readers and doc tooling.

Suggested change
directly), since the caller ``await``\s this coroutine to obtain
directly), since the caller awaits this coroutine to obtain

Copilot uses AI. Check for mistakes.
the iterator.

Subclasses **must** catch all provider-specific exceptions and
re-raise them as appropriate ``ProviderError`` subclasses.

Raises:
ProviderError: All errors must use the provider error hierarchy.
"""
...

@abstractmethod
async def _do_get_model_capabilities(
self,
model: str,
) -> ModelCapabilities:
"""Provider-specific capability lookup.

Raises:
ProviderError: All errors must use the provider error hierarchy.
"""
...

# -- Helpers ------------------------------------------------------

@staticmethod
def compute_cost(
input_tokens: int,
output_tokens: int,
*,
cost_per_1k_input: float,
cost_per_1k_output: float,
) -> TokenUsage:
"""Build a ``TokenUsage`` from raw token counts and per-1k rates.

Args:
input_tokens: Number of input tokens (must be >= 0).
output_tokens: Number of output tokens (must be >= 0).
cost_per_1k_input: Cost per 1 000 input tokens in USD (>= 0).
cost_per_1k_output: Cost per 1 000 output tokens in USD (>= 0).

Returns:
Populated ``TokenUsage`` with computed cost.

Raises:
InvalidRequestError: If any parameter is negative.

Copilot AI Mar 1, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compute_cost() rejects non-finite rates (NaN/inf) in addition to negatives, but the docstring says it only raises InvalidRequestError when a parameter is negative. Please update the docstring to reflect the actual validation behavior (non-negative and finite).

Suggested change
cost_per_1k_input: Cost per 1 000 input tokens in USD (>= 0).
cost_per_1k_output: Cost per 1 000 output tokens in USD (>= 0).
Returns:
Populated ``TokenUsage`` with computed cost.
Raises:
InvalidRequestError: If any parameter is negative.
cost_per_1k_input: Cost per 1 000 input tokens in USD (finite and >= 0).
cost_per_1k_output: Cost per 1 000 output tokens in USD (finite and >= 0).
Returns:
Populated ``TokenUsage`` with computed cost.
Raises:
InvalidRequestError: If any parameter is negative or non-finite
(for example, NaN or infinity).

Copilot uses AI. Check for mistakes.
"""
if input_tokens < 0:
msg = "input_tokens must be non-negative"
raise InvalidRequestError(
msg,
context={"input_tokens": input_tokens},
)
if output_tokens < 0:
msg = "output_tokens must be non-negative"
raise InvalidRequestError(
msg,
context={"output_tokens": output_tokens},
)
if cost_per_1k_input < 0 or not math.isfinite(cost_per_1k_input):
msg = "cost_per_1k_input must be a finite non-negative number"
raise InvalidRequestError(
msg,
context={"cost_per_1k_input": cost_per_1k_input},
)
if cost_per_1k_output < 0 or not math.isfinite(cost_per_1k_output):
msg = "cost_per_1k_output must be a finite non-negative number"
raise InvalidRequestError(
msg,
context={"cost_per_1k_output": cost_per_1k_output},
)
cost = (input_tokens / 1000) * cost_per_1k_input + (
output_tokens / 1000
) * cost_per_1k_output
return TokenUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
cost_usd=round(cost, _COST_ROUNDING_PRECISION),
)
Comment on lines +230 to +238

Copilot AI Mar 1, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compute_cost() hard-codes the rounding precision (round(cost, 10)), which makes the intent and consistency across the codebase harder to manage. There is already a shared BUDGET_ROUNDING_PRECISION = 10 constant used elsewhere to avoid float artifacts; consider reusing a shared constant (or introducing a provider-specific one) instead of a magic number here.

Copilot uses AI. Check for mistakes.

@staticmethod
def _validate_messages(messages: list[ChatMessage]) -> None:
"""Reject empty message lists.

Args:
messages: Conversation messages.

Raises:
InvalidRequestError: If no messages are provided.
"""
if not messages:
msg = "messages must not be empty"
raise InvalidRequestError(msg, context={"field": "messages"})

@staticmethod
def _validate_model(model: str) -> None:
"""Reject blank, empty, or non-string model identifiers.

Args:
model: Model identifier string.

Raises:
InvalidRequestError: If model is not a string, empty,
or whitespace-only.
"""
if not isinstance(model, str) or not model.strip():
msg = "model must be a non-blank string"
raise InvalidRequestError(
msg,
context={
"field": "model",
"received_type": type(model).__name__,
},
)
Loading