Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support optional AS keyword in CTE #199

Merged
merged 1 commit into from
Dec 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sqllineage/core/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from sqlparse.sql import (
Comment,
Function,
Identifier,
IdentifierList,
Statement,
Expand Down Expand Up @@ -128,8 +129,9 @@ def __token_negligible_before_tablename(cls, token: TokenList) -> bool:
@classmethod
def parse_subquery(cls, token: TokenList) -> List[SubQuery]:
result = []
if isinstance(token, Identifier):
if isinstance(token, (Identifier, Function)):
# usually SubQuery is an Identifier, but not all Identifiers are SubQuery
# Function for CTE without AS keyword
result = cls._parse_subquery_from_identifier(token)
elif isinstance(token, IdentifierList):
# IdentifierList for SQL89 style of JOIN or multiple CTEs, this is actually SubQueries
Expand Down
10 changes: 7 additions & 3 deletions sqllineage/core/handlers/cte.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlparse.sql import Identifier, IdentifierList, Token
from sqlparse.sql import Function, Identifier, IdentifierList, Token

from sqllineage.core.handlers.base import NextTokenBaseHandler
from sqllineage.core.holders import SubQueryLineageHolder
Expand All @@ -15,10 +15,14 @@ def _indicate(self, token: Token) -> bool:
return token.normalized in self.CTE_TOKENS

def _handle(self, token: Token, holder: SubQueryLineageHolder) -> None:
if isinstance(token, Identifier):
# when CTE used without AS, it will be parsed as Function. This syntax is valid in SparkSQL
column_token_types = (Identifier, Function)
if isinstance(token, column_token_types):
cte = [token]
elif isinstance(token, IdentifierList):
cte = [token for token in token.tokens if isinstance(token, Identifier)]
cte = [
token for token in token.tokens if isinstance(token, column_token_types)
]
else:
raise SQLLineageException(
"An Identifier or IdentifierList is expected, got %s[value: %s] instead."
Expand Down
5 changes: 2 additions & 3 deletions sqllineage/core/handlers/source.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from typing import Union

from sqlparse.sql import Function, Identifier, IdentifierList, Parenthesis, Token
from sqlparse.sql import Identifier, IdentifierList, Parenthesis, Token

from sqllineage.core.handlers.base import NextTokenBaseHandler
from sqllineage.core.holders import SubQueryLineageHolder
Expand All @@ -20,10 +20,9 @@ class SourceHandler(NextTokenBaseHandler):
)

def _indicate(self, token: Token) -> bool:
# SELECT trim(BOTH ' ' FROM ' abc '); Here FROM is not a source table flag
return any(
re.match(regex, token.normalized) for regex in self.SOURCE_TABLE_TOKENS
) and not isinstance(token.parent.parent, Function)
)

def _handle(self, token: Token, holder: SubQueryLineageHolder) -> None:
if isinstance(token, Identifier):
Expand Down
12 changes: 8 additions & 4 deletions sqllineage/utils/sqlparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@ def get_subquery_parentheses(token: Identifier) -> List[SubQueryTuple]:
the returned list is either empty when no subquery parsed or list of [parenthesis, alias] tuple
"""
subquery = []
kw_idx, kw = token.token_next_by(m=(Keyword, "AS"))
as_idx, as_ = token.token_next_by(m=(Keyword, "AS"))
sublist = list(token.get_sublists())
if kw is not None and len(sublist) == 1:
if as_ is not None and len(sublist) == 1:
# CTE: tbl AS (SELECT 1)
target = sublist[0]
else:
# normal subquery: (SELECT 1) tbl
target = token.token_first(skip_cm=True)
if isinstance(token, Function):
# CTE without AS: tbl (SELECT 1)
target = token.tokens[-1]
else:
# normal subquery: (SELECT 1) tbl
target = token.token_first(skip_cm=True)
if isinstance(target, Case):
# CASE WHEN (SELECT count(*) from tab1) > 0 THEN (SELECT count(*) FROM tab1) ELSE -1
for tk in target.get_sublists():
Expand Down
10 changes: 10 additions & 0 deletions tests/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ def test_with_select_one():
)


def test_with_select_one_without_as():
# AS in CTE is negligible in SparkSQL, however it is required in MySQL. See below reference
# https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-cte.html
# https://dev.mysql.com/doc/refman/8.0/en/with.html
assert_table_lineage_equal(
"WITH wtab1 (SELECT * FROM schema1.tab1) SELECT * FROM wtab1",
{"schema1.tab1"},
)


def test_with_select_many():
assert_table_lineage_equal(
"""WITH
Expand Down