Skip to content

Commit

Permalink
fix: misidentify-column-name-as-alias (#539)
Browse files Browse the repository at this point in the history
  • Loading branch information
maoxingda committed Jan 10, 2024
1 parent 0326b59 commit 15f9569
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 21 deletions.
1 change: 1 addition & 0 deletions sqllineage/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def _to_src_col(
else:
# select unqualified column
source = _to_src_col(src_col, None)
setattr(source, "has_qualifier", False)
for table in set(alias_mapping.values()):
# in case of only one table, we get the right answer
# in case of multiple tables, a bunch of possible tables are set
Expand Down
64 changes: 44 additions & 20 deletions sqllineage/core/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,32 @@ def end_of_query_cleanup(self, holder: SubQueryLineageHolder) -> None:
lateral_aliases = set()
for idx, tgt_col in enumerate(col_grp):
tgt_col.parent = tgt_tbl
for lateral_alias_ref in col_grp[idx + 1 :]: # noqa: E203
if any(
src_col[0] == tgt_col.raw_name
for src_col in lateral_alias_ref.source_columns
):
lateral_aliases.add(tgt_col.raw_name)
break
if (
hasattr(self, "metadata_provider")
and hasattr(tgt_col, "has_alias")
and tgt_col.has_alias is True
):
for lateral_alias_ref in col_grp[idx + 1 :]: # noqa: E203
if any(
src_col[0] == tgt_col.raw_name
for src_col in lateral_alias_ref.source_columns
if src_col[1] is None
and all(
src_col[0]
not in list(
map(
lambda x: str(x).rsplit(".", 1)[-1],
self.metadata_provider.get_table_columns(
read
),
)
)
for read in tbl_grp
if isinstance(read, Table)
)
):
lateral_aliases.add(tgt_col.raw_name)
break
for src_col in tgt_col.to_source_columns(
holder.get_alias_mapping_from_table_group(tbl_grp)
):
Expand All @@ -45,19 +64,24 @@ def end_of_query_cleanup(self, holder: SubQueryLineageHolder) -> None:
# when the length doesn't match, we fall back to default behavior
tgt_col = write_columns[idx]
is_lateral_alias_ref = False
for wc in holder.write_columns:
if wc.raw_name == "*":
continue
if (
src_col.raw_name == wc.raw_name
and src_col.raw_name in lateral_aliases
):
is_lateral_alias_ref = True
for lateral_alias_col in holder.get_source_columns(wc):
holder.add_column_lineage(
lateral_alias_col, tgt_col
)
break
if idx > 0 and len(lateral_aliases) > 0:
for wc in holder.write_columns:
if wc.raw_name == "*":
continue

Check warning on line 70 in sqllineage/core/parser/__init__.py

View check run for this annotation

Codecov / codecov/patch

sqllineage/core/parser/__init__.py#L70

Added line #L70 was not covered by tests
if (
hasattr(src_col, "has_qualifier")
and src_col.has_qualifier is False
and src_col.raw_name == wc.raw_name
and src_col.raw_name in lateral_aliases
):
for lateral_alias_col in holder.get_source_columns(
wc
):
is_lateral_alias_ref = True
holder.add_column_lineage(
lateral_alias_col, tgt_col
)
break
if is_lateral_alias_ref:
continue
holder.add_column_lineage(src_col, tgt_col)
4 changes: 3 additions & 1 deletion sqllineage/core/parser/sqlfluff/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,12 @@ def of(column: BaseSegment, **kwargs) -> Column:
if column.type == "select_clause_element":
source_columns, alias = SqlFluffColumn._get_column_and_alias(column)
if alias:
return Column(
alias_column = Column(
alias,
source_columns=source_columns,
)
setattr(alias_column, "has_alias", True)
return alias_column
if source_columns:
column_name = None
for sub_segment in list_child_segments(column):
Expand Down
109 changes: 109 additions & 0 deletions tests/sql/column/test_column_select_lateral_alias_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,113 @@ def test_column_top_level_lateral_ref():
ColumnQualifierTuple("id", "public.tgt_tbl1"),
),
],
test_sqlparse=False,
)
sql = """
insert into public.tgt_tbl1
(
name,
email
)
select
st1.name,
st1.name || st1.email || '@gmail.com' as email
from
public.src_tbl1 as st1
"""
assert_column_lineage_equal(
sql,
[
(
ColumnQualifierTuple("name", "public.src_tbl1"),
ColumnQualifierTuple("name", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("name", "public.src_tbl1"),
ColumnQualifierTuple("email", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("email", "public.src_tbl1"),
ColumnQualifierTuple("email", "public.tgt_tbl1"),
),
],
test_sqlparse=False,
)
sql = """
insert into public.tgt_tbl1
(
id,
id_original
)
select
'a || b || c' || id as id,
id as id_original
from
public.src_tbl1
"""
assert_column_lineage_equal(
sql,
[
(
ColumnQualifierTuple("id", "public.src_tbl1"),
ColumnQualifierTuple("id", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("id", "public.src_tbl1"),
ColumnQualifierTuple("id_original", "public.tgt_tbl1"),
),
],
test_sqlparse=False,
)
sql = """
insert into public.tgt_tbl1
(
id,
id_original
)
select
a || b || c || id as id,
id as id_original -- # noqa: E501 TODO: I need the metadata information for the table public.src_tbl1 to identify whether the column reference 'id' in this context is from the table public.src_tbl1 or from an alias reference, currently being used as an alias reference. Note: This decision may significantly deviate from the actual scenario.
from
public.src_tbl1
"""
assert_column_lineage_equal(
sql,
[
(
ColumnQualifierTuple("a", "public.src_tbl1"),
ColumnQualifierTuple("id", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("b", "public.src_tbl1"),
ColumnQualifierTuple("id", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("c", "public.src_tbl1"),
ColumnQualifierTuple("id", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("id", "public.src_tbl1"),
ColumnQualifierTuple("id", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("a", "public.src_tbl1"),
ColumnQualifierTuple("id_original", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("b", "public.src_tbl1"),
ColumnQualifierTuple("id_original", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("c", "public.src_tbl1"),
ColumnQualifierTuple("id_original", "public.tgt_tbl1"),
),
(
ColumnQualifierTuple("id", "public.src_tbl1"),
ColumnQualifierTuple("id_original", "public.tgt_tbl1"),
),
],
test_sqlparse=False,
)


Expand Down Expand Up @@ -60,6 +167,7 @@ def test_column_lateral_ref_within_subquery():
ColumnQualifierTuple("name", "public.tgt_tbl1"),
),
],
test_sqlparse=False,
)

sql = """
Expand Down Expand Up @@ -95,4 +203,5 @@ def test_column_lateral_ref_within_subquery():
ColumnQualifierTuple("name", "public.tgt_tbl1"),
),
],
test_sqlparse=False,
)

0 comments on commit 15f9569

Please sign in to comment.