From ef591986aef75131a92569e8dc3c596c7df61a4f Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 25 Mar 2026 21:59:58 +0800 Subject: [PATCH 01/10] chore: update duckdb arrow method --- wren/src/wren/connector/duckdb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wren/src/wren/connector/duckdb.py b/wren/src/wren/connector/duckdb.py index c27fe93e5..d52e3defe 100644 --- a/wren/src/wren/connector/duckdb.py +++ b/wren/src/wren/connector/duckdb.py @@ -75,7 +75,7 @@ def __init__(self, connection_info): def query(self, sql: str, limit: int | None = None) -> pa.Table: if limit is not None: sql = f"SELECT * FROM ({sql}) AS _q LIMIT {int(limit)}" - return self.connection.execute(sql).fetch_arrow_table() + return self.connection.execute(sql).to_arrow_table() def dry_run(self, sql: str) -> None: self.connection.execute(f"EXPLAIN {sql}") From 13b71791ecb17dd094242b5449b77aa483cdddcb Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 25 Mar 2026 22:00:31 +0800 Subject: [PATCH 02/10] feat(wren): CTE-based SQL planning with per-model expansion Replace the whole-query `session_context.transform_sql()` approach with a CTE-based rewriter that: 1. Parses user SQL with sqlglot (source dialect) 2. Uses qualify_tables + normalize_identifiers + qualify_columns to resolve all column references (including SELECT *, aliases, and correlated subqueries) 3. Calls wren-core transform_sql once per model with a simple SELECT col1, col2 FROM model 4. Parses each wren-core output with Wren dialect and injects as CTEs 5. Outputs final SQL in the target data source dialect This eliminates the need for a separate _transpile() step since dialect conversion happens naturally via sqlglot AST generation. Also: - Replace CLI `transpile` command with `dry-plan` - Remove WrenEngine.transpile(), unify on dry_plan() - Add fallback parameter to CTERewriter/WrenEngine for test strictness - Fix relationship column filter (handle null values from extracted manifests) - All tests use fallback=False to catch silent CTE path failures Co-Authored-By: Claude Opus 4.6 (1M context) --- wren/pyproject.toml | 2 +- wren/src/wren/cli.py | 9 +- wren/src/wren/engine.py | 58 ++---- wren/src/wren/mdl/cte_rewriter.py | 226 ++++++++++++++++++++++ wren/src/wren/mdl/wren_dialect.py | 19 ++ wren/tests/connectors/test_duckdb.py | 2 +- wren/tests/connectors/test_postgres.py | 2 +- wren/tests/suite/query.py | 12 +- wren/tests/unit/test_cte_rewriter.py | 250 +++++++++++++++++++++++++ wren/tests/unit/test_engine.py | 40 ++-- wren/uv.lock | 18 +- 11 files changed, 543 insertions(+), 95 deletions(-) create mode 100644 wren/src/wren/mdl/cte_rewriter.py create mode 100644 wren/src/wren/mdl/wren_dialect.py create mode 100644 wren/tests/unit/test_cte_rewriter.py diff --git a/wren/pyproject.toml b/wren/pyproject.toml index c3393bdb8..8913cd852 100644 --- a/wren/pyproject.toml +++ b/wren/pyproject.toml @@ -46,7 +46,7 @@ redshift = ["redshift_connector"] spark = ["pyspark>=3.5"] athena = ["ibis-framework[athena]"] oracle = ["ibis-framework[oracle]", "oracledb"] -all = ["wren[postgres,mysql,bigquery,snowflake,clickhouse,trino,mssql,databricks,redshift,athena,oracle]"] +all = ["wren-engine[postgres,mysql,bigquery,snowflake,clickhouse,trino,mssql,databricks,redshift,athena,oracle]"] dev = ["pytest>=8", "ruff>=0.4", "orjson>=3", "testcontainers[postgres]>=4"] [project.scripts] diff --git a/wren/src/wren/cli.py b/wren/src/wren/cli.py index 8369b1173..b19338fe4 100644 --- a/wren/src/wren/cli.py +++ b/wren/src/wren/cli.py @@ -147,13 +147,13 @@ def dry_run( engine.close() -@app.command() -def transpile( +@app.command(name="dry-plan") +def dry_plan( sql: SqlArg, datasource: DatasourceOpt, mdl: MdlOpt, ): - """Transform SQL through MDL and emit the target dialect SQL (no DB required).""" + """Plan SQL through MDL and print the expanded SQL (no DB required).""" from wren.engine import WrenEngine # noqa: PLC0415 from wren.model.data_source import DataSource # noqa: PLC0415 @@ -164,10 +164,9 @@ def transpile( typer.echo(f"Error: unknown datasource '{datasource}'", err=True) raise typer.Exit(1) - # For transpile we don't need real connection_info — pass a dummy dict engine = WrenEngine(manifest_str=manifest_str, data_source=ds, connection_info={}) try: - result = engine.transpile(sql) + result = engine.dry_plan(sql) typer.echo(result) except Exception as e: typer.echo(f"Error: {e}", err=True) diff --git a/wren/src/wren/engine.py b/wren/src/wren/engine.py index 8d2ad5472..48dcdb2dc 100644 --- a/wren/src/wren/engine.py +++ b/wren/src/wren/engine.py @@ -11,8 +11,8 @@ connection_info={"host": "localhost", "port": 5432, ...}, ) - # Transform only (no DB required) - planned_sql = engine.transpile("SELECT * FROM orders") + # Plan only (no DB required) + planned_sql = engine.dry_plan("SELECT * FROM orders") # Execute against the data source arrow_table = engine.query("SELECT * FROM orders", limit=100) @@ -23,25 +23,12 @@ from typing import Any import pyarrow as pa -import sqlglot from wren.connector.factory import get_connector from wren.mdl import get_manifest_extractor, get_session_context, to_json_base64 +from wren.mdl.cte_rewriter import CTERewriter from wren.model.data_source import DataSource -from wren.model.error import DIALECT_SQL, PLANNED_SQL, ErrorCode, ErrorPhase, WrenError - - -def _get_write_dialect(data_source: DataSource) -> str: - if data_source == DataSource.canner: - return "trino" - if data_source in { - DataSource.local_file, - DataSource.s3_file, - DataSource.minio_file, - DataSource.gcs_file, - }: - return "duckdb" - return data_source.name +from wren.model.error import DIALECT_SQL, ErrorCode, ErrorPhase, WrenError class WrenEngine: @@ -66,6 +53,8 @@ def __init__( data_source: DataSource | str, connection_info: dict[str, Any] | object, function_path: str | None = None, + *, + fallback: bool = True, ): if isinstance(data_source, str): data_source = DataSource(data_source) @@ -73,6 +62,7 @@ def __init__( self.manifest_str = manifest_str self.data_source = data_source self.function_path = function_path + self._fallback = fallback # Build typed ConnectionInfo if a raw dict was given. # An empty dict is allowed for transpile-only usage (no DB connection). @@ -87,16 +77,8 @@ def __init__( # SQL transformation (no DB access) # ------------------------------------------------------------------ - def transpile(self, sql: str, properties: dict | None = None) -> str: - """Transform SQL through MDL and transpile to the target dialect. - - Returns the dialect SQL string without executing it. - """ - planned = self._plan(sql, properties) - return self._transpile(planned) - def dry_plan(self, sql: str, properties: dict | None = None) -> str: - """Return the wren-core planned SQL (DataFusion dialect, before transpile).""" + """Plan SQL through MDL and return the expanded SQL in the target dialect.""" return self._plan(sql, properties) # ------------------------------------------------------------------ @@ -110,7 +92,7 @@ def query( properties: dict | None = None, ) -> pa.Table: """Transpile and execute SQL, return results as an Arrow table.""" - dialect_sql = self.transpile(sql, properties) + dialect_sql = self.dry_plan(sql, properties) connector = self._get_connector() try: return connector.query(dialect_sql, limit) @@ -126,7 +108,7 @@ def query( def dry_run(self, sql: str, properties: dict | None = None) -> None: """Transpile and dry-run SQL without returning results.""" - dialect_sql = self.transpile(sql, properties) + dialect_sql = self.dry_plan(sql, properties) connector = self._get_connector() try: connector.dry_run(dialect_sql) @@ -180,7 +162,13 @@ def _plan(self, sql: str, properties: dict | None) -> str: processed, self.data_source.name, ) - return session.transform_sql(sql) + rewriter = CTERewriter( + effective_manifest, + session, + self.data_source, + fallback=self._fallback, + ) + return rewriter.rewrite(sql) except Exception as e: raise WrenError( ErrorCode.INVALID_SQL, @@ -189,18 +177,6 @@ def _plan(self, sql: str, properties: dict | None) -> str: metadata={DIALECT_SQL: sql}, ) from e - def _transpile(self, planned_sql: str) -> str: - try: - write = _get_write_dialect(self.data_source) - return sqlglot.transpile(planned_sql, read="duckdb", write=write)[0] - except Exception as e: - raise WrenError( - ErrorCode.SQLGLOT_ERROR, - str(e), - phase=ErrorPhase.SQL_TRANSPILE, - metadata={PLANNED_SQL: planned_sql}, - ) from e - def _get_connector(self): if self._connector is None: self._connector = get_connector(self.data_source, self.connection_info) diff --git a/wren/src/wren/mdl/cte_rewriter.py b/wren/src/wren/mdl/cte_rewriter.py new file mode 100644 index 000000000..a5f3fc5eb --- /dev/null +++ b/wren/src/wren/mdl/cte_rewriter.py @@ -0,0 +1,226 @@ +"""CTE-based SQL rewriter. + +Parses user SQL with sqlglot, uses ``qualify_columns`` to fully resolve +all column references, then calls wren-core +``transform_sql`` once per model with a simple ``SELECT col1, col2 FROM +model`` and injects each expanded result as a CTE into the original query. +""" + +from __future__ import annotations + +import base64 +import json +from collections import defaultdict + +from sqlglot import exp, parse_one +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers +from sqlglot.optimizer.qualify_columns import qualify_columns +from sqlglot.optimizer.qualify_tables import qualify_tables +from sqlglot.schema import MappingSchema + +# Ensure the Wren dialect is registered with sqlglot on import. +import wren.mdl.wren_dialect as _wren_dialect # noqa: F401 +from wren.model.data_source import DataSource + +_SQLGLOT_DIALECT_MAP: dict[DataSource, str] = { + DataSource.canner: "trino", + DataSource.mssql: "tsql", + DataSource.local_file: "duckdb", + DataSource.s3_file: "duckdb", + DataSource.minio_file: "duckdb", + DataSource.gcs_file: "duckdb", +} + + +def get_sqlglot_dialect(data_source: DataSource) -> str: + """Map a DataSource to a valid sqlglot dialect name.""" + return _SQLGLOT_DIALECT_MAP.get(data_source, data_source.name) + + +class CTERewriter: + """Rewrite user SQL by expanding MDL model references into CTEs. + + Parameters + ---------- + manifest_str: + Base64-encoded MDL JSON string. + session_context: + A ``wren_core.SessionContext`` used to expand per-model SQL. + data_source: + The target data source (determines sqlglot dialect). + fallback: + When ``True`` (default), if no model references are detected in the + SQL, fall back to ``session_context.transform_sql()`` directly. + Set to ``False`` in tests to ensure the CTE path is always exercised + and silent fallbacks don't mask bugs. + """ + + def __init__( + self, + manifest_str: str, + session_context, + data_source: DataSource, + *, + fallback: bool = True, + ): + self.session_context = session_context + self.fallback = fallback + self.dialect = get_sqlglot_dialect(data_source) + self.manifest = json.loads(base64.b64decode(manifest_str)) + + self.model_dict: dict[str, dict] = {} + self.schema = MappingSchema(dialect=self.dialect) + # normalized column name → original manifest column name, per model + self._col_orig_name: dict[str, dict[str, str]] = {} + + for model in self.manifest.get("models", []): + name = model["name"] + self.model_dict[name] = model + cols: dict[str, str] = {} + orig: dict[str, str] = {} + for col in model.get("columns", []): + if col.get("isHidden"): + continue + if col.get("relationship"): + continue + col_name = col["name"] + cols[col_name] = col.get("type", "TEXT") + orig[col_name.lower()] = col_name + self.schema.add_table(name, cols, dialect=self.dialect) + self._col_orig_name[name] = orig + + def rewrite(self, sql: str) -> str: + """Rewrite *sql* by injecting model CTEs. + + Returns the transformed SQL string in the target sqlglot dialect. + If no model references are found, falls back to + ``session_context.transform_sql(sql)`` directly. + """ + ast = parse_one(sql, dialect=self.dialect) + + user_cte_names = self._collect_user_cte_names(ast) + used_columns = self._collect_model_columns(ast, user_cte_names) + + # No model references detected — either fall back to the legacy + # whole-query transform, or raise so tests can catch the miss. + if not used_columns: + if self.fallback: + return self.session_context.transform_sql(sql) + raise ValueError(f"No model references found in SQL: {sql}") + + model_ctes = self._build_model_ctes(used_columns) + self._inject_ctes(ast, model_ctes) + return ast.sql(dialect=self.dialect) + + # ------------------------------------------------------------------ + # Column collection via qualify + # ------------------------------------------------------------------ + + def _collect_model_columns( + self, ast: exp.Expression, user_cte_names: set[str] + ) -> dict[str, set[str]]: + """Return ``{model_name: {col1, col2, ...}}`` for all referenced models. + + Uses sqlglot's ``qualify_columns`` to fully resolve all column + references (including ``SELECT *`` expansion and + correlated subquery outer references), then walks the qualified AST + to collect model→column mappings. + """ + copy = ast.copy() + copy = qualify_tables(copy, dialect=self.dialect) + copy = normalize_identifiers(copy, dialect=self.dialect) + qualified = qualify_columns( + copy, + schema=self.schema, + dialect=self.dialect, + allow_partial_qualification=True, + ) + + # Build alias → model name mapping from Table nodes + alias_to_model: dict[str, str] = {} + for table in qualified.find_all(exp.Table): + table_name = table.name + if table_name not in self.model_dict or table_name in user_cte_names: + continue + alias = table.alias + if alias: + alias_to_model[alias] = table_name + alias_to_model[table_name] = table_name + + used: dict[str, set[str]] = defaultdict(set) + for col in qualified.find_all(exp.Column): + table_ref = col.table + if not table_ref: + continue + model_name = alias_to_model.get(table_ref) + if model_name: + used[model_name].add(col.name) + + return dict(used) + + # ------------------------------------------------------------------ + # CTE generation + # ------------------------------------------------------------------ + + def _build_model_ctes(self, used_columns: dict[str, set[str]]) -> list[exp.CTE]: + """Generate one CTE per model via wren-core transform_sql.""" + ctes: list[exp.CTE] = [] + for model_name, columns in used_columns.items(): + orig = self._col_orig_name.get(model_name, {}) + resolved = [orig.get(c, c) for c in sorted(columns)] + col_list = ", ".join(f'"{model_name}"."{c}"' for c in resolved) + model_sql = f'SELECT {col_list} FROM "{model_name}"' + expanded = self.session_context.transform_sql(model_sql) + + expanded_ast = parse_one(expanded, dialect="wren") + cte = exp.CTE( + this=expanded_ast, + alias=exp.TableAlias(this=exp.to_identifier(model_name, quoted=True)), + ) + ctes.append(cte) + return ctes + + # ------------------------------------------------------------------ + # CTE injection + # ------------------------------------------------------------------ + + def _inject_ctes(self, ast: exp.Expression, model_ctes: list[exp.CTE]) -> None: + """Prepend *model_ctes* before any existing user CTEs in *ast*.""" + if not model_ctes: + return + + existing_with = ast.args.get("with_") + + if existing_with: + # Prepend model CTEs before user CTEs + existing_ctes = list(existing_with.expressions) + all_ctes = model_ctes + existing_ctes + existing_with.set("expressions", all_ctes) + else: + with_clause = exp.With(expressions=model_ctes) + ast.set("with_", with_clause) + + # Preserve RECURSIVE if the original WITH had it + final_with = ast.args.get("with_") + if existing_with and existing_with.args.get("recursive"): + final_with.set("recursive", True) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _collect_user_cte_names(ast: exp.Expression) -> set[str]: + """Collect all CTE names defined in the user's SQL.""" + names: set[str] = set() + with_clause = ast.args.get("with_") + if with_clause: + for cte in with_clause.expressions: + alias = cte.args.get("alias") + if alias: + names.add( + alias.this.name + if isinstance(alias.this, exp.Identifier) + else str(alias.this) + ) + return names diff --git a/wren/src/wren/mdl/wren_dialect.py b/wren/src/wren/mdl/wren_dialect.py new file mode 100644 index 000000000..ceebe3c26 --- /dev/null +++ b/wren/src/wren/mdl/wren_dialect.py @@ -0,0 +1,19 @@ +"""Wren sqlglot dialect — mirrors wren-core's DataFusion-based output.""" + +from sqlglot import exp, parser +from sqlglot.dialects.dialect import Dialect, build_date_delta_with_interval + + +class Wren(Dialect): + class Parser(parser.Parser): + FUNCTIONS = { + **parser.Parser.FUNCTIONS, + "DATE_ADD": build_date_delta_with_interval(exp.DateAdd), + "DATE_SUB": build_date_delta_with_interval(exp.DateSub), + "DATETIME_ADD": build_date_delta_with_interval(exp.DatetimeAdd), + "DATETIME_SUB": build_date_delta_with_interval(exp.DatetimeSub), + "TIME_ADD": build_date_delta_with_interval(exp.TimeAdd), + "TIME_SUB": build_date_delta_with_interval(exp.TimeSub), + "TIMESTAMP_ADD": build_date_delta_with_interval(exp.TimestampAdd), + "TIMESTAMP_SUB": build_date_delta_with_interval(exp.TimestampSub), + } diff --git a/wren/tests/connectors/test_duckdb.py b/wren/tests/connectors/test_duckdb.py index 1bcc4fd7f..acb9c66e0 100644 --- a/wren/tests/connectors/test_duckdb.py +++ b/wren/tests/connectors/test_duckdb.py @@ -44,5 +44,5 @@ def engine(self, tmp_path_factory) -> WrenEngine: # type: ignore[override] manifest_str = base64.b64encode(orjson.dumps(self.manifest)).decode() conn_info = {"url": str(db_dir), "format": "duckdb"} - with WrenEngine(manifest_str, DataSource.duckdb, conn_info) as e: + with WrenEngine(manifest_str, DataSource.duckdb, conn_info, fallback=False) as e: yield e diff --git a/wren/tests/connectors/test_postgres.py b/wren/tests/connectors/test_postgres.py index 9f2ae064a..fbd4593ba 100644 --- a/wren/tests/connectors/test_postgres.py +++ b/wren/tests/connectors/test_postgres.py @@ -86,5 +86,5 @@ def engine(self) -> WrenEngine: # type: ignore[override] "password": parsed.password, } manifest_str = base64.b64encode(orjson.dumps(self.manifest)).decode() - with WrenEngine(manifest_str, DataSource.postgres, conn_info) as e: + with WrenEngine(manifest_str, DataSource.postgres, conn_info, fallback=False) as e: yield e diff --git a/wren/tests/suite/query.py b/wren/tests/suite/query.py index daa4f1d7e..310186e30 100644 --- a/wren/tests/suite/query.py +++ b/wren/tests/suite/query.py @@ -34,7 +34,7 @@ def engine(self): # (4) engine fixture — class-scoped _load_tpch(ch.get_connection_url()) conn_info = { ... } manifest_str = base64.b64encode(orjson.dumps(self.manifest)).decode() - with WrenEngine(manifest_str, DataSource.clickhouse, conn_info) as e: + with WrenEngine(manifest_str, DataSource.clickhouse, conn_info, fallback=False) as e: yield e Step 2 — Register the pytest marker @@ -201,17 +201,11 @@ def test_dry_run_invalid_table(self, engine: WrenEngine) -> None: engine.dry_run('SELECT * FROM "NotFound"') # ------------------------------------------------------------------ - # Transpile / dry-plan tests (no DB access) + # dry-plan tests (no DB access) # ------------------------------------------------------------------ - def test_transpile_returns_sql(self, engine: WrenEngine) -> None: - sql = engine.transpile('SELECT o_orderkey FROM "orders" LIMIT 1') - assert isinstance(sql, str) - assert len(sql) > 0 - - def test_dry_plan_returns_datafusion_sql(self, engine: WrenEngine) -> None: + def test_dry_plan_returns_sql(self, engine: WrenEngine) -> None: planned = engine.dry_plan('SELECT o_orderkey FROM "orders" LIMIT 1') assert isinstance(planned, str) assert len(planned) > 0 - # dry_plan returns wren-core / DataFusion SQL, not dialect SQL assert "orders" in planned.lower() diff --git a/wren/tests/unit/test_cte_rewriter.py b/wren/tests/unit/test_cte_rewriter.py new file mode 100644 index 000000000..294ba7c8b --- /dev/null +++ b/wren/tests/unit/test_cte_rewriter.py @@ -0,0 +1,250 @@ +"""Unit tests for CTERewriter — no database required.""" + +from __future__ import annotations + +import base64 + +import orjson +import pytest +import sqlglot +from sqlglot import exp + +from wren.mdl import get_session_context, to_json_base64 +from wren.mdl.cte_rewriter import CTERewriter, get_sqlglot_dialect +from wren.model.data_source import DataSource + +pytestmark = pytest.mark.unit + +# --------------------------------------------------------------------------- +# Manifests +# --------------------------------------------------------------------------- + +_SINGLE_MODEL_MANIFEST = { + "catalog": "wren", + "schema": "public", + "models": [ + { + "name": "orders", + "tableReference": {"schema": "main", "table": "orders"}, + "columns": [ + {"name": "o_orderkey", "type": "integer"}, + {"name": "o_custkey", "type": "integer"}, + {"name": "o_orderstatus", "type": "varchar"}, + { + "name": "order_cust_key", + "type": "varchar", + "expression": "concat(cast(o_orderkey as varchar), '_', cast(o_custkey as varchar))", + }, + ], + "primaryKey": "o_orderkey", + } + ], +} + +_MULTI_MODEL_MANIFEST = { + "catalog": "wren", + "schema": "public", + "models": [ + { + "name": "orders", + "tableReference": {"schema": "main", "table": "orders"}, + "columns": [ + {"name": "o_orderkey", "type": "integer"}, + {"name": "o_custkey", "type": "integer"}, + {"name": "o_orderstatus", "type": "varchar"}, + { + "name": "order_cust_key", + "type": "varchar", + "expression": "concat(cast(o_orderkey as varchar), '_', cast(o_custkey as varchar))", + }, + ], + "primaryKey": "o_orderkey", + }, + { + "name": "customer", + "tableReference": {"schema": "main", "table": "customer"}, + "columns": [ + {"name": "c_custkey", "type": "integer"}, + {"name": "c_name", "type": "varchar"}, + ], + "primaryKey": "c_custkey", + }, + ], + "relationships": [ + { + "name": "orders_customer", + "models": ["orders", "customer"], + "joinType": "many_to_one", + "condition": '"orders".o_custkey = "customer".c_custkey', + } + ], +} + + +def _b64(manifest: dict) -> str: + return base64.b64encode(orjson.dumps(manifest)).decode() + + +def _make_rewriter( + manifest: dict, data_source: DataSource = DataSource.duckdb +) -> CTERewriter: + manifest_str = _b64(manifest) + session = get_session_context(manifest_str, None, None, data_source.name) + return CTERewriter(manifest_str, session, data_source) + + +# --------------------------------------------------------------------------- +# Helper: parse and check CTE presence +# --------------------------------------------------------------------------- + + +def _has_cte(sql: str, cte_name: str, dialect: str = "duckdb") -> bool: + """Return True if *sql* contains a CTE named *cte_name*.""" + ast = sqlglot.parse_one(sql, dialect=dialect) + with_clause = ast.args.get("with_") + if not with_clause: + return False + for cte in with_clause.expressions: + alias = cte.args.get("alias") + if alias and alias.this.name == cte_name: + return True + return False + + +def _count_ctes(sql: str, dialect: str = "duckdb") -> int: + ast = sqlglot.parse_one(sql, dialect=dialect) + with_clause = ast.args.get("with_") + if not with_clause: + return 0 + return len(with_clause.expressions) + + +# --------------------------------------------------------------------------- +# Tests: basic model resolution +# --------------------------------------------------------------------------- + + +class TestSingleModel: + def test_single_model_generates_cte(self): + rw = _make_rewriter(_SINGLE_MODEL_MANIFEST) + result = rw.rewrite('SELECT o_orderkey FROM "orders" LIMIT 1') + assert _has_cte(result, "orders") + assert isinstance(result, str) + assert len(result) > 0 + + def test_star_expansion(self): + rw = _make_rewriter(_SINGLE_MODEL_MANIFEST) + result = rw.rewrite('SELECT * FROM "orders"') + assert _has_cte(result, "orders") + + def test_alias_resolution(self): + rw = _make_rewriter(_SINGLE_MODEL_MANIFEST) + result = rw.rewrite('SELECT o.o_orderkey FROM "orders" o') + assert _has_cte(result, "orders") + + def test_calculated_column(self): + rw = _make_rewriter(_SINGLE_MODEL_MANIFEST) + result = rw.rewrite('SELECT order_cust_key FROM "orders"') + assert _has_cte(result, "orders") + # The CTE should contain the expanded expression + assert "concat" in result.lower() or "||" in result.lower() + + +class TestMultiModel: + def test_join_generates_two_ctes(self): + rw = _make_rewriter(_MULTI_MODEL_MANIFEST) + result = rw.rewrite( + 'SELECT o.o_orderkey, c.c_name ' + 'FROM "orders" o JOIN "customer" c ON o.o_custkey = c.c_custkey' + ) + assert _has_cte(result, "orders") + assert _has_cte(result, "customer") + assert _count_ctes(result) >= 2 + + def test_qualified_star(self): + rw = _make_rewriter(_MULTI_MODEL_MANIFEST) + result = rw.rewrite( + 'SELECT "orders".* FROM "orders"' + ) + assert _has_cte(result, "orders") + + +# --------------------------------------------------------------------------- +# Tests: CTE edge cases +# --------------------------------------------------------------------------- + + +class TestCTEEdgeCases: + def test_user_cte_references_model(self): + """User CTE that references a model — model CTE should be prepended.""" + rw = _make_rewriter(_SINGLE_MODEL_MANIFEST) + result = rw.rewrite( + 'WITH summary AS (SELECT o_orderkey FROM "orders") ' + "SELECT * FROM summary" + ) + assert _has_cte(result, "orders") + assert _has_cte(result, "summary") + # Model CTE should come before user CTE + ast = sqlglot.parse_one(result, dialect="duckdb") + cte_names = [ + cte.args["alias"].this.name for cte in ast.args["with_"].expressions + ] + assert cte_names.index("orders") < cte_names.index("summary") + + def test_user_cte_shadows_model(self): + """User CTE with same name as model — no model CTE generated for that name.""" + rw = _make_rewriter(_SINGLE_MODEL_MANIFEST) + result = rw.rewrite( + "WITH orders AS (SELECT 1 AS id) SELECT * FROM orders" + ) + # The model name is shadowed by user CTE, so CTERewriter skips it + # and falls back to session_context.transform_sql which processes + # the whole query (wren-core handles user CTEs natively). + assert isinstance(result, str) + # No model-generated CTE named "orders" should be present + # (wren-core may restructure the query but won't add a model CTE) + assert not _has_cte(result, "orders") or _count_ctes(result) <= 1 + + def test_no_model_references_fallback(self): + """Query referencing no models falls back to direct transform_sql.""" + rw = _make_rewriter(_SINGLE_MODEL_MANIFEST) + # This should raise because 'unknown_table' is not in the manifest, + # and the fallback transform_sql will also fail. + with pytest.raises(Exception): + rw.rewrite("SELECT * FROM unknown_table") + + +# --------------------------------------------------------------------------- +# Tests: correlated subquery (the key fix) +# --------------------------------------------------------------------------- + + +class TestCorrelatedSubquery: + def test_exists_subquery(self): + """Correlated subquery with EXISTS — both models should get CTEs.""" + rw = _make_rewriter(_MULTI_MODEL_MANIFEST) + result = rw.rewrite( + 'SELECT * FROM "orders" WHERE EXISTS ' + '(SELECT 1 FROM "customer" WHERE c_custkey = o_custkey)' + ) + assert _has_cte(result, "orders") + assert _has_cte(result, "customer") + + +# --------------------------------------------------------------------------- +# Tests: dialect mapping +# --------------------------------------------------------------------------- + + +class TestDialectMapping: + def test_postgres(self): + assert get_sqlglot_dialect(DataSource.postgres) == "postgres" + + def test_mssql_maps_to_tsql(self): + assert get_sqlglot_dialect(DataSource.mssql) == "tsql" + + def test_canner_maps_to_trino(self): + assert get_sqlglot_dialect(DataSource.canner) == "trino" + + def test_local_file_maps_to_duckdb(self): + assert get_sqlglot_dialect(DataSource.local_file) == "duckdb" diff --git a/wren/tests/unit/test_engine.py b/wren/tests/unit/test_engine.py index 09a4ba491..d3c82d5b9 100644 --- a/wren/tests/unit/test_engine.py +++ b/wren/tests/unit/test_engine.py @@ -48,7 +48,7 @@ def duckdb_engine(tmp_path_factory): """A WrenEngine pointed at a temporary DuckDB file (not queried by unit tests).""" db_dir = tmp_path_factory.mktemp("unit_duckdb") conn_info = {"url": str(db_dir), "format": "duckdb"} - with WrenEngine(_MANIFEST_STR, DataSource.duckdb, conn_info) as e: + with WrenEngine(_MANIFEST_STR, DataSource.duckdb, conn_info, fallback=False) as e: yield e @@ -62,55 +62,39 @@ def pg_engine(): "user": "test", "password": "test", } - with WrenEngine(_MANIFEST_STR, DataSource.postgres, conn_info) as e: + with WrenEngine(_MANIFEST_STR, DataSource.postgres, conn_info, fallback=False) as e: yield e # ------------------------------------------------------------------ -# transpile +# dry_plan (no DB access) # ------------------------------------------------------------------ -def test_transpile_returns_string(duckdb_engine: WrenEngine) -> None: - sql = duckdb_engine.transpile('SELECT o_orderkey FROM "orders" LIMIT 1') +def test_dry_plan_returns_string(duckdb_engine: WrenEngine) -> None: + sql = duckdb_engine.dry_plan('SELECT o_orderkey FROM "orders" LIMIT 1') assert isinstance(sql, str) assert len(sql) > 0 -def test_transpile_postgres_dialect(pg_engine: WrenEngine) -> None: - """Transpile should produce Postgres-flavoured SQL (no backtick quoting, etc.).""" - sql = pg_engine.transpile('SELECT o_orderkey FROM "orders" LIMIT 1') +def test_dry_plan_postgres_dialect(pg_engine: WrenEngine) -> None: + """dry_plan should produce Postgres-flavoured SQL (no backtick quoting, etc.).""" + sql = pg_engine.dry_plan('SELECT o_orderkey FROM "orders" LIMIT 1') assert isinstance(sql, str) # sqlglot Postgres output uses double-quote identifiers, not backticks assert "`" not in sql -def test_transpile_calculated_field(duckdb_engine: WrenEngine) -> None: - sql = duckdb_engine.transpile('SELECT order_cust_key FROM "orders" LIMIT 1') +def test_dry_plan_calculated_field(duckdb_engine: WrenEngine) -> None: + sql = duckdb_engine.dry_plan('SELECT order_cust_key FROM "orders" LIMIT 1') assert isinstance(sql, str) # The calculated column expression should be expanded in the SQL assert "concat" in sql.lower() or "||" in sql.lower() -def test_transpile_invalid_sql_raises(duckdb_engine: WrenEngine) -> None: - with pytest.raises(WrenError): - duckdb_engine.transpile("SELECT * FROM not_a_model_in_manifest") - - -# ------------------------------------------------------------------ -# dry_plan -# ------------------------------------------------------------------ - - -def test_dry_plan_returns_datafusion_sql(duckdb_engine: WrenEngine) -> None: - planned = duckdb_engine.dry_plan('SELECT o_orderkey FROM "orders" LIMIT 1') - assert isinstance(planned, str) - assert len(planned) > 0 - - def test_dry_plan_invalid_sql_raises(duckdb_engine: WrenEngine) -> None: with pytest.raises(WrenError): - duckdb_engine.dry_plan("SELECT * FROM nonexistent_model") + duckdb_engine.dry_plan("SELECT * FROM not_a_model_in_manifest") # ------------------------------------------------------------------ @@ -120,7 +104,7 @@ def test_dry_plan_invalid_sql_raises(duckdb_engine: WrenEngine) -> None: def test_context_manager_closes_connector() -> None: conn_info = {"url": "/tmp", "format": "duckdb"} - with WrenEngine(_MANIFEST_STR, DataSource.duckdb, conn_info) as e: + with WrenEngine(_MANIFEST_STR, DataSource.duckdb, conn_info, fallback=False) as e: assert e._connector is None # connector is lazily initialized # After __exit__, internal state is cleaned up diff --git a/wren/uv.lock b/wren/uv.lock index cc2d2d3c9..61b405363 100644 --- a/wren/uv.lock +++ b/wren/uv.lock @@ -2282,7 +2282,15 @@ wheels = [ ] [[package]] -name = "wren" +name = "wren-core-py" +version = "0.1.0" +source = { registry = "../wren-core-py/target/wheels" } +wheels = [ + { path = "wren_core_py-0.1.0-cp311-abi3-macosx_11_0_arm64.whl" }, +] + +[[package]] +name = "wren-engine" source = { editable = "." } dependencies = [ { name = "boto3" }, @@ -2416,14 +2424,6 @@ requires-dist = [ ] provides-extras = ["all", "athena", "bigquery", "clickhouse", "databricks", "dev", "mssql", "mysql", "oracle", "postgres", "redshift", "snowflake", "spark", "trino"] -[[package]] -name = "wren-core-py" -version = "0.1.0" -source = { registry = "../wren-core-py/target/wheels" } -wheels = [ - { path = "wren_core_py-0.1.0-cp311-abi3-macosx_11_0_arm64.whl" }, -] - [[package]] name = "zstandard" version = "0.25.0" From 432988cec826b8347bbc41c509b7166376b2ac9e Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 25 Mar 2026 22:05:34 +0800 Subject: [PATCH 03/10] fix(wren): handle COUNT(*) and other column-less model references in CTE rewriter Models referenced in FROM clause but without specific column references (e.g. SELECT COUNT(*) FROM model) now generate CTEs with all columns instead of being treated as "no model references found". Co-Authored-By: Claude Opus 4.6 (1M context) --- wren/src/wren/mdl/cte_rewriter.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/wren/src/wren/mdl/cte_rewriter.py b/wren/src/wren/mdl/cte_rewriter.py index a5f3fc5eb..832ac598d 100644 --- a/wren/src/wren/mdl/cte_rewriter.py +++ b/wren/src/wren/mdl/cte_rewriter.py @@ -10,8 +10,6 @@ import base64 import json -from collections import defaultdict - from sqlglot import exp, parse_one from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import qualify_columns @@ -147,7 +145,9 @@ def _collect_model_columns( alias_to_model[alias] = table_name alias_to_model[table_name] = table_name - used: dict[str, set[str]] = defaultdict(set) + # Ensure every referenced model appears in the result, even if no + # specific columns are referenced (e.g. SELECT COUNT(*) FROM model). + used: dict[str, set[str]] = {m: set() for m in alias_to_model.values()} for col in qualified.find_all(exp.Column): table_ref = col.table if not table_ref: @@ -166,9 +166,13 @@ def _build_model_ctes(self, used_columns: dict[str, set[str]]) -> list[exp.CTE]: """Generate one CTE per model via wren-core transform_sql.""" ctes: list[exp.CTE] = [] for model_name, columns in used_columns.items(): - orig = self._col_orig_name.get(model_name, {}) - resolved = [orig.get(c, c) for c in sorted(columns)] - col_list = ", ".join(f'"{model_name}"."{c}"' for c in resolved) + if columns: + orig = self._col_orig_name.get(model_name, {}) + resolved = [orig.get(c, c) for c in sorted(columns)] + col_list = ", ".join(f'"{model_name}"."{c}"' for c in resolved) + else: + # No specific columns referenced (e.g. COUNT(*)) — only need rows + col_list = "1" model_sql = f'SELECT {col_list} FROM "{model_name}"' expanded = self.session_context.transform_sql(model_sql) From 77ae5fd171227d29cecddc269b61daf8f5ab9a38 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 25 Mar 2026 22:08:27 +0800 Subject: [PATCH 04/10] chore: fix fmt --- wren/src/wren/mdl/cte_rewriter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/wren/src/wren/mdl/cte_rewriter.py b/wren/src/wren/mdl/cte_rewriter.py index 832ac598d..fee288d07 100644 --- a/wren/src/wren/mdl/cte_rewriter.py +++ b/wren/src/wren/mdl/cte_rewriter.py @@ -10,6 +10,7 @@ import base64 import json + from sqlglot import exp, parse_one from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import qualify_columns From 1393bd886a47724f38ef0be9e1fed317fa47289f Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 25 Mar 2026 22:18:27 +0800 Subject: [PATCH 05/10] fix(wren): transpile fallback output from Wren dialect to target dialect MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The fallback path (session_context.transform_sql) outputs Wren dialect, which needs to be transpiled to the target data source dialect before returning — matching the CTE path's behavior. Co-Authored-By: Claude Opus 4.6 (1M context) --- wren/src/wren/mdl/cte_rewriter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/wren/src/wren/mdl/cte_rewriter.py b/wren/src/wren/mdl/cte_rewriter.py index fee288d07..8c0bbbc91 100644 --- a/wren/src/wren/mdl/cte_rewriter.py +++ b/wren/src/wren/mdl/cte_rewriter.py @@ -11,6 +11,7 @@ import base64 import json +import sqlglot from sqlglot import exp, parse_one from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import qualify_columns @@ -104,7 +105,8 @@ def rewrite(self, sql: str) -> str: # whole-query transform, or raise so tests can catch the miss. if not used_columns: if self.fallback: - return self.session_context.transform_sql(sql) + wren_sql = self.session_context.transform_sql(sql) + return sqlglot.transpile(wren_sql, read="wren", write=self.dialect)[0] raise ValueError(f"No model references found in SQL: {sql}") model_ctes = self._build_model_ctes(used_columns) From 7be6adf7c78e6f47b8a333069bfe33d61f24a628 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 25 Mar 2026 22:27:15 +0800 Subject: [PATCH 06/10] fix(wren): use sqlglot instead of DataFusion parser for table name extraction The input SQL is in the target dialect (e.g. Postgres), not DataFusion. Use sqlglot to parse and extract table names for manifest extraction. Co-Authored-By: Claude Opus 4.6 (1M context) --- wren/src/wren/engine.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/wren/src/wren/engine.py b/wren/src/wren/engine.py index 48dcdb2dc..e0dcb3f8e 100644 --- a/wren/src/wren/engine.py +++ b/wren/src/wren/engine.py @@ -24,9 +24,11 @@ import pyarrow as pa +from sqlglot import exp, parse_one + from wren.connector.factory import get_connector from wren.mdl import get_manifest_extractor, get_session_context, to_json_base64 -from wren.mdl.cte_rewriter import CTERewriter +from wren.mdl.cte_rewriter import CTERewriter, get_sqlglot_dialect from wren.model.data_source import DataSource from wren.model.error import DIALECT_SQL, ErrorCode, ErrorPhase, WrenError @@ -78,7 +80,19 @@ def __init__( # ------------------------------------------------------------------ def dry_plan(self, sql: str, properties: dict | None = None) -> str: - """Plan SQL through MDL and return the expanded SQL in the target dialect.""" + """Plan SQL through MDL and return the expanded SQL in the target dialect. + + Transformation flow:: + + User SQL (target dialect, e.g. Postgres) + → sqlglot parse (target dialect) + → qualify_tables + normalize_identifiers + qualify_columns + → identify referenced models and columns + → per-model: wren-core transform_sql → Wren dialect SQL + → per-model: sqlglot parse (Wren dialect) → inject as CTE + → sqlglot generate (target dialect) + → output SQL with model CTEs in target dialect + """ return self._plan(sql, properties) # ------------------------------------------------------------------ @@ -147,9 +161,12 @@ def _plan(self, sql: str, properties: dict | None) -> str: processed = frozenset(properties.items()) try: - # Extract minimal manifest for the query + # Extract minimal manifest scoped to tables referenced in the SQL. + # Use sqlglot (not DataFusion parser) since input is target dialect. + dialect = get_sqlglot_dialect(self.data_source) + ast = parse_one(sql, dialect=dialect) + tables = [t.name for t in ast.find_all(exp.Table)] extractor = get_manifest_extractor(self.manifest_str) - tables = extractor.resolve_used_table_names(sql) manifest = extractor.extract_by(tables) effective_manifest = to_json_base64(manifest) except Exception: From e0387f7ade744aefedce203314771db6bf6b418a Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 25 Mar 2026 22:29:44 +0800 Subject: [PATCH 07/10] =?UTF-8?q?fix(wren):=20address=20CodeRabbit=20revie?= =?UTF-8?q?w=20=E2=80=94=20nested=20CTE=20scoping,=20column=20order,=20tes?= =?UTF-8?q?t=20strictness?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Collect user CTE names from all scopes (find_all(exp.With)), not just root, so nested WITH clauses properly shadow model names - Preserve manifest column order in CTE generation instead of sorting alphabetically — fixes SELECT * column order - Default test helper _make_rewriter to fallback=False so tests exercise the CTE path strictly Co-Authored-By: Claude Opus 4.6 (1M context) --- wren/src/wren/mdl/cte_rewriter.py | 25 ++++++++++++++----------- wren/tests/unit/test_cte_rewriter.py | 11 +++++++---- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/wren/src/wren/mdl/cte_rewriter.py b/wren/src/wren/mdl/cte_rewriter.py index 8c0bbbc91..562719805 100644 --- a/wren/src/wren/mdl/cte_rewriter.py +++ b/wren/src/wren/mdl/cte_rewriter.py @@ -119,13 +119,14 @@ def rewrite(self, sql: str) -> str: def _collect_model_columns( self, ast: exp.Expression, user_cte_names: set[str] - ) -> dict[str, set[str]]: - """Return ``{model_name: {col1, col2, ...}}`` for all referenced models. + ) -> dict[str, list[str]]: + """Return ``{model_name: [col1, col2, ...]}`` for all referenced models. Uses sqlglot's ``qualify_columns`` to fully resolve all column references (including ``SELECT *`` expansion and correlated subquery outer references), then walks the qualified AST - to collect model→column mappings. + to collect model→column mappings. Column order follows the manifest + definition (via insertion order) so ``SELECT *`` preserves schema order. """ copy = ast.copy() copy = qualify_tables(copy, dialect=self.dialect) @@ -150,28 +151,31 @@ def _collect_model_columns( # Ensure every referenced model appears in the result, even if no # specific columns are referenced (e.g. SELECT COUNT(*) FROM model). - used: dict[str, set[str]] = {m: set() for m in alias_to_model.values()} + # Use dict as ordered set to preserve insertion order and deduplicate. + used: dict[str, dict[str, None]] = { + m: {} for m in alias_to_model.values() + } for col in qualified.find_all(exp.Column): table_ref = col.table if not table_ref: continue model_name = alias_to_model.get(table_ref) if model_name: - used[model_name].add(col.name) + used[model_name][col.name] = None - return dict(used) + return {m: list(cols) for m, cols in used.items()} # ------------------------------------------------------------------ # CTE generation # ------------------------------------------------------------------ - def _build_model_ctes(self, used_columns: dict[str, set[str]]) -> list[exp.CTE]: + def _build_model_ctes(self, used_columns: dict[str, list[str]]) -> list[exp.CTE]: """Generate one CTE per model via wren-core transform_sql.""" ctes: list[exp.CTE] = [] for model_name, columns in used_columns.items(): if columns: orig = self._col_orig_name.get(model_name, {}) - resolved = [orig.get(c, c) for c in sorted(columns)] + resolved = [orig.get(c, c) for c in columns] col_list = ", ".join(f'"{model_name}"."{c}"' for c in resolved) else: # No specific columns referenced (e.g. COUNT(*)) — only need rows @@ -218,10 +222,9 @@ def _inject_ctes(self, ast: exp.Expression, model_ctes: list[exp.CTE]) -> None: @staticmethod def _collect_user_cte_names(ast: exp.Expression) -> set[str]: - """Collect all CTE names defined in the user's SQL.""" + """Collect all CTE names defined in the user's SQL (all scopes).""" names: set[str] = set() - with_clause = ast.args.get("with_") - if with_clause: + for with_clause in ast.find_all(exp.With): for cte in with_clause.expressions: alias = cte.args.get("alias") if alias: diff --git a/wren/tests/unit/test_cte_rewriter.py b/wren/tests/unit/test_cte_rewriter.py index 294ba7c8b..c4cabe4f0 100644 --- a/wren/tests/unit/test_cte_rewriter.py +++ b/wren/tests/unit/test_cte_rewriter.py @@ -86,11 +86,14 @@ def _b64(manifest: dict) -> str: def _make_rewriter( - manifest: dict, data_source: DataSource = DataSource.duckdb + manifest: dict, + data_source: DataSource = DataSource.duckdb, + *, + fallback: bool = False, ) -> CTERewriter: manifest_str = _b64(manifest) session = get_session_context(manifest_str, None, None, data_source.name) - return CTERewriter(manifest_str, session, data_source) + return CTERewriter(manifest_str, session, data_source, fallback=fallback) # --------------------------------------------------------------------------- @@ -193,7 +196,7 @@ def test_user_cte_references_model(self): def test_user_cte_shadows_model(self): """User CTE with same name as model — no model CTE generated for that name.""" - rw = _make_rewriter(_SINGLE_MODEL_MANIFEST) + rw = _make_rewriter(_SINGLE_MODEL_MANIFEST, fallback=True) result = rw.rewrite( "WITH orders AS (SELECT 1 AS id) SELECT * FROM orders" ) @@ -207,7 +210,7 @@ def test_user_cte_shadows_model(self): def test_no_model_references_fallback(self): """Query referencing no models falls back to direct transform_sql.""" - rw = _make_rewriter(_SINGLE_MODEL_MANIFEST) + rw = _make_rewriter(_SINGLE_MODEL_MANIFEST, fallback=True) # This should raise because 'unknown_table' is not in the manifest, # and the fallback transform_sql will also fail. with pytest.raises(Exception): From 3e132a4e84f25c3c48ff2251971a2c5c7694adf7 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 25 Mar 2026 22:30:32 +0800 Subject: [PATCH 08/10] test(wren): add test for nested CTE shadowing model names Verifies that a WITH clause inside a subquery correctly shadows the model name, preventing a model CTE from being generated for it. Co-Authored-By: Claude Opus 4.6 (1M context) --- wren/tests/unit/test_cte_rewriter.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/wren/tests/unit/test_cte_rewriter.py b/wren/tests/unit/test_cte_rewriter.py index c4cabe4f0..07b5a195b 100644 --- a/wren/tests/unit/test_cte_rewriter.py +++ b/wren/tests/unit/test_cte_rewriter.py @@ -208,6 +208,20 @@ def test_user_cte_shadows_model(self): # (wren-core may restructure the query but won't add a model CTE) assert not _has_cte(result, "orders") or _count_ctes(result) <= 1 + def test_nested_cte_shadows_model(self): + """CTE defined in a subquery WITH should also shadow the model name.""" + rw = _make_rewriter(_MULTI_MODEL_MANIFEST, fallback=True) + result = rw.rewrite( + "SELECT * FROM orders WHERE o_custkey IN (" + " WITH customer AS (SELECT 1 AS c_custkey) " + " SELECT c_custkey FROM customer" + ")" + ) + # "customer" is shadowed by the nested CTE — only "orders" should + # get a model CTE, not "customer". + assert _has_cte(result, "orders") + assert not _has_cte(result, "customer") + def test_no_model_references_fallback(self): """Query referencing no models falls back to direct transform_sql.""" rw = _make_rewriter(_SINGLE_MODEL_MANIFEST, fallback=True) From 364bc731a8cd58352ea9902720fcf7872c5a6552 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 25 Mar 2026 22:31:10 +0800 Subject: [PATCH 09/10] test(wren): add tests for nested CTE shadowing and WITH RECURSIVE preservation - Nested CTE in subquery correctly shadows model name - WITH RECURSIVE keyword preserved after model CTE injection Co-Authored-By: Claude Opus 4.6 (1M context) --- wren/tests/unit/test_cte_rewriter.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/wren/tests/unit/test_cte_rewriter.py b/wren/tests/unit/test_cte_rewriter.py index 07b5a195b..aec5d122e 100644 --- a/wren/tests/unit/test_cte_rewriter.py +++ b/wren/tests/unit/test_cte_rewriter.py @@ -222,6 +222,23 @@ def test_nested_cte_shadows_model(self): assert _has_cte(result, "orders") assert not _has_cte(result, "customer") + def test_recursive_cte_preserved(self): + """WITH RECURSIVE keyword should be preserved after model CTE injection.""" + rw = _make_rewriter(_MULTI_MODEL_MANIFEST) + result = rw.rewrite( + "WITH RECURSIVE hierarchy AS (" + ' SELECT o_orderkey, o_custkey FROM "orders" WHERE o_custkey = 1' + " UNION ALL" + ' SELECT o.o_orderkey, o.o_custkey FROM "orders" o' + " JOIN hierarchy h ON o.o_custkey = h.o_orderkey" + ") SELECT * FROM hierarchy" + ) + assert _has_cte(result, "orders") + assert _has_cte(result, "hierarchy") + # RECURSIVE keyword must be preserved + ast = sqlglot.parse_one(result, dialect="duckdb") + assert ast.args["with_"].args.get("recursive") + def test_no_model_references_fallback(self): """Query referencing no models falls back to direct transform_sql.""" rw = _make_rewriter(_SINGLE_MODEL_MANIFEST, fallback=True) From 6cfbb1a0626fbd8d0b21cd5d7ff4ad78bc88e04e Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 25 Mar 2026 22:31:50 +0800 Subject: [PATCH 10/10] chore: fix fmt --- wren/src/wren/engine.py | 1 - wren/src/wren/mdl/cte_rewriter.py | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/wren/src/wren/engine.py b/wren/src/wren/engine.py index e0dcb3f8e..49a012a2c 100644 --- a/wren/src/wren/engine.py +++ b/wren/src/wren/engine.py @@ -23,7 +23,6 @@ from typing import Any import pyarrow as pa - from sqlglot import exp, parse_one from wren.connector.factory import get_connector diff --git a/wren/src/wren/mdl/cte_rewriter.py b/wren/src/wren/mdl/cte_rewriter.py index 562719805..d4a59fdfc 100644 --- a/wren/src/wren/mdl/cte_rewriter.py +++ b/wren/src/wren/mdl/cte_rewriter.py @@ -152,9 +152,7 @@ def _collect_model_columns( # Ensure every referenced model appears in the result, even if no # specific columns are referenced (e.g. SELECT COUNT(*) FROM model). # Use dict as ordered set to preserve insertion order and deduplicate. - used: dict[str, dict[str, None]] = { - m: {} for m in alias_to_model.values() - } + used: dict[str, dict[str, None]] = {m: {} for m in alias_to_model.values()} for col in qualified.find_all(exp.Column): table_ref = col.table if not table_ref: