Skip to content

Commit

Permalink
Feat: eliminate join marks (#3580)
Browse files Browse the repository at this point in the history
* tests passed

* with formatters

* added scope to eliminate_join_marks and removed from optimizer rules

* moved eliminate_join_marks to oracle dialect and minor fixes

* minor fixes

* minor changes - parser not implemented

* handle expressions and multiple join marks in lhs/rhs of where predicate

* remove dead function

* moved into parser
  • Loading branch information
mrhopko authored Jun 12, 2024
1 parent 2e43450 commit e998308
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 26 deletions.
161 changes: 161 additions & 0 deletions sqlglot/dialects/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,166 @@ def _build_timetostr_or_tochar(args: t.List) -> exp.TimeToStr | exp.ToChar:
return exp.ToChar.from_arg_list(args)


def eliminate_join_marks(ast: exp.Expression) -> exp.Expression:
from sqlglot.optimizer.scope import traverse_scope

"""Remove join marks from an expression
SELECT * FROM a, b WHERE a.id = b.id(+)
becomes:
SELECT * FROM a LEFT JOIN b ON a.id = b.id
- for each scope
- for each column with a join mark
- find the predicate it belongs to
- remove the predicate from the where clause
- convert the predicate to a join with the (+) side as the left join table
- replace the existing join with the new join
Args:
ast: The AST to remove join marks from
Returns:
The AST with join marks removed"""
for scope in traverse_scope(ast):
_eliminate_join_marks_from_scope(scope)
return ast


def _update_from(
select: exp.Select,
new_join_dict: t.Dict[str, exp.Join],
old_join_dict: t.Dict[str, exp.Join],
) -> None:
"""If the from clause needs to become a new join, find an appropriate table to use as the new from.
updates select in place
Args:
select: The select statement to update
new_join_dict: The dictionary of new joins
old_join_dict: The dictionary of old joins
"""
old_from = select.args["from"]
if old_from.alias_or_name not in new_join_dict:
return
in_old_not_new = old_join_dict.keys() - new_join_dict.keys()
if len(in_old_not_new) >= 1:
new_from_name = list(old_join_dict.keys() - new_join_dict.keys())[0]
new_from_this = old_join_dict[new_from_name].this
new_from = exp.From(this=new_from_this)
del old_join_dict[new_from_name]
select.set("from", new_from)
else:
raise ValueError("Cannot determine which table to use as the new from")


def _has_join_mark(col: exp.Expression) -> bool:
"""Check if the column has a join mark
Args:
The column to check
"""
return col.args.get("join_mark", False)


def _predicate_to_join(
eq: exp.Binary, old_joins: t.Dict[str, exp.Join], old_from: exp.From
) -> t.Optional[exp.Join]:
"""Convert an equality predicate to a join if it contains a join mark
Args:
eq: The equality expression to convert to a join
Returns:
The join expression if the equality contains a join mark (otherwise None)
"""

# if not (isinstance(eq.left, exp.Column) or isinstance(eq.right, exp.Column)):
# return None

left_columns = [col for col in eq.left.find_all(exp.Column) if _has_join_mark(col)]
right_columns = [col for col in eq.right.find_all(exp.Column) if _has_join_mark(col)]

left_has_join_mark = len(left_columns) > 0
right_has_join_mark = len(right_columns) > 0

if left_has_join_mark:
for col in left_columns:
col.set("join_mark", False)
join_on = col.table
elif right_has_join_mark:
for col in right_columns:
col.set("join_mark", False)
join_on = col.table
else:
return None

join_this = old_joins.get(join_on, old_from).this
return exp.Join(this=join_this, on=eq, kind="LEFT")


if t.TYPE_CHECKING:
from sqlglot.optimizer.scope import Scope


def _eliminate_join_marks_from_scope(scope: Scope) -> None:
"""Remove join marks columns in scope's where clause.
Converts them to left joins and replaces any existing joins.
Updates scope in place.
Args:
scope: The scope to remove join marks from
"""
select_scope = scope.expression
where = select_scope.args.get("where")
joins = select_scope.args.get("joins")
if not where:
return
if not joins:
return

# dictionaries used to keep track of joins to be replaced
old_joins = {join.alias_or_name: join for join in list(joins)}
new_joins: t.Dict[str, exp.Join] = {}

for node in scope.find_all(exp.Column):
if _has_join_mark(node):
predicate = node.find_ancestor(exp.Predicate)
if not isinstance(predicate, exp.Binary):
continue
predicate_parent = predicate.parent

join_on = predicate.pop()
new_join = _predicate_to_join(
join_on, old_joins=old_joins, old_from=select_scope.args["from"]
)
# upsert new_join into new_joins dictionary
if new_join:
if new_join.alias_or_name in new_joins:
new_joins[new_join.alias_or_name].set(
"on",
exp.and_(
new_joins[new_join.alias_or_name].args["on"],
new_join.args["on"],
),
)
else:
new_joins[new_join.alias_or_name] = new_join
# If the parent is a binary node with only one child, promote the child to the parent
if predicate_parent:
if isinstance(predicate_parent, exp.Binary):
if predicate_parent.left is None:
predicate_parent.replace(predicate_parent.right)
elif predicate_parent.right is None:
predicate_parent.replace(predicate_parent.left)

_update_from(select_scope, new_joins, old_joins)
replacement_joins = [new_joins.get(join.alias_or_name, join) for join in old_joins.values()]
select_scope.set("joins", replacement_joins)
if not where.this:
where.pop()


class Oracle(Dialect):
ALIAS_POST_TABLESAMPLE = True
LOCKING_READS_SUPPORTED = True
Expand Down Expand Up @@ -231,6 +391,7 @@ class Generator(generator.Generator):
[
transforms.eliminate_distinct_on,
transforms.eliminate_qualify,
eliminate_join_marks,
]
),
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
Expand Down
79 changes: 66 additions & 13 deletions tests/dialects/test_oracle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from sqlglot import exp
from sqlglot.errors import UnsupportedError
from tests.dialects.test_dialect import Validator


