Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 30 additions & 15 deletions superset/sql/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,17 @@ class BaseSQLStatement(Generic[InternalRepresentation]):

def __init__(
self,
statement: str,
engine: str,
statement: str | None = None,
engine: str = "base",
ast: InternalRepresentation | None = None,
):
self._sql = statement
self._parsed = ast or self._parse_statement(statement, engine)
if ast:
self._parsed = ast
elif statement:
self._parsed = self._parse_statement(statement, engine)
else:
raise SupersetParseError("Either statement or ast must be provided")

self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)

Expand Down Expand Up @@ -284,8 +289,8 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):

def __init__(
self,
statement: str,
engine: str,
statement: str | None = None,
engine: str = "base",
ast: exp.Expression | None = None,
):
self._dialect = SQLGLOT_DIALECTS.get(engine)
Expand Down Expand Up @@ -423,7 +428,10 @@ def is_mutating(self) -> bool:
and self._parsed.expression.name.upper().startswith("ANALYZE ")
):
analyzed_sql = self._parsed.expression.name[len("ANALYZE ") :]
return SQLStatement(analyzed_sql, self.engine).is_mutating()
return SQLStatement(
statement=analyzed_sql,
engine=self.engine,
).is_mutating()

return False

Expand Down Expand Up @@ -459,12 +467,11 @@ def optimize(self) -> SQLStatement:
"""
# only optimize statements that have a custom dialect
if not self._dialect:
return SQLStatement(self._sql, self.engine, self._parsed.copy())
return SQLStatement(ast=self._parsed.copy(), engine=self.engine)

optimized = pushdown_predicates(self._parsed, dialect=self._dialect)
sql = optimized.sql(dialect=self._dialect)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we're removing this? Is it redundant?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is redundant. Before, to create a new statement we had to pass the SQL (a string) and the AST, so here we were generating the SQL. Now we just pass the AST directly.


return SQLStatement(sql, self.engine, optimized)
return SQLStatement(ast=optimized, engine=self.engine)

def check_functions_present(self, functions: set[str]) -> bool:
"""
Expand Down Expand Up @@ -668,6 +675,14 @@ class KustoKQLStatement(BaseSQLStatement[str]):
details about it.
"""

def __init__(
self,
statement: str | None = None,
engine: str = "kustokql",
ast: str | None = None,
):
super().__init__(statement, engine, ast)

@classmethod
def split_script(
cls,
Expand Down Expand Up @@ -725,7 +740,7 @@ def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
"""
return self._sql.strip()
return self._parsed.strip()

def get_settings(self) -> dict[str, str | bool]:
"""
Expand Down Expand Up @@ -756,7 +771,7 @@ def optimize(self) -> KustoKQLStatement:

Kusto KQL doesn't support optimization, so this method is a no-op.
"""
return KustoKQLStatement(self._sql, self.engine, self._parsed)
return KustoKQLStatement(ast=self._parsed, engine=self.engine)

def check_functions_present(self, functions: set[str]) -> bool:
"""
Expand All @@ -774,7 +789,7 @@ def get_limit_value(self) -> int | None:
"""
tokens = [
token
for token in tokenize_kql(self._sql)
for token in tokenize_kql(self._parsed)
if token[0] != KQLTokenType.WHITESPACE
]
for idx, (ttype, val) in enumerate(tokens):
Expand All @@ -796,7 +811,7 @@ def set_limit_value(
if method != LimitMethod.FORCE_LIMIT:
raise SupersetParseError("Kusto KQL only supports the FORCE_LIMIT method.")

tokens = tokenize_kql(self._sql)
tokens = tokenize_kql(self._parsed)
found_limit_token = False
for idx, (ttype, val) in enumerate(tokens):
if ttype != KQLTokenType.STRING and val.lower() in {"take", "limit"}:
Expand All @@ -817,7 +832,7 @@ def set_limit_value(
]
)

self._parsed = self._sql = "".join(val for _, val in tokens)
self._parsed = "".join(val for _, val in tokens)


class SQLScript:
Expand Down
5 changes: 4 additions & 1 deletion superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,10 @@ def execute_sql_statement( # pylint: disable=too-many-statements, too-many-loca
if not database.allow_dml:
errors = []
try:
parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine)
parsed_statement = SQLStatement(
statement=sql_statement,
engine=db_engine_spec.engine,
)
disallowed = parsed_statement.is_mutating()
except SupersetParseError as ex:
# if we fail to parse the query, disallow by default
Expand Down
2 changes: 1 addition & 1 deletion superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def has_table_query(expression: str, engine: str) -> bool:
expression = f"({expression})"

sql = f"SELECT {expression}"
statement = SQLStatement(sql, engine)
statement = SQLStatement(statement=sql, engine=engine)
return any(statement.tables)


Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/sql/parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ def test_split_no_dialect() -> None:
sql = "SELECT col FROM t WHERE col NOT IN (1, 2); SELECT * FROM t; SELECT foo"
statements = SQLScript(sql, "dremio").statements
assert len(statements) == 3
assert statements[0]._sql == "SELECT col FROM t WHERE col NOT IN (1, 2)"
assert statements[1]._sql == "SELECT * FROM t"
assert statements[2]._sql == "SELECT foo"
assert statements[0].format() == "SELECT\n col\nFROM t\nWHERE\n NOT col IN (1, 2)"
assert statements[1].format() == "SELECT\n *\nFROM t"
assert statements[2].format() == "SELECT\n foo"


def test_extract_tables_show_columns_from() -> None:
Expand Down
Loading