3
3
from sqlfluff .core .parser import BaseSegment
4
4
5
5
from sqllineage .core .holders import SubQueryLineageHolder
6
- from sqllineage .core .models import SubQuery , Table
6
+ from sqllineage .core .metadata_provider import MetaDataProvider
7
+ from sqllineage .core .models import SubQuery , Table , Column , Path
8
+ from sqllineage .core .parser import SourceHandlerMixin
7
9
from sqllineage .core .parser .sqlfluff .extractors .base import BaseExtractor
8
- from sqllineage .core .parser .sqlfluff .models import SqlFluffSubQuery
10
+ from sqllineage .core .parser .sqlfluff .models import SqlFluffSubQuery , SqlFluffTable
9
11
from sqllineage .core .parser .sqlfluff .utils import (
10
12
find_from_expression_element ,
11
13
find_table_identifier ,
12
14
list_child_segments ,
13
15
list_join_clause ,
16
+ extract_column_qualifier ,
17
+ list_subqueries ,
14
18
)
15
19
from sqllineage .utils .entities import AnalyzerContext
20
+ from sqllineage .utils .helpers import escape_identifier_name
16
21
17
22
18
- class UpdateExtractor (BaseExtractor ):
23
+ class UpdateExtractor (BaseExtractor , SourceHandlerMixin ):
19
24
"""
20
25
Update statement lineage extractor
21
26
"""
22
27
28
+ def __init__ (self , dialect : str , metadata_provider : MetaDataProvider ):
29
+ super ().__init__ (dialect , metadata_provider )
30
+ self .columns = []
31
+ self .tables = []
32
+ self .union_barriers = []
33
+
23
34
SUPPORTED_STMT_TYPES = ["update_statement" ]
24
35
25
36
def extract (
26
37
self , statement : BaseSegment , context : AnalyzerContext
27
38
) -> SubQueryLineageHolder :
28
39
holder = self ._init_holder (context )
29
40
tgt_flag = False
41
+ subqueries = []
30
42
for segment in list_child_segments (statement ):
43
+ for sq in self .list_subquery (segment ):
44
+ # Collecting subquery on the way, hold on parsing until last
45
+ # so that each handler don't have to worry about what's inside subquery
46
+ subqueries .append (sq )
47
+
31
48
if segment .type == "from_expression" :
32
49
# UPDATE with JOIN, mysql only syntax
33
50
if table := self .find_table_from_from_expression_or_join_clause (
@@ -49,16 +66,116 @@ def extract(
49
66
holder .add_write (table )
50
67
tgt_flag = False
51
68
69
+ if segment .type == "set_clause_list" :
70
+ for set_clause in segment .get_children ("set_clause" ):
71
+ columns = set_clause .get_children ("column_reference" )
72
+ if len (columns ) == 2 :
73
+ tgt_cqt = extract_column_qualifier (columns [0 ])
74
+ src_cqt = extract_column_qualifier (columns [1 ])
75
+ if tgt_cqt is not None and src_cqt is not None :
76
+ self .columns .append (
77
+ Column (tgt_cqt .column , source_columns = [src_cqt ])
78
+ )
79
+
52
80
if segment .type == "from_clause" :
53
81
# UPDATE FROM, ansi syntax
54
- if from_expression := segment . get_child ( "from_expression" ):
55
- if table := self . find_table_from_from_expression_or_join_clause (
56
- from_expression , holder
57
- ):
58
- holder . add_read ( table )
82
+ self . _handle_table ( segment , holder )
83
+
84
+ self . end_of_query_cleanup ( holder )
85
+
86
+ self . extract_subquery ( subqueries , holder )
59
87
60
88
return holder
61
89
90
+ def _handle_table (
91
+ self , segment : BaseSegment , holder : SubQueryLineageHolder
92
+ ) -> None :
93
+ """
94
+ handle from_clause or join_clause, join_clause is a child node of from_clause.
95
+ """
96
+ if segment .type in ["from_clause" , "join_clause" ]:
97
+ from_expressions = segment .get_children ("from_expression" )
98
+ if len (from_expressions ) > 1 :
99
+ # SQL89 style of join
100
+ for from_expression in from_expressions :
101
+ if from_expression_element := find_from_expression_element (
102
+ from_expression
103
+ ):
104
+ self ._add_dataset_from_expression_element (
105
+ from_expression_element , holder
106
+ )
107
+ else :
108
+ if from_expression_element := find_from_expression_element (segment ):
109
+ self ._add_dataset_from_expression_element (
110
+ from_expression_element , holder
111
+ )
112
+ for join_clause in list_join_clause (segment ):
113
+ self ._handle_table (join_clause , holder )
114
+
115
+ def _add_dataset_from_expression_element (
116
+ self , segment : BaseSegment , holder : SubQueryLineageHolder
117
+ ) -> None :
118
+ """
119
+ Append tables and subqueries identified in the 'from_expression_element' type segment to the table and
120
+ holder extra subqueries sets
121
+ """
122
+ all_segments = [
123
+ seg for seg in list_child_segments (segment ) if seg .type != "keyword"
124
+ ]
125
+ if table_expression := segment .get_child ("table_expression" ):
126
+ if table_expression .get_child ("function" ):
127
+ # for UNNEST or generator function, no dataset involved
128
+ return
129
+ first_segment = all_segments [0 ]
130
+ if first_segment .type == "bracketed" :
131
+ if table_expression := first_segment .get_child ("table_expression" ):
132
+ if table_expression .get_child ("values_clause" ):
133
+ # (VALUES ...) AS alias, no dataset involved
134
+ return
135
+ subqueries = list_subqueries (segment )
136
+ if subqueries :
137
+ for sq in subqueries :
138
+ bracketed , alias = sq
139
+ read_sq = SqlFluffSubQuery .of (bracketed , alias )
140
+ self .tables .append (read_sq )
141
+ else :
142
+ table_identifier = find_table_identifier (segment )
143
+ if table_identifier :
144
+ subquery_flag = False
145
+ alias = None
146
+ if len (all_segments ) > 1 and all_segments [1 ].type == "alias_expression" :
147
+ all_segments = list_child_segments (all_segments [1 ])
148
+ alias = str (
149
+ all_segments [1 ].raw
150
+ if len (all_segments ) > 1
151
+ else all_segments [0 ].raw
152
+ )
153
+ if "." not in table_identifier .raw :
154
+ cte_dict = {s .alias : s for s in holder .cte }
155
+ cte = cte_dict .get (table_identifier .raw )
156
+ if cte is not None :
157
+ # could reference CTE with or without alias
158
+ self .tables .append (
159
+ SqlFluffSubQuery .of (
160
+ cte .query ,
161
+ alias or table_identifier .raw ,
162
+ )
163
+ )
164
+ subquery_flag = True
165
+ if subquery_flag is False :
166
+ if table_identifier .type == "file_reference" :
167
+ self .tables .append (
168
+ Path (
169
+ escape_identifier_name (
170
+ table_identifier .segments [- 1 ].raw
171
+ )
172
+ )
173
+ )
174
+ else :
175
+ self .tables .append (
176
+ SqlFluffTable .of (table_identifier , alias = alias )
177
+ )
178
+
62
179
def find_table_from_from_expression_or_join_clause (
63
180
self , segment , holder
64
181
) -> Optional [Union [Table , SubQuery ]]:
0 commit comments