Skip to content
1 change: 1 addition & 0 deletions wren/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"opendal>=0.45",
"pandas>=2",
"boto3>=1.26",
"pyyaml>=6.0",
# Transitive dependency constraints for security patches
"requests>=2.33.0",
"pyasn1>=0.6.3",
Expand Down
64 changes: 62 additions & 2 deletions wren/src/wren/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,42 @@ def _build_engine(
connection_file: str | None,
*,
conn_required: bool = True,
datasource: str | None = None,
):
from wren.engine import WrenEngine # noqa: PLC0415
from wren.model.data_source import DataSource # noqa: PLC0415

manifest_str = _load_manifest(_require_mdl(mdl))

# Try active profile when no explicit connection flags given
if not connection_info and not connection_file:
from wren.profile import get_active_profile # noqa: PLC0415

prof_name, prof_dict = get_active_profile()
if prof_dict:
prof_ds = prof_dict.pop("datasource", None)
ds_str = datasource or prof_ds
if ds_str is None:
typer.echo("Error: no datasource in profile or --datasource.", err=True)
raise typer.Exit(1)
try:
ds = DataSource(ds_str.lower())
except ValueError:
typer.echo(f"Error: unknown datasource '{ds_str}'", err=True)
raise typer.Exit(1)
from pydantic import ValidationError # noqa: PLC0415

try:
return WrenEngine(
manifest_str=manifest_str, data_source=ds, connection_info=prof_dict
)
except ValidationError as e:
typer.echo(f"Error: invalid profile connection info: {e}", err=True)
raise typer.Exit(1)

# Existing path: explicit flags / legacy connection_info.json
conn_dict = _load_conn(connection_info, connection_file, required=conn_required)
ds_str = _resolve_datasource(conn_dict)
ds_str = _resolve_datasource(conn_dict, explicit=datasource)

