diff --git a/sqllineage/core/holders.py b/sqllineage/core/holders.py index f96e7f3d..c62a553f 100644 --- a/sqllineage/core/holders.py +++ b/sqllineage/core/holders.py @@ -144,7 +144,7 @@ def expand_wildcard(self, metadata_provider: MetaDataProvider) -> None: for column in self.write_columns: if column.raw_name == "*": tgt_wildcard = column - for src_wildcard in self._get_source_columns(tgt_wildcard): + for src_wildcard in self.get_source_columns(tgt_wildcard): if source_table := src_wildcard.parent: src_table_columns = [] if isinstance(source_table, SubQuery): @@ -169,7 +169,7 @@ def _get_target_table(self) -> Optional[Union[SubQuery, Table]]: table = next(iter(write_only)) return table - def _get_source_columns(self, node: Column) -> List[Column]: + def get_source_columns(self, node: Column) -> List[Column]: return [ src for (src, tgt, edge_type) in self.graph.in_edges(nbunch=node, data="type") diff --git a/sqllineage/core/models.py b/sqllineage/core/models.py index 610571bd..85457d45 100644 --- a/sqllineage/core/models.py +++ b/sqllineage/core/models.py @@ -147,7 +147,7 @@ def __init__(self, name: str, **kwargs): """ self._parent: Set[Union[Path, Table, SubQuery]] = set() self.raw_name = escape_identifier_name(name) - self.source_columns = ( + self.source_columns = [ ( escape_identifier_name(raw_name), escape_identifier_name(qualifier) if qualifier is not None else None, @@ -155,7 +155,7 @@ def __init__(self, name: str, **kwargs): for raw_name, qualifier in kwargs.pop( "source_columns", ((self.raw_name, None),) ) - ) + ] def __str__(self): return ( diff --git a/sqllineage/core/parser/__init__.py b/sqllineage/core/parser/__init__.py index 73d24af2..f9c7ea1b 100644 --- a/sqllineage/core/parser/__init__.py +++ b/sqllineage/core/parser/__init__.py @@ -25,8 +25,16 @@ def end_of_query_cleanup(self, holder: SubQueryLineageHolder) -> None: if len(holder.write) > 1: raise SQLLineageException tgt_tbl = list(holder.write)[0] + lateral_aliases = set() for idx, tgt_col in enumerate(col_grp): tgt_col.parent = tgt_tbl + for lateral_alias_ref in col_grp[idx + 1 :]: # noqa: E203 + if any( + src_col[0] == tgt_col.raw_name + for src_col in lateral_alias_ref.source_columns + ): + lateral_aliases.add(tgt_col.raw_name) + break for src_col in tgt_col.to_source_columns( self.get_alias_mapping_from_table_group(tbl_grp, holder) ): @@ -37,6 +45,22 @@ def end_of_query_cleanup(self, holder: SubQueryLineageHolder) -> None: # for invalid query: create view test (col3, col4) select col1 as col2 from tab, # when the length doesn't match, we fall back to default behavior tgt_col = write_columns[idx] + is_lateral_alias_ref = False + for wc in holder.write_columns: + if wc.raw_name == "*": + continue + if ( + src_col.raw_name == wc.raw_name + and src_col.raw_name in lateral_aliases + ): + is_lateral_alias_ref = True + for lateral_alias_col in holder.get_source_columns(wc): + holder.add_column_lineage( + lateral_alias_col, tgt_col + ) + break + if is_lateral_alias_ref: + continue holder.add_column_lineage(src_col, tgt_col) @classmethod diff --git a/tests/sql/column/test_column_select_lateral_alias_ref.py b/tests/sql/column/test_column_select_lateral_alias_ref.py new file mode 100644 index 00000000..28b7dc81 --- /dev/null +++ b/tests/sql/column/test_column_select_lateral_alias_ref.py @@ -0,0 +1,98 @@ +from sqllineage.utils.entities import ColumnQualifierTuple +from ...helpers import assert_column_lineage_equal + + +def test_column_top_level_lateral_ref(): + sql = """ + insert into public.tgt_tbl1 + select + name as user_name, + user_name || email as id -- lateral ref + from + public.src_tbl1 + """ + assert_column_lineage_equal( + sql, + [ + ( + ColumnQualifierTuple("name", "public.src_tbl1"), + ColumnQualifierTuple("user_name", "public.tgt_tbl1"), + ), + ( + ColumnQualifierTuple("name", "public.src_tbl1"), + ColumnQualifierTuple("id", "public.tgt_tbl1"), + ), + ( + ColumnQualifierTuple("email", "public.src_tbl1"), + ColumnQualifierTuple("id", "public.tgt_tbl1"), + ), + ], + ) + + +def test_column_lateral_ref_within_subquery(): + sql = """ + insert into public.tgt_tbl1 + select + sq.name + from + ( + select + id || name as alias1, + alias1 || email as name + from + public.src_tbl1 + ) as sq + """ + assert_column_lineage_equal( + sql, + [ + ( + ColumnQualifierTuple("id", "public.src_tbl1"), + ColumnQualifierTuple("name", "public.tgt_tbl1"), + ), + ( + ColumnQualifierTuple("name", "public.src_tbl1"), + ColumnQualifierTuple("name", "public.tgt_tbl1"), + ), + ( + ColumnQualifierTuple("email", "public.src_tbl1"), + ColumnQualifierTuple("name", "public.tgt_tbl1"), + ), + ], + ) + + sql = """ + insert into public.tgt_tbl1 + select + sq.name + from + ( + select + st1.id || st1.name as alias1, + alias1 || st2.email as name + from + public.src_tbl1 as st1 + join + public.src_tbl2 as st2 + on + st1.id = st2.id + ) as sq + """ + assert_column_lineage_equal( + sql, + [ + ( + ColumnQualifierTuple("id", "public.src_tbl1"), + ColumnQualifierTuple("name", "public.tgt_tbl1"), + ), + ( + ColumnQualifierTuple("name", "public.src_tbl1"), + ColumnQualifierTuple("name", "public.tgt_tbl1"), + ), + ( + ColumnQualifierTuple("email", "public.src_tbl2"), + ColumnQualifierTuple("name", "public.tgt_tbl1"), + ), + ], + )