Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion wren/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,21 @@ pip install wren-engine[all] # All connectors
}
```

**3. Run queries** — `wren` auto-discovers both files from `~/.wren`:
**3. (Optional) Create `~/.wren/config.json`** — security policy:

```json
{
"strict_mode": true,
"denied_functions": ["pg_read_file", "dblink", "lo_import"]
}
```

| Key | Default | Description |
|-----|---------|-------------|
| `strict_mode` | `false` | When `true`, every table in a query must be defined in the MDL. Queries referencing undeclared tables are rejected before execution. |
| `denied_functions` | `[]` | List of function names (case-insensitive) that are forbidden in queries. |

**4. Run queries** — `wren` auto-discovers all files from `~/.wren` (override with `WREN_HOME=/path/to/dir`):

```bash
wren --sql 'SELECT order_id FROM "orders" LIMIT 10'
Expand Down
29 changes: 24 additions & 5 deletions wren/src/wren/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from __future__ import annotations

import json
import os
from pathlib import Path
from typing import Annotated, Optional

import typer

app = typer.Typer(name="wren", help="Wren Engine CLI", no_args_is_help=False)

_WREN_HOME = Path.home() / ".wren"
_WREN_HOME = Path(os.environ.get("WREN_HOME", Path.home() / ".wren")).expanduser()
_DEFAULT_MDL = _WREN_HOME / "mdl.json"
_DEFAULT_CONN = _WREN_HOME / "connection_info.json"

Expand Down Expand Up @@ -126,8 +127,10 @@ def _build_engine(
*,
conn_required: bool = True,
):
from wren.config import load_config # noqa: PLC0415
from wren.engine import WrenEngine # noqa: PLC0415
from wren.model.data_source import DataSource # noqa: PLC0415
from wren.model.error import WrenError # noqa: PLC0415

