diff --git a/superset/sql/parse.py b/superset/sql/parse.py index d580a7a80e0a..99b8ab60d828 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -164,12 +164,17 @@ class BaseSQLStatement(Generic[InternalRepresentation]): def __init__( self, - statement: str, - engine: str, + statement: str | None = None, + engine: str = "base", ast: InternalRepresentation | None = None, ): - self._sql = statement - self._parsed = ast or self._parse_statement(statement, engine) + if ast: + self._parsed = ast + elif statement: + self._parsed = self._parse_statement(statement, engine) + else: + raise SupersetParseError("Either statement or ast must be provided") + self.engine = engine self.tables = self._extract_tables_from_statement(self._parsed, self.engine) @@ -284,8 +289,8 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): def __init__( self, - statement: str, - engine: str, + statement: str | None = None, + engine: str = "base", ast: exp.Expression | None = None, ): self._dialect = SQLGLOT_DIALECTS.get(engine) @@ -423,7 +428,10 @@ def is_mutating(self) -> bool: and self._parsed.expression.name.upper().startswith("ANALYZE ") ): analyzed_sql = self._parsed.expression.name[len("ANALYZE ") :] - return SQLStatement(analyzed_sql, self.engine).is_mutating() + return SQLStatement( + statement=analyzed_sql, + engine=self.engine, + ).is_mutating() return False @@ -459,12 +467,11 @@ def optimize(self) -> SQLStatement: """ # only optimize statements that have a custom dialect if not self._dialect: - return SQLStatement(self._sql, self.engine, self._parsed.copy()) + return SQLStatement(ast=self._parsed.copy(), engine=self.engine) optimized = pushdown_predicates(self._parsed, dialect=self._dialect) - sql = optimized.sql(dialect=self._dialect) - return SQLStatement(sql, self.engine, optimized) + return SQLStatement(ast=optimized, engine=self.engine) def check_functions_present(self, functions: set[str]) -> bool: """ @@ -668,6 +675,14 @@ class KustoKQLStatement(BaseSQLStatement[str]): details about it. """ + def __init__( + self, + statement: str | None = None, + engine: str = "kustokql", + ast: str | None = None, + ): + super().__init__(statement, engine, ast) + @classmethod def split_script( cls, @@ -725,7 +740,7 @@ def format(self, comments: bool = True) -> str: """ Pretty-format the SQL statement. """ - return self._sql.strip() + return self._parsed.strip() def get_settings(self) -> dict[str, str | bool]: """ @@ -756,7 +771,7 @@ def optimize(self) -> KustoKQLStatement: Kusto KQL doesn't support optimization, so this method is a no-op. """ - return KustoKQLStatement(self._sql, self.engine, self._parsed) + return KustoKQLStatement(ast=self._parsed, engine=self.engine) def check_functions_present(self, functions: set[str]) -> bool: """ @@ -774,7 +789,7 @@ def get_limit_value(self) -> int | None: """ tokens = [ token - for token in tokenize_kql(self._sql) + for token in tokenize_kql(self._parsed) if token[0] != KQLTokenType.WHITESPACE ] for idx, (ttype, val) in enumerate(tokens): @@ -796,7 +811,7 @@ def set_limit_value( if method != LimitMethod.FORCE_LIMIT: raise SupersetParseError("Kusto KQL only supports the FORCE_LIMIT method.") - tokens = tokenize_kql(self._sql) + tokens = tokenize_kql(self._parsed) found_limit_token = False for idx, (ttype, val) in enumerate(tokens): if ttype != KQLTokenType.STRING and val.lower() in {"take", "limit"}: @@ -817,7 +832,7 @@ def set_limit_value( ] ) - self._parsed = self._sql = "".join(val for _, val in tokens) + self._parsed = "".join(val for _, val in tokens) class SQLScript: diff --git a/superset/sql_lab.py b/superset/sql_lab.py index f4ca12de9533..34e648f8fd23 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -242,7 +242,10 @@ def execute_sql_statement( # pylint: disable=too-many-statements, too-many-loca if not database.allow_dml: errors = [] try: - parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine) + parsed_statement = SQLStatement( + statement=sql_statement, + engine=db_engine_spec.engine, + ) disallowed = parsed_statement.is_mutating() except SupersetParseError as ex: # if we fail to parse the query, disallow by default diff --git a/superset/sql_parse.py b/superset/sql_parse.py index e141a4f2900d..d40b3db4e19a 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -533,7 +533,7 @@ def has_table_query(expression: str, engine: str) -> bool: expression = f"({expression})" sql = f"SELECT {expression}" - statement = SQLStatement(sql, engine) + statement = SQLStatement(statement=sql, engine=engine) return any(statement.tables) diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index daf5ebe71dfa..d4341524c13d 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -318,9 +318,9 @@ def test_split_no_dialect() -> None: sql = "SELECT col FROM t WHERE col NOT IN (1, 2); SELECT * FROM t; SELECT foo" statements = SQLScript(sql, "dremio").statements assert len(statements) == 3 - assert statements[0]._sql == "SELECT col FROM t WHERE col NOT IN (1, 2)" - assert statements[1]._sql == "SELECT * FROM t" - assert statements[2]._sql == "SELECT foo" + assert statements[0].format() == "SELECT\n col\nFROM t\nWHERE\n NOT col IN (1, 2)" + assert statements[1].format() == "SELECT\n *\nFROM t" + assert statements[2].format() == "SELECT\n foo" def test_extract_tables_show_columns_from() -> None: