Skip to content

Commit d8d968e

Browse files
committed
feat: update statement column lineage (reata#487)
1 parent b067eaf commit d8d968e

File tree

2 files changed

+258
-8
lines changed

2 files changed

+258
-8
lines changed

sqllineage/core/parser/sqlfluff/extractors/update.py

+125-8
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,48 @@
33
from sqlfluff.core.parser import BaseSegment
44

55
from sqllineage.core.holders import SubQueryLineageHolder
6-
from sqllineage.core.models import SubQuery, Table
6+
from sqllineage.core.metadata_provider import MetaDataProvider
7+
from sqllineage.core.models import SubQuery, Table, Column, Path
8+
from sqllineage.core.parser import SourceHandlerMixin
79
from sqllineage.core.parser.sqlfluff.extractors.base import BaseExtractor
8-
from sqllineage.core.parser.sqlfluff.models import SqlFluffSubQuery
10+
from sqllineage.core.parser.sqlfluff.models import SqlFluffSubQuery, SqlFluffTable
911
from sqllineage.core.parser.sqlfluff.utils import (
1012
find_from_expression_element,
1113
find_table_identifier,
1214
list_child_segments,
1315
list_join_clause,
16+
extract_column_qualifier,
17+
list_subqueries,
1418
)
1519
from sqllineage.utils.entities import AnalyzerContext
20+
from sqllineage.utils.helpers import escape_identifier_name
1621

1722

18-
class UpdateExtractor(BaseExtractor):
23+
class UpdateExtractor(BaseExtractor, SourceHandlerMixin):
1924
"""
2025
Update statement lineage extractor
2126
"""
2227

28+
def __init__(self, dialect: str, metadata_provider: MetaDataProvider):
29+
super().__init__(dialect, metadata_provider)
30+
self.columns = []
31+
self.tables = []
32+
self.union_barriers = []
33+
2334
SUPPORTED_STMT_TYPES = ["update_statement"]
2435

2536
def extract(
2637
self, statement: BaseSegment, context: AnalyzerContext
2738
) -> SubQueryLineageHolder:
2839
holder = self._init_holder(context)
2940
tgt_flag = False
41+
subqueries = []
3042
for segment in list_child_segments(statement):
43+
for sq in self.list_subquery(segment):
44+
# Collecting subquery on the way, hold on parsing until last
45+
# so that each handler don't have to worry about what's inside subquery
46+
subqueries.append(sq)
47+
3148
if segment.type == "from_expression":
3249
# UPDATE with JOIN, mysql only syntax
3350
if table := self.find_table_from_from_expression_or_join_clause(
@@ -49,16 +66,116 @@ def extract(
4966
holder.add_write(table)
5067
tgt_flag = False
5168

69+
if segment.type == "set_clause_list":
70+
for set_clause in segment.get_children("set_clause"):
71+
columns = set_clause.get_children("column_reference")
72+
if len(columns) == 2:
73+
tgt_cqt = extract_column_qualifier(columns[0])
74+
src_cqt = extract_column_qualifier(columns[1])
75+
if tgt_cqt is not None and src_cqt is not None:
76+
self.columns.append(
77+
Column(tgt_cqt.column, source_columns=[src_cqt])
78+
)
79+
5280
if segment.type == "from_clause":
5381
# UPDATE FROM, ansi syntax
54-
if from_expression := segment.get_child("from_expression"):
55-
if table := self.find_table_from_from_expression_or_join_clause(
56-
from_expression, holder
57-
):
58-
holder.add_read(table)
82+
self._handle_table(segment, holder)
83+
84+
self.end_of_query_cleanup(holder)
85+
86+
self.extract_subquery(subqueries, holder)
5987

6088
return holder
6189

90+
def _handle_table(
91+
self, segment: BaseSegment, holder: SubQueryLineageHolder
92+
) -> None:
93+
"""
94+
handle from_clause or join_clause, join_clause is a child node of from_clause.
95+
"""
96+
if segment.type in ["from_clause", "join_clause"]:
97+
from_expressions = segment.get_children("from_expression")
98+
if len(from_expressions) > 1:
99+
# SQL89 style of join
100+
for from_expression in from_expressions:
101+
if from_expression_element := find_from_expression_element(
102+
from_expression
103+
):
104+
self._add_dataset_from_expression_element(
105+
from_expression_element, holder
106+
)
107+
else:
108+
if from_expression_element := find_from_expression_element(segment):
109+
self._add_dataset_from_expression_element(
110+
from_expression_element, holder
111+
)
112+
for join_clause in list_join_clause(segment):
113+
self._handle_table(join_clause, holder)
114+
115+
def _add_dataset_from_expression_element(
116+
self, segment: BaseSegment, holder: SubQueryLineageHolder
117+
) -> None:
118+
"""
119+
Append tables and subqueries identified in the 'from_expression_element' type segment to the table and
120+
holder extra subqueries sets
121+
"""
122+
all_segments = [
123+
seg for seg in list_child_segments(segment) if seg.type != "keyword"
124+
]
125+
if table_expression := segment.get_child("table_expression"):
126+
if table_expression.get_child("function"):
127+
# for UNNEST or generator function, no dataset involved
128+
return
129+
first_segment = all_segments[0]
130+
if first_segment.type == "bracketed":
131+
if table_expression := first_segment.get_child("table_expression"):
132+
if table_expression.get_child("values_clause"):
133+
# (VALUES ...) AS alias, no dataset involved
134+
return
135+
subqueries = list_subqueries(segment)
136+
if subqueries:
137+
for sq in subqueries:
138+
bracketed, alias = sq
139+
read_sq = SqlFluffSubQuery.of(bracketed, alias)
140+
self.tables.append(read_sq)
141+
else:
142+
table_identifier = find_table_identifier(segment)
143+
if table_identifier:
144+
subquery_flag = False
145+
alias = None
146+
if len(all_segments) > 1 and all_segments[1].type == "alias_expression":
147+
all_segments = list_child_segments(all_segments[1])
148+
alias = str(
149+
all_segments[1].raw
150+
if len(all_segments) > 1
151+
else all_segments[0].raw
152+
)
153+
if "." not in table_identifier.raw:
154+
cte_dict = {s.alias: s for s in holder.cte}
155+
cte = cte_dict.get(table_identifier.raw)
156+
if cte is not None:
157+
# could reference CTE with or without alias
158+
self.tables.append(
159+
SqlFluffSubQuery.of(
160+
cte.query,
161+
alias or table_identifier.raw,
162+
)
163+
)
164+
subquery_flag = True
165+
if subquery_flag is False:
166+
if table_identifier.type == "file_reference":
167+
self.tables.append(
168+
Path(
169+
escape_identifier_name(
170+
table_identifier.segments[-1].raw
171+
)
172+
)
173+
)
174+
else:
175+
self.tables.append(
176+
SqlFluffTable.of(table_identifier, alias=alias)
177+
)
178+
62179
def find_table_from_from_expression_or_join_clause(
63180
self, segment, holder
64181
) -> Optional[Union[Table, SubQuery]]:
+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from sqllineage.utils.entities import ColumnQualifierTuple
2+
from ...helpers import assert_column_lineage_equal
3+
4+
5+
def test_column_update_with_single_table():
6+
sql = """
7+
update
8+
public.tgt_tbl1 as t
9+
set
10+
name = s.name,
11+
email = s.address
12+
from
13+
public.src_tbl1 as s
14+
where
15+
s.id = t.id
16+
"""
17+
assert_column_lineage_equal(
18+
sql,
19+
[
20+
(
21+
ColumnQualifierTuple("name", "public.src_tbl1"),
22+
ColumnQualifierTuple("name", "public.tgt_tbl1"),
23+
),
24+
(
25+
ColumnQualifierTuple("address", "public.src_tbl1"),
26+
ColumnQualifierTuple("email", "public.tgt_tbl1"),
27+
),
28+
],
29+
test_sqlparse=False,
30+
)
31+
32+
33+
def test_column_update_with_cte():
34+
sql = """
35+
with
36+
s as (
37+
select name, address from public.src_tbl1
38+
)
39+
update
40+
public.tgt_tbl1 as t
41+
set
42+
name = s.name,
43+
email = s.address
44+
from
45+
s
46+
where
47+
s.id = t.id
48+
"""
49+
assert_column_lineage_equal(
50+
sql,
51+
[
52+
(
53+
ColumnQualifierTuple("name", "public.src_tbl1"),
54+
ColumnQualifierTuple("name", "public.tgt_tbl1"),
55+
),
56+
(
57+
ColumnQualifierTuple("address", "public.src_tbl1"),
58+
ColumnQualifierTuple("email", "public.tgt_tbl1"),
59+
),
60+
],
61+
test_sqlparse=False,
62+
)
63+
64+
65+
def test_column_update_with_multi_tables():
66+
sql = """
67+
update
68+
public.tgt_tbl1 as t
69+
set
70+
name = s1.name,
71+
email = s2.email
72+
from
73+
public.src_tbl1 as s1
74+
join
75+
public.src_tbl2 as s2
76+
on
77+
s1.id = s2.id
78+
where
79+
s1.id = t.id
80+
"""
81+
assert_column_lineage_equal(
82+
sql,
83+
[
84+
(
85+
ColumnQualifierTuple("name", "public.src_tbl1"),
86+
ColumnQualifierTuple("name", "public.tgt_tbl1"),
87+
),
88+
(
89+
ColumnQualifierTuple("email", "public.src_tbl2"),
90+
ColumnQualifierTuple("email", "public.tgt_tbl1"),
91+
),
92+
],
93+
test_sqlparse=False,
94+
)
95+
96+
97+
def test_column_update_with_subquery():
98+
sql = """
99+
update
100+
public.tgt_tbl1 as t
101+
set
102+
name = s.name,
103+
email = s.email
104+
from
105+
(
106+
select
107+
s1.id,
108+
s1.name,
109+
s2.email
110+
from
111+
public.src_tbl1 as s1
112+
join
113+
public.src_tbl2 as s2
114+
on
115+
s1.id = s2.id
116+
) as s
117+
where
118+
s.id = t.id
119+
"""
120+
assert_column_lineage_equal(
121+
sql,
122+
[
123+
(
124+
ColumnQualifierTuple("name", "public.src_tbl1"),
125+
ColumnQualifierTuple("name", "public.tgt_tbl1"),
126+
),
127+
(
128+
ColumnQualifierTuple("email", "public.src_tbl2"),
129+
ColumnQualifierTuple("email", "public.tgt_tbl1"),
130+
),
131+
],
132+
test_sqlparse=False,
133+
)

0 commit comments

Comments
 (0)