Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: lateral column alias reference (#507) #521

Merged
merged 1 commit into from
Jan 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sqllineage/core/holders.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def expand_wildcard(self, metadata_provider: MetaDataProvider) -> None:
for column in self.write_columns:
if column.raw_name == "*":
tgt_wildcard = column
for src_wildcard in self._get_source_columns(tgt_wildcard):
for src_wildcard in self.get_source_columns(tgt_wildcard):
if source_table := src_wildcard.parent:
src_table_columns = []
if isinstance(source_table, SubQuery):
Expand All @@ -169,7 +169,7 @@ def _get_target_table(self) -> Optional[Union[SubQuery, Table]]:
table = next(iter(write_only))
return table

def _get_source_columns(self, node: Column) -> List[Column]:
def get_source_columns(self, node: Column) -> List[Column]:
return [
src
for (src, tgt, edge_type) in self.graph.in_edges(nbunch=node, data="type")
Expand Down
4 changes: 2 additions & 2 deletions sqllineage/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,15 @@ def __init__(self, name: str, **kwargs):
"""
self._parent: Set[Union[Path, Table, SubQuery]] = set()
self.raw_name = escape_identifier_name(name)
self.source_columns = (
self.source_columns = [
(
escape_identifier_name(raw_name),
escape_identifier_name(qualifier) if qualifier is not None else None,
)
for raw_name, qualifier in kwargs.pop(
"source_columns", ((self.raw_name, None),)
)
)
]

def __str__(self):
return (
Expand Down
24 changes: 24 additions & 0 deletions sqllineage/core/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,16 @@ def end_of_query_cleanup(self, holder: SubQueryLineageHolder) -> None:
if len(holder.write) > 1:
raise SQLLineageException
tgt_tbl = list(holder.write)[0]
lateral_aliases = set()
for idx, tgt_col in enumerate(col_grp):
tgt_col.parent = tgt_tbl
for lateral_alias_ref in col_grp[idx + 1 :]: # noqa: E203
if any(
src_col[0] == tgt_col.raw_name
for src_col in lateral_alias_ref.source_columns
):
lateral_aliases.add(tgt_col.raw_name)
break
for src_col in tgt_col.to_source_columns(
self.get_alias_mapping_from_table_group(tbl_grp, holder)
):
Expand All @@ -37,6 +45,22 @@ def end_of_query_cleanup(self, holder: SubQueryLineageHolder) -> None:
# for invalid query: create view test (col3, col4) select col1 as col2 from tab,
# when the length doesn't match, we fall back to default behavior
tgt_col = write_columns[idx]
is_lateral_alias_ref = False
for wc in holder.write_columns:
if wc.raw_name == "*":
continue
if (
src_col.raw_name == wc.raw_name
and src_col.raw_name in lateral_aliases
):
is_lateral_alias_ref = True
for lateral_alias_col in holder.get_source_columns(wc):
holder.add_column_lineage(
lateral_alias_col, tgt_col
)
break
if is_lateral_alias_ref:
continue
holder.add_column_lineage(src_col, tgt_col)

@classmethod
Expand Down
98 changes: 98 additions & 0 deletions tests/sql/column/test_column_select_lateral_alias_ref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from sqllineage.utils.entities import ColumnQualifierTuple
from ...helpers import assert_column_lineage_equal


def test_column_top_level_lateral_ref():
sql = """
insert into public.tgt_tbl1
select
name as user_name,
user_name || email as id -- lateral ref
from
public.src_tbl1
"""
assert_column_lineage_equal(
sql,
[
(
ColumnQualifierTuple("name", "public.src_tbl1"),
ColumnQualifierTuple("user_name", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("name", "public.src_tbl1"),
ColumnQualifierTuple("id", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("email", "public.src_tbl1"),
ColumnQualifierTuple("id", "public.tgt_tbl1"),
),
],
)


def test_column_lateral_ref_within_subquery():
sql = """
insert into public.tgt_tbl1
select
sq.name
from
(
select
id || name as alias1,
alias1 || email as name
from
public.src_tbl1
) as sq
"""
assert_column_lineage_equal(
sql,
[
(
ColumnQualifierTuple("id", "public.src_tbl1"),
ColumnQualifierTuple("name", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("name", "public.src_tbl1"),
ColumnQualifierTuple("name", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("email", "public.src_tbl1"),
ColumnQualifierTuple("name", "public.tgt_tbl1"),
),
],
)

sql = """
insert into public.tgt_tbl1
select
sq.name
from
(
select
st1.id || st1.name as alias1,
alias1 || st2.email as name
from
public.src_tbl1 as st1
join
public.src_tbl2 as st2
on
st1.id = st2.id
) as sq
"""
assert_column_lineage_equal(
sql,
[
(
ColumnQualifierTuple("id", "public.src_tbl1"),
ColumnQualifierTuple("name", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("name", "public.src_tbl1"),
ColumnQualifierTuple("name", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("email", "public.src_tbl2"),
ColumnQualifierTuple("name", "public.tgt_tbl1"),
),
],
)