Skip to content

Commit b63e0fd

Browse files
cpcloudjcrist
authored andcommitted
fix(dot-sql): ensure that CTEs can be used in .sql
1 parent 2786fe4 commit b63e0fd

File tree

5 files changed

+60
-31
lines changed

5 files changed

+60
-31
lines changed

ibis/backends/polars/__init__.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -449,17 +449,12 @@ def compile(
449449

450450
return translate(node, ctx=self._context)
451451

452-
def _get_sql_string_view_schema(self, name, table, query) -> sch.Schema:
453-
import sqlglot as sg
454-
455-
cte = sg.parse_one(str(ibis.to_sql(table, dialect="postgres")), read="postgres")
456-
parsed = sg.parse_one(query, read=self.dialect)
457-
parsed.args["with"] = cte.args.pop("with", [])
458-
parsed = parsed.with_(
459-
sg.to_identifier(name, quoted=True), as_=cte, dialect=self.dialect
460-
)
452+
def _get_sql_string_view_schema(
453+
self, *, name: str, table: ir.Table, query: str
454+
) -> sch.Schema:
455+
from ibis.backends.sql.compilers.postgres import compiler
461456

462-
sql = parsed.sql(self.dialect)
457+
sql = compiler.add_query_to_expr(name=name, table=table, query=query)
463458
return self._get_schema_using_query(sql)
464459

465460
def _get_schema_using_query(self, query: str) -> sch.Schema:

ibis/backends/sql/__init__.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -191,18 +191,10 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
191191
The schema inferred from `query`
192192
"""
193193

194-
def _get_sql_string_view_schema(self, name, table, query) -> sch.Schema:
195-
compiler = self.compiler
196-
dialect = compiler.dialect
197-
198-
cte = compiler.to_sqlglot(table)
199-
parsed = sg.parse_one(query, read=dialect)
200-
parsed.args["with"] = cte.args.pop("with", [])
201-
parsed = parsed.with_(
202-
sg.to_identifier(name, quoted=compiler.quoted), as_=cte, dialect=dialect
203-
)
204-
205-
sql = parsed.sql(dialect)
194+
def _get_sql_string_view_schema(
195+
self, *, name: str, table: ir.Table, query: str
196+
) -> sch.Schema:
197+
sql = self.compiler.add_query_to_expr(name=name, table=table, query=query)
206198
return self._get_schema_using_query(sql)
207199

208200
def _register_udfs(self, expr: ir.Expr) -> None:

ibis/backends/sql/compilers/base.py

+25
Original file line numberDiff line numberDiff line change
@@ -1613,6 +1613,31 @@ def visit_DropColumns(self, op, *, parent, columns_to_drop):
16131613
)
16141614
return sg.select(*columns_to_keep).from_(parent)
16151615

1616+
def add_query_to_expr(self, *, name: str, table: ir.Table, query: str) -> str:
1617+
dialect = self.dialect
1618+
1619+
compiled_ibis_expr = self.to_sqlglot(table)
1620+
1621+
# pull existing CTEs from the compiled Ibis expression and combine them
1622+
# with the new query
1623+
parsed = reduce(
1624+
lambda parsed, cte: parsed.with_(cte.args["alias"], as_=cte.args["this"]),
1625+
compiled_ibis_expr.ctes,
1626+
sg.parse_one(query, read=dialect),
1627+
)
1628+
1629+
# remove all ctes from the compiled expression, since they're now in
1630+
# our larger expression
1631+
compiled_ibis_expr.args.pop("with", None)
1632+
1633+
# add the new str query as a CTE
1634+
parsed = parsed.with_(
1635+
sg.to_identifier(name, quoted=self.quoted), as_=compiled_ibis_expr
1636+
)
1637+
1638+
# generate the SQL string
1639+
return parsed.sql(dialect)
1640+
16161641

16171642
# `__init_subclass__` is uncalled for subclasses - we manually call it here to
16181643
# autogenerate the base class implementations as well.

ibis/backends/tests/test_dot_sql.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pytest
77
import sqlglot as sg
8+
import sqlglot.expressions as sge
89
from pytest import param
910

1011
import ibis
@@ -24,11 +25,21 @@
2425

2526
_NAMES = {
2627
"bigquery": f"ibis_gbq_testing_{getpass.getuser()}_{PYTHON_SHORT_VERSION}.functional_alltypes",
27-
"exasol": '"functional_alltypes"',
2828
}
2929

3030

31-
@pytest.mark.notyet(["oracle"], reason="table quoting behavior")
31+
@pytest.fixture(scope="module")
32+
def ftname_raw(con):
33+
return _NAMES.get(con.name, "functional_alltypes")
34+
35+
36+
@pytest.fixture(scope="module")
37+
def ftname(con, ftname_raw):
38+
table = sg.parse_one(ftname_raw, into=sge.Table)
39+
table = sg.table(table.name, db=table.db, catalog=table.catalog, quoted=True)
40+
return table.sql(con.dialect)
41+
42+
3243
@dot_sql_never
3344
@pytest.mark.parametrize(
3445
"schema",
@@ -37,10 +48,9 @@
3748
param({"s": "string", "new_col": "double"}, id="explicit_schema"),
3849
],
3950
)
40-
def test_con_dot_sql(backend, con, schema):
51+
def test_con_dot_sql(backend, con, schema, ftname):
4152
alltypes = backend.functional_alltypes
4253
# pull out the quoted name
43-
name = _NAMES.get(con.name, "functional_alltypes")
4454
quoted = True
4555
cols = [
4656
sg.column("string_col", quoted=quoted).as_("s", quoted=quoted).sql(con.dialect),
@@ -50,7 +60,7 @@ def test_con_dot_sql(backend, con, schema):
5060
]
5161
t = (
5262
con.sql(
53-
f"SELECT {', '.join(cols)} FROM {name}",
63+
f"SELECT {', '.join(cols)} FROM {ftname}",
5464
schema=schema,
5565
)
5666
.group_by("s") # group by a column from SQL
@@ -325,9 +335,16 @@ def test_cte(alltypes, df):
325335

326336

327337
@dot_sql_never
328-
def test_bare_minimum(con, alltypes, df):
338+
def test_bare_minimum(alltypes, df, ftname_raw):
329339
"""Test that a backend that supports dot sql can do the most basic thing."""
330340

331-
name = _NAMES.get(con.name, "functional_alltypes").replace('"', "")
332-
expr = alltypes.sql(f'SELECT COUNT(*) AS "n" FROM "{name}"', dialect="duckdb")
341+
expr = alltypes.sql(f'SELECT COUNT(*) AS "n" FROM "{ftname_raw}"', dialect="duckdb")
333342
assert expr.to_pandas().iat[0, 0] == len(df)
343+
344+
345+
@dot_sql_never
346+
def test_embedded_cte(alltypes, ftname_raw):
347+
sql = f'WITH "x" AS (SELECT * FROM "{ftname_raw}") SELECT * FROM "x"'
348+
expr = alltypes.sql(sql, dialect="duckdb")
349+
result = expr.head(1).execute()
350+
assert len(result) == 1

ibis/expr/types/relations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3553,7 +3553,7 @@ def sql(self, query: str, dialect: str | None = None) -> ir.Table:
35533553
name = util.gen_name("sql_query")
35543554
expr = self
35553555

3556-
schema = backend._get_sql_string_view_schema(name, expr, query)
3556+
schema = backend._get_sql_string_view_schema(name=name, table=expr, query=query)
35573557
node = ops.SQLStringView(child=self.op(), query=query, schema=schema)
35583558
return node.to_expr()
35593559

0 commit comments

Comments
 (0)