Skip to content

Commit fd78aa3

Browse files
committed
feat: support optional AS keyword in CTE
1 parent dcc402e commit fd78aa3

File tree

5 files changed

+30
-11
lines changed

5 files changed

+30
-11
lines changed

sqllineage/core/analyzer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from sqlparse.sql import (
66
Comment,
7+
Function,
78
Identifier,
89
IdentifierList,
910
Statement,
@@ -128,8 +129,9 @@ def __token_negligible_before_tablename(cls, token: TokenList) -> bool:
128129
@classmethod
129130
def parse_subquery(cls, token: TokenList) -> List[SubQuery]:
130131
result = []
131-
if isinstance(token, Identifier):
132+
if isinstance(token, (Identifier, Function)):
132133
# usually SubQuery is an Identifier, but not all Identifiers are SubQuery
134+
# Function for CTE without AS keyword
133135
result = cls._parse_subquery_from_identifier(token)
134136
elif isinstance(token, IdentifierList):
135137
# IdentifierList for SQL89 style of JOIN or multiple CTEs, this is actually SubQueries

sqllineage/core/handlers/cte.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sqlparse.sql import Identifier, IdentifierList, Token
1+
from sqlparse.sql import Function, Identifier, IdentifierList, Token
22

33
from sqllineage.core.handlers.base import NextTokenBaseHandler
44
from sqllineage.core.holders import SubQueryLineageHolder
@@ -15,10 +15,14 @@ def _indicate(self, token: Token) -> bool:
1515
return token.normalized in self.CTE_TOKENS
1616

1717
def _handle(self, token: Token, holder: SubQueryLineageHolder) -> None:
18-
if isinstance(token, Identifier):
18+
# when CTE used without AS, it will be parsed as Function. This syntax is valid in SparkSQL
19+
column_token_types = (Identifier, Function)
20+
if isinstance(token, column_token_types):
1921
cte = [token]
2022
elif isinstance(token, IdentifierList):
21-
cte = [token for token in token.tokens if isinstance(token, Identifier)]
23+
cte = [
24+
token for token in token.tokens if isinstance(token, column_token_types)
25+
]
2226
else:
2327
raise SQLLineageException(
2428
"An Identifier or IdentifierList is expected, got %s[value: %s] instead."

sqllineage/core/handlers/source.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import re
22
from typing import Union
33

4-
from sqlparse.sql import Function, Identifier, IdentifierList, Parenthesis, Token
4+
from sqlparse.sql import Identifier, IdentifierList, Parenthesis, Token
55

66
from sqllineage.core.handlers.base import NextTokenBaseHandler
77
from sqllineage.core.holders import SubQueryLineageHolder
@@ -20,10 +20,9 @@ class SourceHandler(NextTokenBaseHandler):
2020
)
2121

2222
def _indicate(self, token: Token) -> bool:
23-
# SELECT trim(BOTH ' ' FROM ' abc '); Here FROM is not a source table flag
2423
return any(
2524
re.match(regex, token.normalized) for regex in self.SOURCE_TABLE_TOKENS
26-
) and not isinstance(token.parent.parent, Function)
25+
)
2726

2827
def _handle(self, token: Token, holder: SubQueryLineageHolder) -> None:
2928
if isinstance(token, Identifier):

sqllineage/utils/sqlparse.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,18 @@ def get_subquery_parentheses(token: Identifier) -> List[SubQueryTuple]:
2424
the returned list is either empty when no subquery parsed or list of [parenthesis, alias] tuple
2525
"""
2626
subquery = []
27-
kw_idx, kw = token.token_next_by(m=(Keyword, "AS"))
27+
as_idx, as_ = token.token_next_by(m=(Keyword, "AS"))
2828
sublist = list(token.get_sublists())
29-
if kw is not None and len(sublist) == 1:
29+
if as_ is not None and len(sublist) == 1:
3030
# CTE: tbl AS (SELECT 1)
3131
target = sublist[0]
3232
else:
33-
# normal subquery: (SELECT 1) tbl
34-
target = token.token_first(skip_cm=True)
33+
if isinstance(token, Function):
34+
# CTE without AS: tbl (SELECT 1)
35+
target = token.tokens[-1]
36+
else:
37+
# normal subquery: (SELECT 1) tbl
38+
target = token.token_first(skip_cm=True)
3539
if isinstance(target, Case):
3640
# CASE WHEN (SELECT count(*) from tab1) > 0 THEN (SELECT count(*) FROM tab1) ELSE -1
3741
for tk in target.get_sublists():

tests/test_cte.py

+10
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@ def test_with_select_one():
1212
)
1313

1414

15+
def test_with_select_one_without_as():
16+
# AS in CTE is negligible in SparkSQL, however it is required in MySQL. See below reference
17+
# https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-cte.html
18+
# https://dev.mysql.com/doc/refman/8.0/en/with.html
19+
assert_table_lineage_equal(
20+
"WITH wtab1 (SELECT * FROM schema1.tab1) SELECT * FROM wtab1",
21+
{"schema1.tab1"},
22+
)
23+
24+
1525
def test_with_select_many():
1626
assert_table_lineage_equal(
1727
"""WITH

0 commit comments

Comments
 (0)