try:
ds = DataSource(ds_str.lower())
Expand Down Expand Up @@ -292,7 +321,7 @@ def dry_plan(
typer.Option(
"--datasource",
"-d",
help="Data source dialect (e.g. duckdb, postgres). Falls back to connection_info.json.",
help="Data source dialect (e.g. duckdb, postgres). Falls back to active profile or connection_info.json.",
),
] = None,
mdl: MdlOpt = None,
Expand All @@ -303,6 +332,33 @@ def dry_plan(
from wren.model.data_source import DataSource # noqa: PLC0415

manifest_str = _load_manifest(_require_mdl(mdl))

# Try active profile when no explicit flags given
if datasource is None and connection_file is None:
from wren.profile import get_active_profile # noqa: PLC0415

_prof_name, prof_dict = get_active_profile()
if prof_dict:
prof_ds = prof_dict.pop("datasource", None)
if prof_ds is None:
typer.echo("Error: no datasource in active profile.", err=True)
raise typer.Exit(1)
try:
ds = DataSource(prof_ds.lower())
except ValueError:
typer.echo(f"Error: unknown datasource '{prof_ds}'", err=True)
raise typer.Exit(1)
with WrenEngine(
manifest_str=manifest_str, data_source=ds, connection_info={}
) as engine:
try:
result = engine.dry_plan(sql)
typer.echo(result)
except Exception as e:
typer.echo(f"Error: {e}", err=True)
raise typer.Exit(1)
return

conn_dict = (
_load_conn(None, connection_file, required=False) if datasource is None else {}
)
Expand Down Expand Up @@ -391,6 +447,10 @@ def version():
except ImportError:
pass # wren[memory] not installed

from wren.profile_cli import profile_app # noqa: PLC0415, E402

app.add_typer(profile_app)


if __name__ == "__main__":
app()
181 changes: 181 additions & 0 deletions wren/src/wren/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
"""Profile management — load, save, list, switch, add, remove profiles."""

from __future__ import annotations

import os
import tempfile
from pathlib import Path
from typing import Any

import yaml

_WREN_HOME = Path(os.environ.get("WREN_HOME", Path.home() / ".wren"))
_PROFILES_FILE = _WREN_HOME / "profiles.yml"


def _load_raw() -> dict:
"""Load profiles.yml, returning empty structure if missing.

Raises ValueError on malformed content so callers get a deterministic error
instead of an AttributeError deep inside library code.
"""
if not _PROFILES_FILE.exists():
return {"active": None, "profiles": {}}
try:
data = yaml.safe_load(_PROFILES_FILE.read_text())
except yaml.YAMLError as exc:
raise ValueError(
f"profiles.yml is not valid YAML: {exc}\n"
f"Fix or remove {_PROFILES_FILE} and try again."
) from exc
if data is None:
return {"active": None, "profiles": {}}
if not isinstance(data, dict):
raise ValueError(
f"profiles.yml must contain a YAML mapping; got {type(data).__name__}.\n"
f"Fix or remove {_PROFILES_FILE} and try again."
)
profiles = data.get("profiles", {})
if not isinstance(profiles, dict):
raise ValueError(
f"profiles.yml: 'profiles' must be a mapping; got {type(profiles).__name__}.\n"
f"Fix or remove {_PROFILES_FILE} and try again."
)
active = data.get("active")
if active is not None and not isinstance(active, str):
raise ValueError(
f"profiles.yml: 'active' must be a string or null; got {type(active).__name__}.\n"
f"Fix or remove {_PROFILES_FILE} and try again."
)
return data


def _save_raw(data: dict) -> None:
"""Write profiles.yml atomically with 0600 permissions."""
_WREN_HOME.mkdir(parents=True, exist_ok=True)
payload = yaml.dump(data, default_flow_style=False, sort_keys=False)
# Write to a temp file in the same directory then atomically replace
fd, tmp_path = tempfile.mkstemp(dir=_WREN_HOME, suffix=".yml.tmp")
try:
os.chmod(tmp_path, 0o600)
with os.fdopen(fd, "w", encoding="utf-8") as fh:
fh.write(payload)
os.replace(tmp_path, _PROFILES_FILE)
except Exception:
os.unlink(tmp_path)
raise
os.chmod(_PROFILES_FILE, 0o600)


def list_profiles() -> dict[str, dict]:
"""Return {name: profile_dict} for all profiles."""
return _load_raw().get("profiles", {})


def get_active_name() -> str | None:
"""Return the name of the currently active profile, or None."""
return _load_raw().get("active")


def get_active_profile() -> tuple[str | None, dict]:
"""Return (name, profile_dict) for the active profile. ({} if none set)."""
data = _load_raw()
name = data.get("active")
if name is None:
return None, {}
profiles = data.get("profiles", {})
return name, dict(profiles.get(name, {}))


def add_profile(name: str, profile: dict, *, activate: bool = False) -> None:
"""Add or overwrite a named profile."""
data = _load_raw()
data.setdefault("profiles", {})[name] = profile
if activate or data.get("active") is None:
data["active"] = name
_save_raw(data)


def remove_profile(name: str) -> bool:
"""Remove a profile. Returns True if found. Clears active if it was this profile."""
data = _load_raw()
profiles = data.get("profiles", {})
if name not in profiles:
return False
del profiles[name]
if data.get("active") == name:
data["active"] = next(iter(profiles), None)
_save_raw(data)
return True


def switch_profile(name: str) -> bool:
"""Set the active profile. Returns False if name not found."""
data = _load_raw()
if name not in data.get("profiles", {}):
return False
data["active"] = name
_save_raw(data)
return True


def resolve_connection(
explicit_datasource: str | None,
explicit_conn_info: str | None,
explicit_conn_file: str | None,
) -> tuple[str | None, dict]:
"""Resolve datasource + connection_info from explicit flags or active profile.

Priority: explicit flags > active profile.
Legacy ~/.wren/connection_info.json fallback is handled separately by
cli._load_conn() and is not performed here.
Returns (datasource_str_or_None, connection_dict).
"""
if explicit_datasource or explicit_conn_info or explicit_conn_file:
return explicit_datasource, {}

name, profile = get_active_profile()
if profile:
ds = profile.pop("datasource", None)
return ds, profile

return None, {}


def debug_profile(name: str | None = None) -> dict[str, Any]:
"""Return diagnostic info for a profile (or the active one).

Masks sensitive fields (password, credentials, secret, token).
"""
if name is None:
name = get_active_name()
if name is None:
return {"error": "no active profile"}
data = _load_raw()
profile = data.get("profiles", {}).get(name)
if profile is None:
return {"error": f"profile '{name}' not found"}

_SENSITIVE = {
"password",
"credentials",
"secret",
"token",
"private_key",
"access_key",
"key_id",
"client_id",
"bucket",
"endpoint",
"staging_dir",
"hostname",
"http_path",
"role_arn",
}
masked = {}
for k, v in profile.items():
if k.lower() in _SENSITIVE or any(s in k.lower() for s in _SENSITIVE):
masked[k] = "***"
else:
masked[k] = v
return {"name": name, "active": data.get("active") == name, "config": masked}
Loading
Loading