Skip to content

Commit

Permalink
feat: lateral column alias reference (#507)
Browse files Browse the repository at this point in the history
  • Loading branch information
maoxingda committed Jan 1, 2024
1 parent b620d78 commit f0efc8a
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 4 deletions.
4 changes: 2 additions & 2 deletions sqllineage/core/holders.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def expand_wildcard(self, metadata_provider: MetaDataProvider) -> None:
for column in self.write_columns:
if column.raw_name == "*":
tgt_wildcard = column
for src_wildcard in self._get_source_columns(tgt_wildcard):
for src_wildcard in self.get_source_columns(tgt_wildcard):
if source_table := src_wildcard.parent:
src_table_columns = []
if isinstance(source_table, SubQuery):
Expand All @@ -169,7 +169,7 @@ def _get_target_table(self) -> Optional[Union[SubQuery, Table]]:
table = next(iter(write_only))
return table

def _get_source_columns(self, node: Column) -> List[Column]:
def get_source_columns(self, node: Column) -> List[Column]:
return [
src
for (src, tgt, edge_type) in self.graph.in_edges(nbunch=node, data="type")
Expand Down
4 changes: 2 additions & 2 deletions sqllineage/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,15 @@ def __init__(self, name: str, **kwargs):
"""
self._parent: Set[Union[Path, Table, SubQuery]] = set()
self.raw_name = escape_identifier_name(name)
self.source_columns = (
self.source_columns = [
(
escape_identifier_name(raw_name),
escape_identifier_name(qualifier) if qualifier is not None else None,
)
for raw_name, qualifier in kwargs.pop(
"source_columns", ((self.raw_name, None),)
)
)
]

def __str__(self):
return (
Expand Down
28 changes: 28 additions & 0 deletions sqllineage/core/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,17 @@ def end_of_query_cleanup(self, holder: SubQueryLineageHolder) -> None:
if len(holder.write) > 1:
raise SQLLineageException
tgt_tbl = list(holder.write)[0]
lateral_aliases = set()
for idx, tgt_col in enumerate(col_grp):
tgt_col.parent = tgt_tbl
if idx + 1 < len(col_grp) and isinstance(tgt_tbl, SubQuery):
for _, lateral_alias_ref in enumerate(col_grp, idx + 1):
if any(
src_col[0] == tgt_col.raw_name
for src_col in lateral_alias_ref.source_columns
):
lateral_aliases.add(tgt_col.raw_name)
break
for src_col in tgt_col.to_source_columns(
self.get_alias_mapping_from_table_group(tbl_grp, holder)
):
Expand All @@ -37,6 +46,25 @@ def end_of_query_cleanup(self, holder: SubQueryLineageHolder) -> None:
# for invalid query: create view test (col3, col4) select col1 as col2 from tab,
# when the length doesn't match, we fall back to default behavior
tgt_col = write_columns[idx]
if isinstance(tgt_tbl, SubQuery):
is_lateral_alias_ref = False
for wc in holder.write_columns:
if wc.raw_name == "*":
continue
if (
src_col.raw_name == wc.raw_name
and src_col.raw_name in lateral_aliases
):
is_lateral_alias_ref = True
for lateral_alias_col in holder.get_source_columns(
wc
):
holder.add_column_lineage(
lateral_alias_col, tgt_col
)
break
if is_lateral_alias_ref:
continue
holder.add_column_lineage(src_col, tgt_col)

@classmethod
Expand Down
70 changes: 70 additions & 0 deletions tests/sql/column/test_column_select_lateral_alias_ref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from sqllineage.utils.entities import ColumnQualifierTuple
from ...helpers import assert_column_lineage_equal


def test_column_lateral_ref():
sql = """
insert into public.tgt_tbl1
select
sq.name
from
(
select
id || name as alias1,
alias1 || email as name
from
public.src_tbl1
) as sq
"""
assert_column_lineage_equal(
sql,
[
(
ColumnQualifierTuple("id", "public.src_tbl1"),
ColumnQualifierTuple("name", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("name", "public.src_tbl1"),
ColumnQualifierTuple("name", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("email", "public.src_tbl1"),
ColumnQualifierTuple("name", "public.tgt_tbl1"),
),
],
)

sql = """
insert into public.tgt_tbl1
select
sq.name
from
(
select
st1.id || st1.name as alias1,
alias1 || st2.email as name
from
public.src_tbl1 as st1
join
public.src_tbl2 as st2
on
st1.id = st2.id
) as sq
"""
assert_column_lineage_equal(
sql,
[
(
ColumnQualifierTuple("id", "public.src_tbl1"),
ColumnQualifierTuple("name", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("name", "public.src_tbl1"),
ColumnQualifierTuple("name", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("email", "public.src_tbl2"),
ColumnQualifierTuple("name", "public.tgt_tbl1"),
),
],
)

0 comments on commit f0efc8a

Please sign in to comment.