diff --git a/wren/README.md b/wren/README.md index ae5b9aa37..de956b4f8 100644 --- a/wren/README.md +++ b/wren/README.md @@ -50,7 +50,21 @@ pip install wren-engine[all] # All connectors } ``` -**3. Run queries** — `wren` auto-discovers both files from `~/.wren`: +**3. (Optional) Create `~/.wren/config.json`** — security policy: + +```json +{ + "strict_mode": true, + "denied_functions": ["pg_read_file", "dblink", "lo_import"] +} +``` + +| Key | Default | Description | +|-----|---------|-------------| +| `strict_mode` | `false` | When `true`, every table in a query must be defined in the MDL. Queries referencing undeclared tables are rejected before execution. | +| `denied_functions` | `[]` | List of function names (case-insensitive) that are forbidden in queries. | + +**4. Run queries** — `wren` auto-discovers all files from `~/.wren` (override with `WREN_HOME=/path/to/dir`): ```bash wren --sql 'SELECT order_id FROM "orders" LIMIT 10' diff --git a/wren/src/wren/cli.py b/wren/src/wren/cli.py index 7fc069275..2882be7fe 100644 --- a/wren/src/wren/cli.py +++ b/wren/src/wren/cli.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +import os from pathlib import Path from typing import Annotated, Optional @@ -10,7 +11,7 @@ app = typer.Typer(name="wren", help="Wren Engine CLI", no_args_is_help=False) -_WREN_HOME = Path.home() / ".wren" +_WREN_HOME = Path(os.environ.get("WREN_HOME", Path.home() / ".wren")).expanduser() _DEFAULT_MDL = _WREN_HOME / "mdl.json" _DEFAULT_CONN = _WREN_HOME / "connection_info.json" @@ -126,8 +127,10 @@ def _build_engine( *, conn_required: bool = True, ): + from wren.config import load_config # noqa: PLC0415 from wren.engine import WrenEngine # noqa: PLC0415 from wren.model.data_source import DataSource # noqa: PLC0415 + from wren.model.error import WrenError # noqa: PLC0415 manifest_str = _load_manifest(_require_mdl(mdl)) conn_dict = _load_conn(connection_info, connection_file, required=conn_required) @@ -139,9 +142,17 @@ def _build_engine( typer.echo(f"Error: unknown datasource '{ds_str}'", err=True) raise typer.Exit(1) - return WrenEngine( - manifest_str=manifest_str, data_source=ds, connection_info=conn_dict - ) + try: + config = load_config(_WREN_HOME) + return WrenEngine( + manifest_str=manifest_str, + data_source=ds, + connection_info=conn_dict, + config=config, + ) + except (WrenError, OSError) as e: + typer.echo(f"Error: {e}", err=True) + raise typer.Exit(1) from e # ── Shared option types ──────────────────────────────────────────────────── @@ -281,8 +292,10 @@ def dry_plan( connection_file: ConnFileOpt = None, ): """Plan SQL through MDL and print the expanded SQL (no DB required).""" + from wren.config import load_config # noqa: PLC0415 from wren.engine import WrenEngine # noqa: PLC0415 from wren.model.data_source import DataSource # noqa: PLC0415 + from wren.model.error import WrenError # noqa: PLC0415 manifest_str = _load_manifest(_require_mdl(mdl)) # Read datasource from connection_info.json only when --datasource is not given @@ -299,8 +312,14 @@ def dry_plan( typer.echo(f"Error: unknown datasource '{ds_str}'", err=True) raise typer.Exit(1) + try: + config = load_config(_WREN_HOME) + except (WrenError, OSError) as e: + typer.echo(f"Error: {e}", err=True) + raise typer.Exit(1) from e + with WrenEngine( - manifest_str=manifest_str, data_source=ds, connection_info={} + manifest_str=manifest_str, data_source=ds, connection_info={}, config=config ) as engine: try: result = engine.dry_plan(sql) diff --git a/wren/src/wren/config.py b/wren/src/wren/config.py new file mode 100644 index 000000000..790f016cd --- /dev/null +++ b/wren/src/wren/config.py @@ -0,0 +1,74 @@ +"""Wren CLI configuration loaded from ~/.wren/config.json.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from pathlib import Path + +from wren.model.error import ErrorCode, WrenError + + +@dataclass(frozen=True) +class WrenConfig: + """Immutable configuration for the Wren CLI. + + Attributes + ---------- + strict_mode: + When ``True``, all table references in SQL must be defined in the MDL + manifest. Queries referencing non-MDL tables are rejected. + denied_functions: + Set of function names (lowercase) that are forbidden in SQL queries. + Matching is case-insensitive. + """ + + strict_mode: bool = False + denied_functions: frozenset[str] = field(default_factory=frozenset) + + +def load_config(wren_home: Path) -> WrenConfig: + """Load configuration from ``wren_home/config.json``. + + Returns default ``WrenConfig`` when the file does not exist. + Raises ``WrenError`` when the file exists but contains invalid JSON. + """ + config_path = wren_home / "config.json" + if not config_path.exists(): + return WrenConfig() + + try: + raw = json.loads(config_path.read_text()) + except (json.JSONDecodeError, OSError) as e: + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, + f"Failed to read {config_path}: {e}", + ) from e + + if not isinstance(raw, dict): + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, + f"{config_path} must contain a JSON object.", + ) + + strict_mode_raw = raw.get("strict_mode", False) + if not isinstance(strict_mode_raw, bool): + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, + f"{config_path}: 'strict_mode' must be a JSON boolean.", + ) + + denied_raw = raw.get("denied_functions", []) + if not isinstance(denied_raw, list): + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, + f"{config_path}: 'denied_functions' must be a JSON array.", + ) + if any(not isinstance(f, str) for f in denied_raw): + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, + f"{config_path}: 'denied_functions' must contain only strings.", + ) + denied_functions = frozenset(f.lower() for f in denied_raw) + + return WrenConfig(strict_mode=strict_mode_raw, denied_functions=denied_functions) diff --git a/wren/src/wren/engine.py b/wren/src/wren/engine.py index 49a012a2c..49711949b 100644 --- a/wren/src/wren/engine.py +++ b/wren/src/wren/engine.py @@ -20,16 +20,20 @@ from __future__ import annotations +import base64 +import json from typing import Any import pyarrow as pa from sqlglot import exp, parse_one +from wren.config import WrenConfig from wren.connector.factory import get_connector from wren.mdl import get_manifest_extractor, get_session_context, to_json_base64 from wren.mdl.cte_rewriter import CTERewriter, get_sqlglot_dialect from wren.model.data_source import DataSource from wren.model.error import DIALECT_SQL, ErrorCode, ErrorPhase, WrenError +from wren.policy import validate_sql_policy class WrenEngine: @@ -56,6 +60,7 @@ def __init__( function_path: str | None = None, *, fallback: bool = True, + config: WrenConfig | None = None, ): if isinstance(data_source, str): data_source = DataSource(data_source) @@ -64,6 +69,7 @@ def __init__( self.data_source = data_source self.function_path = function_path self._fallback = fallback + self._config = config or WrenConfig() # Build typed ConnectionInfo if a raw dict was given. # An empty dict is allowed for transpile-only usage (no DB connection). @@ -164,11 +170,27 @@ def _plan(self, sql: str, properties: dict | None) -> str: # Use sqlglot (not DataFusion parser) since input is target dialect. dialect = get_sqlglot_dialect(self.data_source) ast = parse_one(sql, dialect=dialect) + + # Policy validation: check tables and functions before execution. + if self._config.strict_mode or self._config.denied_functions: + manifest_json = json.loads(base64.b64decode(self.manifest_str)) + model_names = {m["name"] for m in manifest_json.get("models", [])} + validate_sql_policy(ast, model_names, self._config) + tables = [t.name for t in ast.find_all(exp.Table)] extractor = get_manifest_extractor(self.manifest_str) manifest = extractor.extract_by(tables) effective_manifest = to_json_base64(manifest) - except Exception: + except WrenError: + raise + except Exception as e: + if self._config.strict_mode or self._config.denied_functions: + raise WrenError( + ErrorCode.INVALID_SQL, + str(e), + phase=ErrorPhase.SQL_PLANNING, + metadata={DIALECT_SQL: sql}, + ) from e effective_manifest = self.manifest_str try: diff --git a/wren/src/wren/model/error.py b/wren/src/wren/model/error.py index 263f0710e..6ee5f57ab 100644 --- a/wren/src/wren/model/error.py +++ b/wren/src/wren/model/error.py @@ -19,6 +19,8 @@ class ErrorCode(int, Enum): VALIDATION_PARAMETER_ERROR = 10 GET_CONNECTION_ERROR = 11 INVALID_CONNECTION_INFO = 12 + MODEL_NOT_FOUND = 13 + BLOCKED_FUNCTION = 14 GENERIC_INTERNAL_ERROR = 100 LEGACY_ENGINE_ERROR = 101 NOT_IMPLEMENTED = 102 @@ -40,6 +42,7 @@ class ErrorPhase(int, Enum): METADATA_FETCHING = 9 VALIDATION = 10 SQL_SUBSTITUTE = 11 + SQL_POLICY_CHECK = 12 class WrenError(Exception): diff --git a/wren/src/wren/policy.py b/wren/src/wren/policy.py new file mode 100644 index 000000000..80a9acb9a --- /dev/null +++ b/wren/src/wren/policy.py @@ -0,0 +1,120 @@ +"""SQL policy validation for strict query mode. + +Validates that a parsed SQL AST only references tables defined in the MDL +manifest and does not use any denied functions. +""" + +from __future__ import annotations + +from sqlglot import exp + +from wren.config import WrenConfig +from wren.model.error import ErrorCode, ErrorPhase, WrenError + + +def validate_sql_policy( + ast: exp.Expression, + model_names: set[str], + config: WrenConfig, +) -> None: + """Raise ``WrenError`` if the SQL violates strict-mode policies. + + Parameters + ---------- + ast: + Parsed sqlglot AST of the user query. + model_names: + Set of model names defined in the MDL manifest. + config: + Wren configuration with strict_mode and denied_functions settings. + """ + if config.strict_mode: + _check_tables(ast, model_names) + if config.denied_functions: + _check_functions(ast, config.denied_functions) + + +def _visible_cte_names(node: exp.Expression) -> set[str]: + """Return CTE names visible at *node*'s scope by walking up the AST.""" + names: set[str] = set() + cursor = node.parent + while cursor is not None: + # A WITH clause is visible to its parent SELECT and siblings. + with_clause = cursor.args.get("with_") if hasattr(cursor, "args") else None + if isinstance(with_clause, exp.With): + for cte in with_clause.expressions: + alias = cte.args.get("alias") + if alias: + cte_name = ( + alias.this.name + if isinstance(alias.this, exp.Identifier) + else str(alias.this) + ) + names.add(cte_name.lower()) + cursor = cursor.parent + return names + + +def _check_tables( + ast: exp.Expression, + model_names: set[str], +) -> None: + model_names_lower = {n.lower() for n in model_names} + + for table in ast.find_all(exp.Table): + name = table.name + if not name: + # Table nodes with no name are table-valued functions + # (e.g. read_csv(), generate_series()). Block them in strict mode. + sql_text = table.sql() + if sql_text: + raise WrenError( + ErrorCode.MODEL_NOT_FOUND, + f"Table-valued function '{sql_text}' is not allowed. " + "In strict mode, all table references must correspond to MDL models.", + phase=ErrorPhase.SQL_POLICY_CHECK, + ) + continue + name_lower = name.lower() + if name_lower in model_names_lower: + continue + if name_lower in _visible_cte_names(table): + continue + raise WrenError( + ErrorCode.MODEL_NOT_FOUND, + f"Table '{name}' is not defined in the MDL manifest. " + "In strict mode, all table references must correspond to MDL models.", + phase=ErrorPhase.SQL_POLICY_CHECK, + ) + + # Func subclasses used as FROM sources (e.g. UNNEST) produce no exp.Table + # node at all. Scan for Func nodes inside From clauses. + for from_clause in ast.find_all(exp.From): + source = from_clause.this + if isinstance(source, exp.Alias): + source = source.this + if isinstance(source, exp.Func): + raise WrenError( + ErrorCode.MODEL_NOT_FOUND, + f"Table-valued function '{source.sql()}' is not allowed. " + "In strict mode, all table references must correspond to MDL models.", + phase=ErrorPhase.SQL_POLICY_CHECK, + ) + + +def _check_functions( + ast: exp.Expression, + denied: frozenset[str], +) -> None: + for func in ast.find_all(exp.Func): + if isinstance(func, exp.Anonymous): + name = func.name + else: + name = type(func).key + if name.lower() in denied: + raise WrenError( + ErrorCode.BLOCKED_FUNCTION, + f"Function '{name}' is not allowed. " + "This function is on the denied list.", + phase=ErrorPhase.SQL_POLICY_CHECK, + ) diff --git a/wren/tests/unit/test_config.py b/wren/tests/unit/test_config.py new file mode 100644 index 000000000..cd3382a1e --- /dev/null +++ b/wren/tests/unit/test_config.py @@ -0,0 +1,104 @@ +"""Unit tests for wren.config — config loading from ~/.wren/config.json.""" + +from __future__ import annotations + +import json + +import pytest + +from wren.config import WrenConfig, load_config +from wren.model.error import WrenError + +pytestmark = pytest.mark.unit + + +def test_load_config_no_file(tmp_path): + config = load_config(tmp_path) + assert config == WrenConfig() + assert config.strict_mode is False + assert config.denied_functions == frozenset() + + +def test_load_config_strict_enabled(tmp_path): + (tmp_path / "config.json").write_text(json.dumps({"strict_mode": True})) + config = load_config(tmp_path) + assert config.strict_mode is True + assert config.denied_functions == frozenset() + + +def test_load_config_with_denied_functions(tmp_path): + data = { + "strict_mode": True, + "denied_functions": ["pg_read_file", "DBLINK", "Lo_Import"], + } + (tmp_path / "config.json").write_text(json.dumps(data)) + config = load_config(tmp_path) + assert config.strict_mode is True + assert config.denied_functions == frozenset(["pg_read_file", "dblink", "lo_import"]) + + +def test_load_config_function_names_lowercased(tmp_path): + data = {"denied_functions": ["PG_READ_FILE"]} + (tmp_path / "config.json").write_text(json.dumps(data)) + config = load_config(tmp_path) + assert "pg_read_file" in config.denied_functions + + +def test_load_config_malformed_json(tmp_path): + (tmp_path / "config.json").write_text("not valid json{{{") + with pytest.raises(WrenError): + load_config(tmp_path) + + +def test_load_config_not_a_dict(tmp_path): + (tmp_path / "config.json").write_text(json.dumps([1, 2, 3])) + with pytest.raises(WrenError): + load_config(tmp_path) + + +def test_load_config_denied_functions_not_array(tmp_path): + data = {"denied_functions": "pg_read_file"} + (tmp_path / "config.json").write_text(json.dumps(data)) + with pytest.raises(WrenError): + load_config(tmp_path) + + +def test_load_config_unknown_keys_ignored(tmp_path): + data = {"strict_mode": True, "unknown_key": "value", "another": 42} + (tmp_path / "config.json").write_text(json.dumps(data)) + config = load_config(tmp_path) + assert config.strict_mode is True + + +def test_load_config_partial_only_denied_functions(tmp_path): + data = {"denied_functions": ["dblink"]} + (tmp_path / "config.json").write_text(json.dumps(data)) + config = load_config(tmp_path) + assert config.strict_mode is False + assert config.denied_functions == frozenset(["dblink"]) + + +def test_load_config_empty_object(tmp_path): + (tmp_path / "config.json").write_text(json.dumps({})) + config = load_config(tmp_path) + assert config == WrenConfig() + + +def test_load_config_strict_mode_string_rejected(tmp_path): + """'strict_mode': 'false' must not silently coerce to True.""" + (tmp_path / "config.json").write_text(json.dumps({"strict_mode": "false"})) + with pytest.raises(WrenError): + load_config(tmp_path) + + +def test_load_config_strict_mode_int_rejected(tmp_path): + (tmp_path / "config.json").write_text(json.dumps({"strict_mode": 1})) + with pytest.raises(WrenError): + load_config(tmp_path) + + +def test_load_config_denied_functions_mixed_types_rejected(tmp_path): + data = {"denied_functions": ["safe", 1, {"obj": True}]} + (tmp_path / "config.json").write_text(json.dumps(data)) + with pytest.raises(WrenError): + load_config(tmp_path) diff --git a/wren/tests/unit/test_cte_rewriter.py b/wren/tests/unit/test_cte_rewriter.py index aec5d122e..add5b2b86 100644 --- a/wren/tests/unit/test_cte_rewriter.py +++ b/wren/tests/unit/test_cte_rewriter.py @@ -7,9 +7,8 @@ import orjson import pytest import sqlglot -from sqlglot import exp -from wren.mdl import get_session_context, to_json_base64 +from wren.mdl import get_session_context from wren.mdl.cte_rewriter import CTERewriter, get_sqlglot_dialect from wren.model.data_source import DataSource @@ -157,7 +156,7 @@ class TestMultiModel: def test_join_generates_two_ctes(self): rw = _make_rewriter(_MULTI_MODEL_MANIFEST) result = rw.rewrite( - 'SELECT o.o_orderkey, c.c_name ' + "SELECT o.o_orderkey, c.c_name " 'FROM "orders" o JOIN "customer" c ON o.o_custkey = c.c_custkey' ) assert _has_cte(result, "orders") @@ -166,9 +165,7 @@ def test_join_generates_two_ctes(self): def test_qualified_star(self): rw = _make_rewriter(_MULTI_MODEL_MANIFEST) - result = rw.rewrite( - 'SELECT "orders".* FROM "orders"' - ) + result = rw.rewrite('SELECT "orders".* FROM "orders"') assert _has_cte(result, "orders") @@ -182,8 +179,7 @@ def test_user_cte_references_model(self): """User CTE that references a model — model CTE should be prepended.""" rw = _make_rewriter(_SINGLE_MODEL_MANIFEST) result = rw.rewrite( - 'WITH summary AS (SELECT o_orderkey FROM "orders") ' - "SELECT * FROM summary" + 'WITH summary AS (SELECT o_orderkey FROM "orders") SELECT * FROM summary' ) assert _has_cte(result, "orders") assert _has_cte(result, "summary") @@ -197,9 +193,7 @@ def test_user_cte_references_model(self): def test_user_cte_shadows_model(self): """User CTE with same name as model — no model CTE generated for that name.""" rw = _make_rewriter(_SINGLE_MODEL_MANIFEST, fallback=True) - result = rw.rewrite( - "WITH orders AS (SELECT 1 AS id) SELECT * FROM orders" - ) + result = rw.rewrite("WITH orders AS (SELECT 1 AS id) SELECT * FROM orders") # The model name is shadowed by user CTE, so CTERewriter skips it # and falls back to session_context.transform_sql which processes # the whole query (wren-core handles user CTEs natively). diff --git a/wren/tests/unit/test_engine.py b/wren/tests/unit/test_engine.py index d3c82d5b9..1abb5b847 100644 --- a/wren/tests/unit/test_engine.py +++ b/wren/tests/unit/test_engine.py @@ -7,14 +7,14 @@ from __future__ import annotations import base64 -import math import orjson import pytest from wren import WrenEngine +from wren.config import WrenConfig from wren.model.data_source import DataSource -from wren.model.error import WrenError +from wren.model.error import ErrorCode, WrenError pytestmark = pytest.mark.unit @@ -109,3 +109,65 @@ def test_context_manager_closes_connector() -> None: # After __exit__, internal state is cleaned up assert e._connector is None + + +# ------------------------------------------------------------------ +# Strict mode (no DB access) +# ------------------------------------------------------------------ + +_STRICT_CONFIG = WrenConfig(strict_mode=True) +_BLACKLIST_CONFIG = WrenConfig(denied_functions=frozenset(["pg_read_file"])) +_COMBINED_CONFIG = WrenConfig( + strict_mode=True, denied_functions=frozenset(["pg_read_file"]) +) + + +def test_strict_mode_blocks_unknown_table(): + conn_info = {"url": "/tmp", "format": "duckdb"} + with WrenEngine( + _MANIFEST_STR, + DataSource.duckdb, + conn_info, + fallback=False, + config=_STRICT_CONFIG, + ) as engine: + with pytest.raises(WrenError) as exc_info: + engine.dry_plan("SELECT * FROM secret_table") + assert exc_info.value.error_code == ErrorCode.MODEL_NOT_FOUND + + +def test_strict_mode_allows_mdl_table(): + conn_info = {"url": "/tmp", "format": "duckdb"} + with WrenEngine( + _MANIFEST_STR, + DataSource.duckdb, + conn_info, + fallback=False, + config=_STRICT_CONFIG, + ) as engine: + sql = engine.dry_plan('SELECT o_orderkey FROM "orders" LIMIT 1') + assert isinstance(sql, str) + assert len(sql) > 0 + + +def test_strict_mode_blocks_denied_function(): + conn_info = {"url": "/tmp", "format": "duckdb"} + with WrenEngine( + _MANIFEST_STR, + DataSource.duckdb, + conn_info, + fallback=False, + config=_BLACKLIST_CONFIG, + ) as engine: + with pytest.raises(WrenError) as exc_info: + engine.dry_plan("SELECT pg_read_file('/etc/passwd')") + assert exc_info.value.error_code == ErrorCode.BLOCKED_FUNCTION + + +def test_non_strict_mode_allows_unknown_table(duckdb_engine: WrenEngine): + # Default config (no strict mode) — non-MDL tables should not be blocked + # by policy (may still fail during planning, but not with MODEL_NOT_FOUND) + try: + duckdb_engine.dry_plan("SELECT * FROM unknown_table") + except WrenError as e: + assert e.error_code != ErrorCode.MODEL_NOT_FOUND diff --git a/wren/tests/unit/test_policy.py b/wren/tests/unit/test_policy.py new file mode 100644 index 000000000..7ccafca0f --- /dev/null +++ b/wren/tests/unit/test_policy.py @@ -0,0 +1,208 @@ +"""Unit tests for wren.policy — SQL policy validation. + +These tests use sqlglot parsing only and do not require a database or wren-core. +""" + +from __future__ import annotations + +import pytest +from sqlglot import parse_one + +from wren.config import WrenConfig +from wren.model.error import ErrorCode, WrenError +from wren.policy import validate_sql_policy + +pytestmark = pytest.mark.unit + +_MODELS = {"orders", "customers"} + + +# ── Table validation ────────────────────────────────────────────────────── + + +def test_valid_query_all_tables_in_mdl(): + ast = parse_one('SELECT * FROM "orders"', dialect="duckdb") + config = WrenConfig(strict_mode=True) + validate_sql_policy(ast, _MODELS, config) + + +def test_table_not_in_mdl_raises(): + ast = parse_one("SELECT * FROM pg_shadow", dialect="postgres") + config = WrenConfig(strict_mode=True) + with pytest.raises(WrenError) as exc_info: + validate_sql_policy(ast, _MODELS, config) + assert exc_info.value.error_code == ErrorCode.MODEL_NOT_FOUND + assert "pg_shadow" in str(exc_info.value) + + +def test_user_cte_not_flagged(): + sql = "WITH foo AS (SELECT 1 AS x) SELECT * FROM foo" + ast = parse_one(sql, dialect="duckdb") + config = WrenConfig(strict_mode=True) + validate_sql_policy(ast, _MODELS, config) + + +def test_mixed_mdl_and_non_mdl_table_raises(): + sql = 'SELECT * FROM "orders" JOIN secret_table ON 1=1' + ast = parse_one(sql, dialect="duckdb") + config = WrenConfig(strict_mode=True) + with pytest.raises(WrenError) as exc_info: + validate_sql_policy(ast, _MODELS, config) + assert exc_info.value.error_code == ErrorCode.MODEL_NOT_FOUND + assert "secret_table" in str(exc_info.value) + + +def test_strict_mode_off_allows_unknown_table(): + ast = parse_one("SELECT * FROM unknown_table", dialect="duckdb") + config = WrenConfig(strict_mode=False) + validate_sql_policy(ast, _MODELS, config) + + +def test_subquery_alias_not_flagged(): + sql = "SELECT * FROM (SELECT 1 AS x) AS t" + ast = parse_one(sql, dialect="duckdb") + config = WrenConfig(strict_mode=True) + # 't' is a subquery alias, not a real table — but sqlglot doesn't emit + # an exp.Table for it, so this should pass without error. + # The only table nodes come from real FROM references. + validate_sql_policy(ast, _MODELS, config) + + +def test_multiple_valid_tables(): + sql = 'SELECT * FROM "orders" o JOIN "customers" c ON o.id = c.id' + ast = parse_one(sql, dialect="duckdb") + config = WrenConfig(strict_mode=True) + validate_sql_policy(ast, _MODELS, config) + + +# ── Denied functions ────────────────────────────────────────────────────── + + +def test_denied_function_raises(): + ast = parse_one("SELECT pg_read_file('/etc/passwd')", dialect="postgres") + config = WrenConfig(denied_functions=frozenset(["pg_read_file"])) + with pytest.raises(WrenError) as exc_info: + validate_sql_policy(ast, _MODELS, config) + assert exc_info.value.error_code == ErrorCode.BLOCKED_FUNCTION + assert "pg_read_file" in str(exc_info.value) + + +def test_denied_function_case_insensitive(): + ast = parse_one("SELECT PG_READ_FILE('/etc/passwd')", dialect="postgres") + config = WrenConfig(denied_functions=frozenset(["pg_read_file"])) + with pytest.raises(WrenError) as exc_info: + validate_sql_policy(ast, _MODELS, config) + assert exc_info.value.error_code == ErrorCode.BLOCKED_FUNCTION + + +def test_allowed_function_passes(): + ast = parse_one('SELECT COUNT(*) FROM "orders"', dialect="duckdb") + config = WrenConfig(denied_functions=frozenset(["pg_read_file"])) + validate_sql_policy(ast, _MODELS, config) + + +def test_builtin_function_on_denied_list(): + ast = parse_one('SELECT COUNT(*) FROM "orders"', dialect="duckdb") + config = WrenConfig(denied_functions=frozenset(["count"])) + with pytest.raises(WrenError) as exc_info: + validate_sql_policy(ast, _MODELS, config) + assert exc_info.value.error_code == ErrorCode.BLOCKED_FUNCTION + + +def test_nested_denied_function(): + sql = "SELECT * FROM (SELECT dblink('host=evil', 'SELECT 1') AS x) AS t" + ast = parse_one(sql, dialect="postgres") + config = WrenConfig(denied_functions=frozenset(["dblink"])) + with pytest.raises(WrenError) as exc_info: + validate_sql_policy(ast, _MODELS, config) + assert exc_info.value.error_code == ErrorCode.BLOCKED_FUNCTION + + +def test_no_denied_list_allows_everything(): + ast = parse_one("SELECT pg_read_file('/etc/passwd')", dialect="postgres") + config = WrenConfig(denied_functions=frozenset()) + validate_sql_policy(ast, _MODELS, config) + + +def test_empty_denied_list_allows_everything(): + ast = parse_one("SELECT dblink('host=evil', 'SELECT 1')", dialect="postgres") + config = WrenConfig() + validate_sql_policy(ast, _MODELS, config) + + +# ── Combined strict_mode + denied_functions ─────────────────────────────── + + +def test_strict_mode_and_denied_functions_together(): + sql = 'SELECT pg_read_file(o_orderkey) FROM "orders"' + ast = parse_one(sql, dialect="postgres") + config = WrenConfig(strict_mode=True, denied_functions=frozenset(["pg_read_file"])) + with pytest.raises(WrenError) as exc_info: + validate_sql_policy(ast, _MODELS, config) + # Either error is acceptable — table check runs first, but orders is valid + # so function check should fire + assert exc_info.value.error_code == ErrorCode.BLOCKED_FUNCTION + + +# ── CTE scope shadowing ────────────────────────────────────────────────── + + +def test_nested_cte_does_not_shadow_outer_table(): + """A CTE defined inside a subquery must not hide an outer FROM reference.""" + sql = """ + SELECT * + FROM secret_table + WHERE EXISTS ( + WITH secret_table AS (SELECT 1) + SELECT 1 FROM secret_table + ) + """ + ast = parse_one(sql, dialect="duckdb") + config = WrenConfig(strict_mode=True) + with pytest.raises(WrenError) as exc_info: + validate_sql_policy(ast, _MODELS, config) + assert exc_info.value.error_code == ErrorCode.MODEL_NOT_FOUND + assert "secret_table" in str(exc_info.value) + + +def test_outer_cte_visible_in_body(): + """A CTE defined at the top level should be visible in the main SELECT.""" + sql = 'WITH tmp AS (SELECT 1 AS x FROM "orders") SELECT * FROM tmp' + ast = parse_one(sql, dialect="duckdb") + config = WrenConfig(strict_mode=True) + validate_sql_policy(ast, _MODELS, config) + + +# ── Table-valued functions ──────────────────────────────────────────────── + + +def test_tvf_read_csv_blocked(): + ast = parse_one("SELECT * FROM read_csv('s3://bucket/file.csv')", dialect="duckdb") + config = WrenConfig(strict_mode=True) + with pytest.raises(WrenError) as exc_info: + validate_sql_policy(ast, _MODELS, config) + assert exc_info.value.error_code == ErrorCode.MODEL_NOT_FOUND + + +def test_tvf_generate_series_blocked(): + sql = "SELECT * FROM generate_series(1, 10) AS t(x)" + ast = parse_one(sql, dialect="postgres") + config = WrenConfig(strict_mode=True) + with pytest.raises(WrenError) as exc_info: + validate_sql_policy(ast, _MODELS, config) + assert exc_info.value.error_code == ErrorCode.MODEL_NOT_FOUND + + +def test_tvf_unnest_blocked(): + sql = "SELECT * FROM unnest(ARRAY[1,2,3]) AS t(x)" + ast = parse_one(sql, dialect="duckdb") + config = WrenConfig(strict_mode=True) + with pytest.raises(WrenError) as exc_info: + validate_sql_policy(ast, _MODELS, config) + assert exc_info.value.error_code == ErrorCode.MODEL_NOT_FOUND + + +def test_tvf_allowed_when_not_strict(): + ast = parse_one("SELECT * FROM read_csv('file.csv')", dialect="duckdb") + config = WrenConfig(strict_mode=False) + validate_sql_policy(ast, _MODELS, config)