Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/ai_company/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
.. autosummary::
load_config
load_config_from_string
discover_config
default_config_dict
RootConfig
AgentConfig
Expand All @@ -27,7 +28,11 @@
ConfigParseError,
ConfigValidationError,
)
from ai_company.config.loader import load_config, load_config_from_string
from ai_company.config.loader import (
discover_config,
load_config,
load_config_from_string,
)
from ai_company.config.schema import (
AgentConfig,
ProviderConfig,
Expand All @@ -50,6 +55,7 @@
"RoutingConfig",
"RoutingRuleConfig",
"default_config_dict",
"discover_config",
"load_config",
"load_config_from_string",
]
150 changes: 146 additions & 4 deletions src/ai_company/config/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import copy
import logging
import os
import re
from pathlib import Path
from typing import Any

Expand All @@ -19,6 +21,15 @@

logger = logging.getLogger(__name__)

_ENV_VAR_PATTERN = re.compile(r"\$\{([^}:]+?)(?::-([^}]*))?\}")

_CWD_CONFIG_LOCATIONS: tuple[Path, ...] = (
Path("ai-company.yaml"),
Path("config/ai-company.yaml"),
)

_HOME_CONFIG_RELATIVE = Path(".ai-company") / "config.yaml"

# ---------------------------------------------------------------------------
# Private helpers
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -263,13 +274,128 @@ def _validate_config_dict(
) from exc


def _resolve_env_var_match(
match: re.Match[str],
*,
source_file: str | None,
) -> str:
"""Resolve a single ``${VAR}`` or ``${VAR:-default}`` match.

Args:
match: Regex match from :data:`_ENV_VAR_PATTERN`.
source_file: File path label for error messages.

Returns:
Resolved environment variable value or default.

Raises:
ConfigValidationError: If the env var is not set and no
default is provided.
"""
var_name = match.group(1)
default = match.group(2)
value = os.environ.get(var_name)
if value is not None:
return value
if default is not None:
return default
msg = f"Environment variable '{var_name}' is not set and no default was provided"
raise ConfigValidationError(
msg,
locations=(ConfigLocation(file_path=source_file),),
)


def _walk_substitute(node: Any, *, source_file: str | None) -> Any:
"""Recursively substitute env var placeholders in a config node.

Args:
node: Config value (str, dict, list, or scalar).
source_file: File path label for error messages.

Returns:
Node with all ``${VAR}`` placeholders resolved.
"""
if isinstance(node, str):
return _ENV_VAR_PATTERN.sub(
lambda m: _resolve_env_var_match(m, source_file=source_file),
node,
)
if isinstance(node, dict):
return {
key: _walk_substitute(value, source_file=source_file)
for key, value in node.items()
}
if isinstance(node, list):
return [_walk_substitute(item, source_file=source_file) for item in node]
return node


def _substitute_env_vars(
data: dict[str, Any],
*,
source_file: str | None = None,
) -> dict[str, Any]:
"""Substitute ``${VAR}`` and ``${VAR:-default}`` in string values.

Walks the dict recursively, replacing environment variable
placeholders in string values. Non-string values (int, float,
bool, None) are passed through unchanged. Returns a new dict;
the input is never mutated.

Args:
data: Configuration dict to process.
source_file: File path label for error messages.

Returns:
A new dict with all env var placeholders resolved.

Raises:
ConfigValidationError: If a referenced env var is not set
and no default is provided.
"""
result: dict[str, Any] = _walk_substitute(data, source_file=source_file)
return result
Comment thread
coderabbitai[bot] marked this conversation as resolved.


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


def discover_config() -> Path:
"""Auto-discover a configuration file from well-known locations.

Search order:

1. ``./ai-company.yaml``
2. ``./config/ai-company.yaml``
3. ``~/.ai-company/config.yaml``

Returns:
Resolved absolute :class:`~pathlib.Path` to the first file found.

Raises:
ConfigFileNotFoundError: If no configuration file is found
at any searched location.
"""
candidates = [*_CWD_CONFIG_LOCATIONS, Path.home() / _HOME_CONFIG_RELATIVE]
for candidate in candidates:
if candidate.is_file():
return candidate.resolve()