Expand Down Expand Up @@ -245,19 +244,20 @@ def test_oracle(self):
},
)

def test_join_marker(self):
self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y")
# def test_join_marker(self):
# self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y")

self.validate_all(
"SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)", write={"": UnsupportedError}
)
self.validate_all(
"SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)",
write={
"": "SELECT e1.x, e2.x FROM e AS e1, e AS e2 WHERE e1.y = e2.y",
"oracle": "SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)",
},
)
# self.validate_all(
# "SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)",
# write={"": UnsupportedError},
# )
# self.validate_all(
# "SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)",
# write={
# "": "SELECT e1.x, e2.x FROM e AS e1, e AS e2 WHERE e1.y = e2.y",
# "oracle": "SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)",
# },
# )

def test_hints(self):
self.validate_identity("SELECT /*+ USE_NL(A B) */ A.COL_TEST FROM TABLE_A A, TABLE_B B")
Expand Down Expand Up @@ -413,3 +413,56 @@ def test_connect_by(self):

for query in (f"{body}{start}{connect}", f"{body}{connect}{start}"):
self.validate_identity(query, pretty, pretty=True)

def test_eliminate_join_marks(self):
test_sql = [
(
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) > 5",
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y > 5",
),
(
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) IS NULL",
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y IS NULL",
),
(
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y IS NULL",
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T2.y IS NULL",
),
(
"SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T1.Z > 4",
"SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T1.Z > 4",
),
(
"SELECT * FROM table1, table2 WHERE table1.column = table2.column(+)",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column",
),
(
"SELECT * FROM table1, table2, table3, table4 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+) and table1.column = table4.column(+)",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column LEFT JOIN table4 ON table1.column = table4.column",
),
(
"SELECT * FROM table1, table2, table3 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+)",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column",
),
(
"SELECT table1.id, table2.cloumn1, table3.id FROM table1, table2, (SELECT tableInner1.id FROM tableInner1, tableInner2 WHERE tableInner1.id = tableInner2.id(+)) AS table3 WHERE table1.id = table2.id(+) and table1.id = table3.id(+)",
"SELECT table1.id, table2.cloumn1, table3.id FROM table1 LEFT JOIN table2 ON table1.id = table2.id LEFT JOIN (SELECT tableInner1.id FROM tableInner1 LEFT JOIN tableInner2 ON tableInner1.id = tableInner2.id) table3 ON table1.id = table3.id",
),
# 2 join marks on one side of predicate
(
"SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + table2.column2(+)",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + table2.column2",
),
# join mark and expression
(
"SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + 25",
"SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + 25",
),
]

for sql_in, sql_out in test_sql:
with self.subTest(sql_in):
self.assertEqual(
self.parse_one(sql_in, dialect="oracle").sql(dialect=self.dialect),
sql_out,
)
52 changes: 41 additions & 11 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def parse_and_optimize(func, sql, read_dialect, **kwargs):

def qualify_columns(expression, **kwargs):
expression = optimizer.qualify.qualify(
expression, infer_schema=True, validate_qualify_columns=False, identify=False, **kwargs
expression,
infer_schema=True,
validate_qualify_columns=False,
identify=False,
**kwargs,
)
return expression

