diff --git a/CLAUDE.md b/CLAUDE.md index 0f56961c83..366a8c22eb 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -49,9 +49,9 @@ src/ai_company/ communication/ # Message bus, dispatcher, messenger, channels, delegation, loop prevention, conflict resolution, meeting protocol config/ # YAML company config loading and validation core/ # Shared domain models and base classes - engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, task lifecycle, recovery, shutdown, workspace isolation, and coordination error classification + engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, task lifecycle, recovery, shutdown, workspace isolation, coordination error classification, and prompt policy validation hr/ # HR engine: hiring, firing, onboarding, offboarding, agent registry, performance tracking (task metrics, collaboration scoring, trend detection) - memory/ # Persistent agent memory (Mem0 initial, custom stack future — ADR-001), retrieval pipeline (ranking, injection, context formatting), shared org memory (org/), consolidation/archival (consolidation/) + memory/ # Persistent agent memory (Mem0 initial, custom stack future — ADR-001), retrieval pipeline (ranking, injection, context formatting, non-inferable filtering), shared org memory (org/), consolidation/archival (consolidation/) persistence/ # Operational data persistence — pluggable PersistenceBackend protocol, SQLite initial (§7.6) observability/ # Structured logging, correlation tracking, log sinks providers/ # LLM provider abstraction (LiteLLM adapter) @@ -84,7 +84,7 @@ src/ai_company/ - **Every module** with business logic MUST have: `from ai_company.observability import get_logger` then `logger = get_logger(__name__)` - **Never** use `import logging` / `logging.getLogger()` / `print()` in application code - **Variable name**: always `logger` (not `_logger`, not `log`) -- **Event names**: always use constants from the domain-specific module under `ai_company.observability.events` (e.g. `PROVIDER_CALL_START` from `events.provider`, `BUDGET_RECORD_ADDED` from `events.budget`, `CFO_ANOMALY_DETECTED` from `events.cfo`, `CONFLICT_DETECTED` from `events.conflict`, `MEETING_STARTED` from `events.meeting`, `CLASSIFICATION_START` from `events.classification`, `CONSOLIDATION_START` from `events.consolidation`, `ORG_MEMORY_QUERY_START` from `events.org_memory`, `API_REQUEST_STARTED` from `events.api`, `CODE_RUNNER_EXECUTE_START` from `events.code_runner`, `DOCKER_EXECUTE_START` from `events.docker`, `MCP_INVOKE_START` from `events.mcp`, `SECURITY_EVALUATE_START` from `events.security`, `HR_HIRING_REQUEST_CREATED` from `events.hr`, `PERF_METRIC_RECORDED` from `events.performance`). Import directly: `from ai_company.observability.events. import EVENT_CONSTANT` +- **Event names**: always use constants from the domain-specific module under `ai_company.observability.events` (e.g. `PROVIDER_CALL_START` from `events.provider`, `BUDGET_RECORD_ADDED` from `events.budget`, `CFO_ANOMALY_DETECTED` from `events.cfo`, `CONFLICT_DETECTED` from `events.conflict`, `MEETING_STARTED` from `events.meeting`, `CLASSIFICATION_START` from `events.classification`, `CONSOLIDATION_START` from `events.consolidation`, `ORG_MEMORY_QUERY_START` from `events.org_memory`, `API_REQUEST_STARTED` from `events.api`, `CODE_RUNNER_EXECUTE_START` from `events.code_runner`, `DOCKER_EXECUTE_START` from `events.docker`, `MCP_INVOKE_START` from `events.mcp`, `SECURITY_EVALUATE_START` from `events.security`, `HR_HIRING_REQUEST_CREATED` from `events.hr`, `PERF_METRIC_RECORDED` from `events.performance`, `PROMPT_BUILD_START` from `events.prompt`, `MEMORY_RETRIEVAL_START` from `events.memory`). Import directly: `from ai_company.observability.events. import EVENT_CONSTANT` - **Structured kwargs**: always `logger.info(EVENT, key=value)` — never `logger.info("msg %s", val)` - **All error paths** must log at WARNING or ERROR with context before raising - **All state transitions** must log at INFO diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index 136e714d1d..e56fd4e124 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -80,7 +80,7 @@ The MVP validates the core hypothesis: **a single agent can complete a real task > **How to read this spec:** Sections describe the full vision. Each section with deferred features includes an **MVP** callout box indicating what ships in M3 and what is deferred. The full design is documented upfront to inform architecture decisions — protocol interfaces are designed even for features that won't be built until later milestones. > **Implementation snapshot (2026-03-10):** -> - **Done:** M0–M6 (tooling, config/core, providers, single-agent engine, multi-agent orchestration, API/CLI surface) + Docker sandbox (#50), MCP bridge (#53), code runner + HR engine (hiring/firing/onboarding/offboarding/registry) + performance tracking (task metrics, quality scoring, collaboration scoring, trend detection, rolling windows). Memory layer backend selected ([ADR-001](docs/decisions/ADR-001-memory-layer.md)). Persistence backend (§7.6) completed. Memory retrieval pipeline (#41: ranking, token-budget formatting, context injection) complete. Budget enforcement complete (BudgetEnforcer + configurable cost tiers + quota/subscription tracking). CFO cost optimization complete (CostOptimizer: anomaly detection, efficiency analysis, downgrade recommendations, routing optimization, approval decisions; ReportGenerator: multi-dimensional spending reports). Shared org memory (#125: HybridPromptRetrievalBackend, OrgFactStore, access control, factory) complete. Memory consolidation/archival (#48: ConsolidationService, SimpleConsolidationStrategy, RetentionEnforcer, ArchivalStore protocol) complete. +> - **Done:** M0–M6 (tooling, config/core, providers, single-agent engine, multi-agent orchestration, API/CLI surface) + Docker sandbox (#50), MCP bridge (#53), code runner + HR engine (hiring/firing/onboarding/offboarding/registry) + performance tracking (task metrics, quality scoring, collaboration scoring, trend detection, rolling windows). Memory layer backend selected ([ADR-001](docs/decisions/ADR-001-memory-layer.md)). Persistence backend (§7.6) completed. Memory retrieval pipeline (#41: ranking, token-budget formatting, context injection, non-inferable filtering) complete. Budget enforcement complete (BudgetEnforcer + configurable cost tiers + quota/subscription tracking). CFO cost optimization complete (CostOptimizer: anomaly detection, efficiency analysis, downgrade recommendations, routing optimization, approval decisions; ReportGenerator: multi-dimensional spending reports). Shared org memory (#125: HybridPromptRetrievalBackend, OrgFactStore, access control, factory) complete. Memory consolidation/archival (#48: ConsolidationService, SimpleConsolidationStrategy, RetentionEnforcer, ArchivalStore protocol) complete. > - **Remaining:** M7 security + approval system (SecOps agent, progressive trust, JWT/OAuth auth). ### 1.5 Configuration Philosophy @@ -1605,11 +1605,12 @@ receives memories. > **Non-inferable filter:** Retrieved memories should be filtered before injection to exclude content the agent can discover by reading the codebase or environment. Only inject memories containing non-inferable information: prior decisions, learned conventions, interpersonal context, historical outcomes. [Research](https://arxiv.org/abs/2602.11988) shows generic context increases cost 20%+ with minimal success improvement; LLM-generated context can actually reduce success rates. > -> **Decision ([ADR-002](docs/decisions/ADR-002-design-decisions-batch-1.md) D23):** Pluggable `MemoryFilterStrategy` protocol. Initial: tag-based at write time. Define `non-inferable` tag convention enforced at `MemoryBackend.store()` boundary. System prompt instructs agents what qualifies: design rationale, team decisions, "why not X", cross-repo knowledge = non-inferable; code structure, API signatures, file contents = inferable. Uses existing `MemoryMetadata.tags` and `MemoryQuery.tags` — zero new models needed. Future strategies: LLM classification at retrieval, keyword/pattern heuristic. +> **Decision ([ADR-002](docs/decisions/ADR-002-design-decisions-batch-1.md) D23):** Pluggable `MemoryFilterStrategy` protocol. Initial: tag-based at write time. Define `non-inferable` tag convention with advisory validation at `MemoryBackend.store()` boundary (warns on missing tags, never blocks). System prompt instructs agents what qualifies: design rationale, team decisions, "why not X", cross-repo knowledge = non-inferable; code structure, API signatures, file contents = inferable. Uses existing `MemoryMetadata.tags` and `MemoryQuery.tags` — zero new models needed. Future strategies: LLM classification at retrieval, keyword/pattern heuristic. Pipeline: `MemoryBackend.retrieve()` -> rank by relevance+recency -> -filter by min_relevance -> greedy token-budget packing -> format as -ChatMessage (configured role: SYSTEM or USER) with delimiters. +filter by min_relevance -> apply `MemoryFilterStrategy` (D23, optional) -> +greedy token-budget packing -> format as ChatMessage (configured role: +SYSTEM or USER) with delimiters. Ranking algorithm: 1. `relevance = entry.relevance_score ?? config.default_relevance` @@ -1961,6 +1962,8 @@ Every completion call produces a `CompletionResponse` with `TokenUsage` (token c - `tokens_per_task` — total tokens consumed (from `AgentContext.accumulated_cost.total_tokens`) - `cost_per_task` — total USD cost (from `AgentContext.accumulated_cost.cost_usd` via `AgentRunResult.total_cost_usd`) - `duration_seconds` — wall-clock execution time in seconds (from `AgentRunResult.duration_seconds`) +- `prompt_tokens` — estimated system prompt tokens (from `SystemPrompt.estimated_tokens`) +- `prompt_token_ratio` — ratio of prompt tokens to total tokens (overhead indicator, `@computed_field`; warns when >0.3) These are natural overhead indicators — a task consuming 15 turns and 50k tokens for a one-line fix signals a problem. @@ -2771,7 +2774,7 @@ ai-company/ │ │ ├── role.py # Role model │ │ ├── role_catalog.py # Role catalog │ │ └── personality.py # Personality compatibility scoring -│ ├── engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, task lifecycle, recovery, shutdown, workspace isolation, and coordination error classification +│ ├── engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, task lifecycle, recovery, shutdown, workspace isolation, coordination error classification, and prompt policy validation │ │ ├── errors.py # Engine error hierarchy │ │ ├── prompt.py # System prompt builder │ │ ├── prompt_template.py # System prompt Jinja2 templates @@ -2779,6 +2782,7 @@ ai-company/ │ │ ├── context.py # AgentContext + AgentContextSnapshot │ │ ├── loop_protocol.py # ExecutionLoop protocol + result models │ │ ├── metrics.py # TaskCompletionMetrics proxy overhead model +│ │ ├── policy_validation.py # Org policy quality heuristics (non-inferable principle) │ │ ├── react_loop.py # ReAct loop implementation │ │ ├── plan_models.py # Plan step, plan, and plan-execute config models │ │ ├── plan_execute_loop.py # Plan-and-Execute loop implementation @@ -2910,7 +2914,7 @@ ai-company/ │ │ │ └── structured_phases.py # StructuredPhasesProtocol implementation │ │ ├── messenger.py # AgentMessenger per-agent facade │ │ └── subscription.py # Subscription + DeliveryEnvelope models -│ ├── memory/ # Agent memory system — protocols, models, config, factory, retrieval pipeline (M5) +│ ├── memory/ # Agent memory system — protocols, models, config, factory, retrieval pipeline (ranking, injection, context formatting, non-inferable filtering) (M5) │ │ ├── __init__.py # Re-exports │ │ ├── capabilities.py # MemoryCapabilities protocol │ │ ├── config.py # CompanyMemoryConfig, MemoryStorageConfig, MemoryOptionsConfig @@ -2922,7 +2926,9 @@ ai-company/ │ │ ├── protocol.py # MemoryBackend protocol │ │ ├── ranking.py # ScoredMemory model, rank_memories(), scoring functions │ │ ├── retrieval_config.py # MemoryRetrievalConfig (weights, thresholds, strategy selection) +│ │ ├── filter.py # MemoryFilterStrategy protocol, TagBasedMemoryFilter, PassthroughMemoryFilter │ │ ├── retriever.py # ContextInjectionStrategy (full retrieval → rank → format pipeline) +│ │ ├── store_guard.py # Advisory non-inferable tag enforcement at store boundary │ │ ├── shared.py # SharedKnowledgeStore protocol │ │ ├── consolidation/ # Memory consolidation — strategies, retention, archival │ │ │ ├── __init__.py @@ -2992,6 +2998,7 @@ ai-company/ │ │ │ ├── role.py # ROLE_* constants │ │ │ ├── routing.py # ROUTING_* constants │ │ │ ├── sandbox.py # SANDBOX_* constants +│ │ │ ├── security.py # SECURITY_* constants │ │ │ ├── task.py # TASK_* constants │ │ │ ├── task_assignment.py # TASK_ASSIGNMENT_* constants │ │ │ ├── task_routing.py # TASK_ROUTING_* constants diff --git a/README.md b/README.md index a46258f4db..81ca1af44b 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ AI Company lets you spin up a virtual organization staffed entirely by AI agents - **Task Intelligence (M4)** - Task decomposition, routing, assignment strategies, workspace isolation via git worktrees - **Templates** - Built-in templates, inheritance/merge, rendering, personality presets - **Persistence Layer (M5)** - Pluggable `PersistenceBackend` protocol with SQLite backend (aiosqlite), repository protocols, schema migrations -- **Memory Interface (M5)** - Pluggable `MemoryBackend` protocol with capability discovery, shared knowledge protocol, domain models, config, factory, and context injection retrieval pipeline (ranking, token-budget formatting). Shared organizational memory via `OrgMemoryBackend` protocol with hybrid prompt+retrieval backend. Memory consolidation/archival with pluggable strategies and retention enforcement +- **Memory Interface (M5)** - Pluggable `MemoryBackend` protocol with capability discovery, shared knowledge protocol, domain models, config, factory, and context injection retrieval pipeline (ranking, token-budget formatting, non-inferable filtering). Shared organizational memory via `OrgMemoryBackend` protocol with hybrid prompt+retrieval backend. Memory consolidation/archival with pluggable strategies and retention enforcement - **Coordination Error Taxonomy (M5)** - Post-execution classification pipeline detecting logical contradictions, numerical drift, context omissions, and coordination failures - **Budget Enforcement (M5)** - `BudgetEnforcer` service with pre-flight checks, in-flight budget checking, auto-downgrade, configurable cost tiers, and quota/subscription tracking; `CostOptimizer` CFO service with anomaly detection, efficiency analysis, downgrade recommendations, and approval decisions; `ReportGenerator` for multi-dimensional spending reports - **Litestar REST API (M6)** - 13 controllers + WebSocket handler covering company, agents, tasks, budget, approvals, analytics, messages, meetings, projects, departments, artifacts, providers, health, and WebSocket real-time feed diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index 383ea03f4f..031351bac2 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -50,6 +50,7 @@ EXECUTION_ENGINE_TIMEOUT, EXECUTION_RECOVERY_FAILED, ) +from ai_company.observability.events.prompt import PROMPT_TOKEN_RATIO_HIGH from ai_company.observability.events.security import SECURITY_DISABLED from ai_company.providers.enums import MessageRole from ai_company.providers.models import ChatMessage @@ -91,6 +92,9 @@ logger = get_logger(__name__) +_PROMPT_TOKEN_RATIO_THRESHOLD: float = 0.3 +"""Prompt-to-total token ratio above which a warning is emitted.""" + _DEFAULT_RECOVERY_STRATEGY = FailAndReassignStrategy() """Module-level default instance for the recovery strategy.""" @@ -357,11 +361,12 @@ async def _post_execution_pipeline( except MemoryError, RecursionError: raise except Exception: - logger.debug( + logger.warning( EXECUTION_ENGINE_ERROR, agent_id=agent_id, task_id=task_id, - error="classification failed (details logged by pipeline)", + error="classification failed", + exc_info=True, ) return execution_result @@ -760,8 +765,20 @@ def _log_completion( tokens_per_task=metrics.tokens_per_task, cost_per_task=metrics.cost_per_task, duration_seconds=metrics.duration_seconds, + prompt_tokens=metrics.prompt_tokens, + prompt_token_ratio=metrics.prompt_token_ratio, ) + if metrics.prompt_token_ratio > _PROMPT_TOKEN_RATIO_THRESHOLD: + logger.warning( + PROMPT_TOKEN_RATIO_HIGH, + agent_id=agent_id, + task_id=task_id, + prompt_token_ratio=metrics.prompt_token_ratio, + prompt_tokens=metrics.prompt_tokens, + total_tokens=metrics.tokens_per_task, + ) + def _handle_budget_error( # noqa: PLR0913 self, *, diff --git a/src/ai_company/engine/metrics.py b/src/ai_company/engine/metrics.py index d8c9f9dfb4..ea523e97a4 100644 --- a/src/ai_company/engine/metrics.py +++ b/src/ai_company/engine/metrics.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator from ai_company.core.types import NotBlankStr # noqa: TC001 @@ -27,9 +27,15 @@ class TaskCompletionMetrics(BaseModel): tokens_per_task: Total tokens consumed (input + output). cost_per_task: Total USD cost for the task. duration_seconds: Wall-clock execution time in seconds. + prompt_tokens: Estimated system prompt tokens (per-call estimate + from ``SystemPrompt.estimated_tokens``). + prompt_token_ratio: Per-call ratio of prompt tokens to total tokens + (overhead indicator, derived via ``@computed_field``). For + multi-turn runs, the actual overhead is higher because the + system prompt is resent on every turn. """ - model_config = ConfigDict(frozen=True) + model_config = ConfigDict(frozen=True, allow_inf_nan=False) task_id: NotBlankStr | None = Field( default=None, @@ -52,6 +58,35 @@ class TaskCompletionMetrics(BaseModel): ge=0.0, description="Wall-clock execution time in seconds", ) + prompt_tokens: int = Field( + default=0, + ge=0, + description="Estimated system prompt tokens", + ) + + @model_validator(mode="after") + def _cap_prompt_tokens(self) -> TaskCompletionMetrics: + """Cap prompt_tokens to tokens_per_task. + + The heuristic estimator (char/4) can legitimately overshoot + actual provider-reported tokens, so we clamp rather than reject. + Skipped when ``tokens_per_task`` is 0 (zero-turn runs). + """ + if self.tokens_per_task > 0 and self.prompt_tokens > self.tokens_per_task: + object.__setattr__(self, "prompt_tokens", self.tokens_per_task) + return self + + @computed_field # type: ignore[prop-decorator] + @property + def prompt_token_ratio(self) -> float: + """Per-call ratio of prompt tokens to total tokens (overhead indicator). + + For multi-turn runs the actual overhead is higher because the + system prompt is resent on every turn. + """ + if self.tokens_per_task > 0: + return self.prompt_tokens / self.tokens_per_task + return 0.0 @classmethod def from_run_result(cls, result: AgentRunResult) -> TaskCompletionMetrics: @@ -72,4 +107,5 @@ def from_run_result(cls, result: AgentRunResult) -> TaskCompletionMetrics: tokens_per_task=accumulated.total_tokens, cost_per_task=result.total_cost_usd, duration_seconds=result.duration_seconds, + prompt_tokens=result.system_prompt.estimated_tokens, ) diff --git a/src/ai_company/engine/policy_validation.py b/src/ai_company/engine/policy_validation.py new file mode 100644 index 0000000000..58263f12cf --- /dev/null +++ b/src/ai_company/engine/policy_validation.py @@ -0,0 +1,229 @@ +"""Org policy quality validation heuristics. + +Applies lightweight checks to detect policies that likely violate the +non-inferable principle — e.g. policies that describe codebase structure +(inferable by reading the repo) rather than actionable constraints. + +Examples of **good** policies (non-inferable, actionable): + +- ``"All API responses must include a correlation_id header"`` +- ``"Never store PII in memory without encryption"`` +- ``"Escalate budget overruns above $5 to the CFO"`` + +Examples of **bad** policies (inferable or non-actionable): + +- ``"The project uses Python 3.14"`` — discoverable from pyproject.toml +- ``"src/api/ contains REST controllers"`` — discoverable by reading code +- ``"x"`` — too short to be actionable +""" + +import re +from typing import Final, Literal + +from pydantic import BaseModel, ConfigDict, Field + +from ai_company.observability import get_logger +from ai_company.observability.events.prompt import ( + PROMPT_POLICY_QUALITY_ISSUE, + PROMPT_POLICY_VALIDATION_START, +) + +logger = get_logger(__name__) + +_MIN_POLICY_LENGTH: Final[int] = 10 +_MAX_POLICY_LENGTH: Final[int] = 500 + +# Patterns that suggest inferable codebase context rather than a policy. +# Case-insensitive to catch capitalized variants like "Import json". +_CODE_PATTERNS: Final[tuple[re.Pattern[str], ...]] = ( + re.compile(r"(?:src|tests|lib|app)/[\w/]+\.py", re.IGNORECASE), # file paths + re.compile(r"\bfrom\s+\w+\s+import\b", re.IGNORECASE), # Python imports + re.compile(r"\bimport\s+\w+", re.IGNORECASE), # bare imports + re.compile(r"\bdef\s+\w+\s*\(", re.IGNORECASE), # function definitions + re.compile(r"\bclass\s+\w+[\s:(]", re.IGNORECASE), # class definitions +) + +# Directive keywords that signal an actionable policy constraint. +_ACTION_VERBS: Final[frozenset[str]] = frozenset( + { + "must", + "should", + "always", + "never", + "require", + "ensure", + "prohibit", + "enforce", + "restrict", + "mandate", + "avoid", + "prefer", + "escalate", + "approve", + "deny", + "reject", + "validate", + "verify", + } +) + + +class PolicyQualityIssue(BaseModel): + """A quality issue found in an org policy. + + Attributes: + policy: The policy text that triggered the issue. + issue: Human-readable description of the problem. + severity: ``"warning"`` for advisory issues; ``"error"`` + reserved for future stricter checks. + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + policy: str = Field(description="The policy text that triggered the issue") + issue: str = Field( + min_length=1, + description="Human-readable description of the problem", + ) + severity: Literal["warning", "error"] = Field( + description="Issue severity (``'error'`` reserved for future stricter checks)", + ) + + +def validate_policy_quality( + policies: tuple[str, ...], +) -> tuple[PolicyQualityIssue, ...]: + """Check org policies for non-inferable principle violations. + + Applies heuristic checks — results are advisory and never block + prompt construction. + + Args: + policies: Org policy texts to validate. + + Returns: + Tuple of quality issues found (empty if all policies pass). + """ + logger.debug( + PROMPT_POLICY_VALIDATION_START, + policy_count=len(policies), + ) + issues: list[PolicyQualityIssue] = [] + for policy in policies: + issues.extend(_check_single_policy(policy)) + + for issue in issues: + logger.warning( + PROMPT_POLICY_QUALITY_ISSUE, + policy_length=len(issue.policy), + issue=issue.issue, + severity=issue.severity, + ) + + return tuple(issues) + + +_ACTION_VERB_RE: re.Pattern[str] = re.compile( + r"\b(?:" + "|".join(sorted(_ACTION_VERBS)) + r")\b", +) + + +def _check_single_policy(policy: str) -> list[PolicyQualityIssue]: + """Run all heuristic checks on a single policy string. + + Args: + policy: The policy text to validate. + + Returns: + List of quality issues found (empty if the policy passes all checks). + """ + return [ + *_check_length(policy), + *_check_code_patterns(policy), + *_check_action_keywords(policy), + ] + + +def _check_length(policy: str) -> list[PolicyQualityIssue]: + """Flag policies that are too short or too long. + + Args: + policy: The policy text to validate. + + Returns: + List of length-related issues (0-2 items). + """ + found: list[PolicyQualityIssue] = [] + + if len(policy) < _MIN_POLICY_LENGTH: + found.append( + PolicyQualityIssue( + policy=policy, + issue=( + f"Too short ({len(policy)} chars) — likely not an actionable policy" + ), + severity="warning", + ), + ) + + if len(policy) > _MAX_POLICY_LENGTH: + found.append( + PolicyQualityIssue( + policy=policy, + issue=( + f"Too long ({len(policy)} chars) — " + f"may contain inferable context rather than a policy" + ), + severity="warning", + ), + ) + + return found + + +def _check_code_patterns(policy: str) -> list[PolicyQualityIssue]: + """Flag policies that contain inferable code patterns. + + Args: + policy: The policy text to validate. + + Returns: + List with one issue if a code pattern is found, else empty. + """ + for pattern in _CODE_PATTERNS: + if pattern.search(policy): + return [ + PolicyQualityIssue( + policy=policy, + issue=( + "Contains code patterns (file paths, imports, or " + "definitions) — likely inferable from the codebase" + ), + severity="warning", + ), + ] + return [] + + +def _check_action_keywords(policy: str) -> list[PolicyQualityIssue]: + """Flag policies missing directive keywords. + + Args: + policy: The policy text to validate. + + Returns: + List with one issue if no action keywords found, else empty. + """ + policy_lower = policy.lower() + if not _ACTION_VERB_RE.search(policy_lower): + return [ + PolicyQualityIssue( + policy=policy, + issue=( + "Missing action verbs (must, should, always, never, " + "etc.) — may not be an actionable policy" + ), + severity="warning", + ), + ] + return [] diff --git a/src/ai_company/engine/prompt.py b/src/ai_company/engine/prompt.py index 0368db92d7..8007feb732 100644 --- a/src/ai_company/engine/prompt.py +++ b/src/ai_company/engine/prompt.py @@ -3,6 +3,14 @@ Translates agent configuration (personality, skills, authority, role) into contextually rich system prompts that shape agent behavior during LLM calls. +**Non-inferable principle:** System prompts should contain only information +that agents cannot discover by reading the codebase or environment. Tool +definitions, for example, are already delivered via the LLM provider's API +``tools`` parameter, so repeating them in the system prompt would increase +cost without benefit (per D22, arXiv:2602.11988). The default template +therefore omits the ``Available Tools`` section. Custom templates may still +reference ``{{ tools }}`` when explicitly needed. + Example:: from ai_company.engine.prompt import build_system_prompt @@ -19,6 +27,7 @@ from pydantic import BaseModel, ConfigDict, Field from ai_company.engine.errors import PromptBuildError +from ai_company.engine.policy_validation import validate_policy_quality from ai_company.engine.prompt_template import ( AUTONOMY_INSTRUCTIONS, DEFAULT_TEMPLATE, @@ -33,6 +42,7 @@ PROMPT_BUILD_TOKEN_TRIMMED, PROMPT_CUSTOM_TEMPLATE_FAILED, PROMPT_CUSTOM_TEMPLATE_LOADED, + PROMPT_POLICY_VALIDATION_FAILED, ) if TYPE_CHECKING: @@ -132,13 +142,14 @@ def estimate_tokens(self, text: str) -> int: _SECTION_ORG_POLICIES = "org_policies" _SECTION_AUTONOMY = "autonomy" _SECTION_TASK = "task" -_SECTION_TOOLS = "tools" _SECTION_COMPANY = "company" +_SECTION_TOOLS = "tools" # Sections trimmed when over token budget, least critical first. +# Tools section was removed from the default template per D22 +# (non-inferable principle), but custom templates may still render tools. _TRIMMABLE_SECTIONS = ( _SECTION_COMPANY, - _SECTION_TOOLS, _SECTION_TASK, _SECTION_ORG_POLICIES, ) @@ -162,14 +173,17 @@ def build_system_prompt( # noqa: PLR0913 """Build a system prompt from agent identity and optional context. When ``max_tokens`` is provided and the prompt exceeds it, optional - sections are progressively trimmed (company, tools, task, org_policies). + sections are progressively trimmed (company, task, org_policies). Args: agent: Agent identity containing personality, skills, authority. role: Optional role with description and responsibilities. task: Optional task context injected into the prompt. - available_tools: Tool definitions available to the agent. - company: Optional company context (name, departments). + available_tools: Tool definitions populated into template context + for custom templates only; the default template omits tools + per D22 (non-inferable principle). + company: Opt-in. Non-inferable principle recommends omitting + unless agents need org-level context they cannot discover. org_policies: Company-wide policy texts to inject into prompt. max_tokens: Token budget; sections are trimmed if exceeded. custom_template: Optional Jinja2 template string override. @@ -182,6 +196,20 @@ def build_system_prompt( # noqa: PLR0913 PromptBuildError: If prompt construction fails. """ _validate_max_tokens(agent, max_tokens) + _validate_org_policies(agent, org_policies) + + # Advisory only — issues are logged but never block prompt construction. + if org_policies: + try: + validate_policy_quality(org_policies) + except MemoryError, RecursionError: + raise + except Exception: + logger.warning( + PROMPT_POLICY_VALIDATION_FAILED, + agent_id=str(agent.id), + exc_info=True, + ) logger.info( PROMPT_BUILD_START, @@ -248,6 +276,30 @@ def _validate_max_tokens( raise PromptBuildError(msg) +def _validate_org_policies( + agent: AgentIdentity, + org_policies: tuple[str, ...], +) -> None: + """Raise ``PromptBuildError`` on blank or non-string policy entries. + + Args: + agent: Agent identity for error context. + org_policies: Policy texts to validate. + + Raises: + PromptBuildError: If any policy entry is empty or whitespace-only. + """ + for index, policy in enumerate(org_policies): + if not isinstance(policy, str) or not policy.strip(): + msg = f"org_policies[{index}] must be a non-empty string" + logger.error( + PROMPT_BUILD_ERROR, + agent_id=str(agent.id), + error=msg, + ) + raise PromptBuildError(msg) + + def _log_and_return( agent: AgentIdentity, result: SystemPrompt, @@ -398,17 +450,24 @@ def _build_template_context( # noqa: PLR0913 def _compute_sections( *, task: Task | None, - available_tools: tuple[ToolDefinition, ...], + available_tools: tuple[ToolDefinition, ...] = (), company: Company | None, org_policies: tuple[str, ...] = (), + custom_template: bool = False, ) -> tuple[str, ...]: """Determine which sections are present in the rendered prompt. + The default template omits the tools section per D22 (non-inferable + principle). Custom templates may still render tools, so the tools + section is tracked when ``available_tools`` is non-empty and a custom + template is in use. + Args: task: Optional task context. - available_tools: Tool definitions. + available_tools: Tool definitions (tracked for custom templates). company: Optional company context. org_policies: Company-wide policy texts. + custom_template: Whether a custom template is being used. Returns: Tuple of section names that are included. @@ -425,7 +484,7 @@ def _compute_sections( sections.append(_SECTION_AUTONOMY) if task is not None: sections.append(_SECTION_TASK) - if available_tools: + if available_tools and custom_template: sections.append(_SECTION_TOOLS) if company is not None: sections.append(_SECTION_COMPANY) @@ -487,14 +546,13 @@ def _trim_sections( # noqa: PLR0913 str, int, Task | None, - tuple[ToolDefinition, ...], Company | None, tuple[str, ...], ]: """Progressively remove optional sections until under token budget. - Returns ``(content, estimated, task, available_tools, company, - org_policies)`` so the caller can reuse the final render. + Returns ``(content, estimated, task, company, org_policies)`` + so the caller can reuse the final render. """ trimmed_sections: list[str] = [] @@ -516,8 +574,6 @@ def _trim_sections( # noqa: PLR0913 company = None elif section == _SECTION_ORG_POLICIES and org_policies: org_policies = () - elif section == _SECTION_TOOLS and available_tools: - available_tools = () elif section == _SECTION_TASK and task is not None: task = None else: @@ -539,7 +595,7 @@ def _trim_sections( # noqa: PLR0913 _log_trim_results(agent, max_tokens, estimated, trimmed_sections) - return content, estimated, task, available_tools, company, org_policies + return content, estimated, task, company, org_policies def _log_trim_results( @@ -591,18 +647,16 @@ def _render_with_trimming( # noqa: PLR0913 ) if max_tokens is not None and estimated > max_tokens: - content, estimated, task, available_tools, company, org_policies = ( - _trim_sections( - template_str=template_str, - agent=agent, - role=role, - task=task, - available_tools=available_tools, - company=company, - org_policies=org_policies, - max_tokens=max_tokens, - estimator=estimator, - ) + content, estimated, task, company, org_policies = _trim_sections( + template_str=template_str, + agent=agent, + role=role, + task=task, + available_tools=available_tools, + company=company, + org_policies=org_policies, + max_tokens=max_tokens, + estimator=estimator, ) return _build_prompt_result( @@ -613,6 +667,7 @@ def _render_with_trimming( # noqa: PLR0913 company, org_policies, agent, + custom_template=template_str is not DEFAULT_TEMPLATE, ) @@ -624,6 +679,8 @@ def _build_prompt_result( # noqa: PLR0913 company: Company | None, org_policies: tuple[str, ...], agent: AgentIdentity, + *, + custom_template: bool = False, ) -> SystemPrompt: """Assemble the final ``SystemPrompt`` from rendered content.""" sections = _compute_sections( @@ -631,6 +688,7 @@ def _build_prompt_result( # noqa: PLR0913 available_tools=available_tools, company=company, org_policies=org_policies, + custom_template=custom_template, ) return SystemPrompt( content=content, diff --git a/src/ai_company/engine/prompt_template.py b/src/ai_company/engine/prompt_template.py index af9ff7e576..5d39d6e74f 100644 --- a/src/ai_company/engine/prompt_template.py +++ b/src/ai_company/engine/prompt_template.py @@ -4,13 +4,20 @@ :func:`~ai_company.engine.prompt.build_system_prompt` to render agent system prompts. The template uses conditional sections that are omitted when the corresponding context is absent. + +**Non-inferable principle (D22):** The default template omits the +``Available Tools`` section because tool definitions are already passed to +the LLM provider via the API's ``tools`` parameter. Injecting them again +into the system prompt doubles cost with no benefit — agents can discover +tool details from the API-level definitions. Custom templates may still +reference ``{{ tools }}`` when explicitly needed. """ from typing import Final from ai_company.core.enums import SeniorityLevel -PROMPT_TEMPLATE_VERSION: Final[str] = "1.2.0" +PROMPT_TEMPLATE_VERSION: Final[str] = "1.3.0" # ── Autonomy instructions by seniority level ───────────────────── @@ -156,13 +163,6 @@ **Deadline**: {{ task.deadline }} {% endif %} {% endif %} -{% if tools %} - -## Available Tools -{% for tool in tools %} -- **{{ tool.name }}**{% if tool.description %}: {{ tool.description }}{% endif %} -{% endfor %} -{% endif %} {% if company %} ## Company Context diff --git a/src/ai_company/memory/filter.py b/src/ai_company/memory/filter.py new file mode 100644 index 0000000000..af8a3968f4 --- /dev/null +++ b/src/ai_company/memory/filter.py @@ -0,0 +1,143 @@ +"""Memory filter strategies for non-inferable principle enforcement. + +Filters scored memories before injection into agent prompts. The +``TagBasedMemoryFilter`` (initial D23 implementation) retains only +memories tagged with ``"non-inferable"``; the ``PassthroughMemoryFilter`` +is a no-op for backward compatibility and testing. + +Both satisfy the ``MemoryFilterStrategy`` runtime-checkable protocol. +""" + +from typing import TYPE_CHECKING, Final, Protocol, runtime_checkable + +from ai_company.observability import get_logger +from ai_company.observability.events.memory import ( + MEMORY_FILTER_APPLIED, + MEMORY_FILTER_INIT, +) + +if TYPE_CHECKING: + from ai_company.memory.ranking import ScoredMemory + +logger = get_logger(__name__) + +NON_INFERABLE_TAG: Final[str] = "non-inferable" + + +@runtime_checkable +class MemoryFilterStrategy(Protocol): + """Protocol for filtering scored memories before prompt injection.""" + + def filter_for_injection( + self, + memories: tuple[ScoredMemory, ...], + ) -> tuple[ScoredMemory, ...]: + """Filter memories suitable for injection. + + Args: + memories: Ranked scored memories from the retrieval pipeline. + + Returns: + Subset of memories that pass the filter. + """ + ... + + @property + def strategy_name(self) -> str: + """Human-readable name of the filter strategy.""" + ... + + +class TagBasedMemoryFilter: + """Filter that retains only memories with a required tag. + + The default required tag is ``"non-inferable"`` per D23. Memories + whose ``entry.metadata.tags`` do not contain the required tag are + excluded from prompt injection. + + Args: + required_tag: Tag that must be present for a memory to pass. + """ + + def __init__(self, required_tag: str = NON_INFERABLE_TAG) -> None: + if not isinstance(required_tag, str) or not required_tag.strip(): + msg = "required_tag must be a non-empty string" + raise ValueError(msg) + self._required_tag = required_tag.strip() + logger.debug( + MEMORY_FILTER_INIT, + strategy=self.strategy_name, + required_tag=required_tag, + ) + + def filter_for_injection( + self, + memories: tuple[ScoredMemory, ...], + ) -> tuple[ScoredMemory, ...]: + """Return only memories containing the required tag. + + Args: + memories: Ranked scored memories. + + Returns: + Filtered tuple with only tagged memories. + """ + retained = tuple( + m for m in memories if self._required_tag in m.entry.metadata.tags + ) + + logger.info( + MEMORY_FILTER_APPLIED, + strategy=self.strategy_name, + candidates=len(memories), + retained=len(retained), + required_tag=self._required_tag, + ) + + return retained + + @property + def strategy_name(self) -> str: + """Human-readable name of the filter strategy. + + Returns: + ``"tag_based"``. + """ + return "tag_based" + + +class PassthroughMemoryFilter: + """No-op filter that returns all memories unchanged. + + Useful for backward compatibility and testing — all memories pass + through without filtering. + """ + + def filter_for_injection( + self, + memories: tuple[ScoredMemory, ...], + ) -> tuple[ScoredMemory, ...]: + """Return all memories unchanged. + + Args: + memories: Ranked scored memories. + + Returns: + The input tuple unchanged. + """ + logger.info( + MEMORY_FILTER_APPLIED, + strategy=self.strategy_name, + candidates=len(memories), + retained=len(memories), + ) + return memories + + @property + def strategy_name(self) -> str: + """Human-readable name of the filter strategy. + + Returns: + ``"passthrough"``. + """ + return "passthrough" diff --git a/src/ai_company/memory/retrieval_config.py b/src/ai_company/memory/retrieval_config.py index c047cb228c..395f0cb36a 100644 --- a/src/ai_company/memory/retrieval_config.py +++ b/src/ai_company/memory/retrieval_config.py @@ -31,6 +31,8 @@ class MemoryRetrievalConfig(BaseModel): include_shared: Whether to query SharedKnowledgeStore. default_relevance: Score for entries missing relevance_score. injection_point: Message role for context injection. + non_inferable_only: When True, auto-creates a ``TagBasedMemoryFilter`` + in ``ContextInjectionStrategy`` if no explicit filter is provided. """ model_config = ConfigDict(frozen=True, allow_inf_nan=False) @@ -88,6 +90,10 @@ class MemoryRetrievalConfig(BaseModel): default=InjectionPoint.SYSTEM, description="Message role for context injection", ) + non_inferable_only: bool = Field( + default=False, + description="When True, only inject memories tagged as non-inferable", + ) @model_validator(mode="after") def _validate_weight_sum(self) -> Self: diff --git a/src/ai_company/memory/retriever.py b/src/ai_company/memory/retriever.py index 62c6dbb7b1..1c34edd9e6 100644 --- a/src/ai_company/memory/retriever.py +++ b/src/ai_company/memory/retriever.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING from ai_company.memory import errors as memory_errors +from ai_company.memory.filter import TagBasedMemoryFilter from ai_company.memory.formatter import format_memory_context from ai_company.memory.injection import ( DefaultTokenEstimator, @@ -18,6 +19,7 @@ from ai_company.memory.ranking import rank_memories from ai_company.observability import get_logger from ai_company.observability.events.memory import ( + MEMORY_FILTER_INIT, MEMORY_RETRIEVAL_COMPLETE, MEMORY_RETRIEVAL_DEGRADED, MEMORY_RETRIEVAL_SKIPPED, @@ -29,6 +31,7 @@ from ai_company.core.enums import MemoryCategory from ai_company.core.types import NotBlankStr + from ai_company.memory.filter import MemoryFilterStrategy from ai_company.memory.models import MemoryEntry from ai_company.memory.protocol import MemoryBackend from ai_company.memory.retrieval_config import MemoryRetrievalConfig @@ -105,6 +108,7 @@ def __init__( config: MemoryRetrievalConfig, shared_store: SharedKnowledgeStore | None = None, token_estimator: TokenEstimator | None = None, + memory_filter: MemoryFilterStrategy | None = None, ) -> None: """Initialise the context injection strategy. @@ -113,10 +117,25 @@ def __init__( config: Retrieval pipeline configuration. shared_store: Optional shared knowledge store. token_estimator: Optional custom token estimator. + memory_filter: Optional filter applied after ranking, + before formatting. When ``None`` and + ``config.non_inferable_only`` is ``True``, a + ``TagBasedMemoryFilter`` is auto-created. When ``None`` + and ``non_inferable_only`` is ``False``, all ranked + memories are injected (backward-compatible). """ self._backend = backend self._config = config self._shared_store = shared_store + if memory_filter is None and config.non_inferable_only: + memory_filter = TagBasedMemoryFilter() + elif memory_filter is not None and config.non_inferable_only: + logger.debug( + MEMORY_FILTER_INIT, + note="explicit memory_filter overrides non_inferable_only config", + filter_strategy=getattr(memory_filter, "strategy_name", "unknown"), + ) + self._memory_filter = memory_filter self._estimator = ( token_estimator if token_estimator is not None else DefaultTokenEstimator() ) @@ -261,6 +280,31 @@ async def _execute_pipeline( ) return () + if self._memory_filter is not None: + try: + ranked = self._memory_filter.filter_for_injection(ranked) + except builtins_MemoryError, RecursionError: + raise + except Exception as exc: + logger.warning( + MEMORY_RETRIEVAL_DEGRADED, + source="memory_filter", + agent_id=agent_id, + error_type=type(exc).__qualname__, + filter_strategy=getattr( + self._memory_filter, "strategy_name", "unknown" + ), + exc_info=True, + ) + # Graceful degradation: use unfiltered ranked memories. + if not ranked: + logger.info( + MEMORY_RETRIEVAL_SKIPPED, + agent_id=agent_id, + reason="all filtered by memory filter", + ) + return () + result = format_memory_context( ranked, estimator=self._estimator, diff --git a/src/ai_company/memory/store_guard.py b/src/ai_company/memory/store_guard.py new file mode 100644 index 0000000000..52b057444c --- /dev/null +++ b/src/ai_company/memory/store_guard.py @@ -0,0 +1,35 @@ +"""Store-boundary tag enforcement for non-inferable principle. + +Advisory guard that warns when memories are stored without the +``"non-inferable"`` tag. Never blocks — the store always succeeds. +""" + +from typing import TYPE_CHECKING + +from ai_company.memory.filter import NON_INFERABLE_TAG +from ai_company.observability import get_logger +from ai_company.observability.events.memory import ( + MEMORY_FILTER_STORE_MISSING_TAG, +) + +if TYPE_CHECKING: + from ai_company.memory.models import MemoryStoreRequest + +logger = get_logger(__name__) + + +def validate_memory_tags(request: MemoryStoreRequest) -> None: + """Log a warning when the non-inferable tag is missing. + + This is advisory only — the store operation is never blocked. + Wire into ``MemoryBackend.store()`` callers to activate enforcement. + + Args: + request: The memory store request to validate. + """ + if NON_INFERABLE_TAG not in request.metadata.tags: + logger.warning( + MEMORY_FILTER_STORE_MISSING_TAG, + category=request.category.value, + content_length=len(request.content), + ) diff --git a/src/ai_company/observability/events/memory.py b/src/ai_company/observability/events/memory.py index 8f64e93c72..203a3d7c91 100644 --- a/src/ai_company/observability/events/memory.py +++ b/src/ai_company/observability/events/memory.py @@ -63,3 +63,9 @@ "memory.format.invalid_injection_point" ) MEMORY_TOKEN_BUDGET_EXCEEDED: Final[str] = "memory.token_budget.exceeded" # noqa: S105 + +# ── Memory filter ────────────────────────────────────────────── + +MEMORY_FILTER_INIT: Final[str] = "memory.filter.init" +MEMORY_FILTER_APPLIED: Final[str] = "memory.filter.applied" +MEMORY_FILTER_STORE_MISSING_TAG: Final[str] = "memory.filter.store_missing_tag" diff --git a/src/ai_company/observability/events/prompt.py b/src/ai_company/observability/events/prompt.py index 06bdc26686..9e2aafaa77 100644 --- a/src/ai_company/observability/events/prompt.py +++ b/src/ai_company/observability/events/prompt.py @@ -9,3 +9,7 @@ PROMPT_BUILD_BUDGET_EXCEEDED: Final[str] = "prompt.build.budget_exceeded" PROMPT_CUSTOM_TEMPLATE_LOADED: Final[str] = "prompt.custom_template.loaded" PROMPT_CUSTOM_TEMPLATE_FAILED: Final[str] = "prompt.custom_template.failed" +PROMPT_POLICY_VALIDATION_START: Final[str] = "prompt.policy.validation_start" +PROMPT_POLICY_QUALITY_ISSUE: Final[str] = "prompt.policy.quality_issue" +PROMPT_POLICY_VALIDATION_FAILED: Final[str] = "prompt.policy.validation_failed" +PROMPT_TOKEN_RATIO_HIGH: Final[str] = "prompt.token_ratio.high" # noqa: S105 — event name, not a credential diff --git a/tests/unit/engine/test_agent_engine.py b/tests/unit/engine/test_agent_engine.py index ee9bdd4365..6ff75b0839 100644 --- a/tests/unit/engine/test_agent_engine.py +++ b/tests/unit/engine/test_agent_engine.py @@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +import structlog.testing from ai_company.budget.coordination_config import ErrorTaxonomyConfig from ai_company.budget.tracker import CostTracker @@ -20,6 +21,7 @@ TurnRecord, ) from ai_company.engine.run_result import AgentRunResult +from ai_company.observability.events.prompt import PROMPT_TOKEN_RATIO_HIGH from ai_company.providers.enums import FinishReason if TYPE_CHECKING: @@ -355,8 +357,8 @@ async def execute( ) assert result.is_success is True - # System prompt should include tools section - assert "tools" in result.system_prompt.sections + # D22: tools section is no longer in the default template. + assert "tools" not in result.system_prompt.sections @pytest.mark.unit @@ -854,3 +856,78 @@ async def test_classification_memory_error_propagates( identity=sample_agent_with_personality, task=sample_task_with_criteria, ) + + +@pytest.mark.unit +class TestAgentEnginePromptTokenRatioWarning: + """High prompt-to-total token ratio emits PROMPT_TOKEN_RATIO_HIGH.""" + + @pytest.mark.parametrize( + ( + "prompt_tokens", + "input_tokens", + "output_tokens", + "cost_usd", + "expect_warning", + ), + [ + # prompt_tokens=200 out of 400 total → ratio 0.50 > 0.3 threshold. + (200, 300, 100, 0.01, True), + # prompt_tokens=50 out of 10000 total → ratio 0.005 < 0.3 threshold. + (50, 5000, 5000, 1.0, False), + ], + ids=["high_ratio", "low_ratio"], + ) + async def test_prompt_token_ratio_warning( # noqa: PLR0913 + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + *, + prompt_tokens: int, + input_tokens: int, + output_tokens: int, + cost_usd: float, + expect_warning: bool, + ) -> None: + """Warning emitted iff prompt tokens dominate total tokens. + + Injects a fixed ``estimated_tokens`` via mock to isolate the + threshold-check logic from the live prompt estimator. + """ + from ai_company.engine.prompt import SystemPrompt + + response = _make_completion_response( + input_tokens=input_tokens, + output_tokens=output_tokens, + cost_usd=cost_usd, + ) + provider = mock_provider_factory([response]) + engine = AgentEngine(provider=provider) + + fixed_prompt = SystemPrompt( + content="test", + template_version="test", + estimated_tokens=prompt_tokens, + sections=("identity",), + metadata={"agent_id": str(sample_agent_with_personality.id)}, + ) + + with ( + patch( + "ai_company.engine.agent_engine.build_system_prompt", + return_value=fixed_prompt, + ), + structlog.testing.capture_logs() as logs, + ): + await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + warning_events = [e for e in logs if e.get("event") == PROMPT_TOKEN_RATIO_HIGH] + if expect_warning: + assert len(warning_events) == 1 + assert "prompt_token_ratio" in warning_events[0] + else: + assert len(warning_events) == 0 diff --git a/tests/unit/engine/test_agent_engine_lifecycle.py b/tests/unit/engine/test_agent_engine_lifecycle.py index 8be682efc2..725f79662f 100644 --- a/tests/unit/engine/test_agent_engine_lifecycle.py +++ b/tests/unit/engine/test_agent_engine_lifecycle.py @@ -347,11 +347,18 @@ async def test_no_timeout_by_default( assert result.is_success is True - async def test_zero_timeout_raises( + @pytest.mark.parametrize( + "timeout_seconds", + [0, -1.0], + ids=["zero", "negative"], + ) + async def test_non_positive_timeout_raises( self, sample_agent_with_personality: AgentIdentity, sample_task_with_criteria: Task, mock_provider_factory: type[MockCompletionProvider], + *, + timeout_seconds: float, ) -> None: provider = mock_provider_factory([]) engine = AgentEngine(provider=provider) @@ -360,23 +367,7 @@ async def test_zero_timeout_raises( await engine.run( identity=sample_agent_with_personality, task=sample_task_with_criteria, - timeout_seconds=0, - ) - - async def test_negative_timeout_raises( - self, - sample_agent_with_personality: AgentIdentity, - sample_task_with_criteria: Task, - mock_provider_factory: type[MockCompletionProvider], - ) -> None: - provider = mock_provider_factory([]) - engine = AgentEngine(provider=provider) - - with pytest.raises(ValueError, match="timeout_seconds must be > 0"): - await engine.run( - identity=sample_agent_with_personality, - task=sample_task_with_criteria, - timeout_seconds=-1.0, + timeout_seconds=timeout_seconds, ) @@ -393,7 +384,11 @@ async def test_metrics_logged_on_completion( """Successful run computes and logs TaskCompletionMetrics.""" from ai_company.engine.metrics import TaskCompletionMetrics - response = _make_completion_response(cost_usd=0.05) + response = _make_completion_response( + input_tokens=400, + output_tokens=200, + cost_usd=0.05, + ) provider = mock_provider_factory([response]) engine = AgentEngine(provider=provider) @@ -410,6 +405,7 @@ async def test_metrics_logged_on_completion( assert metrics.duration_seconds > 0 assert metrics.agent_id == str(sample_agent_with_personality.id) assert metrics.task_id == sample_task_with_criteria.id + assert 0.0 <= metrics.prompt_token_ratio <= 1.0 @pytest.mark.unit diff --git a/tests/unit/engine/test_metrics.py b/tests/unit/engine/test_metrics.py index 7257f50c2d..f3490baf2e 100644 --- a/tests/unit/engine/test_metrics.py +++ b/tests/unit/engine/test_metrics.py @@ -28,6 +28,7 @@ def test_valid_construction(self) -> None: tokens_per_task=1500, cost_per_task=0.05, duration_seconds=12.5, + prompt_tokens=150, ) assert metrics.task_id == "task-001" assert metrics.agent_id == "agent-001" @@ -35,6 +36,54 @@ def test_valid_construction(self) -> None: assert metrics.tokens_per_task == 1500 assert metrics.cost_per_task == 0.05 assert metrics.duration_seconds == 12.5 + assert metrics.prompt_tokens == 150 + assert metrics.prompt_token_ratio == 0.1 + + def test_prompt_fields_default_to_zero(self) -> None: + metrics = TaskCompletionMetrics( + agent_id="agent-001", + turns_per_task=0, + tokens_per_task=0, + cost_per_task=0.0, + duration_seconds=0.0, + ) + assert metrics.prompt_tokens == 0 + assert metrics.prompt_token_ratio == 0.0 + + def test_negative_prompt_tokens_rejected(self) -> None: + with pytest.raises(ValidationError, match="prompt_tokens"): + TaskCompletionMetrics( + agent_id="agent-001", + turns_per_task=0, + tokens_per_task=0, + cost_per_task=0.0, + duration_seconds=0.0, + prompt_tokens=-1, + ) + + def test_prompt_token_ratio_is_computed(self) -> None: + """prompt_token_ratio is derived from prompt_tokens / tokens_per_task.""" + metrics = TaskCompletionMetrics( + agent_id="agent-001", + turns_per_task=1, + tokens_per_task=1000, + cost_per_task=0.01, + duration_seconds=1.0, + prompt_tokens=500, + ) + assert metrics.prompt_token_ratio == pytest.approx(0.5) + + def test_prompt_token_ratio_at_boundary(self) -> None: + """When prompt_tokens == tokens_per_task, ratio is 1.0.""" + metrics = TaskCompletionMetrics( + agent_id="agent-001", + turns_per_task=1, + tokens_per_task=100, + cost_per_task=0.01, + duration_seconds=1.0, + prompt_tokens=100, + ) + assert metrics.prompt_token_ratio == pytest.approx(1.0) def test_task_id_none(self) -> None: metrics = TaskCompletionMetrics( @@ -212,6 +261,10 @@ def test_from_run_result_extracts_values( assert metrics.tokens_per_task == 430 # 300 + 130 assert metrics.cost_per_task == 0.03 assert metrics.duration_seconds == 5.0 + # Prompt tokens come from the SystemPrompt estimated_tokens (10). + assert metrics.prompt_tokens == 10 + # 10 / 430 ≈ 0.0232... + assert 0.02 < metrics.prompt_token_ratio < 0.03 def test_from_run_result_zero_turns( self, @@ -223,3 +276,6 @@ def test_from_run_result_zero_turns( assert metrics.turns_per_task == 0 assert metrics.tokens_per_task == 0 assert metrics.cost_per_task == 0.0 + # Zero total tokens → 0.0 ratio (no divide-by-zero). + assert metrics.prompt_token_ratio == 0.0 + assert metrics.prompt_tokens == 10 diff --git a/tests/unit/engine/test_policy_validation.py b/tests/unit/engine/test_policy_validation.py new file mode 100644 index 0000000000..34dce5994c --- /dev/null +++ b/tests/unit/engine/test_policy_validation.py @@ -0,0 +1,211 @@ +"""Unit tests for org policy quality validation heuristics.""" + +import pytest +import structlog.testing +from pydantic import ValidationError + +from ai_company.engine.policy_validation import ( + PolicyQualityIssue, + validate_policy_quality, +) +from ai_company.observability.events.prompt import PROMPT_POLICY_QUALITY_ISSUE + +pytestmark = [pytest.mark.unit, pytest.mark.timeout(30)] + + +class TestPolicyQualityIssueModel: + """Tests for the PolicyQualityIssue frozen model.""" + + def test_valid_construction(self) -> None: + issue = PolicyQualityIssue( + policy="some policy", + issue="some issue", + severity="warning", + ) + assert issue.policy == "some policy" + assert issue.severity == "warning" + + def test_error_severity(self) -> None: + issue = PolicyQualityIssue( + policy="p", + issue="i", + severity="error", + ) + assert issue.severity == "error" + + +class TestGoodPolicies: + """Good policies should produce no issues.""" + + @pytest.mark.parametrize( + "policy", + [ + "All API responses must include a correlation_id header", + "Never store PII in memory without encryption", + "Escalate budget overruns above $5 to the CFO", + "Always validate user input before processing", + "Agents should prefer structured logging over print statements", + ], + ) + def test_good_policy_no_issues(self, policy: str) -> None: + result = validate_policy_quality((policy,)) + assert result == () + + +class TestTooShort: + """Policies shorter than 10 chars should produce a warning.""" + + @pytest.mark.parametrize("policy", ["x", "ab", "short"]) + def test_too_short_warning(self, policy: str) -> None: + result = validate_policy_quality((policy,)) + assert len(result) >= 1 + short_issues = [i for i in result if "Too short" in i.issue] + assert len(short_issues) == 1 + assert short_issues[0].severity == "warning" + + +class TestTooLong: + """Policies longer than 500 chars should produce a warning.""" + + def test_too_long_warning(self) -> None: + long_policy = "Agents must always " + "x" * 500 + result = validate_policy_quality((long_policy,)) + long_issues = [i for i in result if "Too long" in i.issue] + assert len(long_issues) == 1 + assert long_issues[0].severity == "warning" + + +class TestCodePatterns: + """Policies containing code patterns should produce a warning.""" + + @pytest.mark.parametrize( + "policy", + [ + "The file src/api/controllers.py contains endpoints", + "You should from os import path for file handling", + "Use import json to parse data", + "The def calculate_total(items) function handles pricing", + "The class UserService: handles authentication", + ], + ) + def test_code_pattern_warning(self, policy: str) -> None: + result = validate_policy_quality((policy,)) + code_issues = [i for i in result if "code patterns" in i.issue] + assert len(code_issues) == 1 + assert code_issues[0].severity == "warning" + + +class TestMissingActionVerbs: + """Policies without action verbs should produce a warning.""" + + @pytest.mark.parametrize( + "policy", + [ + "The project uses Python 3.14 for all services", + "Our database is PostgreSQL with replication", + "The codebase follows a hexagonal architecture pattern", + ], + ) + def test_missing_action_verb_warning(self, policy: str) -> None: + result = validate_policy_quality((policy,)) + verb_issues = [i for i in result if "action verbs" in i.issue] + assert len(verb_issues) == 1 + assert verb_issues[0].severity == "warning" + + +class TestEdgeCases: + """Edge cases for policy validation.""" + + def test_empty_tuple_returns_empty(self) -> None: + result = validate_policy_quality(()) + assert result == () + + def test_single_char_produces_two_issues(self) -> None: + """Single char is too short AND missing action verbs.""" + result = validate_policy_quality(("x",)) + assert len(result) >= 2 + + def test_multiple_policies(self) -> None: + """Validates all policies independently.""" + policies = ( + "All API responses must include correlation_id", + "x", + ) + result = validate_policy_quality(policies) + # First is good, second produces issues. + bad_issues = [i for i in result if i.policy == "x"] + assert len(bad_issues) >= 1 + + def test_logging_emits_events(self) -> None: + """Each issue logs a PROMPT_POLICY_QUALITY_ISSUE event.""" + with structlog.testing.capture_logs() as logs: + validate_policy_quality(("x",)) + + events = [e for e in logs if e["event"] == PROMPT_POLICY_QUALITY_ISSUE] + assert len(events) >= 1 + + def test_multiple_code_patterns_produce_one_issue(self) -> None: + """A policy with multiple code patterns produces exactly one issue.""" + policy = ( + "The file src/api/views.py has from os import path " + "and also import json for parsing" + ) + result = validate_policy_quality((policy,)) + code_issues = [i for i in result if "code patterns" in i.issue] + assert len(code_issues) == 1 + + def test_action_verb_word_boundary(self) -> None: + """Noun forms of action verbs (e.g. 'requirement') don't match.""" + policy = ( + "The database has a strict isolation requirement and a validation framework" + ) + result = validate_policy_quality((policy,)) + verb_issues = [i for i in result if "action verbs" in i.issue] + assert len(verb_issues) == 1 + + def test_frozen_enforcement(self) -> None: + """PolicyQualityIssue is immutable after construction.""" + issue = PolicyQualityIssue( + policy="some policy", + issue="some issue", + severity="warning", + ) + with pytest.raises(ValidationError): + issue.policy = "changed" # type: ignore[misc] + + def test_invalid_severity_rejected(self) -> None: + """Non-Literal severity values are rejected.""" + with pytest.raises(ValidationError, match="severity"): + PolicyQualityIssue( + policy="some policy", + issue="some issue", + severity="info", # type: ignore[arg-type] + ) + + @pytest.mark.parametrize( + ("policy", "expect_short"), + [ + ("123456789", True), # 9 chars — below _MIN_POLICY_LENGTH (10) + ("1234567890", False), # 10 chars — exactly at boundary + ], + ) + def test_min_length_boundary(self, policy: str, *, expect_short: bool) -> None: + """9 chars triggers 'Too short', 10 chars does not.""" + result = validate_policy_quality((policy,)) + short_issues = [i for i in result if "Too short" in i.issue] + assert len(short_issues) == (1 if expect_short else 0) + + def test_max_length_boundary(self) -> None: + """500 chars is OK, 501 triggers 'Too long'.""" + # 500-char policy with an action verb (no length warning expected). + at_limit = "Agents must always " + "x" * (500 - len("Agents must always ")) + assert len(at_limit) == 500 + result_ok = validate_policy_quality((at_limit,)) + long_issues_ok = [i for i in result_ok if "Too long" in i.issue] + assert len(long_issues_ok) == 0 + + over_limit = at_limit + "y" + assert len(over_limit) == 501 + result_bad = validate_policy_quality((over_limit,)) + long_issues_bad = [i for i in result_bad if "Too long" in i.issue] + assert len(long_issues_bad) == 1 diff --git a/tests/unit/engine/test_prompt.py b/tests/unit/engine/test_prompt.py index 05cc162d56..2ec243d875 100644 --- a/tests/unit/engine/test_prompt.py +++ b/tests/unit/engine/test_prompt.py @@ -182,17 +182,44 @@ def test_company_context_injected( assert dept.name in result.content @pytest.mark.unit - def test_tool_availability_in_prompt( + def test_tools_not_in_default_template( self, sample_agent_with_personality: AgentIdentity, sample_tool_definitions: tuple[ToolDefinition, ...], ) -> None: - """Tool names and descriptions appear in prompt.""" + """Tools passed to build_system_prompt don't appear (D22).""" result = build_system_prompt( agent=sample_agent_with_personality, available_tools=sample_tool_definitions, ) + assert "Available Tools" not in result.content + for tool in sample_tool_definitions: + assert tool.name not in result.content + assert "tools" not in result.sections + + @pytest.mark.unit + def test_tools_render_in_custom_template( + self, + sample_agent_with_personality: AgentIdentity, + sample_tool_definitions: tuple[ToolDefinition, ...], + ) -> None: + """Custom templates with {% if tools %} still render tools.""" + custom = ( + "Agent: {{ agent_name }}\n" + "{% if tools %}\n" + "Tools:\n" + "{% for tool in tools %}" + "- {{ tool.name }}: {{ tool.description }}\n" + "{% endfor %}" + "{% endif %}" + ) + result = build_system_prompt( + agent=sample_agent_with_personality, + available_tools=sample_tool_definitions, + custom_template=custom, + ) + for tool in sample_tool_definitions: assert tool.name in result.content assert tool.description in result.content @@ -277,11 +304,11 @@ def test_no_task_section_when_task_is_none( assert "task" not in result.sections @pytest.mark.unit - def test_no_tools_section_when_no_tools( + def test_no_tools_section_in_default_template( self, sample_agent_with_personality: AgentIdentity, ) -> None: - """No 'Available Tools' section when no tools are provided.""" + """Default template never includes 'Available Tools' section (D22).""" result = build_system_prompt(agent=sample_agent_with_personality) assert "Available Tools" not in result.content @@ -402,7 +429,6 @@ def test_max_tokens_triggers_trimming( self, sample_agent_with_personality: AgentIdentity, sample_task_with_criteria: Task, - sample_tool_definitions: tuple[ToolDefinition, ...], sample_company: Company, ) -> None: """Very low max_tokens causes optional sections to be removed.""" @@ -410,25 +436,21 @@ def test_max_tokens_triggers_trimming( full = build_system_prompt( agent=sample_agent_with_personality, task=sample_task_with_criteria, - available_tools=sample_tool_definitions, company=sample_company, ) assert "task" in full.sections - assert "tools" in full.sections assert "company" in full.sections # Now build with a tight token budget to force trimming. trimmed = build_system_prompt( agent=sample_agent_with_personality, task=sample_task_with_criteria, - available_tools=sample_tool_definitions, company=sample_company, max_tokens=10, ) # All optional sections should be removed. assert "company" not in trimmed.sections - assert "tools" not in trimmed.sections assert "task" not in trimmed.sections # Core sections remain. assert "identity" in trimmed.sections @@ -457,6 +479,56 @@ def estimate_tokens(self, text: str) -> int: assert result.estimated_tokens > 0 +# ── TestPolicyValidationIntegration ────────────────────────────── + + +class TestPolicyValidationIntegration: + """Tests for policy validation integration in build_system_prompt.""" + + @pytest.mark.unit + def test_policy_validation_error_does_not_block_prompt( + self, + sample_agent_with_personality: AgentIdentity, + ) -> None: + """When validate_policy_quality raises, prompt is still built.""" + from unittest.mock import patch + + with patch( + "ai_company.engine.prompt.validate_policy_quality", + side_effect=RuntimeError("boom"), + ): + result = build_system_prompt( + agent=sample_agent_with_personality, + org_policies=("All responses must include correlation_id",), + ) + + # Prompt is still built despite validation failure. + assert result.content + assert "org_policies" in result.sections + + @pytest.mark.unit + @pytest.mark.parametrize( + "policies", + [ + ("valid policy must exist", ""), + (" ",), + ], + ids=["empty_string", "whitespace_only"], + ) + def test_invalid_org_policy_raises( + self, + sample_agent_with_personality: AgentIdentity, + *, + policies: tuple[str, ...], + ) -> None: + """Empty or whitespace-only policy is rejected with PromptBuildError.""" + with pytest.raises(PromptBuildError, match="org_policies"): + build_system_prompt( + agent=sample_agent_with_personality, + org_policies=policies, + ) + + # ── TestPromptVersioning ───────────────────────────────────────── @@ -464,9 +536,9 @@ class TestPromptVersioning: """Tests for prompt versioning and section tracking.""" @pytest.mark.unit - def test_template_version_is_1_2_0(self) -> None: - """PROMPT_TEMPLATE_VERSION is '1.2.0'.""" - assert PROMPT_TEMPLATE_VERSION == "1.2.0" + def test_template_version_is_1_3_0(self) -> None: + """PROMPT_TEMPLATE_VERSION is '1.3.0' (D22 tools removal).""" + assert PROMPT_TEMPLATE_VERSION == "1.3.0" @pytest.mark.unit def test_template_version_in_result( @@ -485,7 +557,7 @@ def test_sections_tracked( sample_tool_definitions: tuple[ToolDefinition, ...], sample_company: Company, ) -> None: - """Sections tuple lists all included sections.""" + """Sections tuple lists all included sections (tools excluded per D22).""" result = build_system_prompt( agent=sample_agent_with_personality, task=sample_task_with_criteria, @@ -499,7 +571,7 @@ def test_sections_tracked( assert "authority" in result.sections assert "autonomy" in result.sections assert "task" in result.sections - assert "tools" in result.sections + assert "tools" not in result.sections assert "company" in result.sections @@ -633,11 +705,10 @@ class TestTrimmingPriority: """Tests for section trimming priority order.""" @pytest.mark.unit - def test_company_trimmed_before_tools_and_task( + def test_company_trimmed_before_task( self, sample_agent_with_personality: AgentIdentity, sample_task_with_criteria: Task, - sample_tool_definitions: tuple[ToolDefinition, ...], sample_company: Company, ) -> None: """With a moderately tight budget, only company is trimmed first.""" @@ -645,7 +716,6 @@ def test_company_trimmed_before_tools_and_task( full = build_system_prompt( agent=sample_agent_with_personality, task=sample_task_with_criteria, - available_tools=sample_tool_definitions, company=sample_company, ) assert "company" in full.sections @@ -655,7 +725,6 @@ def test_company_trimmed_before_tools_and_task( without_company = build_system_prompt( agent=sample_agent_with_personality, task=sample_task_with_criteria, - available_tools=sample_tool_definitions, ) # Set max_tokens between without-company and full sizes. @@ -663,49 +732,46 @@ def test_company_trimmed_before_tools_and_task( trimmed = build_system_prompt( agent=sample_agent_with_personality, task=sample_task_with_criteria, - available_tools=sample_tool_definitions, company=sample_company, max_tokens=budget, ) - # Company should be trimmed but tools and task remain. + # Company should be trimmed but task remains. assert "company" not in trimmed.sections - assert "tools" in trimmed.sections assert "task" in trimmed.sections @pytest.mark.unit - def test_tools_trimmed_before_task( + def test_trimming_order_without_tools( self, sample_agent_with_personality: AgentIdentity, sample_task_with_criteria: Task, - sample_tool_definitions: tuple[ToolDefinition, ...], sample_company: Company, ) -> None: - """With a tighter budget, company and tools are trimmed but task remains.""" - # Build with only task to find core + task size. - with_task = build_system_prompt( - agent=sample_agent_with_personality, - task=sample_task_with_criteria, - ) - # Build with task + tools to find core + task + tools size. - with_tools = build_system_prompt( + """Trimming order is company → task → org_policies (no tools section).""" + # Build with company + task + org_policies. + org_policies = ("All responses must include correlation_id",) + full = build_system_prompt( agent=sample_agent_with_personality, task=sample_task_with_criteria, - available_tools=sample_tool_definitions, + company=sample_company, + org_policies=org_policies, ) + assert "company" in full.sections + assert "task" in full.sections + assert "org_policies" in full.sections - budget = (with_task.estimated_tokens + with_tools.estimated_tokens) // 2 + # With very tight budget, all optional sections are removed. trimmed = build_system_prompt( agent=sample_agent_with_personality, task=sample_task_with_criteria, - available_tools=sample_tool_definitions, company=sample_company, - max_tokens=budget, + org_policies=org_policies, + max_tokens=10, ) - assert "company" not in trimmed.sections - assert "tools" not in trimmed.sections - assert "task" in trimmed.sections + assert "task" not in trimmed.sections + assert "org_policies" not in trimmed.sections + assert "identity" in trimmed.sections # ── TestDefaultAgentPrompt ───────────────────────────────────── diff --git a/tests/unit/memory/org/test_prompt_integration.py b/tests/unit/memory/org/test_prompt_integration.py index a2eec55a90..6f86226fde 100644 --- a/tests/unit/memory/org/test_prompt_integration.py +++ b/tests/unit/memory/org/test_prompt_integration.py @@ -49,7 +49,7 @@ def test_policies_rendered_in_prompt(self) -> None: assert "org_policies" in result.sections def test_template_version_updated(self) -> None: - assert PROMPT_TEMPLATE_VERSION == "1.2.0" + assert PROMPT_TEMPLATE_VERSION == "1.3.0" def test_policies_trimmed_under_budget(self) -> None: agent = _make_agent() diff --git a/tests/unit/memory/test_filter.py b/tests/unit/memory/test_filter.py new file mode 100644 index 0000000000..b27f836bef --- /dev/null +++ b/tests/unit/memory/test_filter.py @@ -0,0 +1,154 @@ +"""Unit tests for memory filter strategies.""" + +from datetime import UTC, datetime + +import pytest + +from ai_company.core.enums import MemoryCategory +from ai_company.memory.filter import ( + NON_INFERABLE_TAG, + MemoryFilterStrategy, + PassthroughMemoryFilter, + TagBasedMemoryFilter, +) +from ai_company.memory.models import MemoryEntry, MemoryMetadata +from ai_company.memory.ranking import ScoredMemory + +pytestmark = pytest.mark.timeout(30) + + +def _make_scored_memory( + *, + entry_id: str = "mem-1", + tags: tuple[str, ...] = (), + content: str = "test memory", + combined_score: float = 0.8, +) -> ScoredMemory: + """Build a ScoredMemory with specified tags.""" + entry = MemoryEntry( + id=entry_id, + agent_id="agent-1", + category=MemoryCategory.EPISODIC, + content=content, + metadata=MemoryMetadata(tags=tags), + created_at=datetime.now(UTC), + relevance_score=0.8, + ) + return ScoredMemory( + entry=entry, + relevance_score=0.8, + recency_score=0.9, + combined_score=combined_score, + ) + + +# ── Protocol compliance ────────────────────────────────────────── + + +@pytest.mark.unit +class TestProtocolCompliance: + """Both filters satisfy the MemoryFilterStrategy protocol.""" + + def test_tag_based_satisfies_protocol(self) -> None: + assert isinstance(TagBasedMemoryFilter(), MemoryFilterStrategy) + + def test_passthrough_satisfies_protocol(self) -> None: + assert isinstance(PassthroughMemoryFilter(), MemoryFilterStrategy) + + +# ── TagBasedMemoryFilter ────────────────────────────────────────── + + +@pytest.mark.unit +class TestTagBasedMemoryFilter: + """Tests for the tag-based memory filter.""" + + def test_retains_tagged_memories(self) -> None: + tagged = _make_scored_memory( + entry_id="m1", + tags=(NON_INFERABLE_TAG,), + ) + untagged = _make_scored_memory(entry_id="m2", tags=()) + filt = TagBasedMemoryFilter() + + result = filt.filter_for_injection((tagged, untagged)) + + assert len(result) == 1 + assert result[0].entry.id == "m1" + + def test_excludes_all_untagged(self) -> None: + untagged = _make_scored_memory(tags=("other-tag",)) + filt = TagBasedMemoryFilter() + + result = filt.filter_for_injection((untagged,)) + + assert result == () + + def test_retains_all_tagged(self) -> None: + m1 = _make_scored_memory( + entry_id="m1", + tags=(NON_INFERABLE_TAG, "extra"), + ) + m2 = _make_scored_memory( + entry_id="m2", + tags=(NON_INFERABLE_TAG,), + ) + filt = TagBasedMemoryFilter() + + result = filt.filter_for_injection((m1, m2)) + + assert len(result) == 2 + + def test_custom_required_tag(self) -> None: + memory = _make_scored_memory(tags=("custom-tag",)) + filt = TagBasedMemoryFilter(required_tag="custom-tag") + + result = filt.filter_for_injection((memory,)) + + assert len(result) == 1 + + def test_empty_input_returns_empty(self) -> None: + filt = TagBasedMemoryFilter() + result = filt.filter_for_injection(()) + assert result == () + + def test_strategy_name(self) -> None: + assert TagBasedMemoryFilter().strategy_name == "tag_based" + + +# ── PassthroughMemoryFilter ────────────────────────────────────── + + +@pytest.mark.unit +class TestPassthroughMemoryFilter: + """Tests for the passthrough (no-op) memory filter.""" + + def test_returns_all_unchanged(self) -> None: + m1 = _make_scored_memory(entry_id="m1") + m2 = _make_scored_memory(entry_id="m2") + filt = PassthroughMemoryFilter() + + result = filt.filter_for_injection((m1, m2)) + + assert len(result) == 2 + assert result[0].entry.id == "m1" + assert result[1].entry.id == "m2" + + def test_empty_input_returns_empty(self) -> None: + filt = PassthroughMemoryFilter() + result = filt.filter_for_injection(()) + assert result == () + + def test_strategy_name(self) -> None: + assert PassthroughMemoryFilter().strategy_name == "passthrough" + + +# ── NON_INFERABLE_TAG constant ─────────────────────────────────── + + +@pytest.mark.unit +class TestNonInferableTag: + """Tests for the NON_INFERABLE_TAG constant.""" + + def test_value(self) -> None: + assert NON_INFERABLE_TAG == "non-inferable" diff --git a/tests/unit/memory/test_retriever.py b/tests/unit/memory/test_retriever.py index 67cfc2dcf4..96b8a9c888 100644 --- a/tests/unit/memory/test_retriever.py +++ b/tests/unit/memory/test_retriever.py @@ -1,12 +1,18 @@ """Tests for ContextInjectionStrategy (retriever pipeline).""" from datetime import UTC, datetime +from typing import TYPE_CHECKING from unittest.mock import AsyncMock import pytest from ai_company.core.enums import MemoryCategory from ai_company.memory.errors import MemoryRetrievalError +from ai_company.memory.filter import ( + NON_INFERABLE_TAG, + PassthroughMemoryFilter, + TagBasedMemoryFilter, +) from ai_company.memory.formatter import MEMORY_BLOCK_START from ai_company.memory.injection import ( DefaultTokenEstimator, @@ -17,6 +23,9 @@ from ai_company.memory.retriever import ContextInjectionStrategy from ai_company.providers.enums import MessageRole +if TYPE_CHECKING: + from ai_company.memory.ranking import ScoredMemory + pytestmark = pytest.mark.timeout(30) @@ -420,3 +429,189 @@ async def test_zero_budget_returns_empty(self) -> None: token_budget=0, ) assert result == () + + +# ── Memory filter integration ──────────────────────────────────── + + +@pytest.mark.unit +class TestMemoryFilterIntegration: + async def test_filter_applied_after_ranking(self) -> None: + """TagBasedMemoryFilter excludes untagged memories after ranking.""" + tagged = _make_entry( + entry_id="tagged", + content="tagged memory", + relevance_score=0.9, + ) + # Manually set the non-inferable tag on metadata. + tagged = tagged.model_copy( + update={ + "metadata": MemoryMetadata(tags=(NON_INFERABLE_TAG,)), + }, + ) + untagged = _make_entry( + entry_id="untagged", + content="untagged memory", + relevance_score=0.9, + ) + strategy = ContextInjectionStrategy( + backend=_make_backend((tagged, untagged)), + config=MemoryRetrievalConfig(min_relevance=0.0), + memory_filter=TagBasedMemoryFilter(), + ) + result = await strategy.prepare_messages( + agent_id="agent-1", + query_text="query", + token_budget=5000, + ) + assert len(result) == 1 + content = result[0].content + assert content is not None + assert "tagged memory" in content + assert "untagged memory" not in content + + async def test_filter_skipped_when_none(self) -> None: + """When memory_filter is None, all ranked memories are injected.""" + entry = _make_entry(content="all memories pass") + strategy = ContextInjectionStrategy( + backend=_make_backend((entry,)), + config=MemoryRetrievalConfig( + min_relevance=0.0, + non_inferable_only=False, + ), + memory_filter=None, + ) + result = await strategy.prepare_messages( + agent_id="agent-1", + query_text="query", + token_budget=5000, + ) + assert len(result) == 1 + content = result[0].content + assert content is not None + assert "all memories pass" in content + + async def test_filter_reduces_output(self) -> None: + """Filter that excludes everything returns empty result.""" + entry = _make_entry(content="will be filtered out") + strategy = ContextInjectionStrategy( + backend=_make_backend((entry,)), + config=MemoryRetrievalConfig(min_relevance=0.0), + memory_filter=TagBasedMemoryFilter(), + ) + result = await strategy.prepare_messages( + agent_id="agent-1", + query_text="query", + token_budget=5000, + ) + assert result == () + + async def test_passthrough_filter_keeps_all(self) -> None: + """PassthroughMemoryFilter returns all memories unchanged.""" + entry = _make_entry(content="passthrough content") + strategy = ContextInjectionStrategy( + backend=_make_backend((entry,)), + config=MemoryRetrievalConfig(min_relevance=0.0), + memory_filter=PassthroughMemoryFilter(), + ) + result = await strategy.prepare_messages( + agent_id="agent-1", + query_text="query", + token_budget=5000, + ) + assert len(result) == 1 + content = result[0].content + assert content is not None + assert "passthrough content" in content + + async def test_non_inferable_only_config_creates_filter(self) -> None: + """non_inferable_only=True auto-creates TagBasedMemoryFilter.""" + tagged = _make_entry( + entry_id="tagged", + content="tagged memory", + relevance_score=0.9, + ) + tagged = tagged.model_copy( + update={"metadata": MemoryMetadata(tags=(NON_INFERABLE_TAG,))}, + ) + untagged = _make_entry( + entry_id="untagged", + content="untagged memory", + relevance_score=0.9, + ) + strategy = ContextInjectionStrategy( + backend=_make_backend((tagged, untagged)), + config=MemoryRetrievalConfig( + min_relevance=0.0, + non_inferable_only=True, + ), + ) + result = await strategy.prepare_messages( + agent_id="agent-1", + query_text="query", + token_budget=5000, + ) + assert len(result) == 1 + content = result[0].content + assert content is not None + assert "tagged memory" in content + assert "untagged memory" not in content + + async def test_filter_graceful_degradation(self) -> None: + """Filter error falls back to unfiltered ranked memories.""" + + class _BrokenFilter: + def filter_for_injection( + self, + memories: tuple[ScoredMemory, ...], + ) -> tuple[ScoredMemory, ...]: + msg = "filter exploded" + raise RuntimeError(msg) + + @property + def strategy_name(self) -> str: + return "broken" + + entry = _make_entry(content="survives filter error") + strategy = ContextInjectionStrategy( + backend=_make_backend((entry,)), + config=MemoryRetrievalConfig(min_relevance=0.0), + memory_filter=_BrokenFilter(), + ) + result = await strategy.prepare_messages( + agent_id="agent-1", + query_text="query", + token_budget=5000, + ) + # Graceful degradation: unfiltered memories are still returned. + assert len(result) == 1 + content = result[0].content + assert content is not None + assert "survives filter error" in content + + async def test_filter_memory_error_propagates(self) -> None: + """MemoryError through the filter path is re-raised.""" + + class _MemoryErrorFilter: + def filter_for_injection( + self, + memories: tuple[ScoredMemory, ...], + ) -> tuple[ScoredMemory, ...]: + raise MemoryError + + @property + def strategy_name(self) -> str: + return "oom" + + entry = _make_entry(content="oom test") + strategy = ContextInjectionStrategy( + backend=_make_backend((entry,)), + config=MemoryRetrievalConfig(min_relevance=0.0), + memory_filter=_MemoryErrorFilter(), + ) + with pytest.raises(MemoryError): + await strategy.prepare_messages( + agent_id="agent-1", + query_text="query", + token_budget=5000, + ) diff --git a/tests/unit/memory/test_store_guard.py b/tests/unit/memory/test_store_guard.py new file mode 100644 index 0000000000..d25b8b860d --- /dev/null +++ b/tests/unit/memory/test_store_guard.py @@ -0,0 +1,72 @@ +"""Unit tests for memory store guard (non-inferable tag validation).""" + +import pytest +import structlog.testing + +from ai_company.core.enums import MemoryCategory +from ai_company.memory.filter import NON_INFERABLE_TAG +from ai_company.memory.models import MemoryMetadata, MemoryStoreRequest +from ai_company.memory.store_guard import validate_memory_tags +from ai_company.observability.events.memory import ( + MEMORY_FILTER_STORE_MISSING_TAG, +) + +pytestmark = pytest.mark.timeout(30) + + +@pytest.mark.unit +class TestValidateMemoryTags: + """Tests for the validate_memory_tags advisory guard.""" + + def test_tagged_request_no_warning(self) -> None: + """Request with non-inferable tag produces no warning.""" + request = MemoryStoreRequest( + category=MemoryCategory.EPISODIC, + content="important fact", + metadata=MemoryMetadata(tags=(NON_INFERABLE_TAG,)), + ) + with structlog.testing.capture_logs() as logs: + validate_memory_tags(request) + + warning_events = [ + e for e in logs if e["event"] == MEMORY_FILTER_STORE_MISSING_TAG + ] + assert len(warning_events) == 0 + + def test_untagged_request_logs_warning(self) -> None: + """Request without non-inferable tag logs a warning.""" + request = MemoryStoreRequest( + category=MemoryCategory.SEMANTIC, + content="some knowledge", + ) + with structlog.testing.capture_logs() as logs: + validate_memory_tags(request) + + warning_events = [ + e for e in logs if e["event"] == MEMORY_FILTER_STORE_MISSING_TAG + ] + assert len(warning_events) == 1 + + def test_other_tags_still_warns(self) -> None: + """Request with other tags but missing non-inferable tag warns.""" + request = MemoryStoreRequest( + category=MemoryCategory.PROCEDURAL, + content="how to deploy", + metadata=MemoryMetadata(tags=("deployment", "ops")), + ) + with structlog.testing.capture_logs() as logs: + validate_memory_tags(request) + + warning_events = [ + e for e in logs if e["event"] == MEMORY_FILTER_STORE_MISSING_TAG + ] + assert len(warning_events) == 1 + + def test_store_never_blocked(self) -> None: + """validate_memory_tags never raises — advisory only.""" + request = MemoryStoreRequest( + category=MemoryCategory.EPISODIC, + content="any content", + ) + # Should complete without exception. + validate_memory_tags(request)