Skip to content

Commit

Permalink
Refactor: clean up join mark elimination rule (#3653)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas authored Jun 14, 2024
1 parent e7a158b commit e8cab58
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 252 deletions.
8 changes: 5 additions & 3 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[

klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()}

klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS

if enum not in ("", "bigquery"):
klass.generator_class.SELECT_KINDS = ()

Expand Down Expand Up @@ -232,6 +234,9 @@ class Dialect(metaclass=_Dialect):
SUPPORTS_COLUMN_JOIN_MARKS = False
"""Whether the old-style outer join (+) syntax is supported."""

COPY_PARAMS_ARE_CSV = True
"""Separator of COPY statement parameters."""

NORMALIZE_FUNCTIONS: bool | str = "upper"
"""
Determines how function names are going to be normalized.
Expand Down Expand Up @@ -342,9 +347,6 @@ class Dialect(metaclass=_Dialect):
UNICODE_START: t.Optional[str] = None
UNICODE_END: t.Optional[str] = None

# Separator of COPY statement parameters
COPY_PARAMS_ARE_CSV = True

@classmethod
def get_or_raise(cls, dialect: DialectType) -> Dialect:
"""
Expand Down
161 changes: 0 additions & 161 deletions sqlglot/dialects/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,171 +33,10 @@ def _build_timetostr_or_tochar(args: t.List) -> exp.TimeToStr | exp.ToChar:
return exp.ToChar.from_arg_list(args)


def eliminate_join_marks(ast: exp.Expression) -> exp.Expression:
from sqlglot.optimizer.scope import traverse_scope

"""Remove join marks from an expression
SELECT * FROM a, b WHERE a.id = b.id(+)
becomes:
SELECT * FROM a LEFT JOIN b ON a.id = b.id
- for each scope
- for each column with a join mark
- find the predicate it belongs to
- remove the predicate from the where clause
- convert the predicate to a join with the (+) side as the left join table
- replace the existing join with the new join
Args:
ast: The AST to remove join marks from
Returns:
The AST with join marks removed"""
for scope in traverse_scope(ast):
_eliminate_join_marks_from_scope(scope)
return ast


def _update_from(
select: exp.Select,
new_join_dict: t.Dict[str, exp.Join],
old_join_dict: t.Dict[str, exp.Join],
) -> None:
"""If the from clause needs to become a new join, find an appropriate table to use as the new from.
updates select in place
Args:
select: The select statement to update
new_join_dict: The dictionary of new joins
old_join_dict: The dictionary of old joins
"""
old_from = select.args["from"]
if old_from.alias_or_name not in new_join_dict:
return
in_old_not_new = old_join_dict.keys() - new_join_dict.keys()
if len(in_old_not_new) >= 1:
new_from_name = list(old_join_dict.keys() - new_join_dict.keys())[0]
new_from_this = old_join_dict[new_from_name].this
new_from = exp.From(this=new_from_this)
del old_join_dict[new_from_name]
select.set("from", new_from)
else:
raise ValueError("Cannot determine which table to use as the new from")


def _has_join_mark(col: exp.Expression) -> bool:
"""Check if the column has a join mark
Args:
The column to check
"""
return col.args.get("join_mark", False)


def _predicate_to_join(
eq: exp.Binary, old_joins: t.Dict[str, exp.Join], old_from: exp.From
) -> t.Optional[exp.Join]:
"""Convert an equality predicate to a join if it contains a join mark
Args:
eq: The equality expression to convert to a join
Returns:
The join expression if the equality contains a join mark (otherwise None)
"""

# if not (isinstance(eq.left, exp.Column) or isinstance(eq.right, exp.Column)):
# return None

left_columns = [col for col in eq.left.find_all(exp.Column) if _has_join_mark(col)]
right_columns = [col for col in eq.right.find_all(exp.Column) if _has_join_mark(col)]

left_has_join_mark = len(left_columns) > 0
right_has_join_mark = len(right_columns) > 0

if left_has_join_mark:
for col in left_columns:
col.set("join_mark", False)
join_on = col.table
elif right_has_join_mark:
for col in right_columns:
col.set("join_mark", False)
join_on = col.table
else:
return None

join_this = old_joins.get(join_on, old_from).this
return exp.Join(this=join_this, on=eq, kind="LEFT")


if t.TYPE_CHECKING:
from sqlglot.optimizer.scope import Scope


def _eliminate_join_marks_from_scope(scope: Scope) -> None:
"""Remove join marks columns in scope's where clause.
Converts them to left joins and replaces any existing joins.
Updates scope in place.
Args:
scope: The scope to remove join marks from
"""
select_scope = scope.expression
where = select_scope.args.get("where")
joins = select_scope.args.get("joins")
if not where:
return
if not joins:
return

# dictionaries used to keep track of joins to be replaced
old_joins = {join.alias_or_name: join for join in list(joins)}
new_joins: t.Dict[str, exp.Join] = {}

for node in scope.find_all(exp.Column):
if _has_join_mark(node):
predicate = node.find_ancestor(exp.Predicate)
if not isinstance(predicate, exp.Binary):
continue
predicate_parent = predicate.parent

join_on = predicate.pop()
new_join = _predicate_to_join(
join_on, old_joins=old_joins, old_from=select_scope.args["from"]
)
# upsert new_join into new_joins dictionary
if new_join:
if new_join.alias_or_name in new_joins:
new_joins[new_join.alias_or_name].set(
"on",
exp.and_(
new_joins[new_join.alias_or_name].args["on"],
new_join.args["on"],
),
)
else:
new_joins[new_join.alias_or_name] = new_join
# If the parent is a binary node with only one child, promote the child to the parent
if predicate_parent:
if isinstance(predicate_parent, exp.Binary):
if predicate_parent.left is None:
predicate_parent.replace(predicate_parent.right)
elif predicate_parent.right is None:
predicate_parent.replace(predicate_parent.left)

_update_from(select_scope, new_joins, old_joins)
replacement_joins = [new_joins.get(join.alias_or_name, join) for join in old_joins.values()]
select_scope.set("joins", replacement_joins)
if not where.this:
where.pop()


class Oracle(Dialect):
ALIAS_POST_TABLESAMPLE = True
LOCKING_READS_SUPPORTED = True
TABLESAMPLE_SIZE_IS_PERCENT = True
SUPPORTS_COLUMN_JOIN_MARKS = True

# See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm
NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE
Expand Down
1 change: 0 additions & 1 deletion sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class Redshift(Postgres):
INDEX_OFFSET = 0
COPY_PARAMS_ARE_CSV = False
HEX_LOWERCASE = True
SUPPORTS_COLUMN_JOIN_MARKS = True

TIME_FORMAT = "'YYYY-MM-DD HH:MI:SS'"
TIME_MAPPING = {
Expand Down
162 changes: 132 additions & 30 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,52 @@
from sqlglot.generator import Generator


def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]:
"""
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
Args:
transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
"""

def _to_sql(self, expression: exp.Expression) -> str:
expression_type = type(expression)

expression = transforms[0](expression)
for transform in transforms[1:]:
expression = transform(expression)

_sql_handler = getattr(self, expression.key + "_sql", None)
if _sql_handler:
return _sql_handler(expression)

transforms_handler = self.TRANSFORMS.get(type(expression))
if transforms_handler:
if expression_type is type(expression):
if isinstance(expression, exp.Func):
return self.function_fallback_sql(expression)

# Ensures we don't enter an infinite loop. This can happen when the original expression
# has the same type as the final expression and there's no _sql method available for it,
# because then it'd re-enter _to_sql.
raise ValueError(
f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
)

return transforms_handler(self, expression)

raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")

return _to_sql


def unalias_group(expression: exp.Expression) -> exp.Expression:
"""
Replace references to select aliases in GROUP BY clauses.
Expand Down Expand Up @@ -623,47 +669,103 @@ def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression:
return expression


def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
) -> t.Callable[[Generator, exp.Expression], str]:
def eliminate_join_marks(expression: exp.Expression) -> exp.Expression:
"""
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first.
For example,
SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to
SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
Args:
transforms: sequence of transform functions. These will be called in order.
expression: The AST to remove join marks from.
Returns:
Function that can be used as a generator transform.
The AST with join marks removed.
"""
from sqlglot.optimizer.scope import traverse_scope

def _to_sql(self, expression: exp.Expression) -> str:
expression_type = type(expression)
for scope in traverse_scope(expression):
query = scope.expression

expression = transforms[0](expression)
for transform in transforms[1:]:
expression = transform(expression)
where = query.args.get("where")
joins = query.args.get("joins")

_sql_handler = getattr(self, expression.key + "_sql", None)
if _sql_handler:
return _sql_handler(expression)
if not where or not joins:
continue

transforms_handler = self.TRANSFORMS.get(type(expression))
if transforms_handler:
if expression_type is type(expression):
if isinstance(expression, exp.Func):
return self.function_fallback_sql(expression)
query_from = query.args["from"]

# Ensures we don't enter an infinite loop. This can happen when the original expression
# has the same type as the final expression and there's no _sql method available for it,
# because then it'd re-enter _to_sql.
raise ValueError(
f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed."
)
# These keep track of the joins to be replaced
new_joins: t.Dict[str, exp.Join] = {}
old_joins = {join.alias_or_name: join for join in joins}

return transforms_handler(self, expression)
for column in scope.columns:
if not column.args.get("join_mark"):
continue

raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
predicate = column.find_ancestor(exp.Predicate, exp.Select)
assert isinstance(
predicate, exp.Binary
), "Columns can only be marked with (+) when involved in a binary operation"

return _to_sql
predicate_parent = predicate.parent
join_predicate = predicate.pop()

left_columns = [
c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark")
]
right_columns = [
c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark")
]

assert not (
left_columns and right_columns
), "The (+) marker cannot appear in both sides of a binary predicate"

marked_column_tables = set()
for col in left_columns or right_columns:
table = col.table
assert table, f"Column {col} needs to be qualified with a table"

col.set("join_mark", False)
marked_column_tables.add(table)

assert (
len(marked_column_tables) == 1
), "Columns of only a single table can be marked with (+) in a given binary predicate"

join_this = old_joins.get(col.table, query_from).this
new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT")

# Upsert new_join into new_joins dictionary
new_join_alias_or_name = new_join.alias_or_name
existing_join = new_joins.get(new_join_alias_or_name)
if existing_join:
existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"]))
else:
new_joins[new_join_alias_or_name] = new_join

# If the parent of the target predicate is a binary node, then it now has only one child
if isinstance(predicate_parent, exp.Binary):
if predicate_parent.left is None:
predicate_parent.replace(predicate_parent.right)
else:
predicate_parent.replace(predicate_parent.left)

if query_from.alias_or_name in new_joins:
only_old_joins = old_joins.keys() - new_joins.keys()
assert (
len(only_old_joins) >= 1
), "Cannot determine which table to use in the new FROM clause"

new_from_name = list(only_old_joins)[0]
query.set("from", exp.From(this=old_joins[new_from_name].this))

query.set("joins", list(new_joins.values()))

if not where.this:
where.pop()

return expression
Loading

0 comments on commit e8cab58

Please sign in to comment.