From 4c4f3a0f7e9f36f1a08797d2234dc04588882e78 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Sat, 28 Feb 2026 20:18:35 +0100 Subject: [PATCH 1/3] feat: implement company template system with 7 built-in presets (#60) Add a two-pass template rendering pipeline that bootstraps company configurations from pre-built YAML templates with Jinja2 variable substitution. Pass 1 extracts metadata and variables from plain YAML, Pass 2 renders the full template through a SandboxedEnvironment then validates as RootConfig. - Extract deep_merge to config/utils.py as shared utility - Add template error hierarchy (TemplateError, NotFound, Render, Validation) - Add Pydantic schema models (CompanyTemplate, TemplateVariable, etc.) - Add 4 minimal personality presets and auto-name generation - Add template loader with user dir (~/.ai-company/templates/) override - Add Jinja2 renderer pipeline with variable collection and agent expansion - Ship 7 built-in templates: solo_founder, startup, dev_shop, product_team, agency, full_company, research_lab - Add jinja2==3.1.6 explicit dependency - Exclude Jinja2 template files from check-yaml pre-commit hook - 88 new unit tests, 97% overall coverage Closes #60 --- .pre-commit-config.yaml | 1 + pyproject.toml | 1 + src/ai_company/config/loader.py | 35 +- src/ai_company/config/utils.py | 31 ++ src/ai_company/templates/__init__.py | 64 +++ src/ai_company/templates/builtins/__init__.py | 0 src/ai_company/templates/builtins/agency.yaml | 52 +++ .../templates/builtins/dev_shop.yaml | 63 +++ .../templates/builtins/full_company.yaml | 68 +++ .../templates/builtins/product_team.yaml | 57 +++ .../templates/builtins/research_lab.yaml | 50 +++ .../templates/builtins/solo_founder.yaml | 46 ++ .../templates/builtins/startup.yaml | 63 +++ src/ai_company/templates/errors.py | 33 ++ src/ai_company/templates/loader.py | 349 ++++++++++++++++ src/ai_company/templates/presets.py | 103 +++++ src/ai_company/templates/renderer.py | 392 ++++++++++++++++++ src/ai_company/templates/schema.py | 237 +++++++++++ tests/unit/config/test_loader.py | 24 +- tests/unit/config/test_utils.py | 68 +++ tests/unit/templates/__init__.py | 0 tests/unit/templates/conftest.py | 131 ++++++ tests/unit/templates/test_loader.py | 201 +++++++++ tests/unit/templates/test_presets.py | 74 ++++ tests/unit/templates/test_renderer.py | 191 +++++++++ tests/unit/templates/test_schema.py | 255 ++++++++++++ uv.lock | 2 + 27 files changed, 2548 insertions(+), 43 deletions(-) create mode 100644 src/ai_company/config/utils.py create mode 100644 src/ai_company/templates/builtins/__init__.py create mode 100644 src/ai_company/templates/builtins/agency.yaml create mode 100644 src/ai_company/templates/builtins/dev_shop.yaml create mode 100644 src/ai_company/templates/builtins/full_company.yaml create mode 100644 src/ai_company/templates/builtins/product_team.yaml create mode 100644 src/ai_company/templates/builtins/research_lab.yaml create mode 100644 src/ai_company/templates/builtins/solo_founder.yaml create mode 100644 src/ai_company/templates/builtins/startup.yaml create mode 100644 src/ai_company/templates/errors.py create mode 100644 src/ai_company/templates/loader.py create mode 100644 src/ai_company/templates/presets.py create mode 100644 src/ai_company/templates/renderer.py create mode 100644 src/ai_company/templates/schema.py create mode 100644 tests/unit/config/test_utils.py create mode 100644 tests/unit/templates/__init__.py create mode 100644 tests/unit/templates/conftest.py create mode 100644 tests/unit/templates/test_loader.py create mode 100644 tests/unit/templates/test_presets.py create mode 100644 tests/unit/templates/test_renderer.py create mode 100644 tests/unit/templates/test_schema.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a1bf19803e..a45bb75f8a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,6 +14,7 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer - id: check-yaml + exclude: ^src/ai_company/templates/builtins/ - id: check-toml - id: check-json - id: check-merge-conflict diff --git a/pyproject.toml b/pyproject.toml index ace4b2bc4f..a4e1c4c9f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ classifiers = [ "Typing :: Typed", ] dependencies = [ + "jinja2==3.1.6", "pydantic==2.12.5", "pyyaml==6.0.2", "structlog==25.5.0", diff --git a/src/ai_company/config/loader.py b/src/ai_company/config/loader.py index d0da63612c..32bd49cf55 100644 --- a/src/ai_company/config/loader.py +++ b/src/ai_company/config/loader.py @@ -1,6 +1,5 @@ """YAML configuration loader with layered merging and validation.""" -import copy import logging import os import re @@ -18,6 +17,7 @@ ConfigValidationError, ) from ai_company.config.schema import RootConfig +from ai_company.config.utils import deep_merge logger = logging.getLogger(__name__) @@ -35,33 +35,6 @@ # --------------------------------------------------------------------------- -def _deep_merge( - base: dict[str, Any], - override: dict[str, Any], -) -> dict[str, Any]: - """Recursively merge *override* into *base*, returning a new dict. - - Nested dicts are merged recursively. Lists, scalars, and all other - types in *override* replace the corresponding value in *base* - entirely. Keys present only in *base* are preserved unchanged in - the result. Neither input dict is mutated. - - Args: - base: Base configuration dict. - override: Override values to layer on top. - - Returns: - A new merged dict. - """ - result = copy.deepcopy(base) - for key, value in override.items(): - if key in result and isinstance(result[key], dict) and isinstance(value, dict): - result[key] = _deep_merge(result[key], value) - else: - result[key] = copy.deepcopy(value) - return result - - def _read_config_text(file_path: Path) -> str: """Read a configuration file as UTF-8 text. @@ -446,12 +419,12 @@ def load_config( # and line-map construction) yaml_text = _read_config_text(config_path) primary = _parse_yaml_string(yaml_text, str(config_path)) - merged = _deep_merge(merged, primary) + merged = deep_merge(merged, primary) # 3. Apply override layers for override_path in override_paths: override = _parse_yaml_file(Path(override_path)) - merged = _deep_merge(merged, override) + merged = deep_merge(merged, override) # 4. Substitute environment variables on the fully merged config. # Use a neutral label so env-var errors aren't misattributed solely @@ -491,7 +464,7 @@ def load_config_from_string( ConfigValidationError: If the merged config fails validation. """ data = _parse_yaml_string(yaml_string, source_name) - merged = _deep_merge(default_config_dict(), data) + merged = deep_merge(default_config_dict(), data) merged = _substitute_env_vars(merged, source_file=source_name) line_map = _build_line_map(yaml_string) return _validate_config_dict( diff --git a/src/ai_company/config/utils.py b/src/ai_company/config/utils.py new file mode 100644 index 0000000000..df9d1185f3 --- /dev/null +++ b/src/ai_company/config/utils.py @@ -0,0 +1,31 @@ +"""Shared configuration utilities.""" + +import copy +from typing import Any + + +def deep_merge( + base: dict[str, Any], + override: dict[str, Any], +) -> dict[str, Any]: + """Recursively merge *override* into *base*, returning a new dict. + + Nested dicts are merged recursively. Lists, scalars, and all other + types in *override* replace the corresponding value in *base* + entirely. Keys present only in *base* are preserved unchanged in + the result. Neither input dict is mutated. + + Args: + base: Base configuration dict. + override: Override values to layer on top. + + Returns: + A new merged dict. + """ + result = copy.deepcopy(base) + for key, value in override.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = deep_merge(result[key], value) + else: + result[key] = copy.deepcopy(value) + return result diff --git a/src/ai_company/templates/__init__.py b/src/ai_company/templates/__init__.py index e69de29bb2..f47b6e9ab6 100644 --- a/src/ai_company/templates/__init__.py +++ b/src/ai_company/templates/__init__.py @@ -0,0 +1,64 @@ +"""Company templates: built-in presets and custom template loading. + +Public API +---------- +.. autosummary:: + load_template + load_template_file + list_templates + list_builtin_templates + render_template + CompanyTemplate + LoadedTemplate + TemplateInfo + TemplateMetadata + TemplateVariable + TemplateAgentConfig + TemplateDepartmentConfig + TemplateError + TemplateNotFoundError + TemplateRenderError + TemplateValidationError +""" + +from ai_company.templates.errors import ( + TemplateError, + TemplateNotFoundError, + TemplateRenderError, + TemplateValidationError, +) +from ai_company.templates.loader import ( + LoadedTemplate, + TemplateInfo, + list_builtin_templates, + list_templates, + load_template, + load_template_file, +) +from ai_company.templates.renderer import render_template +from ai_company.templates.schema import ( + CompanyTemplate, + TemplateAgentConfig, + TemplateDepartmentConfig, + TemplateMetadata, + TemplateVariable, +) + +__all__ = [ + "CompanyTemplate", + "LoadedTemplate", + "TemplateAgentConfig", + "TemplateDepartmentConfig", + "TemplateError", + "TemplateInfo", + "TemplateMetadata", + "TemplateNotFoundError", + "TemplateRenderError", + "TemplateValidationError", + "TemplateVariable", + "list_builtin_templates", + "list_templates", + "load_template", + "load_template_file", + "render_template", +] diff --git a/src/ai_company/templates/builtins/__init__.py b/src/ai_company/templates/builtins/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/ai_company/templates/builtins/agency.yaml b/src/ai_company/templates/builtins/agency.yaml new file mode 100644 index 0000000000..b3501c7a97 --- /dev/null +++ b/src/ai_company/templates/builtins/agency.yaml @@ -0,0 +1,52 @@ +template: + name: "Agency" + description: "Client-focused agency with project management and creative roles" + version: "1.0.0" + tags: + - "agency" + - "client-work" + - "creative" + + variables: + - name: "company_name" + description: "Name of your agency" + default: "Creative Agency" + - name: "budget" + description: "Monthly budget in USD" + var_type: "float" + default: 100.0 + + company: + type: "agency" + budget_monthly: {{ budget | default(100.0) }} + autonomy: 0.5 + + departments: + - name: "operations" + budget_percent: 30 + head_role: "Project Manager" + - name: "engineering" + budget_percent: 40 + head_role: "Full-Stack Developer" + - name: "design" + budget_percent: 30 + head_role: "UI Designer" + + agents: + - role: "Project Manager" + level: "senior" + model: "sonnet" + personality_preset: "visionary_leader" + department: "operations" + - role: "UI Designer" + level: "mid" + model: "sonnet" + department: "design" + - role: "Full-Stack Developer" + level: "senior" + model: "sonnet" + personality_preset: "pragmatic_builder" + department: "engineering" + + workflow: "kanban" + communication: "hybrid" diff --git a/src/ai_company/templates/builtins/dev_shop.yaml b/src/ai_company/templates/builtins/dev_shop.yaml new file mode 100644 index 0000000000..b70283ec95 --- /dev/null +++ b/src/ai_company/templates/builtins/dev_shop.yaml @@ -0,0 +1,63 @@ +template: + name: "Dev Shop" + description: "Software development focused team with QA and DevOps" + version: "1.0.0" + tags: + - "development" + - "engineering" + - "qa" + + variables: + - name: "company_name" + description: "Name of your company" + default: "Dev Shop" + - name: "budget" + description: "Monthly budget in USD" + var_type: "float" + default: 75.0 + + company: + type: "dev_shop" + budget_monthly: {{ budget | default(75.0) }} + autonomy: 0.5 + + departments: + - name: "engineering" + budget_percent: 70 + head_role: "Software Architect" + - name: "quality_assurance" + budget_percent: 20 + head_role: "QA Lead" + - name: "operations" + budget_percent: 10 + head_role: "DevOps/SRE Engineer" + + agents: + - role: "Software Architect" + level: "principal" + model: "opus" + personality_preset: "methodical_analyst" + department: "engineering" + - role: "Backend Developer" + level: "senior" + model: "sonnet" + personality_preset: "pragmatic_builder" + department: "engineering" + - role: "Backend Developer" + level: "mid" + model: "haiku" + personality_preset: "eager_learner" + department: "engineering" + - role: "QA Lead" + level: "lead" + model: "sonnet" + personality_preset: "methodical_analyst" + department: "quality_assurance" + - role: "DevOps/SRE Engineer" + level: "mid" + model: "sonnet" + personality_preset: "pragmatic_builder" + department: "operations" + + workflow: "agile_kanban" + communication: "hybrid" diff --git a/src/ai_company/templates/builtins/full_company.yaml b/src/ai_company/templates/builtins/full_company.yaml new file mode 100644 index 0000000000..f0602dcd17 --- /dev/null +++ b/src/ai_company/templates/builtins/full_company.yaml @@ -0,0 +1,68 @@ +template: + name: "Full Company" + description: "Enterprise simulation with all departments and full hierarchy" + version: "1.0.0" + tags: + - "enterprise" + - "full-hierarchy" + - "all-departments" + + variables: + - name: "company_name" + description: "Name of your company" + default: "Acme Corp" + - name: "budget" + description: "Monthly budget in USD" + var_type: "float" + default: 200.0 + + company: + type: "full_company" + budget_monthly: {{ budget | default(200.0) }} + autonomy: 0.5 + + departments: + - name: "executive" + budget_percent: 15 + head_role: "CEO" + - name: "engineering" + budget_percent: 50 + head_role: "CTO" + - name: "product" + budget_percent: 15 + head_role: "Product Manager" + - name: "quality_assurance" + budget_percent: 10 + head_role: "QA Engineer" + - name: "operations" + budget_percent: 10 + head_role: "CFO" + + agents: + - role: "CEO" + level: "c_suite" + model: "opus" + personality_preset: "visionary_leader" + department: "executive" + - role: "CTO" + level: "c_suite" + model: "opus" + personality_preset: "methodical_analyst" + department: "executive" + - role: "CFO" + level: "c_suite" + model: "opus" + department: "operations" + - role: "Backend Developer" + level: "senior" + model: "sonnet" + personality_preset: "pragmatic_builder" + department: "engineering" + - role: "QA Engineer" + level: "mid" + model: "haiku" + personality_preset: "methodical_analyst" + department: "quality_assurance" + + workflow: "agile_kanban" + communication: "hybrid" diff --git a/src/ai_company/templates/builtins/product_team.yaml b/src/ai_company/templates/builtins/product_team.yaml new file mode 100644 index 0000000000..afe96dc296 --- /dev/null +++ b/src/ai_company/templates/builtins/product_team.yaml @@ -0,0 +1,57 @@ +template: + name: "Product Team" + description: "Product-focused development with design and QA" + version: "1.0.0" + tags: + - "product" + - "design" + - "user-centered" + + variables: + - name: "company_name" + description: "Name of your company" + default: "Product Co" + - name: "budget" + description: "Monthly budget in USD" + var_type: "float" + default: 80.0 + + company: + type: "product_team" + budget_monthly: {{ budget | default(80.0) }} + autonomy: 0.5 + + departments: + - name: "product" + budget_percent: 30 + head_role: "Product Manager" + - name: "engineering" + budget_percent: 50 + head_role: "Backend Developer" + - name: "quality_assurance" + budget_percent: 20 + head_role: "QA Engineer" + + agents: + - role: "Product Manager" + level: "senior" + model: "sonnet" + personality_preset: "visionary_leader" + department: "product" + - role: "UX Designer" + level: "mid" + model: "sonnet" + department: "product" + - role: "Full-Stack Developer" + level: "senior" + model: "sonnet" + personality_preset: "pragmatic_builder" + department: "engineering" + - role: "QA Engineer" + level: "mid" + model: "haiku" + personality_preset: "methodical_analyst" + department: "quality_assurance" + + workflow: "agile_kanban" + communication: "hybrid" diff --git a/src/ai_company/templates/builtins/research_lab.yaml b/src/ai_company/templates/builtins/research_lab.yaml new file mode 100644 index 0000000000..40417f0989 --- /dev/null +++ b/src/ai_company/templates/builtins/research_lab.yaml @@ -0,0 +1,50 @@ +template: + name: "Research Lab" + description: "Research and analysis focused team" + version: "1.0.0" + tags: + - "research" + - "analysis" + - "data" + + variables: + - name: "company_name" + description: "Name of your research lab" + default: "Research Lab" + - name: "budget" + description: "Monthly budget in USD" + var_type: "float" + default: 60.0 + + company: + type: "research_lab" + budget_monthly: {{ budget | default(60.0) }} + autonomy: 0.5 + + departments: + - name: "engineering" + budget_percent: 40 + head_role: "Software Architect" + - name: "data_analytics" + budget_percent: 60 + head_role: "Data Analyst" + + agents: + - role: "Software Architect" + level: "principal" + model: "opus" + personality_preset: "methodical_analyst" + department: "engineering" + - role: "Data Engineer" + level: "senior" + model: "sonnet" + personality_preset: "pragmatic_builder" + department: "data_analytics" + - role: "Data Analyst" + level: "mid" + model: "sonnet" + personality_preset: "methodical_analyst" + department: "data_analytics" + + workflow: "kanban" + communication: "hybrid" diff --git a/src/ai_company/templates/builtins/solo_founder.yaml b/src/ai_company/templates/builtins/solo_founder.yaml new file mode 100644 index 0000000000..d369594fa2 --- /dev/null +++ b/src/ai_company/templates/builtins/solo_founder.yaml @@ -0,0 +1,46 @@ +template: + name: "Solo Founder" + description: "Minimal setup for quick prototypes and solo projects" + version: "1.0.0" + tags: + - "minimal" + - "solo" + - "prototype" + + variables: + - name: "company_name" + description: "Name of your company" + default: "My Company" + - name: "budget" + description: "Monthly budget in USD" + var_type: "float" + default: 25.0 + + company: + type: "solo_founder" + budget_monthly: {{ budget | default(25.0) }} + autonomy: 0.5 + + departments: + - name: "executive" + budget_percent: 40 + head_role: "CEO" + - name: "engineering" + budget_percent: 60 + head_role: "Full-Stack Developer" + + agents: + - role: "CEO" + name: "{{ company_name | default('My Company') }} CEO" + level: "c_suite" + model: "opus" + personality_preset: "visionary_leader" + department: "executive" + - role: "Full-Stack Developer" + level: "senior" + model: "sonnet" + personality_preset: "pragmatic_builder" + department: "engineering" + + workflow: "kanban" + communication: "hybrid" diff --git a/src/ai_company/templates/builtins/startup.yaml b/src/ai_company/templates/builtins/startup.yaml new file mode 100644 index 0000000000..863cf9b945 --- /dev/null +++ b/src/ai_company/templates/builtins/startup.yaml @@ -0,0 +1,63 @@ +template: + name: "Tech Startup" + description: "Small team for building MVPs and prototypes" + version: "1.0.0" + tags: + - "startup" + - "mvp" + - "small-team" + + variables: + - name: "company_name" + description: "Name of your company" + default: "Startup Co" + - name: "budget" + description: "Monthly budget in USD" + var_type: "float" + default: 50.0 + + company: + type: "startup" + budget_monthly: {{ budget | default(50.0) }} + autonomy: 0.5 + + departments: + - name: "executive" + budget_percent: 20 + head_role: "CEO" + - name: "engineering" + budget_percent: 60 + head_role: "CTO" + - name: "product" + budget_percent: 20 + head_role: "Product Manager" + + agents: + - role: "CEO" + name: "{{ company_name | default('Startup Co') }} CEO" + level: "c_suite" + model: "opus" + personality_preset: "visionary_leader" + department: "executive" + - role: "CTO" + level: "c_suite" + model: "opus" + personality_preset: "methodical_analyst" + department: "executive" + - role: "Full-Stack Developer" + level: "senior" + model: "sonnet" + personality_preset: "pragmatic_builder" + department: "engineering" + - role: "Full-Stack Developer" + level: "mid" + model: "haiku" + personality_preset: "eager_learner" + department: "engineering" + - role: "Product Manager" + level: "senior" + model: "sonnet" + department: "product" + + workflow: "agile_kanban" + communication: "hybrid" diff --git a/src/ai_company/templates/errors.py b/src/ai_company/templates/errors.py new file mode 100644 index 0000000000..ff372ef821 --- /dev/null +++ b/src/ai_company/templates/errors.py @@ -0,0 +1,33 @@ +"""Custom exception hierarchy for template errors.""" + +from ai_company.config.errors import ConfigError, ConfigLocation + + +class TemplateError(ConfigError): + """Base exception for template errors.""" + + +class TemplateNotFoundError(TemplateError): + """Raised when a template cannot be found.""" + + +class TemplateRenderError(TemplateError): + """Raised when Jinja2 rendering fails or a required variable is missing.""" + + +class TemplateValidationError(TemplateError): + """Raised when a rendered template fails validation. + + Attributes: + field_errors: Per-field error messages as + ``(key_path, message)`` pairs. + """ + + def __init__( + self, + message: str, + locations: tuple[ConfigLocation, ...] = (), + field_errors: tuple[tuple[str, str], ...] = (), + ) -> None: + super().__init__(message, locations) + self.field_errors = field_errors diff --git a/src/ai_company/templates/loader.py b/src/ai_company/templates/loader.py new file mode 100644 index 0000000000..3a4bbafe64 --- /dev/null +++ b/src/ai_company/templates/loader.py @@ -0,0 +1,349 @@ +"""Template loading from built-in, user directory, and file-system sources. + +Implements a two-pass loading strategy: + +- **Pass 1**: YAML-parse the template to extract metadata and the + ``variables`` section (which uses plain YAML, no Jinja2). +- **Pass 2**: Performed later by the renderer — Jinja2-renders the raw + YAML text, then YAML-parses the result. + +The loader returns both the structured :class:`CompanyTemplate` (from +Pass 1) and the raw YAML text (for Pass 2). +""" + +import logging +import re +from dataclasses import dataclass +from importlib import resources +from pathlib import Path +from typing import Any + +import yaml + +from ai_company.config.errors import ConfigLocation +from ai_company.templates.errors import ( + TemplateNotFoundError, + TemplateRenderError, + TemplateValidationError, +) +from ai_company.templates.schema import CompanyTemplate + +logger = logging.getLogger(__name__) + +_USER_TEMPLATES_DIR = Path.home() / ".ai-company" / "templates" + +# Registry of built-in template names -> resource filenames. +BUILTIN_TEMPLATES: dict[str, str] = { + "solo_founder": "solo_founder.yaml", + "startup": "startup.yaml", + "dev_shop": "dev_shop.yaml", + "product_team": "product_team.yaml", + "agency": "agency.yaml", + "full_company": "full_company.yaml", + "research_lab": "research_lab.yaml", +} + + +@dataclass(frozen=True) +class TemplateInfo: + """Summary information about an available template. + + Attributes: + name: Template identifier (e.g. ``"startup"``). + display_name: Human-readable display name. + description: Short description. + source: Where the template was found (``"builtin"`` or ``"user"``). + """ + + name: str + display_name: str + description: str + source: str + + +@dataclass(frozen=True) +class LoadedTemplate: + """Result of loading a template: structured data + raw text. + + Attributes: + template: Validated ``CompanyTemplate`` from Pass 1. + raw_yaml: Raw YAML text for Pass 2 (Jinja2 rendering). + source_name: Label for error messages. + """ + + template: CompanyTemplate + raw_yaml: str + source_name: str + + +def list_templates() -> tuple[TemplateInfo, ...]: + """Return all available templates (user directory + built-in). + + User templates are listed first. If a user template has the same + name as a built-in, only the user template appears. + + Returns: + Sorted tuple of :class:`TemplateInfo` objects. + """ + seen: dict[str, TemplateInfo] = {} + + # User templates (higher priority). + if _USER_TEMPLATES_DIR.is_dir(): + for path in sorted(_USER_TEMPLATES_DIR.glob("*.yaml")): + name = path.stem + try: + loaded = _load_from_file(path) + meta = loaded.template.metadata + seen[name] = TemplateInfo( + name=name, + display_name=meta.name, + description=meta.description, + source="user", + ) + except Exception: + logger.warning("Skipping invalid user template: %s", path) + + # Built-in templates (lower priority). + for name in sorted(BUILTIN_TEMPLATES): + if name not in seen: + try: + loaded = _load_builtin(name) + meta = loaded.template.metadata + seen[name] = TemplateInfo( + name=name, + display_name=meta.name, + description=meta.description, + source="builtin", + ) + except Exception: + logger.warning("Skipping invalid builtin template: %s", name) + + return tuple(info for _, info in sorted(seen.items())) + + +def list_builtin_templates() -> tuple[str, ...]: + """Return names of all built-in templates. + + Returns: + Sorted tuple of built-in template names. + """ + return tuple(sorted(BUILTIN_TEMPLATES)) + + +def load_template(name: str) -> LoadedTemplate: + """Load a template by name: user directory first, then builtins. + + Args: + name: Template name (e.g. ``"startup"``). + + Returns: + :class:`LoadedTemplate` with validated data and raw YAML. + + Raises: + TemplateNotFoundError: If no template with *name* exists. + """ + name_clean = name.strip().lower() + + # Try user directory first. + if _USER_TEMPLATES_DIR.is_dir(): + user_path = _USER_TEMPLATES_DIR / f"{name_clean}.yaml" + if user_path.is_file(): + return _load_from_file(user_path) + + # Fall back to builtins. + if name_clean in BUILTIN_TEMPLATES: + return _load_builtin(name_clean) + + available = list_builtin_templates() + msg = f"Unknown template {name!r}. Available: {list(available)}" + raise TemplateNotFoundError( + msg, + locations=(ConfigLocation(file_path=f""),), + ) + + +def load_template_file(path: Path | str) -> LoadedTemplate: + """Load a template from an explicit file path. + + Args: + path: Path to the template YAML file. + + Returns: + :class:`LoadedTemplate` with validated data and raw YAML. + + Raises: + TemplateNotFoundError: If the file does not exist. + TemplateValidationError: If validation fails. + """ + path = Path(path) + if not path.is_file(): + msg = f"Template file not found: {path}" + raise TemplateNotFoundError( + msg, + locations=(ConfigLocation(file_path=str(path)),), + ) + return _load_from_file(path) + + +# --------------------------------------------------------------------------- +# Private helpers +# --------------------------------------------------------------------------- + + +def _load_builtin(name: str) -> LoadedTemplate: + """Load a built-in template by name.""" + filename = BUILTIN_TEMPLATES.get(name) + if filename is None: + msg = f"Unknown built-in template: {name!r}" + raise TemplateNotFoundError( + msg, + locations=(ConfigLocation(file_path=f""),), + ) + ref = resources.files("ai_company.templates.builtins") / filename + yaml_text = ref.read_text(encoding="utf-8") + source_name = f"" + template = _parse_template_yaml(yaml_text, source_name=source_name) + return LoadedTemplate( + template=template, + raw_yaml=yaml_text, + source_name=source_name, + ) + + +def _load_from_file(path: Path) -> LoadedTemplate: + """Load a template from a file path.""" + yaml_text = path.read_text(encoding="utf-8") + source_name = str(path) + template = _parse_template_yaml(yaml_text, source_name=source_name) + return LoadedTemplate( + template=template, + raw_yaml=yaml_text, + source_name=source_name, + ) + + +def _strip_jinja2_for_pass1(yaml_text: str) -> str: + """Replace Jinja2 expressions with YAML-safe placeholders for Pass 1. + + Pass 1 only extracts metadata and the ``variables`` section (which + must be plain YAML). The rest of the template may contain unquoted + Jinja2 expressions (``{{ }}``, ``{% %}``, ``{# #}``) that are + invalid YAML. This function replaces them with safe placeholders + so that ``yaml.safe_load`` succeeds. + + Args: + yaml_text: Raw template YAML with possible Jinja2 expressions. + + Returns: + YAML text with Jinja2 expressions replaced by safe strings. + """ + # Replace {{ ... }} with a bare placeholder (no extra quotes, + # so it works both inside quoted strings and unquoted values). + text = re.sub(r"\{\{.*?\}\}", "__JINJA2__", yaml_text) + # Remove {% ... %} block tags (lines containing only a tag are removed). + text = re.sub(r"\{%.*?%\}", "", text) + # Remove {# ... #} comments. + return re.sub(r"\{#.*?#\}", "", text) + + +def _parse_template_yaml( + yaml_text: str, + *, + source_name: str, +) -> CompanyTemplate: + """Parse a template YAML string into a CompanyTemplate (Pass 1). + + Jinja2 expressions are stripped before YAML parsing so that + unquoted ``{{ }}`` syntax does not cause parse errors. Only + metadata and the ``variables`` section (which must be plain YAML) + are needed from this pass. + + Args: + yaml_text: Raw YAML content. + source_name: Label for error messages. + + Returns: + Validated :class:`CompanyTemplate`. + + Raises: + TemplateRenderError: If YAML parsing fails. + TemplateValidationError: If the structure fails validation. + """ + safe_text = _strip_jinja2_for_pass1(yaml_text) + try: + data = yaml.safe_load(safe_text) + except yaml.YAMLError as exc: + msg = f"Template YAML syntax error in {source_name}: {exc}" + raise TemplateRenderError( + msg, + locations=(ConfigLocation(file_path=source_name),), + ) from exc + + if not isinstance(data, dict) or "template" not in data: + msg = f"Template YAML must have a top-level 'template' key in {source_name}" + raise TemplateValidationError( + msg, + locations=(ConfigLocation(file_path=source_name),), + ) + + template_data = data["template"] + normalized = _normalize_template_data(template_data) + try: + return CompanyTemplate(**normalized) + except Exception as exc: + msg = f"Template validation failed for {source_name}: {exc}" + raise TemplateValidationError( + msg, + locations=(ConfigLocation(file_path=source_name),), + ) from exc + + +def _normalize_template_data(data: dict[str, Any]) -> dict[str, Any]: + """Transform raw YAML template data into CompanyTemplate kwargs. + + Bridges the human-friendly flat YAML format and the nested Pydantic + model shape. + + Args: + data: The dict under the top-level ``template`` key. + + Returns: + Dict suitable for ``CompanyTemplate(**result)``. + """ + company = data.get("company", {}) + + metadata = { + "name": data.get("name", ""), + "description": data.get("description", ""), + "version": data.get("version", "1.0.0"), + "company_type": company.get("type", "custom"), + "tags": tuple(data.get("tags", ())), + } + + return { + "metadata": metadata, + "variables": data.get("variables", ()), + "agents": data.get("agents", ()), + "departments": data.get("departments", ()), + "workflow": data.get("workflow", "agile_kanban"), + "communication": data.get("communication", "hybrid"), + "budget_monthly": _to_float(company.get("budget_monthly", 50.0)), + "autonomy": _to_float(company.get("autonomy", 0.5)), + } + + +def _to_float(value: Any) -> float: + """Coerce a value to float, handling string numerics. + + Args: + value: Raw value from YAML (may be str, int, float). + + Returns: + Float value. + """ + if isinstance(value, str): + try: + return float(value) + except ValueError: + return 0.0 + return float(value) diff --git a/src/ai_company/templates/presets.py b/src/ai_company/templates/presets.py new file mode 100644 index 0000000000..542f3483d9 --- /dev/null +++ b/src/ai_company/templates/presets.py @@ -0,0 +1,103 @@ +"""Minimal personality presets and auto-name generation for templates. + +This module provides placeholder presets for M1. A comprehensive +preset library is planned in a follow-up issue (see GH #80). +""" + +import random +from typing import Any + +# Preset name -> dict compatible with PersonalityConfig constructor. +PERSONALITY_PRESETS: dict[str, dict[str, Any]] = { + "visionary_leader": { + "traits": ("strategic", "decisive", "inspiring"), + "communication_style": "authoritative", + "risk_tolerance": "high", + "creativity": "high", + "description": ("A visionary leader who sets direction and inspires the team."), + }, + "pragmatic_builder": { + "traits": ("practical", "reliable", "detail-oriented"), + "communication_style": "concise", + "risk_tolerance": "medium", + "creativity": "medium", + "description": ("A pragmatic builder focused on shipping quality code."), + }, + "eager_learner": { + "traits": ("curious", "enthusiastic", "adaptable"), + "communication_style": "collaborative", + "risk_tolerance": "low", + "creativity": "medium", + "description": ("An eager learner who grows quickly and asks good questions."), + }, + "methodical_analyst": { + "traits": ("thorough", "systematic", "objective"), + "communication_style": "formal", + "risk_tolerance": "low", + "creativity": "low", + "description": ("A methodical analyst who values precision and completeness."), + }, +} + +# Role-aware auto-generated name pools (gender-neutral names). +_AUTO_NAMES: dict[str, tuple[str, ...]] = { + "ceo": ("Alex Chen", "Jordan Park", "Morgan Lee", "Taylor Kim"), + "cto": ("Quinn Zhang", "Sage Patel", "Avery Nakamura", "Reese Torres"), + "cfo": ("Drew Collins", "Casey Rivera", "Blake Morrison", "Ellis Ward"), + "coo": ("Rowan Blake", "Finley Cruz", "Emery Santos", "Harper Quinn"), + "cpo": ("Phoenix Reed", "Kendall Brooks", "Harley Stone", "Lennox Hayes"), + "full-stack developer": ("Riley Sharma", "Dakota Wei", "Skyler Okafor"), + "backend developer": ("Cameron Ito", "Hayden Reyes", "Jamie Novak"), + "frontend developer": ("Kai Jensen", "Noel Andersen", "Sage Hoffman"), + "product manager": ("Emery Cho", "Phoenix Larsen", "Lennox Dunn"), + "qa lead": ("Jordan Vega", "Taylor Marsh", "Morgan Frost"), + "qa engineer": ("Riley Tran", "Avery Grant", "Blake Russell"), + "devops/sre engineer": ("Quinn Mercer", "Drew Kemp", "Casey Mills"), + "software architect": ("Sage Holloway", "Rowan Fischer", "Emery Drake"), + "ux designer": ("Kai Sinclair", "Harper Lane", "Noel Ashford"), + "ui designer": ("Finley Archer", "Lennox Byrne", "Phoenix Dale"), + "data analyst": ("Drew Hartley", "Casey Lowe", "Blake Summers"), + "data engineer": ("Reese Gallagher", "Jordan Holt", "Taylor Crane"), + "security engineer": ("Quinn Steele", "Morgan Wolfe", "Avery Knox"), + "content writer": ("Harper Ellis", "Kendall Frost", "Sage Monroe"), + "_default": ("Agent Alpha", "Agent Beta", "Agent Gamma", "Agent Delta"), +} + + +def get_personality_preset(name: str) -> dict[str, Any]: + """Look up a personality preset by name. + + Args: + name: Preset name (case-insensitive, whitespace-stripped). + + Returns: + A *copy* of the personality configuration dict. + + Raises: + KeyError: If the preset name is not found. + """ + key = name.strip().lower() + if key not in PERSONALITY_PRESETS: + available = sorted(PERSONALITY_PRESETS) + msg = f"Unknown personality preset {name!r}. Available: {available}" + raise KeyError(msg) + return dict(PERSONALITY_PRESETS[key]) + + +def generate_auto_name(role: str, *, seed: int | None = None) -> str: + """Generate a contextual agent name based on role. + + Uses a deterministic PRNG when *seed* is provided, ensuring + reproducible name generation across runs. + + Args: + role: The agent's role name. + seed: Optional random seed for deterministic naming. + + Returns: + A generated agent name string. + """ + key = role.strip().lower() + pool = _AUTO_NAMES.get(key, _AUTO_NAMES["_default"]) + rng = random.Random(seed) # noqa: S311 + return rng.choice(pool) diff --git a/src/ai_company/templates/renderer.py b/src/ai_company/templates/renderer.py new file mode 100644 index 0000000000..e4fb861ef2 --- /dev/null +++ b/src/ai_company/templates/renderer.py @@ -0,0 +1,392 @@ +"""Template rendering: Jinja2 substitution + validation to RootConfig. + +Implements the second pass of the two-pass rendering pipeline: + +1. Collect user variables + defaults from the ``CompanyTemplate``. +2. Render the raw YAML text through a Jinja2 ``SandboxedEnvironment``. +3. YAML-parse the rendered text. +4. Build a ``RootConfig``-compatible dict and validate. +""" + +import logging +from typing import TYPE_CHECKING, Any + +import yaml +from jinja2.sandbox import SandboxedEnvironment +from pydantic import ValidationError + +from ai_company.config.defaults import default_config_dict +from ai_company.config.errors import ConfigLocation +from ai_company.config.schema import RootConfig +from ai_company.config.utils import deep_merge +from ai_company.templates.errors import ( + TemplateRenderError, + TemplateValidationError, +) +from ai_company.templates.presets import ( + generate_auto_name, + get_personality_preset, +) + +if TYPE_CHECKING: + from ai_company.templates.loader import LoadedTemplate + from ai_company.templates.schema import CompanyTemplate + +logger = logging.getLogger(__name__) + + +def render_template( + loaded: LoadedTemplate, + variables: dict[str, Any] | None = None, +) -> RootConfig: + """Render a loaded template into a validated RootConfig. + + Pipeline: + + 1. Collect variables (user-supplied + defaults from template). + 2. Jinja2-render the raw YAML text with collected variables. + 3. YAML-parse the rendered text. + 4. Normalize into ``RootConfig`` shape. + 5. Deep-merge with ``default_config_dict()``. + 6. Validate as ``RootConfig``. + + Args: + loaded: :class:`LoadedTemplate` from the loader. + variables: User-supplied variable values (overrides defaults). + + Returns: + Validated, frozen :class:`RootConfig`. + + Raises: + TemplateRenderError: If variable collection or Jinja2 rendering + fails. + TemplateValidationError: If the rendered result fails + ``RootConfig`` validation. + """ + template = loaded.template + vars_dict = _collect_variables(template, variables or {}) + + # Jinja2-render the raw YAML (Pass 2). + rendered_text = _render_jinja2( + loaded.raw_yaml, + vars_dict, + source_name=loaded.source_name, + ) + + # Parse the rendered YAML. + rendered_data = _parse_rendered_yaml(rendered_text, loaded.source_name) + + # Build RootConfig dict from the rendered data. + config_dict = _build_config_dict(rendered_data, template, vars_dict) + + # Merge with defaults and validate. + merged = deep_merge(default_config_dict(), config_dict) + return _validate_as_root_config(merged, loaded.source_name) + + +# --------------------------------------------------------------------------- +# Private helpers +# --------------------------------------------------------------------------- + + +def _collect_variables( + template: CompanyTemplate, + user_vars: dict[str, Any], +) -> dict[str, Any]: + """Merge user variables with template defaults. + + Args: + template: Template with variable declarations. + user_vars: User-supplied values. + + Returns: + Complete variable dict. + + Raises: + TemplateRenderError: If a required variable is missing. + """ + result: dict[str, Any] = {} + for var in template.variables: + if var.name in user_vars: + result[var.name] = user_vars[var.name] + elif var.default is not None: + result[var.name] = var.default + elif var.required: + msg = f"Required template variable {var.name!r} was not provided" + raise TemplateRenderError(msg) + # Optional vars with no default and no user value are omitted; + # the Jinja2 template will get ``Undefined`` for them. + + # Pass through extra user vars not declared in the template. + for key, value in user_vars.items(): + if key not in result: + result[key] = value + + return result + + +def _create_jinja_env() -> SandboxedEnvironment: + """Create a sandboxed Jinja2 environment with custom filters. + + Returns: + Configured :class:`SandboxedEnvironment`. + """ + env = SandboxedEnvironment( + keep_trailing_newline=True, + ) + # ``auto`` filter: returns empty string for falsy values (triggers + # auto-name generation in _expand_agents). + env.filters["auto"] = lambda value: value or "" + return env + + +def _render_jinja2( + raw_yaml: str, + variables: dict[str, Any], + *, + source_name: str, +) -> str: + """Render raw YAML text through Jinja2 with given variables. + + Args: + raw_yaml: Template YAML text with Jinja2 expressions. + variables: Collected variable values. + source_name: Label for error messages. + + Returns: + Rendered YAML text with all expressions resolved. + + Raises: + TemplateRenderError: If Jinja2 rendering fails. + """ + env = _create_jinja_env() + try: + jinja_template = env.from_string(raw_yaml) + return jinja_template.render(**variables) + except Exception as exc: + msg = f"Jinja2 rendering failed for {source_name}: {exc}" + raise TemplateRenderError( + msg, + locations=(ConfigLocation(file_path=source_name),), + ) from exc + + +def _parse_rendered_yaml( + rendered_text: str, + source_name: str, +) -> dict[str, Any]: + """Parse the Jinja2-rendered YAML text. + + Args: + rendered_text: YAML text with all Jinja2 expressions resolved. + source_name: Label for error messages. + + Returns: + Parsed dict from the ``template`` key. + + Raises: + TemplateRenderError: If YAML parsing fails. + """ + try: + data = yaml.safe_load(rendered_text) + except yaml.YAMLError as exc: + msg = f"Rendered template YAML is invalid for {source_name}: {exc}" + raise TemplateRenderError( + msg, + locations=(ConfigLocation(file_path=source_name),), + ) from exc + + if not isinstance(data, dict) or "template" not in data: + msg = f"Rendered template missing 'template' key: {source_name}" + raise TemplateRenderError( + msg, + locations=(ConfigLocation(file_path=source_name),), + ) + + result: dict[str, Any] = data["template"] + return result + + +def _build_config_dict( + rendered_data: dict[str, Any], + template: CompanyTemplate, + variables: dict[str, Any], +) -> dict[str, Any]: + """Build a RootConfig-compatible dict from rendered template data. + + Args: + rendered_data: Parsed dict from the rendered YAML. + template: Original template metadata (for fallback values). + variables: Collected variables. + + Returns: + Dict suitable for ``RootConfig(**deep_merge(defaults, result))``. + """ + company = rendered_data.get("company", {}) + company_name = variables.get( + "company_name", + template.metadata.name, + ) + + # Expand agents. + raw_agents = rendered_data.get("agents", []) + agents = _expand_agents(raw_agents, variables) + + # Build departments for RootConfig. + raw_depts = rendered_data.get("departments", []) + departments = _build_departments(raw_depts) + + return { + "company_name": company_name, + "company_type": company.get("type", template.metadata.company_type.value), + "agents": agents, + "departments": departments, + "config": { + "autonomy": _safe_float(company.get("autonomy", template.autonomy)), + "budget_monthly": _safe_float( + company.get("budget_monthly", template.budget_monthly), + ), + "communication_pattern": rendered_data.get( + "communication", + template.communication, + ), + }, + } + + +def _expand_agents( + raw_agents: list[dict[str, Any]], + _variables: dict[str, Any], +) -> list[dict[str, Any]]: + """Expand template agent dicts into AgentConfig-compatible dicts. + + Handles auto-name generation and personality preset expansion. + + Args: + raw_agents: List of agent dicts from rendered YAML. + variables: Collected variables. + + Returns: + List of dicts suitable for ``AgentConfig`` construction. + """ + expanded: list[dict[str, Any]] = [] + used_names: set[str] = set() + + for idx, agent in enumerate(raw_agents): + role = agent.get("role", "Agent") + name = str(agent.get("name", "")).strip() + + # Auto-generate name if empty or still a Jinja2 expression. + if not name or name.startswith("{{"): + name = generate_auto_name(role, seed=idx) + + # Ensure uniqueness by appending a suffix if needed. + base_name = name + counter = 2 + while name in used_names: + name = f"{base_name} {counter}" + counter += 1 + used_names.add(name) + + agent_dict: dict[str, Any] = { + "name": name, + "role": role, + "department": agent.get("department", "engineering"), + "level": agent.get("level", "mid"), + } + + # Expand personality preset. + preset_name = agent.get("personality_preset") + if preset_name: + try: + agent_dict["personality"] = get_personality_preset(preset_name) + except KeyError: + logger.warning( + "Unknown personality preset %r for agent %r, using defaults", + preset_name, + name, + ) + + # Model config (raw dict for AgentConfig). + model_tier = agent.get("model", "sonnet") + agent_dict["model"] = {"provider": "default", "model_id": model_tier} + + expanded.append(agent_dict) + + return expanded + + +def _build_departments( + raw_depts: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Build RootConfig-compatible department dicts. + + Args: + raw_depts: List of department dicts from rendered YAML. + + Returns: + List of dicts suitable for ``Department`` construction. + """ + departments: list[dict[str, Any]] = [] + for dept in raw_depts: + dept_dict: dict[str, Any] = { + "name": dept.get("name", ""), + "head": dept.get("head_role", dept.get("name", "")), + "budget_percent": _safe_float(dept.get("budget_percent", 0.0)), + } + departments.append(dept_dict) + return departments + + +def _validate_as_root_config( + merged: dict[str, Any], + source_name: str, +) -> RootConfig: + """Validate a merged config dict as RootConfig. + + Args: + merged: Merged config dict. + source_name: Label for error messages. + + Returns: + Validated, frozen :class:`RootConfig`. + + Raises: + TemplateValidationError: If validation fails. + """ + try: + return RootConfig(**merged) + except ValidationError as exc: + field_errors: list[tuple[str, str]] = [] + locations: list[ConfigLocation] = [] + for error in exc.errors(): + key_path = ".".join(str(p) for p in error["loc"]) + error_msg = error["msg"] + field_errors.append((key_path, error_msg)) + locations.append( + ConfigLocation( + file_path=source_name, + key_path=key_path, + ), + ) + msg = f"Rendered template failed RootConfig validation: {source_name}" + raise TemplateValidationError( + msg, + locations=tuple(locations), + field_errors=tuple(field_errors), + ) from exc + + +def _safe_float(value: Any) -> float: + """Coerce a value to float safely. + + Args: + value: Value from rendered YAML (str, int, or float). + + Returns: + Float value, or 0.0 on conversion failure. + """ + try: + return float(value) + except TypeError, ValueError: + return 0.0 diff --git a/src/ai_company/templates/schema.py b/src/ai_company/templates/schema.py new file mode 100644 index 0000000000..99a9578eac --- /dev/null +++ b/src/ai_company/templates/schema.py @@ -0,0 +1,237 @@ +"""Template schema: Pydantic models for company templates.""" + +from collections import Counter +from typing import Any, Literal, Self + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from ai_company.core.enums import CompanyType, SeniorityLevel +from ai_company.core.types import NotBlankStr # noqa: TC001 + + +class TemplateVariable(BaseModel): + """A user-configurable variable within a template. + + Variables declared here are extracted from the template YAML during + the first parsing pass (before Jinja2 rendering). The ``variables`` + section must use plain YAML — no Jinja2 expressions. + + Attributes: + name: Variable name (used in ``{{ name }}`` placeholders). + description: Human-readable description for prompts/docs. + var_type: Expected Python type name. + default: Default value (``None`` means the variable is required). + required: Whether the user must provide this value. + """ + + model_config = ConfigDict(frozen=True) + + name: NotBlankStr = Field(description="Variable name") + description: str = Field(default="", description="Human-readable description") + var_type: Literal["str", "int", "float", "bool"] = Field( + default="str", + description="Expected value type", + ) + default: Any = Field(default=None, description="Default value") + required: bool = Field(default=False, description="Whether required") + + @model_validator(mode="after") + def _validate_required_has_no_default(self) -> Self: + """Required variables must not define a default.""" + if self.required and self.default is not None: + msg = f"Variable {self.name!r} is required but defines a default" + raise ValueError(msg) + return self + + +class TemplateAgentConfig(BaseModel): + """Agent definition within a template. + + Uses string references and presets rather than full ``AgentConfig``. + The renderer expands these into full agent configuration dicts. + + Attributes: + role: Built-in role name (case-insensitive match to role catalog). + name: Agent name (may contain Jinja2 placeholders; empty triggers + auto-generation). + level: Seniority level override. + model: Model tier alias (e.g. ``"opus"``, ``"sonnet"``, ``"haiku"``). + personality_preset: Named personality preset from the presets registry. + department: Department override (``None`` uses the role's default). + """ + + model_config = ConfigDict(frozen=True) + + role: NotBlankStr = Field(description="Built-in role name") + name: str = Field(default="", description="Agent name (may have Jinja2 vars)") + level: SeniorityLevel = Field( + default=SeniorityLevel.MID, + description="Seniority level", + ) + model: str = Field(default="sonnet", description="Model tier alias") + personality_preset: str | None = Field( + default=None, + description="Named personality preset", + ) + department: str | None = Field( + default=None, + description="Department override", + ) + + +class TemplateDepartmentConfig(BaseModel): + """Department definition within a template. + + Provides structural information only — department names, budget + allocations, and the head role. Internal team composition and + reporting lines are defined separately (see follow-up issues). + + Attributes: + name: Department name (standard or custom). + budget_percent: Percentage of company budget (0-100). + head_role: Role name of the department head. + """ + + model_config = ConfigDict(frozen=True) + + name: NotBlankStr = Field(description="Department name") + budget_percent: float = Field( + default=0.0, + ge=0.0, + le=100.0, + description="Percentage of company budget", + ) + head_role: str | None = Field( + default=None, + description="Role name of department head", + ) + + +class TemplateMetadata(BaseModel): + """Metadata about a company template. + + Attributes: + name: Template display name. + description: What this template is for. + version: Semantic version string. + company_type: Which ``CompanyType`` this template creates. + min_agents: Minimum number of agents. + max_agents: Maximum number of agents. + tags: Categorization tags. + """ + + model_config = ConfigDict(frozen=True) + + name: NotBlankStr = Field(description="Template display name") + description: str = Field(default="", description="Template description") + version: NotBlankStr = Field(default="1.0.0", description="Semantic version") + company_type: CompanyType = Field( + description="Company type this template creates", + ) + min_agents: int = Field(default=1, ge=1, description="Minimum agents") + max_agents: int = Field(default=100, ge=1, description="Maximum agents") + tags: tuple[str, ...] = Field(default=(), description="Categorization tags") + + @model_validator(mode="after") + def _validate_agent_range(self) -> Self: + """Ensure min_agents <= max_agents.""" + if self.min_agents > self.max_agents: + msg = f"min_agents ({self.min_agents}) > max_agents ({self.max_agents})" + raise ValueError(msg) + return self + + +class CompanyTemplate(BaseModel): + """A complete company template definition. + + This is the top-level model parsed from a template YAML file + during the first pass (before Jinja2 rendering). It holds + metadata, variable declarations, and the structural definitions + for agents and departments. + + The raw YAML text is stored separately by the loader for the + second pass (Jinja2 rendering). + + Attributes: + metadata: Template metadata. + variables: Declared template variables (plain YAML, no Jinja2). + agents: Template agent definitions. + departments: Template department definitions. + workflow: Workflow name. + communication: Communication pattern name. + budget_monthly: Default monthly budget in USD. + autonomy: Autonomy level (0.0 = full human oversight, + 1.0 = fully autonomous). + """ + + model_config = ConfigDict(frozen=True) + + metadata: TemplateMetadata = Field(description="Template metadata") + variables: tuple[TemplateVariable, ...] = Field( + default=(), + description="Declared template variables", + ) + agents: tuple[TemplateAgentConfig, ...] = Field( + description="Template agent definitions", + ) + departments: tuple[TemplateDepartmentConfig, ...] = Field( + default=(), + description="Template department definitions", + ) + workflow: str = Field( + default="agile_kanban", + description="Workflow name", + ) + communication: str = Field( + default="hybrid", + description="Communication pattern", + ) + budget_monthly: float = Field( + default=50.0, + ge=0.0, + description="Default monthly budget in USD", + ) + autonomy: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Autonomy level", + ) + + @model_validator(mode="after") + def _validate_agent_count_in_range(self) -> Self: + """Agent count must be within metadata min/max.""" + count = len(self.agents) + if count < self.metadata.min_agents: + msg = ( + f"Template defines {count} agent(s), " + f"minimum is {self.metadata.min_agents}" + ) + raise ValueError(msg) + if count > self.metadata.max_agents: + msg = ( + f"Template defines {count} agent(s), " + f"maximum is {self.metadata.max_agents}" + ) + raise ValueError(msg) + return self + + @model_validator(mode="after") + def _validate_unique_variable_names(self) -> Self: + """Variable names must be unique.""" + names = [v.name for v in self.variables] + if len(names) != len(set(names)): + dupes = sorted(n for n, c in Counter(names).items() if c > 1) + msg = f"Duplicate variable names: {dupes}" + raise ValueError(msg) + return self + + @model_validator(mode="after") + def _validate_unique_department_names(self) -> Self: + """Department names must be unique.""" + names = [d.name for d in self.departments] + if len(names) != len(set(names)): + dupes = sorted(n for n, c in Counter(names).items() if c > 1) + msg = f"Duplicate department names: {dupes}" + raise ValueError(msg) + return self diff --git a/tests/unit/config/test_loader.py b/tests/unit/config/test_loader.py index 3d9099c6e7..46d08ce3d8 100644 --- a/tests/unit/config/test_loader.py +++ b/tests/unit/config/test_loader.py @@ -11,7 +11,6 @@ ) from ai_company.config.loader import ( _build_line_map, - _deep_merge, _parse_yaml_file, _parse_yaml_string, _read_config_text, @@ -22,6 +21,7 @@ load_config_from_string, ) from ai_company.config.schema import RootConfig +from ai_company.config.utils import deep_merge from .conftest import ( ENV_VAR_MISSING_YAML, @@ -34,7 +34,7 @@ MISSING_REQUIRED_YAML, ) -# ── _deep_merge ────────────────────────────────────────────────── +# ── deep_merge ────────────────────────────────────────────────── @pytest.mark.unit @@ -42,31 +42,31 @@ class TestDeepMerge: def test_simple_override(self): base = {"a": 1, "b": 2} override = {"b": 3} - result = _deep_merge(base, override) + result = deep_merge(base, override) assert result == {"a": 1, "b": 3} def test_nested_merge(self): base = {"x": {"a": 1, "b": 2}} override = {"x": {"b": 3, "c": 4}} - result = _deep_merge(base, override) + result = deep_merge(base, override) assert result == {"x": {"a": 1, "b": 3, "c": 4}} def test_list_replaced_entirely(self): base = {"items": [1, 2, 3]} override = {"items": [4, 5]} - result = _deep_merge(base, override) + result = deep_merge(base, override) assert result == {"items": [4, 5]} def test_base_preserved(self): base = {"a": 1, "b": 2} override = {"c": 3} - result = _deep_merge(base, override) + result = deep_merge(base, override) assert result == {"a": 1, "b": 2, "c": 3} def test_new_keys_added(self): base = {"a": 1} override = {"b": 2, "c": 3} - result = _deep_merge(base, override) + result = deep_merge(base, override) assert result == {"a": 1, "b": 2, "c": 3} def test_inputs_not_mutated(self): @@ -74,28 +74,28 @@ def test_inputs_not_mutated(self): override = {"x": {"b": 2}} base_copy = {"x": {"a": 1}} override_copy = {"x": {"b": 2}} - _deep_merge(base, override) + deep_merge(base, override) assert base == base_copy assert override == override_copy def test_deeply_nested(self): base = {"a": {"b": {"c": 1}}} override = {"a": {"b": {"d": 2}}} - result = _deep_merge(base, override) + result = deep_merge(base, override) assert result == {"a": {"b": {"c": 1, "d": 2}}} def test_result_does_not_share_mutable_refs_with_base(self): base = {"x": {"nested": [1, 2, 3]}} - result = _deep_merge(base, {}) + result = deep_merge(base, {}) result["x"]["nested"].append(4) assert base["x"]["nested"] == [1, 2, 3] def test_empty_base(self): - result = _deep_merge({}, {"a": 1}) + result = deep_merge({}, {"a": 1}) assert result == {"a": 1} def test_empty_override(self): - result = _deep_merge({"a": 1}, {}) + result = deep_merge({"a": 1}, {}) assert result == {"a": 1} diff --git a/tests/unit/config/test_utils.py b/tests/unit/config/test_utils.py new file mode 100644 index 0000000000..d9aba4f2a3 --- /dev/null +++ b/tests/unit/config/test_utils.py @@ -0,0 +1,68 @@ +"""Tests for shared configuration utilities.""" + +import pytest + +from ai_company.config.utils import deep_merge + + +@pytest.mark.unit +class TestDeepMerge: + def test_empty_base(self): + result = deep_merge({}, {"a": 1}) + assert result == {"a": 1} + + def test_empty_override(self): + result = deep_merge({"a": 1}, {}) + assert result == {"a": 1} + + def test_both_empty(self): + result = deep_merge({}, {}) + assert result == {} + + def test_simple_override(self): + result = deep_merge({"a": 1, "b": 2}, {"b": 3}) + assert result == {"a": 1, "b": 3} + + def test_nested_merge(self): + base = {"x": {"a": 1, "b": 2}} + override = {"x": {"b": 3, "c": 4}} + result = deep_merge(base, override) + assert result == {"x": {"a": 1, "b": 3, "c": 4}} + + def test_deeply_nested(self): + base = {"x": {"y": {"z": 1, "w": 2}}} + override = {"x": {"y": {"z": 99}}} + result = deep_merge(base, override) + assert result == {"x": {"y": {"z": 99, "w": 2}}} + + def test_override_dict_with_scalar(self): + base = {"x": {"a": 1}} + override = {"x": 42} + result = deep_merge(base, override) + assert result == {"x": 42} + + def test_override_scalar_with_dict(self): + base = {"x": 42} + override = {"x": {"a": 1}} + result = deep_merge(base, override) + assert result == {"x": {"a": 1}} + + def test_does_not_mutate_base(self): + base = {"x": {"a": 1}} + override = {"x": {"b": 2}} + original_base = {"x": {"a": 1}} + deep_merge(base, override) + assert base == original_base + + def test_does_not_mutate_override(self): + base = {"x": 1} + override = {"y": {"a": [1, 2]}} + original_override = {"y": {"a": [1, 2]}} + deep_merge(base, override) + assert override == original_override + + def test_list_replaced_not_merged(self): + base = {"items": [1, 2, 3]} + override = {"items": [4, 5]} + result = deep_merge(base, override) + assert result == {"items": [4, 5]} diff --git a/tests/unit/templates/__init__.py b/tests/unit/templates/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/templates/conftest.py b/tests/unit/templates/conftest.py new file mode 100644 index 0000000000..2697c73a2d --- /dev/null +++ b/tests/unit/templates/conftest.py @@ -0,0 +1,131 @@ +"""Unit test configuration and fixtures for templates.""" + +from typing import TYPE_CHECKING, Any + +import pytest + +if TYPE_CHECKING: + from collections.abc import Callable + from pathlib import Path + + +MINIMAL_TEMPLATE_YAML = """\ +template: + name: "Test Template" + description: "A minimal test template" + version: "1.0.0" + + company: + type: "custom" + + agents: + - role: "Backend Developer" + level: "mid" + model: "sonnet" + department: "engineering" +""" + +TEMPLATE_WITH_VARIABLES_YAML = """\ +template: + name: "Var Template" + description: "Template with variables" + version: "1.0.0" + + variables: + - name: "company_name" + description: "Name of your company" + default: "Default Corp" + - name: "budget" + description: "Monthly budget" + var_type: "float" + default: 42.0 + + company: + type: "startup" + budget_monthly: {{ budget | default(42.0) }} + autonomy: 0.7 + + departments: + - name: "engineering" + budget_percent: 100 + head_role: "Backend Developer" + + agents: + - role: "Backend Developer" + name: "{{ company_name }} Dev" + level: "senior" + model: "sonnet" + department: "engineering" +""" + +TEMPLATE_REQUIRED_VAR_YAML = """\ +template: + name: "Required Var" + description: "Has a required variable" + version: "1.0.0" + + variables: + - name: "team_lead" + description: "Name of the team lead" + required: true + + company: + type: "custom" + + agents: + - role: "Backend Developer" + name: "{{ team_lead }}" + level: "mid" + model: "sonnet" + department: "engineering" +""" + +INVALID_SYNTAX_YAML = """\ +template: + name: "Bad YAML" + agents: [unterminated +""" + +MISSING_TEMPLATE_KEY_YAML = """\ +name: "No Template Key" +agents: [] +""" + + +def _make_template_dict(**overrides: Any) -> dict[str, Any]: + """Build a minimal valid CompanyTemplate kwargs dict with overrides.""" + base: dict[str, Any] = { + "metadata": { + "name": "Test", + "description": "desc", + "version": "1.0.0", + "company_type": "custom", + }, + "agents": ( + { + "role": "Backend Developer", + "level": "mid", + "model": "sonnet", + }, + ), + } + base.update(overrides) + return base + + +@pytest.fixture +def make_template_dict() -> Callable[..., dict[str, Any]]: + """Factory fixture for building template kwargs dicts.""" + return _make_template_dict + + +@pytest.fixture +def tmp_template_file(tmp_path: Path) -> Callable[[str, str], Path]: + """Factory fixture for writing a temporary template YAML file.""" + + def _create(content: str, name: str = "test_template.yaml") -> Path: + path = tmp_path / name + path.write_text(content, encoding="utf-8") + return path + + return _create diff --git a/tests/unit/templates/test_loader.py b/tests/unit/templates/test_loader.py new file mode 100644 index 0000000000..d57136ec41 --- /dev/null +++ b/tests/unit/templates/test_loader.py @@ -0,0 +1,201 @@ +"""Tests for template loading from built-in and file-system sources.""" + +from pathlib import Path +from typing import TYPE_CHECKING +from unittest.mock import patch + +import pytest + +from ai_company.templates.errors import ( + TemplateNotFoundError, + TemplateRenderError, + TemplateValidationError, +) +from ai_company.templates.loader import ( + BUILTIN_TEMPLATES, + LoadedTemplate, + TemplateInfo, + list_builtin_templates, + list_templates, + load_template, + load_template_file, +) + +if TYPE_CHECKING: + from collections.abc import Callable + +from .conftest import ( + INVALID_SYNTAX_YAML, + MINIMAL_TEMPLATE_YAML, + MISSING_TEMPLATE_KEY_YAML, + TEMPLATE_WITH_VARIABLES_YAML, +) + +# ── list_builtin_templates ─────────────────────────────────────── + + +@pytest.mark.unit +class TestListBuiltinTemplates: + def test_returns_sorted_tuple(self): + names = list_builtin_templates() + assert isinstance(names, tuple) + assert names == tuple(sorted(names)) + + def test_contains_all_registered(self): + names = list_builtin_templates() + for name in BUILTIN_TEMPLATES: + assert name in names + + def test_count_matches_registry(self): + assert len(list_builtin_templates()) == len(BUILTIN_TEMPLATES) + + +# ── list_templates ─────────────────────────────────────────────── + + +@pytest.mark.unit +class TestListTemplates: + def test_returns_tuple_of_template_info(self): + templates = list_templates() + assert isinstance(templates, tuple) + assert all(isinstance(t, TemplateInfo) for t in templates) + + def test_all_builtins_present(self): + templates = list_templates() + names = {t.name for t in templates} + for builtin_name in BUILTIN_TEMPLATES: + assert builtin_name in names + + def test_builtin_source_label(self): + templates = list_templates() + for t in templates: + if t.name in BUILTIN_TEMPLATES: + assert t.source == "builtin" + + def test_user_template_overrides_builtin( + self, + tmp_path: Path, + tmp_template_file: Callable[[str, str], Path], + ): + user_dir = tmp_path / "user_templates" + user_dir.mkdir() + user_yaml = MINIMAL_TEMPLATE_YAML + (user_dir / "solo_founder.yaml").write_text(user_yaml, encoding="utf-8") + + with patch( + "ai_company.templates.loader._USER_TEMPLATES_DIR", + user_dir, + ): + templates = list_templates() + solo = next(t for t in templates if t.name == "solo_founder") + assert solo.source == "user" + + +# ── load_template ──────────────────────────────────────────────── + + +@pytest.mark.unit +class TestLoadTemplate: + def test_load_builtin_by_name(self): + loaded = load_template("solo_founder") + assert isinstance(loaded, LoadedTemplate) + assert loaded.template.metadata.name == "Solo Founder" + assert " 0 + assert len(loaded.template.agents) >= 1 + + def test_unknown_name_raises_not_found(self): + with pytest.raises(TemplateNotFoundError, match="Unknown template"): + load_template("does_not_exist") + + def test_user_template_preferred(self, tmp_path: Path): + user_dir = tmp_path / "user_templates" + user_dir.mkdir() + (user_dir / "solo_founder.yaml").write_text( + MINIMAL_TEMPLATE_YAML, encoding="utf-8" + ) + + with patch( + "ai_company.templates.loader._USER_TEMPLATES_DIR", + user_dir, + ): + loaded = load_template("solo_founder") + # User template has "Test Template" name, not "Solo Founder". + assert loaded.template.metadata.name == "Test Template" + + +# ── load_template_file ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestLoadTemplateFile: + def test_load_from_path( + self, + tmp_template_file: Callable[[str, str], Path], + ): + path = tmp_template_file(MINIMAL_TEMPLATE_YAML) + loaded = load_template_file(path) + assert isinstance(loaded, LoadedTemplate) + assert loaded.template.metadata.name == "Test Template" + + def test_load_with_variables( + self, + tmp_template_file: Callable[[str, str], Path], + ): + path = tmp_template_file(TEMPLATE_WITH_VARIABLES_YAML) + loaded = load_template_file(path) + assert len(loaded.template.variables) == 2 + assert loaded.template.variables[0].name == "company_name" + + def test_nonexistent_file_raises_not_found(self): + with pytest.raises(TemplateNotFoundError, match="not found"): + load_template_file(Path("/nonexistent/template.yaml")) + + def test_invalid_yaml_raises_render_error( + self, + tmp_template_file: Callable[[str, str], Path], + ): + path = tmp_template_file(INVALID_SYNTAX_YAML) + with pytest.raises(TemplateRenderError, match="syntax error"): + load_template_file(path) + + def test_missing_template_key_raises_validation_error( + self, + tmp_template_file: Callable[[str, str], Path], + ): + path = tmp_template_file(MISSING_TEMPLATE_KEY_YAML) + with pytest.raises(TemplateValidationError, match="template"): + load_template_file(path) + + def test_accepts_string_path( + self, + tmp_template_file: Callable[[str, str], Path], + ): + path = tmp_template_file(MINIMAL_TEMPLATE_YAML) + loaded = load_template_file(str(path)) + assert isinstance(loaded, LoadedTemplate) + + +# ── LoadedTemplate dataclass ───────────────────────────────────── + + +@pytest.mark.unit +class TestLoadedTemplate: + def test_frozen(self): + loaded = load_template("solo_founder") + with pytest.raises(AttributeError): + loaded.source_name = "changed" # type: ignore[misc] + + def test_raw_yaml_is_string(self): + loaded = load_template("startup") + assert isinstance(loaded.raw_yaml, str) + assert "template:" in loaded.raw_yaml diff --git a/tests/unit/templates/test_presets.py b/tests/unit/templates/test_presets.py new file mode 100644 index 0000000000..2c554ac158 --- /dev/null +++ b/tests/unit/templates/test_presets.py @@ -0,0 +1,74 @@ +"""Tests for template personality presets and auto-name generation.""" + +import pytest + +from ai_company.templates.presets import ( + PERSONALITY_PRESETS, + generate_auto_name, + get_personality_preset, +) + + +@pytest.mark.unit +class TestGetPersonalityPreset: + def test_valid_preset_returns_dict(self): + result = get_personality_preset("visionary_leader") + assert isinstance(result, dict) + assert "traits" in result + assert "communication_style" in result + + def test_case_insensitive(self): + result = get_personality_preset("VISIONARY_LEADER") + assert result == get_personality_preset("visionary_leader") + + def test_whitespace_stripped(self): + result = get_personality_preset(" pragmatic_builder ") + assert result["communication_style"] == "concise" + + def test_returns_copy(self): + a = get_personality_preset("eager_learner") + b = get_personality_preset("eager_learner") + assert a == b + assert a is not b + + def test_unknown_preset_raises_key_error(self): + with pytest.raises(KeyError, match="Unknown personality preset"): + get_personality_preset("nonexistent") + + def test_all_presets_have_required_keys(self): + required_keys = {"traits", "communication_style", "description"} + for name in PERSONALITY_PRESETS: + preset = get_personality_preset(name) + assert required_keys.issubset(preset.keys()), f"{name} missing keys" + + +@pytest.mark.unit +class TestGenerateAutoName: + def test_known_role_returns_from_pool(self): + name = generate_auto_name("CEO", seed=0) + assert isinstance(name, str) + assert len(name) > 0 + + def test_unknown_role_uses_default_pool(self): + name = generate_auto_name("Alien Commander", seed=0) + assert name.startswith("Agent ") + + def test_deterministic_with_seed(self): + a = generate_auto_name("Backend Developer", seed=42) + b = generate_auto_name("Backend Developer", seed=42) + assert a == b + + def test_different_seeds_may_differ(self): + names = {generate_auto_name("CEO", seed=i) for i in range(10)} + # With 4 names in the pool, at least 2 distinct names expected. + assert len(names) >= 2 + + def test_case_insensitive_role(self): + a = generate_auto_name("ceo", seed=0) + b = generate_auto_name("CEO", seed=0) + assert a == b + + def test_whitespace_stripped_from_role(self): + a = generate_auto_name(" CEO ", seed=0) + b = generate_auto_name("CEO", seed=0) + assert a == b diff --git a/tests/unit/templates/test_renderer.py b/tests/unit/templates/test_renderer.py new file mode 100644 index 0000000000..977ea8d4c7 --- /dev/null +++ b/tests/unit/templates/test_renderer.py @@ -0,0 +1,191 @@ +"""Tests for the two-pass template rendering pipeline.""" + +from typing import TYPE_CHECKING + +import pytest +from pydantic import ValidationError + +from ai_company.config.schema import RootConfig +from ai_company.templates.errors import TemplateRenderError +from ai_company.templates.loader import load_template, load_template_file +from ai_company.templates.renderer import render_template + +from .conftest import TEMPLATE_REQUIRED_VAR_YAML, TEMPLATE_WITH_VARIABLES_YAML + +if TYPE_CHECKING: + from collections.abc import Callable + from pathlib import Path + +# ── render_template basic ──────────────────────────────────────── + + +@pytest.mark.unit +class TestRenderTemplateBasic: + def test_render_builtin_solo_founder(self): + loaded = load_template("solo_founder") + config = render_template(loaded) + assert isinstance(config, RootConfig) + assert config.company_name == "My Company" + assert len(config.agents) == 2 + + def test_render_builtin_startup(self): + loaded = load_template("startup") + config = render_template(loaded) + assert isinstance(config, RootConfig) + assert config.company_name == "Startup Co" + assert len(config.agents) == 5 + + def test_render_all_builtins_produce_valid_root_config(self): + from ai_company.templates.loader import BUILTIN_TEMPLATES + + for name in BUILTIN_TEMPLATES: + loaded = load_template(name) + config = render_template(loaded) + assert isinstance(config, RootConfig), f"{name} failed" + assert len(config.agents) >= 1, f"{name} has no agents" + + def test_render_returns_frozen_config(self): + loaded = load_template("solo_founder") + config = render_template(loaded) + with pytest.raises(ValidationError): + config.company_name = "Changed" # type: ignore[misc] + + +# ── Variables ──────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestRenderTemplateVariables: + def test_default_variables_applied( + self, + tmp_template_file: Callable[[str, str], Path], + ): + path = tmp_template_file(TEMPLATE_WITH_VARIABLES_YAML) + loaded = load_template_file(path) + config = render_template(loaded) + assert config.company_name == "Default Corp" + + def test_user_variables_override_defaults( + self, + tmp_template_file: Callable[[str, str], Path], + ): + path = tmp_template_file(TEMPLATE_WITH_VARIABLES_YAML) + loaded = load_template_file(path) + config = render_template(loaded, variables={"company_name": "Acme Inc"}) + assert config.company_name == "Acme Inc" + + def test_budget_variable_applied( + self, + tmp_template_file: Callable[[str, str], Path], + ): + path = tmp_template_file(TEMPLATE_WITH_VARIABLES_YAML) + loaded = load_template_file(path) + config = render_template(loaded, variables={"budget": 100.0}) + assert config.config.budget_monthly == 100.0 + + def test_required_variable_missing_raises_error( + self, + tmp_template_file: Callable[[str, str], Path], + ): + path = tmp_template_file(TEMPLATE_REQUIRED_VAR_YAML) + loaded = load_template_file(path) + with pytest.raises(TemplateRenderError, match="Required template variable"): + render_template(loaded) + + def test_required_variable_provided( + self, + tmp_template_file: Callable[[str, str], Path], + ): + path = tmp_template_file(TEMPLATE_REQUIRED_VAR_YAML) + loaded = load_template_file(path) + config = render_template(loaded, variables={"team_lead": "Alice"}) + assert isinstance(config, RootConfig) + + def test_extra_variables_passed_through( + self, + tmp_template_file: Callable[[str, str], Path], + ): + path = tmp_template_file(TEMPLATE_WITH_VARIABLES_YAML) + loaded = load_template_file(path) + # Extra variables don't cause errors. + config = render_template( + loaded, + variables={"company_name": "Test", "extra_key": "ignored"}, + ) + assert isinstance(config, RootConfig) + + +# ── Agent expansion ────────────────────────────────────────────── + + +@pytest.mark.unit +class TestRenderTemplateAgents: + def test_agents_have_unique_names(self): + loaded = load_template("startup") + config = render_template(loaded) + names = [a.name for a in config.agents] + assert len(names) == len(set(names)) + + def test_agent_name_from_jinja2(self): + loaded = load_template("solo_founder") + config = render_template(loaded, variables={"company_name": "ACME"}) + # The CEO agent's name is "{{ company_name }} CEO" → "ACME CEO". + ceo_agents = [a for a in config.agents if a.role == "CEO"] + assert len(ceo_agents) == 1 + assert "ACME" in ceo_agents[0].name + + def test_auto_name_for_unnamed_agents(self): + loaded = load_template("research_lab") + config = render_template(loaded) + # research_lab agents don't have explicit names. + for agent in config.agents: + assert agent.name != "" + assert len(agent.name) > 0 + + +# ── Departments ────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestRenderTemplateDepartments: + def test_departments_included(self): + loaded = load_template("startup") + config = render_template(loaded) + assert len(config.departments) >= 1 + + def test_department_names(self): + loaded = load_template("solo_founder") + config = render_template(loaded) + dept_names = {d.name for d in config.departments} + assert "executive" in dept_names or "engineering" in dept_names + + +# ── Error cases ────────────────────────────────────────────────── + + +@pytest.mark.unit +class TestRenderTemplateErrors: + def test_invalid_jinja2_raises_render_error( + self, + tmp_template_file: Callable[[str, str], Path], + ): + bad_yaml = """\ +template: + name: "Bad Jinja" + description: "test" + version: "1.0.0" + + company: + type: "custom" + + agents: + - role: "Dev" + name: "{{ undefined_func() | bad_filter }}" + level: "mid" + model: "sonnet" + department: "engineering" +""" + path = tmp_template_file(bad_yaml) + loaded = load_template_file(path) + with pytest.raises(TemplateRenderError, match="Jinja2 rendering failed"): + render_template(loaded) diff --git a/tests/unit/templates/test_schema.py b/tests/unit/templates/test_schema.py new file mode 100644 index 0000000000..05c5d65937 --- /dev/null +++ b/tests/unit/templates/test_schema.py @@ -0,0 +1,255 @@ +"""Tests for template schema models.""" + +import pytest +from pydantic import ValidationError + +from ai_company.core.enums import CompanyType, SeniorityLevel +from ai_company.templates.schema import ( + CompanyTemplate, + TemplateAgentConfig, + TemplateDepartmentConfig, + TemplateMetadata, + TemplateVariable, +) + +# ── TemplateVariable ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestTemplateVariable: + def test_valid_minimal(self): + v = TemplateVariable(name="my_var") + assert v.name == "my_var" + assert v.description == "" + assert v.var_type == "str" + assert v.default is None + assert v.required is False + + def test_valid_full(self): + v = TemplateVariable( + name="budget", + description="Monthly budget", + var_type="float", + default=50.0, + required=False, + ) + assert v.var_type == "float" + assert v.default == 50.0 + + def test_blank_name_rejected(self): + with pytest.raises(ValidationError): + TemplateVariable(name="") + + def test_whitespace_name_rejected(self): + with pytest.raises(ValidationError): + TemplateVariable(name=" ") + + def test_required_with_default_rejected(self): + with pytest.raises(ValidationError, match="required but defines a default"): + TemplateVariable(name="x", required=True, default="oops") + + def test_required_without_default_accepted(self): + v = TemplateVariable(name="x", required=True) + assert v.required is True + assert v.default is None + + def test_frozen(self): + v = TemplateVariable(name="x") + with pytest.raises(ValidationError): + v.name = "y" # type: ignore[misc] + + +# ── TemplateAgentConfig ────────────────────────────────────────── + + +@pytest.mark.unit +class TestTemplateAgentConfig: + def test_valid_minimal(self): + a = TemplateAgentConfig(role="Backend Developer") + assert a.role == "Backend Developer" + assert a.name == "" + assert a.level == SeniorityLevel.MID + assert a.model == "sonnet" + assert a.personality_preset is None + assert a.department is None + + def test_valid_full(self): + a = TemplateAgentConfig( + role="CEO", + name="{{ company_name }} CEO", + level="c_suite", + model="opus", + personality_preset="visionary_leader", + department="executive", + ) + assert a.level == SeniorityLevel.C_SUITE + assert a.personality_preset == "visionary_leader" + + def test_blank_role_rejected(self): + with pytest.raises(ValidationError): + TemplateAgentConfig(role="") + + +# ── TemplateDepartmentConfig ───────────────────────────────────── + + +@pytest.mark.unit +class TestTemplateDepartmentConfig: + def test_valid_minimal(self): + d = TemplateDepartmentConfig(name="engineering") + assert d.name == "engineering" + assert d.budget_percent == 0.0 + assert d.head_role is None + + def test_valid_full(self): + d = TemplateDepartmentConfig( + name="engineering", + budget_percent=60.0, + head_role="CTO", + ) + assert d.budget_percent == 60.0 + assert d.head_role == "CTO" + + def test_budget_percent_negative_rejected(self): + with pytest.raises(ValidationError): + TemplateDepartmentConfig(name="eng", budget_percent=-1.0) + + def test_budget_percent_over_100_rejected(self): + with pytest.raises(ValidationError): + TemplateDepartmentConfig(name="eng", budget_percent=101.0) + + def test_blank_name_rejected(self): + with pytest.raises(ValidationError): + TemplateDepartmentConfig(name="") + + +# ── TemplateMetadata ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestTemplateMetadata: + def test_valid_minimal(self): + m = TemplateMetadata(name="Test", company_type="custom") + assert m.name == "Test" + assert m.company_type == CompanyType.CUSTOM + assert m.min_agents == 1 + assert m.max_agents == 100 + assert m.tags == () + + def test_valid_full(self): + m = TemplateMetadata( + name="My Template", + description="A description", + version="2.0.0", + company_type="startup", + min_agents=2, + max_agents=10, + tags=("startup", "mvp"), + ) + assert m.version == "2.0.0" + assert m.tags == ("startup", "mvp") + + def test_min_greater_than_max_rejected(self): + with pytest.raises(ValidationError, match="min_agents"): + TemplateMetadata( + name="Bad", + company_type="custom", + min_agents=10, + max_agents=5, + ) + + def test_blank_name_rejected(self): + with pytest.raises(ValidationError): + TemplateMetadata(name="", company_type="custom") + + def test_invalid_company_type_rejected(self): + with pytest.raises(ValidationError): + TemplateMetadata(name="T", company_type="nonexistent_type") + + +# ── CompanyTemplate ────────────────────────────────────────────── + + +@pytest.mark.unit +class TestCompanyTemplate: + def test_valid_minimal(self, make_template_dict): + t = CompanyTemplate(**make_template_dict()) + assert t.metadata.name == "Test" + assert len(t.agents) == 1 + assert t.workflow == "agile_kanban" + assert t.communication == "hybrid" + assert t.budget_monthly == 50.0 + assert t.autonomy == 0.5 + + def test_agent_count_below_min_rejected(self, make_template_dict): + with pytest.raises(ValidationError, match="minimum"): + CompanyTemplate( + **make_template_dict( + metadata={ + "name": "T", + "company_type": "custom", + "min_agents": 3, + }, + agents=({"role": "Dev", "level": "mid"},), + ) + ) + + def test_agent_count_above_max_rejected(self, make_template_dict): + agents = tuple({"role": f"Dev{i}", "level": "mid"} for i in range(5)) + with pytest.raises(ValidationError, match="maximum"): + CompanyTemplate( + **make_template_dict( + metadata={ + "name": "T", + "company_type": "custom", + "max_agents": 2, + }, + agents=agents, + ) + ) + + def test_duplicate_variable_names_rejected(self, make_template_dict): + with pytest.raises(ValidationError, match="Duplicate variable names"): + CompanyTemplate( + **make_template_dict( + variables=( + {"name": "x", "var_type": "str"}, + {"name": "x", "var_type": "int"}, + ), + ) + ) + + def test_duplicate_department_names_rejected(self, make_template_dict): + with pytest.raises(ValidationError, match="Duplicate department names"): + CompanyTemplate( + **make_template_dict( + departments=( + {"name": "eng", "budget_percent": 50}, + {"name": "eng", "budget_percent": 50}, + ), + ) + ) + + def test_unique_variables_accepted(self, make_template_dict): + t = CompanyTemplate( + **make_template_dict( + variables=( + {"name": "x"}, + {"name": "y"}, + ), + ) + ) + assert len(t.variables) == 2 + + def test_autonomy_out_of_range_rejected(self, make_template_dict): + with pytest.raises(ValidationError): + CompanyTemplate(**make_template_dict(autonomy=1.5)) + + def test_negative_budget_rejected(self, make_template_dict): + with pytest.raises(ValidationError): + CompanyTemplate(**make_template_dict(budget_monthly=-10.0)) + + def test_frozen(self, make_template_dict): + t = CompanyTemplate(**make_template_dict()) + with pytest.raises(ValidationError): + t.workflow = "scrum" # type: ignore[misc] diff --git a/uv.lock b/uv.lock index 64d209e012..ab3346a49c 100644 --- a/uv.lock +++ b/uv.lock @@ -6,6 +6,7 @@ requires-python = ">=3.14" name = "ai-company" source = { editable = "." } dependencies = [ + { name = "jinja2" }, { name = "pydantic" }, { name = "pyyaml" }, { name = "structlog" }, @@ -41,6 +42,7 @@ test = [ [package.metadata] requires-dist = [ + { name = "jinja2", specifier = "==3.1.6" }, { name = "pydantic", specifier = "==2.12.5" }, { name = "pyyaml", specifier = "==6.0.2" }, { name = "structlog", specifier = "==25.5.0" }, From 0ad651938fbe841095077d88e554eb5e69cbce12 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Sun, 1 Mar 2026 08:43:10 +0100 Subject: [PATCH 2/3] fix: address 29 PR review items from local agents, CodeRabbit, Gemini, and Copilot - Narrow exception handling: replace broad `except Exception` with specific types - Add path traversal protection in template loader - Add filesystem error handling for template file reads (OSError, UnicodeDecodeError) - Extract `to_float()` utility to config/utils.py for strict numeric validation - Add type-match validation for TemplateVariable defaults - Strengthen type hints: use NotBlankStr | None for optional identifiers - Add __str__ method to TemplateValidationError with per-field details - Remove duplicate TestDeepMerge class from test_loader.py - Fix docstrings: module-level, class-level, and field descriptions - Include exception details in warning logs - Add template data type guard before Pydantic validation --- src/ai_company/config/loader.py | 5 +- src/ai_company/config/utils.py | 23 ++++++ src/ai_company/templates/errors.py | 28 ++++++- src/ai_company/templates/loader.py | 112 ++++++++++++++++++++------- src/ai_company/templates/renderer.py | 77 ++++++++++-------- src/ai_company/templates/schema.py | 33 ++++++-- tests/unit/config/test_loader.py | 66 ---------------- 7 files changed, 212 insertions(+), 132 deletions(-) diff --git a/src/ai_company/config/loader.py b/src/ai_company/config/loader.py index 32bd49cf55..b24e329b97 100644 --- a/src/ai_company/config/loader.py +++ b/src/ai_company/config/loader.py @@ -187,10 +187,11 @@ def _build_line_map(yaml_text: str) -> dict[str, tuple[int, int]]: """ try: root = yaml.compose(yaml_text, Loader=yaml.SafeLoader) - except yaml.YAMLError: + except yaml.YAMLError as exc: logger.warning( "Failed to compose YAML AST for line mapping; " - "validation errors will lack line/column information", + "validation errors will lack line/column information: %s", + exc, ) return {} if root is None or not isinstance(root, yaml.MappingNode): diff --git a/src/ai_company/config/utils.py b/src/ai_company/config/utils.py index df9d1185f3..ff36a497b9 100644 --- a/src/ai_company/config/utils.py +++ b/src/ai_company/config/utils.py @@ -4,6 +4,29 @@ from typing import Any +def to_float(value: Any, *, field_name: str = "value") -> float: + """Coerce a value to float with clear error reporting. + + Args: + value: Value to convert (str, int, float, etc.). + field_name: Field name for error messages. + + Returns: + Float value. + + Raises: + ValueError: If *value* cannot be converted to float. + """ + if value is None: + msg = f"Expected numeric value for {field_name}, got None" + raise ValueError(msg) + try: + return float(value) + except (TypeError, ValueError) as exc: + msg = f"Invalid numeric value for {field_name}: {value!r}" + raise ValueError(msg) from exc + + def deep_merge( base: dict[str, Any], override: dict[str, Any], diff --git a/src/ai_company/templates/errors.py b/src/ai_company/templates/errors.py index ff372ef821..42e1873e66 100644 --- a/src/ai_company/templates/errors.py +++ b/src/ai_company/templates/errors.py @@ -12,7 +12,12 @@ class TemplateNotFoundError(TemplateError): class TemplateRenderError(TemplateError): - """Raised when Jinja2 rendering fails or a required variable is missing.""" + """Raised when template rendering fails. + + Covers Jinja2 evaluation errors, missing required variables, + YAML parse errors during template processing, and invalid + numeric values in rendered output. + """ class TemplateValidationError(TemplateError): @@ -31,3 +36,24 @@ def __init__( ) -> None: super().__init__(message, locations) self.field_errors = field_errors + + def __str__(self) -> str: + """Format validation error with per-field details.""" + if not self.field_errors: + return super().__str__() + parts = [f"{self.message} ({len(self.field_errors)} errors):"] + loc_by_key: dict[str, ConfigLocation] = { + loc.key_path: loc for loc in self.locations if loc.key_path + } + for key_path, msg in self.field_errors: + parts.append(f" {key_path}: {msg}") + loc = loc_by_key.get(key_path) + if loc and loc.file_path: + if loc.line is not None and loc.column is not None: + line_info = f" at line {loc.line}, column {loc.column}" + elif loc.line is not None: + line_info = f" at line {loc.line}" + else: + line_info = "" + parts.append(f" in {loc.file_path}{line_info}") + return "\n".join(parts) diff --git a/src/ai_company/templates/loader.py b/src/ai_company/templates/loader.py index 3a4bbafe64..6337acdca2 100644 --- a/src/ai_company/templates/loader.py +++ b/src/ai_company/templates/loader.py @@ -7,8 +7,7 @@ - **Pass 2**: Performed later by the renderer — Jinja2-renders the raw YAML text, then YAML-parses the result. -The loader returns both the structured :class:`CompanyTemplate` (from -Pass 1) and the raw YAML text (for Pass 2). +Both are returned bundled as a :class:`LoadedTemplate` dataclass. """ import logging @@ -16,9 +15,10 @@ from dataclasses import dataclass from importlib import resources from pathlib import Path -from typing import Any +from typing import Any, Literal import yaml +from pydantic import ValidationError from ai_company.config.errors import ConfigLocation from ai_company.templates.errors import ( @@ -58,7 +58,7 @@ class TemplateInfo: name: str display_name: str description: str - source: str + source: Literal["builtin", "user"] @dataclass(frozen=True) @@ -79,8 +79,9 @@ class LoadedTemplate: def list_templates() -> tuple[TemplateInfo, ...]: """Return all available templates (user directory + built-in). - User templates are listed first. If a user template has the same - name as a built-in, only the user template appears. + User templates override built-in templates of the same name. + The result is sorted alphabetically by template name. Templates + that fail to load are silently skipped with a warning log. Returns: Sorted tuple of :class:`TemplateInfo` objects. @@ -89,7 +90,9 @@ def list_templates() -> tuple[TemplateInfo, ...]: # User templates (higher priority). if _USER_TEMPLATES_DIR.is_dir(): - for path in sorted(_USER_TEMPLATES_DIR.glob("*.yaml")): + for path in sorted( + p for p in _USER_TEMPLATES_DIR.glob("*.yaml") if p.is_file() + ): name = path.stem try: loaded = _load_from_file(path) @@ -100,8 +103,12 @@ def list_templates() -> tuple[TemplateInfo, ...]: description=meta.description, source="user", ) - except Exception: - logger.warning("Skipping invalid user template: %s", path) + except (TemplateRenderError, TemplateValidationError, OSError) as exc: + logger.warning( + "Skipping invalid user template %s: %s", + path, + exc, + ) # Built-in templates (lower priority). for name in sorted(BUILTIN_TEMPLATES): @@ -115,8 +122,11 @@ def list_templates() -> tuple[TemplateInfo, ...]: description=meta.description, source="builtin", ) - except Exception: - logger.warning("Skipping invalid builtin template: %s", name) + except TemplateRenderError, TemplateValidationError, OSError: + logger.exception( + "Built-in template %r is invalid (packaging defect)", + name, + ) return tuple(info for _, info in sorted(seen.items())) @@ -144,9 +154,18 @@ def load_template(name: str) -> LoadedTemplate: """ name_clean = name.strip().lower() + # Sanitize to prevent path traversal. + safe_name = Path(name_clean).name + if safe_name != name_clean: + msg = f"Invalid template name {name!r}: must not contain path separators" + raise TemplateNotFoundError( + msg, + locations=(ConfigLocation(file_path=f""),), + ) + # Try user directory first. if _USER_TEMPLATES_DIR.is_dir(): - user_path = _USER_TEMPLATES_DIR / f"{name_clean}.yaml" + user_path = _USER_TEMPLATES_DIR / f"{safe_name}.yaml" if user_path.is_file(): return _load_from_file(user_path) @@ -211,9 +230,27 @@ def _load_builtin(name: str) -> LoadedTemplate: def _load_from_file(path: Path) -> LoadedTemplate: - """Load a template from a file path.""" - yaml_text = path.read_text(encoding="utf-8") + """Load a template from a file path. + + Raises: + TemplateRenderError: If the file cannot be read. + TemplateValidationError: If validation fails. + """ source_name = str(path) + try: + yaml_text = path.read_text(encoding="utf-8") + except OSError as exc: + msg = f"Unable to read template file: {path}" + raise TemplateRenderError( + msg, + locations=(ConfigLocation(file_path=source_name),), + ) from exc + except UnicodeDecodeError as exc: + msg = f"Template file is not valid UTF-8: {path}" + raise TemplateRenderError( + msg, + locations=(ConfigLocation(file_path=source_name),), + ) from exc template = _parse_template_yaml(yaml_text, source_name=source_name) return LoadedTemplate( template=template, @@ -287,10 +324,19 @@ def _parse_template_yaml( ) template_data = data["template"] - normalized = _normalize_template_data(template_data) try: + if not isinstance(template_data, dict): + msg = f"Template 'template' key must map to an object in {source_name}" + raise TypeError(msg) # noqa: TRY301 + normalized = _normalize_template_data(template_data) return CompanyTemplate(**normalized) - except Exception as exc: + except ValidationError as exc: + msg = f"Template validation failed for {source_name}: {exc}" + raise TemplateValidationError( + msg, + locations=(ConfigLocation(file_path=source_name),), + ) from exc + except (ValueError, TypeError) as exc: msg = f"Template validation failed for {source_name}: {exc}" raise TemplateValidationError( msg, @@ -312,13 +358,14 @@ def _normalize_template_data(data: dict[str, Any]) -> dict[str, Any]: """ company = data.get("company", {}) - metadata = { - "name": data.get("name", ""), + metadata: dict[str, Any] = { "description": data.get("description", ""), "version": data.get("version", "1.0.0"), "company_type": company.get("type", "custom"), "tags": tuple(data.get("tags", ())), } + if "name" in data: + metadata["name"] = data["name"] return { "metadata": metadata, @@ -333,17 +380,28 @@ def _normalize_template_data(data: dict[str, Any]) -> dict[str, Any]: def _to_float(value: Any) -> float: - """Coerce a value to float, handling string numerics. + """Coerce a value to float for Pass 1 normalization. + + Returns ``0.0`` for values that cannot be converted (e.g. Jinja2 + placeholders like ``__JINJA2__``) since the real value will be + resolved in Pass 2. Args: - value: Raw value from YAML (may be str, int, float). + value: Raw value from YAML (may be str, int, float, or + ``None``). Returns: - Float value. + Float value, or ``0.0`` for ``None`` or unconvertible strings + (typically Jinja2 placeholders). """ - if isinstance(value, str): - try: - return float(value) - except ValueError: - return 0.0 - return float(value) + if value is None: + return 0.0 + try: + return float(value) + except TypeError, ValueError: + logger.debug( + "Cannot convert %r to float in Pass 1 " + "(may be a Jinja2 placeholder), using 0.0", + value, + ) + return 0.0 diff --git a/src/ai_company/templates/renderer.py b/src/ai_company/templates/renderer.py index e4fb861ef2..052c764ee8 100644 --- a/src/ai_company/templates/renderer.py +++ b/src/ai_company/templates/renderer.py @@ -12,18 +12,20 @@ from typing import TYPE_CHECKING, Any import yaml +from jinja2 import TemplateError as Jinja2TemplateError from jinja2.sandbox import SandboxedEnvironment from pydantic import ValidationError from ai_company.config.defaults import default_config_dict from ai_company.config.errors import ConfigLocation from ai_company.config.schema import RootConfig -from ai_company.config.utils import deep_merge +from ai_company.config.utils import deep_merge, to_float from ai_company.templates.errors import ( TemplateRenderError, TemplateValidationError, ) from ai_company.templates.presets import ( + PERSONALITY_PRESETS, generate_auto_name, get_personality_preset, ) @@ -47,7 +49,7 @@ def render_template( 2. Jinja2-render the raw YAML text with collected variables. 3. YAML-parse the rendered text. 4. Normalize into ``RootConfig`` shape. - 5. Deep-merge with ``default_config_dict()``. + 5. Deep-merge template output onto ``default_config_dict()`` base. 6. Validate as ``RootConfig``. Args: @@ -134,8 +136,9 @@ def _create_jinja_env() -> SandboxedEnvironment: env = SandboxedEnvironment( keep_trailing_newline=True, ) - # ``auto`` filter: returns empty string for falsy values (triggers - # auto-name generation in _expand_agents). + # ``auto`` filter: converts falsy values to empty string, which + # triggers auto-name generation downstream (empty names are + # detected by ``_expand_agents``). env.filters["auto"] = lambda value: value or "" return env @@ -163,7 +166,7 @@ def _render_jinja2( try: jinja_template = env.from_string(raw_yaml) return jinja_template.render(**variables) - except Exception as exc: + except Jinja2TemplateError as exc: msg = f"Jinja2 rendering failed for {source_name}: {exc}" raise TemplateRenderError( msg, @@ -203,8 +206,14 @@ def _parse_rendered_yaml( locations=(ConfigLocation(file_path=source_name),), ) - result: dict[str, Any] = data["template"] - return result + template_data = data["template"] + if not isinstance(template_data, dict): + msg = f"Rendered template 'template' key must be a mapping: {source_name}" + raise TemplateRenderError( + msg, + locations=(ConfigLocation(file_path=source_name),), + ) + return template_data def _build_config_dict( @@ -230,22 +239,34 @@ def _build_config_dict( # Expand agents. raw_agents = rendered_data.get("agents", []) - agents = _expand_agents(raw_agents, variables) + agents = _expand_agents(raw_agents) # Build departments for RootConfig. raw_depts = rendered_data.get("departments", []) departments = _build_departments(raw_depts) + source_name = template.metadata.name + try: + autonomy = to_float( + company.get("autonomy", template.autonomy), + field_name="autonomy", + ) + budget_monthly = to_float( + company.get("budget_monthly", template.budget_monthly), + field_name="budget_monthly", + ) + except ValueError as exc: + msg = f"Invalid numeric value in rendered template {source_name!r}: {exc}" + raise TemplateRenderError(msg) from exc + return { "company_name": company_name, "company_type": company.get("type", template.metadata.company_type.value), "agents": agents, "departments": departments, "config": { - "autonomy": _safe_float(company.get("autonomy", template.autonomy)), - "budget_monthly": _safe_float( - company.get("budget_monthly", template.budget_monthly), - ), + "autonomy": autonomy, + "budget_monthly": budget_monthly, "communication_pattern": rendered_data.get( "communication", template.communication, @@ -256,7 +277,6 @@ def _build_config_dict( def _expand_agents( raw_agents: list[dict[str, Any]], - _variables: dict[str, Any], ) -> list[dict[str, Any]]: """Expand template agent dicts into AgentConfig-compatible dicts. @@ -264,7 +284,6 @@ def _expand_agents( Args: raw_agents: List of agent dicts from rendered YAML. - variables: Collected variables. Returns: List of dicts suitable for ``AgentConfig`` construction. @@ -301,10 +320,13 @@ def _expand_agents( try: agent_dict["personality"] = get_personality_preset(preset_name) except KeyError: + available = sorted(PERSONALITY_PRESETS) logger.warning( - "Unknown personality preset %r for agent %r, using defaults", + "Unknown personality preset %r for agent %r; " + "using default personality. Available presets: %s", preset_name, name, + available, ) # Model config (raw dict for AgentConfig). @@ -329,10 +351,18 @@ def _build_departments( """ departments: list[dict[str, Any]] = [] for dept in raw_depts: + try: + budget_pct = to_float( + dept.get("budget_percent", 0.0), + field_name=f"departments[{dept.get('name', '')}].budget_percent", + ) + except ValueError as exc: + msg = f"Invalid department budget value: {exc}" + raise TemplateRenderError(msg) from exc dept_dict: dict[str, Any] = { "name": dept.get("name", ""), "head": dept.get("head_role", dept.get("name", "")), - "budget_percent": _safe_float(dept.get("budget_percent", 0.0)), + "budget_percent": budget_pct, } departments.append(dept_dict) return departments @@ -375,18 +405,3 @@ def _validate_as_root_config( locations=tuple(locations), field_errors=tuple(field_errors), ) from exc - - -def _safe_float(value: Any) -> float: - """Coerce a value to float safely. - - Args: - value: Value from rendered YAML (str, int, or float). - - Returns: - Float value, or 0.0 on conversion failure. - """ - try: - return float(value) - except TypeError, ValueError: - return 0.0 diff --git a/src/ai_company/templates/schema.py b/src/ai_company/templates/schema.py index 99a9578eac..83e7e3516a 100644 --- a/src/ai_company/templates/schema.py +++ b/src/ai_company/templates/schema.py @@ -20,7 +20,9 @@ class TemplateVariable(BaseModel): name: Variable name (used in ``{{ name }}`` placeholders). description: Human-readable description for prompts/docs. var_type: Expected Python type name. - default: Default value (``None`` means the variable is required). + default: Default value (``None`` means no default is provided). + The ``required`` attribute determines whether the user must + supply a value. required: Whether the user must provide this value. """ @@ -43,6 +45,26 @@ def _validate_required_has_no_default(self) -> Self: raise ValueError(msg) return self + @model_validator(mode="after") + def _validate_default_matches_var_type(self) -> Self: + """Default value type must match ``var_type`` when provided.""" + if self.default is None: + return self + type_map: dict[str, type | tuple[type, ...]] = { + "str": str, + "int": int, + "float": (int, float), + "bool": bool, + } + expected = type_map[self.var_type] + if not isinstance(self.default, expected): + msg = ( + f"Variable {self.name!r}: default {self.default!r} " + f"is not compatible with var_type {self.var_type!r}" + ) + raise ValueError(msg) # noqa: TRY004 + return self + class TemplateAgentConfig(BaseModel): """Agent definition within a template. @@ -57,7 +79,8 @@ class TemplateAgentConfig(BaseModel): level: Seniority level override. model: Model tier alias (e.g. ``"opus"``, ``"sonnet"``, ``"haiku"``). personality_preset: Named personality preset from the presets registry. - department: Department override (``None`` uses the role's default). + department: Department override (``None`` defaults to + ``"engineering"`` during rendering). """ model_config = ConfigDict(frozen=True) @@ -69,11 +92,11 @@ class TemplateAgentConfig(BaseModel): description="Seniority level", ) model: str = Field(default="sonnet", description="Model tier alias") - personality_preset: str | None = Field( + personality_preset: NotBlankStr | None = Field( default=None, description="Named personality preset", ) - department: str | None = Field( + department: NotBlankStr | None = Field( default=None, description="Department override", ) @@ -101,7 +124,7 @@ class TemplateDepartmentConfig(BaseModel): le=100.0, description="Percentage of company budget", ) - head_role: str | None = Field( + head_role: NotBlankStr | None = Field( default=None, description="Role name of department head", ) diff --git a/tests/unit/config/test_loader.py b/tests/unit/config/test_loader.py index 46d08ce3d8..a7c4001c54 100644 --- a/tests/unit/config/test_loader.py +++ b/tests/unit/config/test_loader.py @@ -21,7 +21,6 @@ load_config_from_string, ) from ai_company.config.schema import RootConfig -from ai_company.config.utils import deep_merge from .conftest import ( ENV_VAR_MISSING_YAML, @@ -34,71 +33,6 @@ MISSING_REQUIRED_YAML, ) -# ── deep_merge ────────────────────────────────────────────────── - - -@pytest.mark.unit -class TestDeepMerge: - def test_simple_override(self): - base = {"a": 1, "b": 2} - override = {"b": 3} - result = deep_merge(base, override) - assert result == {"a": 1, "b": 3} - - def test_nested_merge(self): - base = {"x": {"a": 1, "b": 2}} - override = {"x": {"b": 3, "c": 4}} - result = deep_merge(base, override) - assert result == {"x": {"a": 1, "b": 3, "c": 4}} - - def test_list_replaced_entirely(self): - base = {"items": [1, 2, 3]} - override = {"items": [4, 5]} - result = deep_merge(base, override) - assert result == {"items": [4, 5]} - - def test_base_preserved(self): - base = {"a": 1, "b": 2} - override = {"c": 3} - result = deep_merge(base, override) - assert result == {"a": 1, "b": 2, "c": 3} - - def test_new_keys_added(self): - base = {"a": 1} - override = {"b": 2, "c": 3} - result = deep_merge(base, override) - assert result == {"a": 1, "b": 2, "c": 3} - - def test_inputs_not_mutated(self): - base = {"x": {"a": 1}} - override = {"x": {"b": 2}} - base_copy = {"x": {"a": 1}} - override_copy = {"x": {"b": 2}} - deep_merge(base, override) - assert base == base_copy - assert override == override_copy - - def test_deeply_nested(self): - base = {"a": {"b": {"c": 1}}} - override = {"a": {"b": {"d": 2}}} - result = deep_merge(base, override) - assert result == {"a": {"b": {"c": 1, "d": 2}}} - - def test_result_does_not_share_mutable_refs_with_base(self): - base = {"x": {"nested": [1, 2, 3]}} - result = deep_merge(base, {}) - result["x"]["nested"].append(4) - assert base["x"]["nested"] == [1, 2, 3] - - def test_empty_base(self): - result = deep_merge({}, {"a": 1}) - assert result == {"a": 1} - - def test_empty_override(self): - result = deep_merge({"a": 1}, {}) - assert result == {"a": 1} - - # ── _read_config_text ──────────────────────────────────────────── From e52bfe7c54e4d5ee4e859a4140116b967b56d0fc Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Sun, 1 Mar 2026 09:02:29 +0100 Subject: [PATCH 3/3] fix: address round-2 CodeRabbit review (5 items) - Guard template.company type before calling .get() in normalization - Add extra="forbid" to all 5 template schema models (catches YAML typos) - Reject bool defaults for int/float var_type (bool subclasses int) - Validate company/agents/departments shapes in renderer before use --- src/ai_company/templates/loader.py | 8 +++++++- src/ai_company/templates/renderer.py | 16 ++++++++++++++++ src/ai_company/templates/schema.py | 18 +++++++++++++----- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/src/ai_company/templates/loader.py b/src/ai_company/templates/loader.py index 6337acdca2..205d369871 100644 --- a/src/ai_company/templates/loader.py +++ b/src/ai_company/templates/loader.py @@ -356,7 +356,13 @@ def _normalize_template_data(data: dict[str, Any]) -> dict[str, Any]: Returns: Dict suitable for ``CompanyTemplate(**result)``. """ - company = data.get("company", {}) + company_raw = data.get("company", {}) + if company_raw is None: + company_raw = {} + if not isinstance(company_raw, dict): + msg = "Template field 'template.company' must be a mapping" + raise TypeError(msg) + company: dict[str, Any] = company_raw metadata: dict[str, Any] = { "description": data.get("description", ""), diff --git a/src/ai_company/templates/renderer.py b/src/ai_company/templates/renderer.py index 052c764ee8..3f28ec027b 100644 --- a/src/ai_company/templates/renderer.py +++ b/src/ai_company/templates/renderer.py @@ -232,6 +232,12 @@ def _build_config_dict( Dict suitable for ``RootConfig(**deep_merge(defaults, result))``. """ company = rendered_data.get("company", {}) + if company is None: + company = {} + if not isinstance(company, dict): + msg = "Rendered template 'company' must be a mapping" + raise TemplateRenderError(msg) + company_name = variables.get( "company_name", template.metadata.name, @@ -239,10 +245,20 @@ def _build_config_dict( # Expand agents. raw_agents = rendered_data.get("agents", []) + if raw_agents is None: + raw_agents = [] + if not isinstance(raw_agents, list): + msg = "Rendered template 'agents' must be a list" + raise TemplateRenderError(msg) agents = _expand_agents(raw_agents) # Build departments for RootConfig. raw_depts = rendered_data.get("departments", []) + if raw_depts is None: + raw_depts = [] + if not isinstance(raw_depts, list): + msg = "Rendered template 'departments' must be a list" + raise TemplateRenderError(msg) departments = _build_departments(raw_depts) source_name = template.metadata.name diff --git a/src/ai_company/templates/schema.py b/src/ai_company/templates/schema.py index 83e7e3516a..cd2624c489 100644 --- a/src/ai_company/templates/schema.py +++ b/src/ai_company/templates/schema.py @@ -26,7 +26,7 @@ class TemplateVariable(BaseModel): required: Whether the user must provide this value. """ - model_config = ConfigDict(frozen=True) + model_config = ConfigDict(frozen=True, extra="forbid") name: NotBlankStr = Field(description="Variable name") description: str = Field(default="", description="Human-readable description") @@ -50,6 +50,14 @@ def _validate_default_matches_var_type(self) -> Self: """Default value type must match ``var_type`` when provided.""" if self.default is None: return self + # Reject bools explicitly for numeric types because + # ``isinstance(True, int)`` is ``True`` in Python. + if isinstance(self.default, bool) and self.var_type in ("int", "float"): + msg = ( + f"Variable {self.name!r}: default {self.default!r} " + f"is not compatible with var_type {self.var_type!r}" + ) + raise ValueError(msg) type_map: dict[str, type | tuple[type, ...]] = { "str": str, "int": int, @@ -83,7 +91,7 @@ class TemplateAgentConfig(BaseModel): ``"engineering"`` during rendering). """ - model_config = ConfigDict(frozen=True) + model_config = ConfigDict(frozen=True, extra="forbid") role: NotBlankStr = Field(description="Built-in role name") name: str = Field(default="", description="Agent name (may have Jinja2 vars)") @@ -115,7 +123,7 @@ class TemplateDepartmentConfig(BaseModel): head_role: Role name of the department head. """ - model_config = ConfigDict(frozen=True) + model_config = ConfigDict(frozen=True, extra="forbid") name: NotBlankStr = Field(description="Department name") budget_percent: float = Field( @@ -143,7 +151,7 @@ class TemplateMetadata(BaseModel): tags: Categorization tags. """ - model_config = ConfigDict(frozen=True) + model_config = ConfigDict(frozen=True, extra="forbid") name: NotBlankStr = Field(description="Template display name") description: str = Field(default="", description="Template description") @@ -187,7 +195,7 @@ class CompanyTemplate(BaseModel): 1.0 = fully autonomous). """ - model_config = ConfigDict(frozen=True) + model_config = ConfigDict(frozen=True, extra="forbid") metadata: TemplateMetadata = Field(description="Template metadata") variables: tuple[TemplateVariable, ...] = Field(