Skip to content

Commit

Permalink
fix: ParsedQuery subselect edge case (#13602)
Browse files Browse the repository at this point in the history
  • Loading branch information
Erik Ritter authored Mar 12, 2021
1 parent 4fc41e1 commit 06d6d7f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
14 changes: 11 additions & 3 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@
from urllib import parse

import sqlparse
from sqlparse.sql import Identifier, IdentifierList, remove_quotes, Token, TokenList
from sqlparse.sql import (
Identifier,
IdentifierList,
Parenthesis,
remove_quotes,
Token,
TokenList,
)
from sqlparse.tokens import Keyword, Name, Punctuation, String, Whitespace
from sqlparse.utils import imt

Expand Down Expand Up @@ -278,7 +285,9 @@ def _extract_from_token( # pylint: disable=too-many-branches
table_name_preceding_token = False

for item in token.tokens:
if item.is_group and not self._is_identifier(item):
if item.is_group and (
not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
):
self._extract_from_token(item)

if item.ttype in Keyword and (
Expand All @@ -291,7 +300,6 @@ def _extract_from_token( # pylint: disable=too-many-branches
if item.ttype in Keyword:
table_name_preceding_token = False
continue

if table_name_preceding_token:
if isinstance(item, Identifier):
self._process_tokenlist(item)
Expand Down
7 changes: 7 additions & 0 deletions tests/sql_parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ def test_select_in_expression(self):
query = "SELECT f1, (SELECT count(1) FROM t2) FROM t1"
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))

query = "SELECT f1, (SELECT count(1) FROM t2) as f2 FROM t1"
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))

def test_parentheses(self):
query = "SELECT f1, (x + y) AS f2 FROM t1"
self.assertEqual({Table("t1")}, self.extract_tables(query))

def test_union(self):
query = "SELECT * FROM t1 UNION SELECT * FROM t2"
self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
Expand Down

0 comments on commit 06d6d7f

Please sign in to comment.