searched = [str(c) for c in candidates]
msg = "No configuration file found. Searched:\n" + "\n".join(
f" - {p}" for p in searched
)
raise ConfigFileNotFoundError(
msg,
locations=tuple(ConfigLocation(file_path=p) for p in searched),
)


def load_config(
config_path: Path | str,
config_path: Path | str | None = None,
*,
override_paths: tuple[Path | str, ...] = (),
) -> RootConfig:
Expand All @@ -280,6 +406,11 @@ def load_config(
1. Built-in defaults (from :func:`default_config_dict`).
2. Primary config file at *config_path*.
3. Override files in order.
4. Environment variable substitution (``${VAR}`` /
``${VAR:-default}``).

When *config_path* is ``None``, :func:`discover_config` is called
to auto-discover the configuration file.

.. note::

Expand All @@ -289,18 +420,23 @@ def load_config(
file's line numbers or lack location information entirely.

Args:
config_path: Path to the primary config file.
config_path: Path to the primary config file, or ``None``
to auto-discover.
override_paths: Additional config files layered on top.

Returns:
Validated, frozen :class:`RootConfig`.

Raises:
ConfigFileNotFoundError: If any config file does not exist.
ConfigFileNotFoundError: If any config file does not exist
(or discovery finds nothing).
ConfigParseError: If any file contains invalid YAML or cannot
be read.
ConfigValidationError: If the merged config fails validation.
ConfigValidationError: If the merged config fails validation
or an env var is missing.
"""
if config_path is None:
config_path = discover_config()
config_path = Path(config_path)

# 1. Start with built-in defaults
Expand All @@ -317,6 +453,11 @@ def load_config(
override = _parse_yaml_file(Path(override_path))
merged = _deep_merge(merged, override)

# 4. Substitute environment variables on the fully merged config.
# Use a neutral label so env-var errors aren't misattributed solely
# to the primary config file when they may originate from overrides.
merged = _substitute_env_vars(merged, source_file="<merged config>")

# Build line map from primary file for error enrichment
line_map = _build_line_map(yaml_text)

Expand Down Expand Up @@ -351,6 +492,7 @@ def load_config_from_string(
"""
data = _parse_yaml_string(yaml_string, source_name)
merged = _deep_merge(default_config_dict(), data)
merged = _substitute_env_vars(merged, source_file=source_name)
line_map = _build_line_map(yaml_string)
return _validate_config_dict(
merged,
Expand Down
9 changes: 1 addition & 8 deletions src/ai_company/observability/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,11 @@ def _validate_no_duplicate_file_paths(self) -> Self:

@model_validator(mode="after")
def _validate_log_dir_safe(self) -> Self:
"""Ensure ``log_dir`` is not blank, relative, and has no traversal."""
"""Ensure ``log_dir`` is not blank and has no path traversal."""
if not self.log_dir.strip():
msg = "log_dir must not be blank"
raise ValueError(msg)
path = PurePath(self.log_dir)
if (
path.is_absolute()
or PurePosixPath(self.log_dir).is_absolute()
or PureWindowsPath(self.log_dir).is_absolute()
):
msg = f"log_dir must be relative: {self.log_dir}"
raise ValueError(msg)
if ".." in path.parts:
msg = f"log_dir must not contain '..' components: {self.log_dir}"
raise ValueError(msg)
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/config/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,27 @@ class RootConfigFactory(ModelFactory):
total_monthly: -100.0
"""

ENV_VAR_SIMPLE_YAML = """\
company_name: ${COMPANY_NAME}
"""

ENV_VAR_NESTED_YAML = """\
company_name: ${COMPANY_NAME}
budget:
total_monthly: 500.0
alerts:
warn_at: 75
critical_at: 90
hard_stop_at: 100
providers:
anthropic:
base_url: ${ANTHROPIC_BASE_URL:-https://api.anthropic.com}
"""

ENV_VAR_MISSING_YAML = """\
company_name: ${UNDEFINED_VAR}
"""


# ── Fixtures ──────────────────────────────────────────────────────

Expand Down
Loading