Expand Down Expand Up @@ -111,7 +115,14 @@ def setUp(self):
}

def check_file(
self, file, func, pretty=False, execute=False, set_dialect=False, only=None, **kwargs
self,
file,
func,
pretty=False,
execute=False,
set_dialect=False,
only=None,
**kwargs,
):
with ProcessPoolExecutor() as pool:
results = {}
Expand Down Expand Up @@ -331,7 +342,11 @@ def test_qualify_columns(self, logger):
)

self.check_file(
"qualify_columns", qualify_columns, execute=True, schema=self.schema, set_dialect=True
"qualify_columns",
qualify_columns,
execute=True,
schema=self.schema,
set_dialect=True,
)
self.check_file(
"qualify_columns_ddl", qualify_columns, schema=self.schema, set_dialect=True
Expand All @@ -343,7 +358,8 @@ def test_qualify_columns__with_invisible(self):

def test_pushdown_cte_alias_columns(self):
self.check_file(
"pushdown_cte_alias_columns", optimizer.qualify_columns.pushdown_cte_alias_columns
"pushdown_cte_alias_columns",
optimizer.qualify_columns.pushdown_cte_alias_columns,
)

def test_qualify_columns__invalid(self):
Expand Down Expand Up @@ -405,7 +421,8 @@ def test_simplify(self):
self.assertEqual(optimizer.simplify.gen(query), optimizer.simplify.gen(query.copy()))

anon_unquoted_identifier = exp.Anonymous(
this=exp.to_identifier("anonymous"), expressions=[exp.column("x"), exp.column("y")]
this=exp.to_identifier("anonymous"),
expressions=[exp.column("x"), exp.column("y")],
)
self.assertEqual(optimizer.simplify.gen(anon_unquoted_identifier), "ANONYMOUS(x,y)")

Expand All @@ -416,7 +433,10 @@ def test_simplify(self):
anon_invalid = exp.Anonymous(this=5)
optimizer.simplify.gen(anon_invalid)

self.assertIn("Anonymous.this expects a str or an Identifier, got 'int'.", str(e.exception))
self.assertIn(
"Anonymous.this expects a str or an Identifier, got 'int'.",
str(e.exception),
)

sql = parse_one(
"""
Expand Down Expand Up @@ -906,7 +926,8 @@ def test_cte_column_annotation(self):

# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
for d, t in zip(
cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]
cte_select.find_all(exp.Subquery),
[exp.DataType.Type.CHAR, exp.DataType.Type.TEXT],
):
self.assertEqual(d.this.expressions[0].this.type.this, t)

Expand Down Expand Up @@ -1020,7 +1041,8 @@ def test_aggfunc_annotation(self):

for (func, col), target_type in tests.items():
expression = annotate_types(
parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema
parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"),
schema=schema,
)
self.assertEqual(expression.expressions[0].type.this, target_type)

Expand All @@ -1035,7 +1057,13 @@ def test_root_subquery_annotation(self):
self.assertEqual(exp.DataType.Type.INT, expression.selects[1].type.this)

def test_nested_type_annotation(self):
schema = {"order": {"customer_id": "bigint", "item_id": "bigint", "item_price": "numeric"}}
schema = {
"order": {
"customer_id": "bigint",
"item_id": "bigint",
"item_price": "numeric",
}
}
sql = """
SELECT ARRAY_AGG(DISTINCT order.item_id) FILTER (WHERE order.item_price > 10) AS items,
FROM order AS order
Expand All @@ -1057,7 +1085,8 @@ def test_nested_type_annotation(self):

self.assertEqual(expression.selects[0].type.sql(dialect="bigquery"), "STRUCT<`f` STRING>")
self.assertEqual(
expression.selects[1].type.sql(dialect="bigquery"), "ARRAY<STRUCT<`f` STRING>>"
expression.selects[1].type.sql(dialect="bigquery"),
"ARRAY<STRUCT<`f` STRING>>",
)

expression = annotate_types(
Expand Down Expand Up @@ -1206,7 +1235,8 @@ def test_no_pseudocolumn_expansion(self):

self.assertEqual(
optimizer.optimize(
parse_one("SELECT * FROM a"), schema=MappingSchema(schema, dialect="bigquery")
parse_one("SELECT * FROM a"),
schema=MappingSchema(schema, dialect="bigquery"),
),
parse_one('SELECT "a"."a" AS "a", "a"."b" AS "b" FROM "a" AS "a"'),
)
Expand Down
Loading

0 comments on commit e998308

Please sign in to comment.