manifest_str = _load_manifest(_require_mdl(mdl))
conn_dict = _load_conn(connection_info, connection_file, required=conn_required)
Expand All @@ -139,9 +142,17 @@ def _build_engine(
typer.echo(f"Error: unknown datasource '{ds_str}'", err=True)
raise typer.Exit(1)

return WrenEngine(
manifest_str=manifest_str, data_source=ds, connection_info=conn_dict
)
try:
config = load_config(_WREN_HOME)
return WrenEngine(
manifest_str=manifest_str,
data_source=ds,
connection_info=conn_dict,
config=config,
)
except (WrenError, OSError) as e:
typer.echo(f"Error: {e}", err=True)
raise typer.Exit(1) from e


# ── Shared option types ────────────────────────────────────────────────────
Expand Down Expand Up @@ -281,8 +292,10 @@ def dry_plan(
connection_file: ConnFileOpt = None,
):
"""Plan SQL through MDL and print the expanded SQL (no DB required)."""
from wren.config import load_config # noqa: PLC0415
from wren.engine import WrenEngine # noqa: PLC0415
from wren.model.data_source import DataSource # noqa: PLC0415
from wren.model.error import WrenError # noqa: PLC0415

manifest_str = _load_manifest(_require_mdl(mdl))
# Read datasource from connection_info.json only when --datasource is not given
Expand All @@ -299,8 +312,14 @@ def dry_plan(
typer.echo(f"Error: unknown datasource '{ds_str}'", err=True)
raise typer.Exit(1)

try:
config = load_config(_WREN_HOME)
except (WrenError, OSError) as e:
typer.echo(f"Error: {e}", err=True)
raise typer.Exit(1) from e

with WrenEngine(
manifest_str=manifest_str, data_source=ds, connection_info={}
manifest_str=manifest_str, data_source=ds, connection_info={}, config=config
) as engine:
try:
result = engine.dry_plan(sql)
Expand Down
74 changes: 74 additions & 0 deletions wren/src/wren/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""Wren CLI configuration loaded from ~/.wren/config.json."""

from __future__ import annotations

import json
from dataclasses import dataclass, field
from pathlib import Path

from wren.model.error import ErrorCode, WrenError


@dataclass(frozen=True)
class WrenConfig:
"""Immutable configuration for the Wren CLI.

Attributes
----------
strict_mode:
When ``True``, all table references in SQL must be defined in the MDL
manifest. Queries referencing non-MDL tables are rejected.
denied_functions:
Set of function names (lowercase) that are forbidden in SQL queries.
Matching is case-insensitive.
"""

strict_mode: bool = False
denied_functions: frozenset[str] = field(default_factory=frozenset)


def load_config(wren_home: Path) -> WrenConfig:
"""Load configuration from ``wren_home/config.json``.

Returns default ``WrenConfig`` when the file does not exist.
Raises ``WrenError`` when the file exists but contains invalid JSON.
"""
config_path = wren_home / "config.json"
if not config_path.exists():
return WrenConfig()

try:
raw = json.loads(config_path.read_text())
except (json.JSONDecodeError, OSError) as e:
raise WrenError(
ErrorCode.GENERIC_USER_ERROR,
f"Failed to read {config_path}: {e}",
) from e

if not isinstance(raw, dict):
raise WrenError(
ErrorCode.GENERIC_USER_ERROR,
f"{config_path} must contain a JSON object.",
)

strict_mode_raw = raw.get("strict_mode", False)
if not isinstance(strict_mode_raw, bool):
raise WrenError(
ErrorCode.GENERIC_USER_ERROR,
f"{config_path}: 'strict_mode' must be a JSON boolean.",
)

denied_raw = raw.get("denied_functions", [])
if not isinstance(denied_raw, list):
raise WrenError(
ErrorCode.GENERIC_USER_ERROR,
f"{config_path}: 'denied_functions' must be a JSON array.",
)
if any(not isinstance(f, str) for f in denied_raw):
raise WrenError(
ErrorCode.GENERIC_USER_ERROR,
f"{config_path}: 'denied_functions' must contain only strings.",
)
denied_functions = frozenset(f.lower() for f in denied_raw)

return WrenConfig(strict_mode=strict_mode_raw, denied_functions=denied_functions)
24 changes: 23 additions & 1 deletion wren/src/wren/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,20 @@

from __future__ import annotations

import base64
import json
from typing import Any

import pyarrow as pa
from sqlglot import exp, parse_one

from wren.config import WrenConfig
from wren.connector.factory import get_connector
from wren.mdl import get_manifest_extractor, get_session_context, to_json_base64
from wren.mdl.cte_rewriter import CTERewriter, get_sqlglot_dialect
from wren.model.data_source import DataSource
from wren.model.error import DIALECT_SQL, ErrorCode, ErrorPhase, WrenError
from wren.policy import validate_sql_policy


class WrenEngine:
Expand All @@ -56,6 +60,7 @@ def __init__(
function_path: str | None = None,
*,
fallback: bool = True,
config: WrenConfig | None = None,
):
if isinstance(data_source, str):
data_source = DataSource(data_source)
Expand All @@ -64,6 +69,7 @@ def __init__(
self.data_source = data_source
self.function_path = function_path
self._fallback = fallback
self._config = config or WrenConfig()

# Build typed ConnectionInfo if a raw dict was given.
# An empty dict is allowed for transpile-only usage (no DB connection).
Expand Down Expand Up @@ -164,11 +170,27 @@ def _plan(self, sql: str, properties: dict | None) -> str:
# Use sqlglot (not DataFusion parser) since input is target dialect.
dialect = get_sqlglot_dialect(self.data_source)
ast = parse_one(sql, dialect=dialect)

# Policy validation: check tables and functions before execution.
if self._config.strict_mode or self._config.denied_functions:
manifest_json = json.loads(base64.b64decode(self.manifest_str))
model_names = {m["name"] for m in manifest_json.get("models", [])}
validate_sql_policy(ast, model_names, self._config)

tables = [t.name for t in ast.find_all(exp.Table)]
extractor = get_manifest_extractor(self.manifest_str)
manifest = extractor.extract_by(tables)
effective_manifest = to_json_base64(manifest)
except Exception:
except WrenError:
raise
except Exception as e:
if self._config.strict_mode or self._config.denied_functions:
raise WrenError(
ErrorCode.INVALID_SQL,
str(e),
phase=ErrorPhase.SQL_PLANNING,
metadata={DIALECT_SQL: sql},
) from e
effective_manifest = self.manifest_str

try:
Expand Down
3 changes: 3 additions & 0 deletions wren/src/wren/model/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class ErrorCode(int, Enum):
VALIDATION_PARAMETER_ERROR = 10
GET_CONNECTION_ERROR = 11
INVALID_CONNECTION_INFO = 12
MODEL_NOT_FOUND = 13
BLOCKED_FUNCTION = 14
GENERIC_INTERNAL_ERROR = 100
LEGACY_ENGINE_ERROR = 101
NOT_IMPLEMENTED = 102
Expand All @@ -40,6 +42,7 @@ class ErrorPhase(int, Enum):
METADATA_FETCHING = 9
VALIDATION = 10
SQL_SUBSTITUTE = 11
SQL_POLICY_CHECK = 12


class WrenError(Exception):
Expand Down
120 changes: 120 additions & 0 deletions wren/src/wren/policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""SQL policy validation for strict query mode.

Validates that a parsed SQL AST only references tables defined in the MDL
manifest and does not use any denied functions.
"""

from __future__ import annotations

from sqlglot import exp

from wren.config import WrenConfig
from wren.model.error import ErrorCode, ErrorPhase, WrenError


def validate_sql_policy(
ast: exp.Expression,
model_names: set[str],
config: WrenConfig,
) -> None:
"""Raise ``WrenError`` if the SQL violates strict-mode policies.

Parameters
----------
ast:
Parsed sqlglot AST of the user query.
model_names:
Set of model names defined in the MDL manifest.
config:
Wren configuration with strict_mode and denied_functions settings.
"""
if config.strict_mode:
_check_tables(ast, model_names)
if config.denied_functions:
_check_functions(ast, config.denied_functions)


def _visible_cte_names(node: exp.Expression) -> set[str]:
"""Return CTE names visible at *node*'s scope by walking up the AST."""
names: set[str] = set()
cursor = node.parent
while cursor is not None:
# A WITH clause is visible to its parent SELECT and siblings.
with_clause = cursor.args.get("with_") if hasattr(cursor, "args") else None
if isinstance(with_clause, exp.With):
for cte in with_clause.expressions:
alias = cte.args.get("alias")
if alias:
cte_name = (
alias.this.name
if isinstance(alias.this, exp.Identifier)
else str(alias.this)
)
names.add(cte_name.lower())
cursor = cursor.parent
return names


def _check_tables(
ast: exp.Expression,
model_names: set[str],
) -> None:
model_names_lower = {n.lower() for n in model_names}

for table in ast.find_all(exp.Table):
name = table.name
if not name:
# Table nodes with no name are table-valued functions
# (e.g. read_csv(), generate_series()). Block them in strict mode.
sql_text = table.sql()
if sql_text:
raise WrenError(
ErrorCode.MODEL_NOT_FOUND,
f"Table-valued function '{sql_text}' is not allowed. "
"In strict mode, all table references must correspond to MDL models.",
phase=ErrorPhase.SQL_POLICY_CHECK,
)
continue
name_lower = name.lower()
if name_lower in model_names_lower:
continue
if name_lower in _visible_cte_names(table):
continue
raise WrenError(
ErrorCode.MODEL_NOT_FOUND,
f"Table '{name}' is not defined in the MDL manifest. "
"In strict mode, all table references must correspond to MDL models.",
phase=ErrorPhase.SQL_POLICY_CHECK,
)

# Func subclasses used as FROM sources (e.g. UNNEST) produce no exp.Table
# node at all. Scan for Func nodes inside From clauses.
for from_clause in ast.find_all(exp.From):
source = from_clause.this
if isinstance(source, exp.Alias):
source = source.this
if isinstance(source, exp.Func):
raise WrenError(
ErrorCode.MODEL_NOT_FOUND,
f"Table-valued function '{source.sql()}' is not allowed. "
"In strict mode, all table references must correspond to MDL models.",
phase=ErrorPhase.SQL_POLICY_CHECK,
)


def _check_functions(
ast: exp.Expression,
denied: frozenset[str],
) -> None:
for func in ast.find_all(exp.Func):
if isinstance(func, exp.Anonymous):
name = func.name
else:
name = type(func).key
if name.lower() in denied:
raise WrenError(
ErrorCode.BLOCKED_FUNCTION,
f"Function '{name}' is not allowed. "
"This function is on the denied list.",
phase=ErrorPhase.SQL_POLICY_CHECK,
)
Loading
Loading