Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion wren/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ redshift = ["redshift_connector"]
spark = ["pyspark>=3.5"]
athena = ["ibis-framework[athena]"]
oracle = ["ibis-framework[oracle]", "oracledb"]
all = ["wren[postgres,mysql,bigquery,snowflake,clickhouse,trino,mssql,databricks,redshift,athena,oracle]"]
all = ["wren-engine[postgres,mysql,bigquery,snowflake,clickhouse,trino,mssql,databricks,redshift,athena,oracle]"]
dev = ["pytest>=8", "ruff>=0.4", "orjson>=3", "testcontainers[postgres]>=4"]

[project.scripts]
Expand Down
9 changes: 4 additions & 5 deletions wren/src/wren/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,13 @@ def dry_run(
engine.close()


@app.command()
def transpile(
@app.command(name="dry-plan")
def dry_plan(
sql: SqlArg,
datasource: DatasourceOpt,
mdl: MdlOpt,
):
"""Transform SQL through MDL and emit the target dialect SQL (no DB required)."""
"""Plan SQL through MDL and print the expanded SQL (no DB required)."""
from wren.engine import WrenEngine # noqa: PLC0415
from wren.model.data_source import DataSource # noqa: PLC0415

Expand All @@ -164,10 +164,9 @@ def transpile(
typer.echo(f"Error: unknown datasource '{datasource}'", err=True)
raise typer.Exit(1)

# For transpile we don't need real connection_info — pass a dummy dict
engine = WrenEngine(manifest_str=manifest_str, data_source=ds, connection_info={})
try:
result = engine.transpile(sql)
result = engine.dry_plan(sql)
typer.echo(result)
except Exception as e:
typer.echo(f"Error: {e}", err=True)
Expand Down
2 changes: 1 addition & 1 deletion wren/src/wren/connector/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, connection_info):
def query(self, sql: str, limit: int | None = None) -> pa.Table:
if limit is not None:
sql = f"SELECT * FROM ({sql}) AS _q LIMIT {int(limit)}"
return self.connection.execute(sql).fetch_arrow_table()
return self.connection.execute(sql).to_arrow_table()

def dry_run(self, sql: str) -> None:
self.connection.execute(f"EXPLAIN {sql}")
Expand Down
78 changes: 35 additions & 43 deletions wren/src/wren/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
connection_info={"host": "localhost", "port": 5432, ...},
)

# Transform only (no DB required)
planned_sql = engine.transpile("SELECT * FROM orders")
# Plan only (no DB required)
planned_sql = engine.dry_plan("SELECT * FROM orders")

# Execute against the data source
arrow_table = engine.query("SELECT * FROM orders", limit=100)
Expand All @@ -23,25 +23,13 @@
from typing import Any

import pyarrow as pa
import sqlglot
from sqlglot import exp, parse_one

from wren.connector.factory import get_connector
from wren.mdl import get_manifest_extractor, get_session_context, to_json_base64
from wren.mdl.cte_rewriter import CTERewriter, get_sqlglot_dialect
from wren.model.data_source import DataSource
from wren.model.error import DIALECT_SQL, PLANNED_SQL, ErrorCode, ErrorPhase, WrenError


def _get_write_dialect(data_source: DataSource) -> str:
if data_source == DataSource.canner:
return "trino"
if data_source in {
DataSource.local_file,
DataSource.s3_file,
DataSource.minio_file,
DataSource.gcs_file,
}:
return "duckdb"
return data_source.name
from wren.model.error import DIALECT_SQL, ErrorCode, ErrorPhase, WrenError


class WrenEngine:
Expand All @@ -66,13 +54,16 @@ def __init__(
data_source: DataSource | str,
connection_info: dict[str, Any] | object,
function_path: str | None = None,
*,
fallback: bool = True,
):
if isinstance(data_source, str):
data_source = DataSource(data_source)

self.manifest_str = manifest_str
self.data_source = data_source
self.function_path = function_path
self._fallback = fallback

# Build typed ConnectionInfo if a raw dict was given.
# An empty dict is allowed for transpile-only usage (no DB connection).
Expand All @@ -87,16 +78,20 @@ def __init__(
# SQL transformation (no DB access)
# ------------------------------------------------------------------

def transpile(self, sql: str, properties: dict | None = None) -> str:
"""Transform SQL through MDL and transpile to the target dialect.

Returns the dialect SQL string without executing it.
"""
planned = self._plan(sql, properties)
return self._transpile(planned)

def dry_plan(self, sql: str, properties: dict | None = None) -> str:
"""Return the wren-core planned SQL (DataFusion dialect, before transpile)."""
"""Plan SQL through MDL and return the expanded SQL in the target dialect.

Transformation flow::

User SQL (target dialect, e.g. Postgres)
→ sqlglot parse (target dialect)
→ qualify_tables + normalize_identifiers + qualify_columns
→ identify referenced models and columns
→ per-model: wren-core transform_sql → Wren dialect SQL
→ per-model: sqlglot parse (Wren dialect) → inject as CTE
→ sqlglot generate (target dialect)
→ output SQL with model CTEs in target dialect
"""
return self._plan(sql, properties)

# ------------------------------------------------------------------
Expand All @@ -110,7 +105,7 @@ def query(
properties: dict | None = None,
) -> pa.Table:
"""Transpile and execute SQL, return results as an Arrow table."""
dialect_sql = self.transpile(sql, properties)
dialect_sql = self.dry_plan(sql, properties)
connector = self._get_connector()
try:
return connector.query(dialect_sql, limit)
Expand All @@ -126,7 +121,7 @@ def query(

def dry_run(self, sql: str, properties: dict | None = None) -> None:
"""Transpile and dry-run SQL without returning results."""
dialect_sql = self.transpile(sql, properties)
dialect_sql = self.dry_plan(sql, properties)
connector = self._get_connector()
try:
connector.dry_run(dialect_sql)
Expand Down Expand Up @@ -165,9 +160,12 @@ def _plan(self, sql: str, properties: dict | None) -> str:
processed = frozenset(properties.items())

try:
# Extract minimal manifest for the query
# Extract minimal manifest scoped to tables referenced in the SQL.
# Use sqlglot (not DataFusion parser) since input is target dialect.
dialect = get_sqlglot_dialect(self.data_source)
ast = parse_one(sql, dialect=dialect)
tables = [t.name for t in ast.find_all(exp.Table)]
extractor = get_manifest_extractor(self.manifest_str)
tables = extractor.resolve_used_table_names(sql)
manifest = extractor.extract_by(tables)
effective_manifest = to_json_base64(manifest)
except Exception:
Expand All @@ -180,7 +178,13 @@ def _plan(self, sql: str, properties: dict | None) -> str:
processed,
self.data_source.name,
)
return session.transform_sql(sql)
rewriter = CTERewriter(
effective_manifest,
session,
self.data_source,
fallback=self._fallback,
)
return rewriter.rewrite(sql)
except Exception as e:
raise WrenError(
ErrorCode.INVALID_SQL,
Expand All @@ -189,18 +193,6 @@ def _plan(self, sql: str, properties: dict | None) -> str:
metadata={DIALECT_SQL: sql},
) from e

def _transpile(self, planned_sql: str) -> str:
try:
write = _get_write_dialect(self.data_source)
return sqlglot.transpile(planned_sql, read="duckdb", write=write)[0]
except Exception as e:
raise WrenError(
ErrorCode.SQLGLOT_ERROR,
str(e),
phase=ErrorPhase.SQL_TRANSPILE,
metadata={PLANNED_SQL: planned_sql},
) from e

def _get_connector(self):
if self._connector is None:
self._connector = get_connector(self.data_source, self.connection_info)
Expand Down
Loading
Loading