diff --git a/src/ai_company/config/__init__.py b/src/ai_company/config/__init__.py index b2c192bb67..3bc3ca7831 100644 --- a/src/ai_company/config/__init__.py +++ b/src/ai_company/config/__init__.py @@ -5,6 +5,7 @@ .. autosummary:: load_config load_config_from_string + discover_config default_config_dict RootConfig AgentConfig @@ -27,7 +28,11 @@ ConfigParseError, ConfigValidationError, ) -from ai_company.config.loader import load_config, load_config_from_string +from ai_company.config.loader import ( + discover_config, + load_config, + load_config_from_string, +) from ai_company.config.schema import ( AgentConfig, ProviderConfig, @@ -50,6 +55,7 @@ "RoutingConfig", "RoutingRuleConfig", "default_config_dict", + "discover_config", "load_config", "load_config_from_string", ] diff --git a/src/ai_company/config/loader.py b/src/ai_company/config/loader.py index be1d4dbb64..d0da63612c 100644 --- a/src/ai_company/config/loader.py +++ b/src/ai_company/config/loader.py @@ -2,6 +2,8 @@ import copy import logging +import os +import re from pathlib import Path from typing import Any @@ -19,6 +21,15 @@ logger = logging.getLogger(__name__) +_ENV_VAR_PATTERN = re.compile(r"\$\{([^}:]+?)(?::-([^}]*))?\}") + +_CWD_CONFIG_LOCATIONS: tuple[Path, ...] = ( + Path("ai-company.yaml"), + Path("config/ai-company.yaml"), +) + +_HOME_CONFIG_RELATIVE = Path(".ai-company") / "config.yaml" + # --------------------------------------------------------------------------- # Private helpers # --------------------------------------------------------------------------- @@ -263,13 +274,128 @@ def _validate_config_dict( ) from exc +def _resolve_env_var_match( + match: re.Match[str], + *, + source_file: str | None, +) -> str: + """Resolve a single ``${VAR}`` or ``${VAR:-default}`` match. + + Args: + match: Regex match from :data:`_ENV_VAR_PATTERN`. + source_file: File path label for error messages. + + Returns: + Resolved environment variable value or default. + + Raises: + ConfigValidationError: If the env var is not set and no + default is provided. + """ + var_name = match.group(1) + default = match.group(2) + value = os.environ.get(var_name) + if value is not None: + return value + if default is not None: + return default + msg = f"Environment variable '{var_name}' is not set and no default was provided" + raise ConfigValidationError( + msg, + locations=(ConfigLocation(file_path=source_file),), + ) + + +def _walk_substitute(node: Any, *, source_file: str | None) -> Any: + """Recursively substitute env var placeholders in a config node. + + Args: + node: Config value (str, dict, list, or scalar). + source_file: File path label for error messages. + + Returns: + Node with all ``${VAR}`` placeholders resolved. + """ + if isinstance(node, str): + return _ENV_VAR_PATTERN.sub( + lambda m: _resolve_env_var_match(m, source_file=source_file), + node, + ) + if isinstance(node, dict): + return { + key: _walk_substitute(value, source_file=source_file) + for key, value in node.items() + } + if isinstance(node, list): + return [_walk_substitute(item, source_file=source_file) for item in node] + return node + + +def _substitute_env_vars( + data: dict[str, Any], + *, + source_file: str | None = None, +) -> dict[str, Any]: + """Substitute ``${VAR}`` and ``${VAR:-default}`` in string values. + + Walks the dict recursively, replacing environment variable + placeholders in string values. Non-string values (int, float, + bool, None) are passed through unchanged. Returns a new dict; + the input is never mutated. + + Args: + data: Configuration dict to process. + source_file: File path label for error messages. + + Returns: + A new dict with all env var placeholders resolved. + + Raises: + ConfigValidationError: If a referenced env var is not set + and no default is provided. + """ + result: dict[str, Any] = _walk_substitute(data, source_file=source_file) + return result + + # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- +def discover_config() -> Path: + """Auto-discover a configuration file from well-known locations. + + Search order: + + 1. ``./ai-company.yaml`` + 2. ``./config/ai-company.yaml`` + 3. ``~/.ai-company/config.yaml`` + + Returns: + Resolved absolute :class:`~pathlib.Path` to the first file found. + + Raises: + ConfigFileNotFoundError: If no configuration file is found + at any searched location. + """ + candidates = [*_CWD_CONFIG_LOCATIONS, Path.home() / _HOME_CONFIG_RELATIVE] + for candidate in candidates: + if candidate.is_file(): + return candidate.resolve() + + searched = [str(c) for c in candidates] + msg = "No configuration file found. Searched:\n" + "\n".join( + f" - {p}" for p in searched + ) + raise ConfigFileNotFoundError( + msg, + locations=tuple(ConfigLocation(file_path=p) for p in searched), + ) + + def load_config( - config_path: Path | str, + config_path: Path | str | None = None, *, override_paths: tuple[Path | str, ...] = (), ) -> RootConfig: @@ -280,6 +406,11 @@ def load_config( 1. Built-in defaults (from :func:`default_config_dict`). 2. Primary config file at *config_path*. 3. Override files in order. + 4. Environment variable substitution (``${VAR}`` / + ``${VAR:-default}``). + + When *config_path* is ``None``, :func:`discover_config` is called + to auto-discover the configuration file. .. note:: @@ -289,18 +420,23 @@ def load_config( file's line numbers or lack location information entirely. Args: - config_path: Path to the primary config file. + config_path: Path to the primary config file, or ``None`` + to auto-discover. override_paths: Additional config files layered on top. Returns: Validated, frozen :class:`RootConfig`. Raises: - ConfigFileNotFoundError: If any config file does not exist. + ConfigFileNotFoundError: If any config file does not exist + (or discovery finds nothing). ConfigParseError: If any file contains invalid YAML or cannot be read. - ConfigValidationError: If the merged config fails validation. + ConfigValidationError: If the merged config fails validation + or an env var is missing. """ + if config_path is None: + config_path = discover_config() config_path = Path(config_path) # 1. Start with built-in defaults @@ -317,6 +453,11 @@ def load_config( override = _parse_yaml_file(Path(override_path)) merged = _deep_merge(merged, override) + # 4. Substitute environment variables on the fully merged config. + # Use a neutral label so env-var errors aren't misattributed solely + # to the primary config file when they may originate from overrides. + merged = _substitute_env_vars(merged, source_file="") + # Build line map from primary file for error enrichment line_map = _build_line_map(yaml_text) @@ -351,6 +492,7 @@ def load_config_from_string( """ data = _parse_yaml_string(yaml_string, source_name) merged = _deep_merge(default_config_dict(), data) + merged = _substitute_env_vars(merged, source_file=source_name) line_map = _build_line_map(yaml_string) return _validate_config_dict( merged, diff --git a/src/ai_company/observability/config.py b/src/ai_company/observability/config.py index 7c534def73..8a55bb1a7d 100644 --- a/src/ai_company/observability/config.py +++ b/src/ai_company/observability/config.py @@ -181,18 +181,11 @@ def _validate_no_duplicate_file_paths(self) -> Self: @model_validator(mode="after") def _validate_log_dir_safe(self) -> Self: - """Ensure ``log_dir`` is not blank, relative, and has no traversal.""" + """Ensure ``log_dir`` is not blank and has no path traversal.""" if not self.log_dir.strip(): msg = "log_dir must not be blank" raise ValueError(msg) path = PurePath(self.log_dir) - if ( - path.is_absolute() - or PurePosixPath(self.log_dir).is_absolute() - or PureWindowsPath(self.log_dir).is_absolute() - ): - msg = f"log_dir must be relative: {self.log_dir}" - raise ValueError(msg) if ".." in path.parts: msg = f"log_dir must not contain '..' components: {self.log_dir}" raise ValueError(msg) diff --git a/tests/unit/config/conftest.py b/tests/unit/config/conftest.py index cea3d2b7a7..5fef9952a5 100644 --- a/tests/unit/config/conftest.py +++ b/tests/unit/config/conftest.py @@ -108,6 +108,27 @@ class RootConfigFactory(ModelFactory): total_monthly: -100.0 """ +ENV_VAR_SIMPLE_YAML = """\ +company_name: ${COMPANY_NAME} +""" + +ENV_VAR_NESTED_YAML = """\ +company_name: ${COMPANY_NAME} +budget: + total_monthly: 500.0 + alerts: + warn_at: 75 + critical_at: 90 + hard_stop_at: 100 +providers: + anthropic: + base_url: ${ANTHROPIC_BASE_URL:-https://api.anthropic.com} +""" + +ENV_VAR_MISSING_YAML = """\ +company_name: ${UNDEFINED_VAR} +""" + # ── Fixtures ────────────────────────────────────────────────────── diff --git a/tests/unit/config/test_loader.py b/tests/unit/config/test_loader.py index 6346ef937b..3d9099c6e7 100644 --- a/tests/unit/config/test_loader.py +++ b/tests/unit/config/test_loader.py @@ -1,5 +1,7 @@ """Tests for config loader (parsing, merging, validation).""" +from pathlib import Path + import pytest from ai_company.config.errors import ( @@ -13,13 +15,18 @@ _parse_yaml_file, _parse_yaml_string, _read_config_text, + _substitute_env_vars, _validate_config_dict, + discover_config, load_config, load_config_from_string, ) from ai_company.config.schema import RootConfig from .conftest import ( + ENV_VAR_MISSING_YAML, + ENV_VAR_NESTED_YAML, + ENV_VAR_SIMPLE_YAML, FULL_VALID_YAML, INVALID_FIELD_VALUES_YAML, INVALID_SYNTAX_YAML, @@ -349,6 +356,12 @@ def test_nested_override_merge(self, tmp_config_file): assert cfg.budget.total_monthly == 200.0 assert cfg.budget.per_task_limit == 10.0 + def test_string_path_accepted(self, tmp_config_file): + """String paths are coerced to Path objects.""" + path = tmp_config_file(MINIMAL_VALID_YAML) + cfg = load_config(str(path)) + assert cfg.company_name == "Test Corp" + def test_directory_path_rejected(self, tmp_path): with pytest.raises(ConfigFileNotFoundError): load_config(tmp_path) @@ -393,3 +406,271 @@ def test_custom_source_name(self): def test_empty_string_uses_defaults(self): cfg = load_config_from_string("") assert cfg.company_name == "AI Company" + + +# ── _substitute_env_vars ──────────────────────────────────────── + + +@pytest.mark.unit +class TestSubstituteEnvVars: + def test_simple_substitution(self, monkeypatch): + monkeypatch.setenv("FOO", "bar") + data = {"key": "${FOO}"} + result = _substitute_env_vars(data) + assert result == {"key": "bar"} + + def test_missing_var_raises(self): + data = {"key": "${MISSING_VAR_XYZ}"} + with pytest.raises(ConfigValidationError, match="MISSING_VAR_XYZ"): + _substitute_env_vars(data) + + def test_default_used_when_missing(self): + data = {"key": "${MISSING_VAR_XYZ:-fallback}"} + result = _substitute_env_vars(data) + assert result == {"key": "fallback"} + + def test_default_ignored_when_present(self, monkeypatch): + monkeypatch.setenv("SET_VAR", "real") + data = {"key": "${SET_VAR:-fallback}"} + result = _substitute_env_vars(data) + assert result == {"key": "real"} + + def test_empty_default(self): + data = {"key": "${MISSING_VAR_XYZ:-}"} + result = _substitute_env_vars(data) + assert result == {"key": ""} + + def test_nested_dict(self, monkeypatch): + monkeypatch.setenv("INNER", "resolved") + data = {"outer": {"inner": "${INNER}"}} + result = _substitute_env_vars(data) + assert result == {"outer": {"inner": "resolved"}} + + def test_list_values(self, monkeypatch): + monkeypatch.setenv("ITEM", "hello") + data = {"items": ["${ITEM}", "static"]} + result = _substitute_env_vars(data) + assert result == {"items": ["hello", "static"]} + + def test_non_string_unchanged(self): + data = {"int": 42, "float": 3.14, "bool": True, "null": None} + result = _substitute_env_vars(data) + assert result == {"int": 42, "float": 3.14, "bool": True, "null": None} + + def test_multiple_vars_in_one_string(self, monkeypatch): + monkeypatch.setenv("A", "alpha") + monkeypatch.setenv("B", "beta") + data = {"key": "${A}:${B}"} + result = _substitute_env_vars(data) + assert result == {"key": "alpha:beta"} + + def test_partial_string(self, monkeypatch): + monkeypatch.setenv("VAR", "middle") + data = {"key": "prefix-${VAR}-suffix"} + result = _substitute_env_vars(data) + assert result == {"key": "prefix-middle-suffix"} + + def test_input_not_mutated(self, monkeypatch): + monkeypatch.setenv("X", "replaced") + original = {"key": "${X}", "nested": {"deep": "${X}"}} + original_copy = {"key": "${X}", "nested": {"deep": "${X}"}} + _substitute_env_vars(original) + assert original == original_copy + + def test_no_placeholders_passthrough(self): + data = {"key": "no vars here", "num": 123} + result = _substitute_env_vars(data) + assert result == {"key": "no vars here", "num": 123} + + def test_deeply_nested(self, monkeypatch): + monkeypatch.setenv("DEEP", "found") + data = {"a": {"b": {"c": {"d": {"e": "${DEEP}"}}}}} + result = _substitute_env_vars(data) + assert result == {"a": {"b": {"c": {"d": {"e": "found"}}}}} + + def test_no_recursive_expansion(self, monkeypatch): + """Env var values containing ${...} syntax are NOT recursively expanded.""" + monkeypatch.setenv("OUTER", "${INNER}") + monkeypatch.setenv("INNER", "should_not_appear") + data = {"key": "${OUTER}"} + result = _substitute_env_vars(data) + assert result == {"key": "${INNER}"} + + def test_special_chars_in_env_value(self, monkeypatch): + """Env var values with regex/URL special chars are preserved verbatim.""" + monkeypatch.setenv("URL", "https://example.com/path?a=1&b=2#frag") + data = {"endpoint": "${URL}"} + result = _substitute_env_vars(data) + assert result == {"endpoint": "https://example.com/path?a=1&b=2#frag"} + + def test_missing_var_error_includes_source_file(self): + """Error for missing env var includes the source_file in locations.""" + data = {"key": "${MISSING_XYZ}"} + with pytest.raises(ConfigValidationError) as exc_info: + _substitute_env_vars(data, source_file="my-config.yaml") + assert exc_info.value.locations[0].file_path == "my-config.yaml" + + +# ── discover_config ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestDiscoverConfig: + def test_finds_cwd_config(self, tmp_path, monkeypatch): + config_file = tmp_path / "ai-company.yaml" + config_file.write_text("company_name: Test\n", encoding="utf-8") + monkeypatch.chdir(tmp_path) + result = discover_config() + assert result == config_file.resolve() + + def test_finds_config_subdir(self, tmp_path, monkeypatch): + config_dir = tmp_path / "config" + config_dir.mkdir() + config_file = config_dir / "ai-company.yaml" + config_file.write_text("company_name: Test\n", encoding="utf-8") + monkeypatch.chdir(tmp_path) + result = discover_config() + assert result == config_file.resolve() + + def test_finds_home_config(self, tmp_path, monkeypatch): + # CWD has no config + monkeypatch.chdir(tmp_path) + # Home dir has config + fake_home = tmp_path / "fakehome" + config_dir = fake_home / ".ai-company" + config_dir.mkdir(parents=True) + config_file = config_dir / "config.yaml" + config_file.write_text("company_name: Test\n", encoding="utf-8") + monkeypatch.setattr(Path, "home", classmethod(lambda cls: fake_home)) + result = discover_config() + assert result == config_file.resolve() + + def test_precedence_cwd_over_subdir(self, tmp_path, monkeypatch): + # Both CWD and config/ have files + cwd_file = tmp_path / "ai-company.yaml" + cwd_file.write_text("company_name: CWD\n", encoding="utf-8") + config_dir = tmp_path / "config" + config_dir.mkdir() + subdir_file = config_dir / "ai-company.yaml" + subdir_file.write_text("company_name: SubDir\n", encoding="utf-8") + monkeypatch.chdir(tmp_path) + result = discover_config() + assert result == cwd_file.resolve() + + def test_precedence_subdir_over_home(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + # config/ subdir has file + config_dir = tmp_path / "config" + config_dir.mkdir() + subdir_file = config_dir / "ai-company.yaml" + subdir_file.write_text("company_name: SubDir\n", encoding="utf-8") + # Home dir has file + fake_home = tmp_path / "fakehome" + home_config_dir = fake_home / ".ai-company" + home_config_dir.mkdir(parents=True) + home_file = home_config_dir / "config.yaml" + home_file.write_text("company_name: Home\n", encoding="utf-8") + monkeypatch.setattr(Path, "home", classmethod(lambda cls: fake_home)) + result = discover_config() + assert result == subdir_file.resolve() + + def test_no_config_raises(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + fake_home = tmp_path / "fakehome" + fake_home.mkdir() + monkeypatch.setattr(Path, "home", classmethod(lambda cls: fake_home)) + with pytest.raises( + ConfigFileNotFoundError, match="No configuration file" + ) as exc_info: + discover_config() + # All 3 search locations should be reported + assert len(exc_info.value.locations) == 3 + + def test_returns_resolved_path(self, tmp_path, monkeypatch): + config_file = tmp_path / "ai-company.yaml" + config_file.write_text("company_name: Test\n", encoding="utf-8") + monkeypatch.chdir(tmp_path) + result = discover_config() + assert result.is_absolute() + + +# ── Env var substitution through load_config ──────────────────── + + +@pytest.mark.unit +class TestLoadConfigEnvVar: + def test_env_var_in_load_config(self, tmp_config_file, monkeypatch): + monkeypatch.setenv("COMPANY_NAME", "Env Corp") + path = tmp_config_file(ENV_VAR_SIMPLE_YAML) + cfg = load_config(path) + assert cfg.company_name == "Env Corp" + + def test_env_var_with_default_in_load_config(self, tmp_config_file): + yaml_content = "company_name: ${UNDEFINED_TEST_VAR:-Default Corp}\n" + path = tmp_config_file(yaml_content) + cfg = load_config(path) + assert cfg.company_name == "Default Corp" + + def test_missing_env_var_raises_in_load_config(self, tmp_config_file): + path = tmp_config_file(ENV_VAR_MISSING_YAML) + with pytest.raises(ConfigValidationError, match="UNDEFINED_VAR"): + load_config(path) + + def test_env_var_in_nested_config(self, tmp_config_file, monkeypatch): + monkeypatch.setenv("COMPANY_NAME", "Nested Corp") + monkeypatch.setenv("ANTHROPIC_BASE_URL", "https://custom.api") + path = tmp_config_file(ENV_VAR_NESTED_YAML) + cfg = load_config(path) + assert cfg.company_name == "Nested Corp" + assert cfg.providers["anthropic"].base_url == "https://custom.api" + + def test_env_var_in_load_config_from_string(self, monkeypatch): + monkeypatch.setenv("COMPANY_NAME", "String Corp") + cfg = load_config_from_string(ENV_VAR_SIMPLE_YAML) + assert cfg.company_name == "String Corp" + + def test_env_var_default_in_load_config_from_string(self): + yaml_content = "company_name: ${UNDEFINED_TEST_VAR:-FromString Corp}\n" + cfg = load_config_from_string(yaml_content) + assert cfg.company_name == "FromString Corp" + + def test_missing_env_var_raises_in_load_config_from_string(self): + with pytest.raises(ConfigValidationError, match="UNDEFINED_VAR"): + load_config_from_string(ENV_VAR_MISSING_YAML) + + def test_env_var_in_override_file(self, tmp_config_file, monkeypatch): + monkeypatch.setenv("OVERRIDE_NAME", "Override Corp") + base = tmp_config_file(MINIMAL_VALID_YAML, name="base.yaml") + override = tmp_config_file( + "company_name: ${OVERRIDE_NAME}\n", + name="override.yaml", + ) + cfg = load_config(base, override_paths=(override,)) + assert cfg.company_name == "Override Corp" + + +# ── discover_config with load_config ──────────────────────────── + + +@pytest.mark.unit +class TestLoadConfigDiscovery: + def test_load_config_none_uses_discovery(self, tmp_path, monkeypatch): + config_file = tmp_path / "ai-company.yaml" + config_file.write_text(MINIMAL_VALID_YAML, encoding="utf-8") + monkeypatch.chdir(tmp_path) + cfg = load_config(None) + assert cfg.company_name == "Test Corp" + + def test_load_config_none_no_config_raises(self, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + fake_home = tmp_path / "fakehome" + fake_home.mkdir() + monkeypatch.setattr(Path, "home", classmethod(lambda cls: fake_home)) + with pytest.raises(ConfigFileNotFoundError): + load_config(None) + + def test_load_config_explicit_path_still_works(self, tmp_config_file): + """Backward compatibility: explicit path still works as before.""" + path = tmp_config_file(MINIMAL_VALID_YAML) + cfg = load_config(path) + assert cfg.company_name == "Test Corp" diff --git a/tests/unit/observability/test_config.py b/tests/unit/observability/test_config.py index d229c81f12..045deaf7b4 100644 --- a/tests/unit/observability/test_config.py +++ b/tests/unit/observability/test_config.py @@ -242,11 +242,11 @@ def test_log_dir_traversal_rejected(self) -> None: @pytest.mark.parametrize( "absolute_dir", - ["/var/log", "/etc/evil", "C:\\Windows\\System32"], + ["/var/log", "/opt/app/logs", "C:\\Logs"], ) - def test_absolute_log_dir_rejected(self, absolute_dir: str) -> None: - with pytest.raises(ValidationError, match="log_dir must be relative"): - LogConfig(sinks=(_console_sink(),), log_dir=absolute_dir) + def test_absolute_log_dir_accepted(self, absolute_dir: str) -> None: + cfg = LogConfig(sinks=(_console_sink(),), log_dir=absolute_dir) + assert cfg.log_dir == absolute_dir def test_frozen(self) -> None: cfg = LogConfig(sinks=(_console_sink(),))