Skip to content

Commit cbd915c

Browse files
authored
refactor: merge sqlfluff handler to extractor (#436)
1 parent 2eae7da commit cbd915c

24 files changed

+730
-912
lines changed

sqllineage/core/parser/sqlfluff/analyzer.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44

55
from sqllineage.core.analyzer import LineageAnalyzer
66
from sqllineage.core.holders import StatementLineageHolder
7-
from sqllineage.core.parser.sqlfluff.extractors.lineage_holder_extractor import (
8-
LineageHolderExtractor,
9-
)
7+
from sqllineage.core.parser.sqlfluff.extractors.base import BaseExtractor
108
from sqllineage.exceptions import (
119
InvalidSyntaxException,
1210
UnsupportedStatementException,
@@ -58,7 +56,7 @@ def analyze(self, sql: str) -> StatementLineageHolder:
5856
else:
5957
for extractor in [
6058
extractor_cls(self._dialect)
61-
for extractor_cls in LineageHolderExtractor.__subclasses__()
59+
for extractor_cls in BaseExtractor.__subclasses__()
6260
]:
6361
if extractor.can_extract(statement_segment.type):
6462
lineage_holder = extractor.extract(

sqllineage/core/parser/sqlfluff/extractors/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
import os
33
import pkgutil
44

5-
# import each module so that LineageHolderExtractor's __subclasses__ will work
5+
# import each module so that BaseExtractor's __subclasses__ will work
66
for module in pkgutil.iter_modules([os.path.dirname(__file__)]):
77
importlib.import_module(__name__ + "." + module.name)

sqllineage/core/parser/sqlfluff/extractors/lineage_holder_extractor.py renamed to sqllineage/core/parser/sqlfluff/extractors/base.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from functools import reduce
22
from operator import add
3-
from typing import List
3+
from typing import List, Optional, Type
44

55
from sqlfluff.core.parser import BaseSegment
66

77
from sqllineage.core.holders import SubQueryLineageHolder
8-
from sqllineage.core.models import SubQuery
9-
from sqllineage.core.parser.sqlfluff.models import SqlFluffSubQuery
8+
from sqllineage.core.models import SubQuery, Table
9+
from sqllineage.core.parser.sqlfluff.models import SqlFluffSubQuery, SqlFluffTable
1010
from sqllineage.core.parser.sqlfluff.utils import (
1111
get_children,
1212
is_subquery,
@@ -15,7 +15,7 @@
1515
from sqllineage.utils.entities import AnalyzerContext, SubQueryTuple
1616

1717

18-
class LineageHolderExtractor:
18+
class BaseExtractor:
1919
"""
2020
Abstract class implementation for extract 'SubQueryLineageHolder' from different statement types
2121
"""
@@ -46,7 +46,14 @@ def extract(
4646
raise NotImplementedError
4747

4848
@classmethod
49-
def parse_subquery(cls, segment: BaseSegment) -> List[SubQuery]:
49+
def find_table(cls, segment: BaseSegment) -> Optional[Table]:
50+
table = None
51+
if segment.type in ["table_reference", "object_reference"]:
52+
table = SqlFluffTable.of(segment)
53+
return table
54+
55+
@classmethod
56+
def list_subquery(cls, segment: BaseSegment) -> List[SubQuery]:
5057
"""
5158
The parse_subquery function takes a segment as an argument.
5259
:param segment: segment to determine if it is a subquery
@@ -84,6 +91,17 @@ def _parse_subquery(cls, subqueries: List[SubQueryTuple]) -> List[SubQuery]:
8491
for bracketed_segment, alias in subqueries
8592
]
8693

94+
def delegate_to(
95+
self,
96+
extractor_cls: "Type[BaseExtractor]",
97+
segment: BaseSegment,
98+
context: AnalyzerContext,
99+
) -> SubQueryLineageHolder:
100+
"""
101+
delegate to another type of extractor to extract
102+
"""
103+
return extractor_cls(self.dialect).extract(segment, context)
104+
87105
@staticmethod
88106
def _init_holder(context: AnalyzerContext) -> SubQueryLineageHolder:
89107
"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from sqlfluff.core.parser import BaseSegment
2+
3+
from sqllineage.core.holders import StatementLineageHolder
4+
from sqllineage.core.models import Path
5+
from sqllineage.core.parser.sqlfluff.extractors.base import BaseExtractor
6+
from sqllineage.core.parser.sqlfluff.utils import (
7+
find_from_expression_element,
8+
get_child,
9+
get_children,
10+
list_child_segments,
11+
)
12+
from sqllineage.utils.entities import AnalyzerContext
13+
from sqllineage.utils.helpers import escape_identifier_name
14+
15+
16+
class CopyExtractor(BaseExtractor):
17+
"""
18+
Copy statement lineage extractor
19+
"""
20+
21+
SUPPORTED_STMT_TYPES = [
22+
"copy_statement",
23+
"copy_into_table_statement",
24+
]
25+
26+
def extract(
27+
self, statement: BaseSegment, context: AnalyzerContext
28+
) -> StatementLineageHolder:
29+
holder = StatementLineageHolder()
30+
src_flag = tgt_flag = False
31+
for segment in list_child_segments(statement):
32+
if segment.type == "from_clause":
33+
if from_expression_element := find_from_expression_element(segment):
34+
for table_expression in get_children(
35+
from_expression_element, "table_expression"
36+
):
37+
if storage_location := get_child(
38+
table_expression, "storage_location"
39+
):
40+
holder.add_read(Path(storage_location.raw))
41+
elif segment.type == "keyword":
42+
if segment.raw_upper in ["COPY", "INTO"]:
43+
tgt_flag = True
44+
elif segment.raw_upper == "FROM":
45+
src_flag = True
46+
continue
47+
48+
if tgt_flag:
49+
if table := self.find_table(segment):
50+
holder.add_write(table)
51+
tgt_flag = False
52+
if src_flag:
53+
if segment.type in ["literal", "storage_location"]:
54+
path = Path(escape_identifier_name(segment.raw))
55+
holder.add_read(path)
56+
src_flag = False
57+
58+
return holder
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from sqlfluff.core.parser import BaseSegment
2+
3+
from sqllineage.core.holders import SubQueryLineageHolder
4+
from sqllineage.core.models import Path
5+
from sqllineage.core.parser.sqlfluff.extractors.base import BaseExtractor
6+
from sqllineage.core.parser.sqlfluff.extractors.select import SelectExtractor
7+
from sqllineage.core.parser.sqlfluff.models import SqlFluffColumn, SqlFluffTable
8+
from sqllineage.core.parser.sqlfluff.utils import (
9+
get_child,
10+
is_set_expression,
11+
list_child_segments,
12+
)
13+
from sqllineage.utils.entities import AnalyzerContext
14+
from sqllineage.utils.helpers import escape_identifier_name
15+
16+
17+
class CreateInsertExtractor(BaseExtractor):
18+
"""
19+
Create statement and Insert statement lineage extractor
20+
"""
21+
22+
SUPPORTED_STMT_TYPES = [
23+
"create_table_statement",
24+
"create_table_as_statement",
25+
"create_view_statement",
26+
"insert_statement",
27+
"insert_overwrite_directory_hive_fmt_statement",
28+
]
29+
30+
def extract(
31+
self,
32+
statement: BaseSegment,
33+
context: AnalyzerContext,
34+
) -> SubQueryLineageHolder:
35+
holder = self._init_holder(context)
36+
src_flag = tgt_flag = False
37+
for segment in list_child_segments(statement):
38+
if segment.type == "with_compound_statement":
39+
holder |= self.delegate_to_cte(segment, holder)
40+
elif segment.type == "bracketed" and any(
41+
s.type == "with_compound_statement" for s in segment.segments
42+
):
43+
for sgmt in segment.segments:
44+
if sgmt.type == "with_compound_statement":
45+
holder |= self.delegate_to_cte(segment, holder)
46+
elif segment.type in ("select_statement", "set_expression"):
47+
holder |= self.delegate_to_select(segment, holder)
48+
elif segment.type == "bracketed" and (
49+
self.list_subquery(segment) or is_set_expression(segment)
50+
):
51+
# note regular subquery within SELECT statement is handled by SelectExtractor, this is only to handle
52+
# top-level subquery in DML like: 1) create table foo as (subquery); 2) insert into foo (subquery)
53+
# subquery here isn't added as read source, and it inherits DML-level write_columns if parsed
54+
if subquery_segment := get_child(
55+
segment, "select_statement", "set_expression"
56+
):
57+
holder |= self.delegate_to_select(subquery_segment, holder)
58+
59+
elif segment.type == "bracketed":
60+
# In case of bracketed column reference, add these target columns to holder
61+
# so that when we compute the column level lineage
62+
# we keep these columns into consideration
63+
sub_segments = list_child_segments(segment)
64+
if all(
65+
sub_segment.type in ["column_reference", "column_definition"]
66+
for sub_segment in sub_segments
67+
):
68+
# target columns only apply to bracketed column_reference and column_definition
69+
columns = []
70+
for sub_segment in sub_segments:
71+
if sub_segment.type == "column_definition":
72+
sub_segment = get_child(sub_segment, "identifier")
73+
columns.append(SqlFluffColumn.of(sub_segment))
74+
holder.add_write_column(*columns)
75+
76+
elif segment.type == "keyword":
77+
if segment.raw_upper in [
78+
"INTO",
79+
"OVERWRITE",
80+
"TABLE",
81+
"VIEW",
82+
"DIRECTORY",
83+
] or (
84+
tgt_flag is True and segment.raw_upper in ["IF", "NOT", "EXISTS"]
85+
):
86+
tgt_flag = True
87+
elif segment.raw_upper in ["LIKE", "CLONE"]:
88+
src_flag = True
89+
continue
90+
91+
if tgt_flag:
92+
if segment.type in ["table_reference", "object_reference"]:
93+
write_obj = SqlFluffTable.of(segment)
94+
holder.add_write(write_obj)
95+
elif segment.type == "literal":
96+
if segment.raw.isnumeric():
97+
# Special Handling for Spark Bucket Table DDL
98+
pass
99+
else:
100+
holder.add_write(Path(escape_identifier_name(segment.raw)))
101+
tgt_flag = False
102+
if src_flag:
103+
if segment.type in ["table_reference", "object_reference"]:
104+
holder.add_read(SqlFluffTable.of(segment))
105+
src_flag = False
106+
return holder
107+
108+
def delegate_to_cte(
109+
self, segment: BaseSegment, holder: SubQueryLineageHolder
110+
) -> SubQueryLineageHolder:
111+
from .cte import CteExtractor
112+
113+
return self.delegate_to(
114+
CteExtractor, segment, AnalyzerContext(cte=holder.cte, write=holder.write)
115+
)
116+
117+
def delegate_to_select(
118+
self,
119+
segment: BaseSegment,
120+
holder: SubQueryLineageHolder,
121+
) -> SubQueryLineageHolder:
122+
return self.delegate_to(
123+
SelectExtractor,
124+
segment,
125+
AnalyzerContext(
126+
cte=holder.cte,
127+
write=holder.write,
128+
write_columns=holder.write_columns,
129+
),
130+
)

sqllineage/core/parser/sqlfluff/extractors/cte_extractor.py renamed to sqllineage/core/parser/sqlfluff/extractors/cte.py

+15-30
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,43 @@
11
from sqlfluff.core.parser import BaseSegment
22

33
from sqllineage.core.holders import SubQueryLineageHolder
4-
from sqllineage.core.parser.sqlfluff.extractors.dml_insert_extractor import (
5-
DmlInsertExtractor,
6-
)
7-
from sqllineage.core.parser.sqlfluff.extractors.dml_select_extractor import (
8-
DmlSelectExtractor,
9-
)
10-
from sqllineage.core.parser.sqlfluff.extractors.lineage_holder_extractor import (
11-
LineageHolderExtractor,
4+
from sqllineage.core.parser.sqlfluff.extractors.base import BaseExtractor
5+
from sqllineage.core.parser.sqlfluff.extractors.create_insert import (
6+
CreateInsertExtractor,
127
)
8+
from sqllineage.core.parser.sqlfluff.extractors.select import SelectExtractor
139
from sqllineage.core.parser.sqlfluff.models import SqlFluffSubQuery
1410
from sqllineage.core.parser.sqlfluff.utils import get_children, list_child_segments
1511
from sqllineage.utils.entities import AnalyzerContext
1612

1713

18-
class DmlCteExtractor(LineageHolderExtractor):
14+
class CteExtractor(BaseExtractor):
1915
"""
20-
DML CTE queries lineage extractor
16+
CTE queries lineage extractor
2117
"""
2218

2319
SUPPORTED_STMT_TYPES = ["with_compound_statement"]
2420

25-
def __init__(self, dialect: str):
26-
super().__init__(dialect)
27-
2821
def extract(
2922
self,
3023
statement: BaseSegment,
3124
context: AnalyzerContext,
3225
) -> SubQueryLineageHolder:
33-
"""
34-
Extract lineage for a given statement.
35-
:param statement: a sqlfluff segment with a statement
36-
:param context: 'AnalyzerContext'
37-
:return 'SubQueryLineageHolder' object
38-
"""
3926
holder = self._init_holder(context)
4027
subqueries = []
4128
for segment in list_child_segments(statement):
4229
if segment.type in ["select_statement", "set_expression"]:
43-
holder |= DmlSelectExtractor(self.dialect).extract(
30+
holder |= self.delegate_to(
31+
SelectExtractor,
4432
segment,
4533
AnalyzerContext(cte=holder.cte, write=holder.write),
4634
)
47-
48-
if segment.type == "insert_statement":
49-
holder |= DmlInsertExtractor(self.dialect).extract(
50-
segment,
51-
AnalyzerContext(cte=holder.cte),
35+
elif segment.type == "insert_statement":
36+
holder |= self.delegate_to(
37+
CreateInsertExtractor, segment, AnalyzerContext(cte=holder.cte)
5238
)
53-
54-
identifier = None
55-
if segment.type == "common_table_expression":
39+
elif segment.type == "common_table_expression":
40+
identifier = None
5641
segment_has_alias = any(
5742
s for s in get_children(segment, "keyword") if s.raw_upper == "AS"
5843
)
@@ -63,7 +48,7 @@ def extract(
6348
if not segment_has_alias:
6449
holder.add_cte(SqlFluffSubQuery.of(sub_segment, identifier))
6550
if sub_segment.type == "bracketed":
66-
for sq in self.parse_subquery(sub_segment):
51+
for sq in self.list_subquery(sub_segment):
6752
if identifier:
6853
sq.alias = identifier
6954
subqueries.append(sq)
@@ -72,7 +57,7 @@ def extract(
7257

7358
# By recursively extracting each extractor of the parent and merge, we're doing Depth-first search
7459
for sq in subqueries:
75-
holder |= DmlSelectExtractor(self.dialect).extract(
60+
holder |= SelectExtractor(self.dialect).extract(
7661
sq.query,
7762
AnalyzerContext(cte=holder.cte, write={sq}),
7863
)

0 commit comments

Comments
 (0)