Skip to content

Commit

Permalink
fix(internals): ensure that CTEs are emitted in topological order (#9726
Browse files Browse the repository at this point in the history
)

Co-authored-by: Jim Crist-Harif <[email protected]>
  • Loading branch information
cpcloud and jcrist authored Jul 30, 2024
1 parent 4cf0014 commit acd7d82
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
10 changes: 5 additions & 5 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,17 @@ def merge_select_select(_, **kwargs):
return result if complexity(result) <= complexity(_) else _


def extract_ctes(node):
result = []
def extract_ctes(node: ops.Relation) -> set[ops.Relation]:
cte_types = (Select, ops.Aggregate, ops.JoinChain, ops.Set, ops.Limit, ops.Sample)
dont_count = (ops.Field, ops.CountStar, ops.CountDistinctStar)

g = Graph.from_bfs(node, filter=~InstanceOf(dont_count))
result = set()
for op, dependents in g.invert().items():
if isinstance(op, ops.View) or (
len(dependents) > 1 and isinstance(op, cte_types)
):
result.append(op)
result.add(op)

return result

Expand Down Expand Up @@ -315,14 +315,14 @@ def sqlize(
simplified = sqlized

# extract common table expressions while wrapping them in a CTE node
ctes = frozenset(extract_ctes(simplified))
ctes = extract_ctes(simplified)

def wrap(node, _, **kwargs):
new = node.__recreate__(kwargs)
return CTE(new) if node in ctes else new

result = simplified.replace(wrap)
ctes = reversed([cte.parent for cte in result.find(CTE)])
ctes = [cte.parent for cte in result.find(CTE, ordered=True)]

return result, ctes

Expand Down
14 changes: 14 additions & 0 deletions ibis/backends/tests/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,17 @@ def test_no_cartesian_join(snapshot):
]
)
snapshot.assert_match(ibis.to_sql(final, dialect="duckdb"), "out.sql")


def test_ctes_in_order():
table1 = ibis.table({"id": "int"}, name="table1")
table2 = ibis.table({"id": "int"}, name="table2")
table3 = ibis.table({"id": "int"}, name="table3")

ids_table = table1.union(table2).alias("first")
info_table = ids_table.union(table3).alias("second")

expr = ids_table.union(info_table)

sql = ibis.to_sql(expr, dialect="duckdb")
assert sql.find('"first" AS (') < sql.find('"second" AS (')
5 changes: 5 additions & 0 deletions ibis/common/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ def find(
finder: FinderLike,
filter: Optional[FinderLike] = None,
context: Optional[dict] = None,
ordered: bool = False,
) -> list[Node]:
"""Find all nodes matching a given pattern or type in the graph.
Expand All @@ -355,6 +356,8 @@ def find(
the given filter and stop otherwise.
context
Optional context to use if `finder` or `filter` is a pattern.
ordered
Emit nodes in topological order if `True`.
Returns
-------
Expand All @@ -364,6 +367,8 @@ def find(
"""
graph = Graph.from_bfs(self, filter=filter, context=context)
finder = _coerce_finder(finder, context)
if ordered:
graph, _ = graph.toposort()
return [node for node in graph.nodes() if finder(node)]

@experimental
Expand Down

0 comments on commit acd7d82

Please sign in to comment.