diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 000000000000..f57c1f6ed6eb --- /dev/null +++ b/.coveragerc @@ -0,0 +1,36 @@ +# .coveragerc to control coverage.py +[run] +branch = True +source = superset +# omit = bad_file.py + +[paths] +source = + superset/ + */site-packages/ + +[report] +# Regexes for lines to exclude from consideration +exclude_lines = + # Have to re-enable the standard pragma + pragma: no cover + + # Don't complain about missing debug-only code: + def __repr__ + if self\.debug + + # Don't complain if tests don't hit defensive assertion code: + raise AssertionError + raise NotImplementedError + + # Don't complain if non-runnable code isn't run: + if 0: + if __name__ == .__main__.: + + # Ignore importlib backport + from importlib + + if TYPE_CHECKING: + +#fail_under = 100 +show_missing = True diff --git a/.github/workflows/superset-python-unittest.yml b/.github/workflows/superset-python-unittest.yml index 208cf82421a7..615993164d82 100644 --- a/.github/workflows/superset-python-unittest.yml +++ b/.github/workflows/superset-python-unittest.yml @@ -45,6 +45,13 @@ jobs: SUPERSET_SECRET_KEY: not-a-secret run: | pytest --durations-min=0.5 --cov-report= --cov=superset ./tests/common ./tests/unit_tests --cache-clear --maxfail=50 + - name: Python 100% coverage unit tests + if: steps.check.outputs.python + env: + SUPERSET_TESTENV: true + SUPERSET_SECRET_KEY: not-a-secret + run: | + pytest --durations-min=0.5 --cov-report= --cov=superset/sql/ ./tests/unit_tests/sql/ --cache-clear --cov-fail-under=100 - name: Upload code coverage uses: codecov/codecov-action@v5 with: diff --git a/superset/exceptions.py b/superset/exceptions.py index d80e025f37d8..6dddfb8302cf 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -334,6 +334,9 @@ def __init__( # pylint: disable=too-many-arguments ) super().__init__(error) + def __str__(self) -> str: + return self.error.message + class OAuth2RedirectError(SupersetErrorException): """ diff --git a/superset/sql/dialects/firebolt.py b/superset/sql/dialects/firebolt.py index cfdfc2dfe6c6..d0a562dfa5e1 100644 --- a/superset/sql/dialects/firebolt.py +++ b/superset/sql/dialects/firebolt.py @@ -48,7 +48,7 @@ def _negate_range( self, this: exp.Expression | None = None, ) -> exp.Expression | None: - if not this: + if not this: # pragma: no cover return this return self.expression(exp.Not, this=self.expression(exp.Paren, this=this)) @@ -109,42 +109,15 @@ def _parse_unnest(self, with_alias: bool = True) -> exp.Unnest | None: expressions = self._parse_wrapped_csv(self._parse_expression) offset = self._match_pair(TokenType.WITH, TokenType.ORDINALITY) - alias = self._parse_table_alias() if with_alias else None - - if alias: - if self.dialect.UNNEST_COLUMN_ONLY: - if alias.args.get("columns"): - self.raise_error("Unexpected extra column alias in unnest.") - - alias.set("columns", [alias.this]) - alias.set("this", None) - - columns = alias.args.get("columns") or [] - if offset and len(expressions) < len(columns): - offset = columns.pop() - - if not offset and self._match_pair(TokenType.WITH, TokenType.OFFSET): - self._match(TokenType.ALIAS) - offset = self._parse_id_var( - any_token=False, tokens=self.UNNEST_OFFSET_ALIAS_TOKENS - ) or exp.to_identifier("offset") - return self.expression( exp.Unnest, expressions=expressions, - alias=alias, offset=offset, ) class Generator(Firebolt.Generator): def join_sql(self, expression: exp.Join) -> str: - if not self.SEMI_ANTI_JOIN_WITH_SIDE and expression.kind in ( - "SEMI", - "ANTI", - ): - side = None - else: - side = expression.side + side = expression.side op_sql = " ".join( op @@ -168,9 +141,6 @@ def join_sql(self, expression: exp.Join) -> str: this = expression.this this_sql = self.sql(this) - if exprs := self.expressions(expression): - this_sql = f"{this_sql},{self.seg(exprs)}" - if on_sql: on_sql = self.indent(on_sql, skip_first=True) space = self.seg(" " * self.pad) if self.pretty else " " @@ -189,7 +159,6 @@ def join_sql(self, expression: exp.Join) -> str: return f", {this_sql}" - if op_sql != "STRAIGHT_JOIN": - op_sql = f"{op_sql} JOIN" if op_sql else "JOIN" + op_sql = f"{op_sql} JOIN" if op_sql else "JOIN" return f"{self.seg(op_sql)} {this_sql}{match_cond}{on_sql}" diff --git a/superset/sql/parse.py b/superset/sql/parse.py index e49abc5a0ce6..7487f3250db7 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -551,7 +551,7 @@ def _parse(cls, script: str, engine: str) -> list[exp.Expression]: last_statement = statements.pop() target = statements[-1] for node in statements[-1].walk(): - if hasattr(node, "comments"): + if hasattr(node, "comments"): # pragma: no cover target = node target.comments = target.comments or [] @@ -565,47 +565,9 @@ def split_script( script: str, engine: str, ) -> list[SQLStatement]: - if dialect := SQLGLOT_DIALECTS.get(engine): - try: - return [ - cls(ast.sql(), engine, ast) - for ast in cls._parse(script, engine) - if ast - ] - except ValueError: - # `ast.sql()` might raise an error on some cases (eg, `SHOW TABLES - # FROM`). In this case, we rely on the tokenizer to generate the - # statements. - pass - - # When we don't have a sqlglot dialect we can't rely on `ast.sql()` to correctly - # generate the SQL of each statement, so we tokenize the script and split it - # based on the location of semi-colons. - statements = [] - start = 0 - remainder = script - - try: - tokens = sqlglot.tokenize(script, dialect) - except sqlglot.errors.TokenError as ex: - raise SupersetParseError( - script, - engine, - message="Unable to tokenize script", - ) from ex - - for token in tokens: - if token.token_type == sqlglot.TokenType.SEMICOLON: - statement, start = script[start : token.start], token.end + 1 - ast = cls._parse(statement, engine)[0] - statements.append(cls(statement.strip(), engine, ast)) - remainder = script[start:] - - if remainder.strip(): - ast = cls._parse(remainder, engine)[0] - statements.append(cls(remainder.strip(), engine, ast)) - - return statements + return [ + cls(ast=ast, engine=engine) for ast in cls._parse(script, engine) if ast + ] @classmethod def _parse_statement( @@ -618,7 +580,11 @@ def _parse_statement( """ statements = cls.split_script(statement, engine) if len(statements) != 1: - raise SupersetParseError("SQLStatement should have exactly one statement") + raise SupersetParseError( + statement, + engine, + message="SQLStatement should have exactly one statement", + ) return statements[0]._parsed # pylint: disable=protected-access @@ -657,10 +623,13 @@ def is_mutating(self) -> bool: exp.Create, exp.Drop, exp.TruncateTable, + exp.Alter, ), ): return True + # depending on the dialect (Oracle, MS SQL) the `ALTER` is parsed as a + # command, not an expression if isinstance(node, exp.Command) and node.name == "ALTER": return True @@ -821,9 +790,16 @@ def has_subquery(self) -> bool: """ Check if the statement has a subquery. - :return: True if the statement has a subquery at the top level. + :return: True if the statement has a subquery. """ - return bool(self._parsed.find(exp.Subquery)) + return bool(self._parsed.find(exp.Subquery)) or ( + isinstance(self._parsed, exp.Select) + and any( + isinstance(expression, exp.Select) + for expression in self._parsed.walk() + if expression != self._parsed + ) + ) def parse_predicate(self, predicate: str) -> exp.Expression: """ @@ -933,11 +909,8 @@ def tokenize_kql(kql: str) -> list[tuple[KQLTokenType, str]]: ) buffer = ch elif ch == "`" and script[i - 2 : i] == "``": - if buffer: - tokens.extend(classify_non_string_kql(buffer)) - buffer = "" state = KQLSplitState.INSIDE_MULTILINE_STRING - buffer = "`" + buffer = "```" else: buffer += ch else: @@ -1042,11 +1015,19 @@ def _parse_statement( engine: str, ) -> str: if engine != "kustokql": - raise SupersetParseError(f"Invalid engine: {engine}") + raise SupersetParseError( + statement, + engine, + message=f"Invalid engine: {engine}", + ) statements = split_kql(statement) if len(statements) != 1: - raise SupersetParseError("SQLStatement should have exactly one statement") + raise SupersetParseError( + statement, + engine, + message="KustoKQLStatement should have exactly one statement", + ) return statements[0].strip() @@ -1122,7 +1103,7 @@ def check_functions_present(self, functions: set[str]) -> bool: :return: True if any of the functions are present """ logger.warning("Kusto KQL doesn't support checking for functions present.") - return True + return False def get_limit_value(self) -> int | None: """ @@ -1150,7 +1131,11 @@ def set_limit_value( Add a limit to the statement. """ if method != LimitMethod.FORCE_LIMIT: - raise SupersetParseError("Kusto KQL only supports the FORCE_LIMIT method.") + raise SupersetParseError( + self._parsed, + self.engine, + message="Kusto KQL only supports the FORCE_LIMIT method.", + ) tokens = tokenize_kql(self._parsed) found_limit_token = False diff --git a/tests/unit_tests/sql/dialects/__init__.py b/tests/unit_tests/sql/dialects/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/sql/dialects/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/sql/dialects/firebolt_tests.py b/tests/unit_tests/sql/dialects/firebolt_tests.py new file mode 100644 index 000000000000..49b31146df17 --- /dev/null +++ b/tests/unit_tests/sql/dialects/firebolt_tests.py @@ -0,0 +1,433 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import sqlglot + +from superset.sql.dialects.firebolt import Firebolt, FireboltOld + + +def test_not_sql() -> None: # pylint: disable=invalid-name + """ + Test the `not_sql` method in the generator. + """ + # use generic parser, since the Firebolt dialect will parenthesize + ast = sqlglot.parse_one("SELECT * FROM t WHERE NOT col IN (1, 2)") + + # make sure generated SQL is parenthesized + assert ( + Firebolt().generate(expression=ast, pretty=True) + == """ +SELECT + * +FROM t +WHERE + NOT (col IN (1, 2)) + """.strip() + ) + + +@pytest.mark.parametrize( + "sql, expected", + [ + ( + "SELECT price, quantity, price * quantity AS sales_amount FROM Sales", + """ +SELECT + price, + quantity, + price * quantity AS sales_amount +FROM Sales + """.strip(), + ), + ( + "SELECT ALL * FROM Sales", + """ +SELECT + * +FROM Sales + """.strip(), + ), + ( + "SELECT DISTINCT product FROM Sales", + """ +SELECT DISTINCT + product +FROM Sales + """.strip(), + ), + ( + "SELECT * FROM Sales, Products", + """ +SELECT + * +FROM Sales, Products + """.strip(), + ), + ], +) +def test_select_from(sql: str, expected: str) -> None: + """ + Test the `SELECT` statement in the old dialect. + """ + ast = sqlglot.parse_one(sql, FireboltOld) + assert FireboltOld().generate(expression=ast, pretty=True) == expected + + +def test_unnest() -> None: + """ + Test the `UNNEST` in the old dialect. + """ + ast = sqlglot.parse_one( + """ +SELECT + id, + tags +FROM visits + UNNEST(tags); + """, + FireboltOld, + ) + + assert ( + FireboltOld().generate(expression=ast, pretty=True) + == """ +SELECT + id, + tags +FROM visits UNNEST(tags) + """.strip() + ) + + +def test_unnest_with_array() -> None: + """ + Test the `UNNEST` in the old dialect with array columns. + """ + ast = sqlglot.parse_one( + """ +SELECT + id, + a_keys, + a_vals +FROM + visits + UNNEST(agent_props_keys as a_keys, + agent_props_vals as a_vals) + + """, + FireboltOld, + ) + + assert ( + FireboltOld().generate(expression=ast, pretty=True) + == """ +SELECT + id, + a_keys, + a_vals +FROM visits UNNEST(agent_props_keys AS a_keys, agent_props_vals AS a_vals) + """.strip() + ) + + +def test_unnest_multiple() -> None: + """ + Test multiple `UNNEST` in the old dialect. + """ + ast = sqlglot.parse_one( + """ +SELECT + id, + a_keys, + a_vals +FROM + visits +UNNEST(agent_props_keys as a_keys) +UNNEST(agent_props_vals as a_vals) + """, + FireboltOld, + ) + + assert ( + FireboltOld().generate(expression=ast, pretty=True) + == """ +SELECT + id, + a_keys, + a_vals +FROM visits UNNEST(agent_props_keys AS a_keys) UNNEST(agent_props_vals AS a_vals) + """.strip() + ) + + +def test_unnest_translating() -> None: + """ + Test translating the `UNNEST` from the old to the new dialect. + """ + ast = sqlglot.parse_one( + """ +SELECT + id, + tags +FROM visits + UNNEST(tags); + """, + FireboltOld, + ) + + assert ( + Firebolt().generate(expression=ast, pretty=True) + == """ +SELECT + id, + tags +FROM visits, UNNEST(tags) + """.strip() + ) + + +def test_join_on() -> None: + """ + Test the `JOIN ... ON` syntax in the Firebolt dialect. + """ + ast = sqlglot.parse_one( + """ +SELECT + * +FROM + t1 +JOIN + t2 +ON t1.foo = t2.id; + """, + FireboltOld, + ) + + assert ( + FireboltOld().generate(expression=ast, pretty=True) + == """ +SELECT + * +FROM t1 +JOIN t2 + ON t1.foo = t2.id + """.strip() + ) + + +def test_join_using() -> None: + """ + Test the `JOIN ... USING` syntax in the Firebolt dialect. + """ + ast = sqlglot.parse_one( + """ +SELECT + * +FROM + t1 +JOIN + t2 USING (id, age); + """, + FireboltOld, + ) + + assert ( + FireboltOld().generate(expression=ast, pretty=True) + == """ +SELECT + * +FROM t1 +JOIN t2 + USING (id, age) + """.strip() + ) + + +def test_cte() -> None: + """ + Test the `WITH` clause in the Firebolt dialect. + """ + ast = sqlglot.parse_one( + """ +WITH nl_subscribers AS ( + SELECT + * + FROM + players + WHERE + issubscribedtonewsletter=TRUE +) +SELECT + nickname, + email +FROM + players +ORDER BY + nickname + """, + FireboltOld, + ) + + assert ( + FireboltOld().generate(expression=ast, pretty=True) + == """ +WITH nl_subscribers AS ( + SELECT + * + FROM players + WHERE + issubscribedtonewsletter = TRUE +) +SELECT + nickname, + email +FROM players +ORDER BY + nickname + """.strip() + ) + + +@pytest.mark.parametrize( + "sql, expected", + [ + ( + """ +SELECT + * +FROM + num_test +INNER JOIN + num_test2 + USING ( + firstname, + score + ); + """, + """ +SELECT + * +FROM num_test +INNER JOIN num_test2 + USING (firstname, score) + """.strip(), + ), + ( + """ +SELECT + * +FROM + num_test +INNER JOIN + num_test2 + ON num_test.firstname = num_test2.firstname + AND num_test.score = num_test2.score; + """, + """ +SELECT + * +FROM num_test +INNER JOIN num_test2 + ON num_test.firstname = num_test2.firstname AND num_test.score = num_test2.score + """.strip(), + ), + ( + """ +SELECT + num_test.firstname, + num_test2.firstname +FROM num_test +LEFT OUTER JOIN + num_test2 + USING (firstname); + """, + """ +SELECT + num_test.firstname, + num_test2.firstname +FROM num_test +LEFT OUTER JOIN num_test2 + USING (firstname) + """.strip(), + ), + ( + """ +SELECT + num_test.firstname, + num_test2.firstname +FROM + num_test +RIGHT OUTER JOIN + num_test2 + USING (firstname); + """, + """ +SELECT + num_test.firstname, + num_test2.firstname +FROM num_test +RIGHT OUTER JOIN num_test2 + USING (firstname) + """.strip(), + ), + ( + """ +SELECT + num_test.firstname, + num_test2.firstname +FROM + num_test +FULL OUTER JOIN + num_test2 + USING (firstname); + """, + """ +SELECT + num_test.firstname, + num_test2.firstname +FROM num_test +FULL OUTER JOIN num_test2 + USING (firstname) + """.strip(), + ), + ( + """ +SELECT + crossjoin_test.letter, + crossjoin_test2.letter +FROM + crossjoin_test +CROSS JOIN + crossjoin_test2; + """, + """ +SELECT + crossjoin_test.letter, + crossjoin_test2.letter +FROM crossjoin_test +CROSS JOIN crossjoin_test2 + """.strip(), + ), + ], +) +def test_join(sql: str, expected: str) -> None: + """ + Test different joins in the old dialect. + """ + ast = sqlglot.parse_one(sql, FireboltOld) + assert FireboltOld().generate(expression=ast, pretty=True) == expected diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index 64df9c5415d7..46774759abe4 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -19,15 +19,18 @@ import pytest from pytest_mock import MockerFixture -from sqlglot import Dialects, parse_one +from sqlglot import Dialects, exp, parse_one from superset.exceptions import QueryClauseValidationException, SupersetParseError +from superset.jinja_context import JinjaTemplateProcessor from superset.sql.parse import ( CTASMethod, extract_tables_from_jinja_sql, extract_tables_from_statement, + KQLTokenType, KustoKQLStatement, LimitMethod, + remove_quotes, RLSMethod, sanitize_clause, split_kql, @@ -35,6 +38,7 @@ SQLScript, SQLStatement, Table, + tokenize_kql, ) from tests.integration_tests.conftest import with_feature_flags @@ -776,6 +780,20 @@ def test_sqlscript() -> None: "postgresql", ["SELECT\n 1 /* extraneous comment */"], ), + ( + "SHOW TABLES FROM s1 like '%order%';", + "mysql", + ["SHOW TABLES FROM s1 LIKE '%order%'"], + ), + ( + "SELECT 1; SELECT 2; SELECT 3;", + "unknown-engine", + [ + "SELECT\n 1", + "SELECT\n 2", + "SELECT\n 3", + ], + ), ], ) def test_sqlscript_split(sql: str, engine: str, expected: list[str]) -> None: @@ -795,18 +813,59 @@ def test_sqlstatement() -> None: "sqlite", ) + assert ( + statement.format() + == "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2" + ) + assert str(statement) == statement.format() + assert statement.tables == { Table(table="table1", schema=None, catalog=None), Table(table="table2", schema=None, catalog=None), } - assert ( - statement.format() - == "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2" + + assert statement.parse_predicate("a > 1") == exp.GT( + this=exp.Column(this=exp.Identifier(this="a", quoted=False)), + expression=exp.Literal(this="1", is_string=False), ) statement = SQLStatement("SET a=1", "sqlite") assert statement.get_settings() == {"a": "1"} + with pytest.raises( + ValueError, + match="Either statement or ast must be provided", + ): + SQLStatement() + + +def test_kustokqlstatement() -> None: + """ + Test the `KustoKQLStatement` class. + """ + statement = KustoKQLStatement("foo | take 100", "kustokql") + + assert statement.format() == "foo | take 100" + assert str(statement) == statement.format() + + # doesn't support table extraction + assert statement.tables == set() + + # optimize is a no-op + assert statement.optimize().format() == "foo | take 100" + + # predicate parsing is also no-op + assert statement.parse_predicate("a > 1") == "a > 1" + + with pytest.raises(SupersetParseError, match="Invalid engine: invalid-engine"): + KustoKQLStatement("foo | take 100", "invalid-engine") + + with pytest.raises( + SupersetParseError, + match="KustoKQLStatement should have exactly one statement", + ): + KustoKQLStatement("foo | take 1; bar | take 2", "kustokql") + def test_kustokqlstatement_split_script() -> None: """ @@ -887,11 +946,13 @@ def test_kustokql_statement_split_special(kql: str, statements: int) -> None: assert len(KustoKQLStatement.split_script(kql, "kustokql")) == statements -def test_split_kql() -> None: - """ - Test the `split_kql` function. - """ - kql = """ +@pytest.mark.parametrize( + "kql, expected", + [ + (";Table | take 5", ["Table | take 5"]), + (";Table | take 5;", ["Table | take 5"]), + ( + """ let totalPagesPerDay = PageViews | summarize by Page, Day = startofday(Timestamp) | summarize count() by Day; @@ -912,18 +973,18 @@ def test_split_kql() -> None: totalPagesPerDay on $left.Day1 == $right.Day | project Day1, Day2, Percentage = count_*100.0/count_1 - """ - assert split_kql(kql) == [ - """ + """, + [ + """ let totalPagesPerDay = PageViews | summarize by Page, Day = startofday(Timestamp) | summarize count() by Day""", - """ + """ let materializedScope = PageViews | summarize by Page, Day = startofday(Timestamp)""", - """ + """ let cachedResult = materialize(materializedScope)""", - """ + """ cachedResult | project Page, Day1 = Day | join kind = inner @@ -938,8 +999,16 @@ def test_split_kql() -> None: totalPagesPerDay on $left.Day1 == $right.Day | project Day1, Day2, Percentage = count_*100.0/count_1 - """, - ] + """, + ], + ), + ], +) +def test_split_kql(kql: str, expected: list[str]) -> None: + """ + Test the `split_kql` function. + """ + assert split_kql(kql) == expected @pytest.mark.parametrize( @@ -1106,14 +1175,19 @@ def test_custom_dialect(app: None) -> None: "vertica", ], ) -def test_is_mutating(engine: str) -> None: +@pytest.mark.parametrize( + "sql, expected", + [ + ("SELECT 1", False), + ("with source as ( select 1 as one ) select * from source", False), + ("ALTER TABLE foo ADD COLUMN bar INT", True), + ], +) +def test_is_mutating(sql: str, engine: str, expected: bool) -> None: """ Global tests for `is_mutating`, covering all supported engines. """ - assert not SQLStatement( - "with source as ( select 1 as one ) select * from source", - engine=engine, - ).is_mutating() + assert SQLStatement(sql, engine).is_mutating() == expected def test_optimize() -> None: @@ -1164,6 +1238,9 @@ def test_optimize() -> None: assert SQLStatement(sql, "sqlite").optimize().format() == optimized assert SQLStatement(sql, "dremio").optimize().format() == not_optimized + # also works for scripts + assert SQLScript(sql, "sqlite").optimize().format() == optimized + def test_firebolt() -> None: """ @@ -1285,6 +1362,8 @@ def test_firebolt_old_escape_string() -> None: "postgresql", None, ), + # not really valid SQL, but let's roll with it + ("SELECT * FROM my_table LIMIT invalid", "postgresql", None), ], ) def test_get_limit_value(sql: str, engine: str, expected: str) -> None: @@ -1307,6 +1386,7 @@ def test_get_limit_value(sql: str, engine: str, expected: str) -> None: """, 5, ), + ("table | take five", None), ], ) def test_get_kql_limit_value(kql: str, expected: str) -> None: @@ -1492,6 +1572,13 @@ def test_get_kql_limit_value(kql: str, expected: str) -> None: LimitMethod.FORCE_LIMIT, "SELECT\n *\nFROM birth_names\nLIMIT 1000", ), + ( + "SELECT * FROM birth_names LIMIT 555", + "postgresql", + 1000, + LimitMethod.FETCH_MANY, + "SELECT\n *\nFROM birth_names\nLIMIT 555", + ), ], ) def test_set_limit_value( @@ -1539,11 +1626,28 @@ def test_set_limit_value( ], ) def test_set_kql_limit_value(kql: str, limit: int, expected: str) -> None: + """ + Test the `set_limit_value` method for KustoKQLStatement. + """ statement = KustoKQLStatement(kql, "kustokql") statement.set_limit_value(limit) assert statement.format() == expected +@pytest.mark.parametrize("method", [LimitMethod.WRAP_SQL, LimitMethod.FETCH_MANY]) +def test_set_kql_limit_value_invalid_method(method: LimitMethod) -> None: + """ + Test that setting a limit value with an invalid method raises an error. + """ + statement = KustoKQLStatement("foo", "kustokql") + + with pytest.raises( + SupersetParseError, + match="Kusto KQL only supports the FORCE_LIMIT method.", + ): + statement.set_limit_value(10, method) + + @pytest.mark.parametrize( "sql, engine, expected", [ @@ -1670,6 +1774,15 @@ def test_as_cte(sql: str, engine: str, expected: str) -> None: ) AS t """.strip(), ), + ( + "SELECT t.foo FROM some_table AS t", + {}, + """ +SELECT + t.foo +FROM some_table AS t + """.strip(), + ), ( "SELECT t.foo FROM some_table AS t WHERE bar = 'baz'", {Table("some_table", "schema1", "catalog1"): "id = 42"}, @@ -1947,6 +2060,17 @@ def test_rls_subquery_transformer( assert statement.format() == expected +def test_rls_invalid_method(mocker: MockerFixture) -> None: + """ + Test that an invalid RLS method raises an error. + """ + statement = SQLStatement("SELECT 1", "postgresql") + predicates = mocker.MagicMock() + + with pytest.raises(ValueError, match="Invalid RLS method: invalid"): + statement.apply_rls("catalog1", "schema1", predicates, "invalid") # type: ignore + + @pytest.mark.parametrize( "sql, rules, expected", [ @@ -2171,6 +2295,17 @@ def test_rls_subquery_transformer( ) """.strip(), ), + ( + "SELECT * FROM table JOIN other_table", + {Table("other_table", "schema1", "catalog1"): "id = 42"}, + """ +SELECT + * +FROM table +JOIN other_table + ON other_table.id = 42 + """.strip(), + ), ( """ SELECT * @@ -2237,6 +2372,18 @@ def test_rls_subquery_transformer( other_table.id = 42 """.strip(), ), + ( + "INSERT INTO some_table (col1, col2) VALUES (1, 2)", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + """ +INSERT INTO some_table ( + col1, + col2 +) +VALUES + (1, 2) + """.strip(), + ), ], ) def test_rls_predicate_transformer( @@ -2427,7 +2574,7 @@ def test_sanitize_clause(sql: str, expected: str | Exception, engine: str) -> No ], ) @pytest.mark.parametrize( - "macro,expected", + "macro, expected", [ ( "latest_partition('foo.bar')", @@ -2464,7 +2611,7 @@ def test_extract_tables_from_jinja_sql( assert ( extract_tables_from_jinja_sql( sql=f"'{{{{ {engine}.{macro} }}}}'", - database=mocker.Mock(), + database=mocker.MagicMock(backend=engine), ) == expected ) @@ -2475,10 +2622,154 @@ def test_extract_tables_from_jinja_sql_disabled(mocker: MockerFixture) -> None: """ Test the function when the feature flag is disabled. """ - database = mocker.Mock() + database = mocker.MagicMock() database.db_engine_spec.engine = "mssql" assert extract_tables_from_jinja_sql( sql="SELECT 1 FROM t", database=database, ) == {Table("t")} + + +def test_extract_tables_from_jinja_sql_invalid_function(mocker: MockerFixture) -> None: + """ + Test the function with an invalid function. + """ + database = mocker.MagicMock(backend="postgresql") + + processor = JinjaTemplateProcessor(database) + processor.env.globals["my_table"] = lambda: "t" + mocker.patch( + "superset.jinja_context.get_template_processor", + return_value=processor, + ) + + assert extract_tables_from_jinja_sql( + sql="SELECT * FROM {{ my_table() }}", + database=database, + ) == {Table("t")} + + +@pytest.mark.parametrize( + "sql, engine, expected", + [ + ("SELECT * FROM users", "postgresql", True), + ("WITH cte AS (SELECT * FROM users) SELECT * FROM cte", "postgresql", True), + ("CREATE TABLE users AS SELECT * FROM users", "postgresql", False), + ("ALTER TABLE users ADD COLUMN age INT", "postgresql", False), + ("SET @value = 42", "postgresql", False), + ], +) +def test_sqlstatement_is_select(sql: str, engine: str, expected: bool) -> None: + """ + Test the `SQLStatement.is_select()` method. + """ + assert SQLStatement(sql, engine).is_select() == expected + + +@pytest.mark.parametrize( + "kql, expected", + [ + ("StormEvents | take 10", True), + ("StormEvents | limit 20", True), + ("StormEvents | where State == 'FL' | summarize count()", True), + ("StormEvents | where name has 'limit 10'", True), + ("AnotherTable | take 5", True), + ("datatable(x:int) [1, 2, 3] | take 100", True), + (".create table StormEvents (x:int)", False), + (".ingest inline into table StormEvents <| StormEvents | take 10", False), + ], +) +def test_kqlstatement_is_select(kql: str, expected: bool) -> None: + """ + Test the `KustoKQLStatement.is_select()` method. + """ + assert KustoKQLStatement(kql, "kustokql").is_select() == expected + + +def test_remove_quotes() -> None: + """ + Test the `remove_quotes` helper function. + """ + assert remove_quotes(None) is None + assert remove_quotes('"foo"') == "foo" + assert remove_quotes("'foo'") == "foo" + assert remove_quotes("`foo`") == "foo" + assert remove_quotes("'foo`") == "'foo`" + + +@pytest.mark.parametrize( + "sql, engine, expected", + [ + ("SELECT * FROM table", "postgresql", False), + ("SELECT VERSION()", "postgresql", True), + ("SELECT query_to_xml()", "postgresql", True), + ("WITH cte AS (SELECT * FROM table) SELECT * FROM cte", "postgresql", False), + ( + """ +SELECT * +FROM query_to_xml('SELECT * from some_table WHERE id = 42') + """, + "postgresql", + True, + ), + ("Table | limit 10", "kustokql", False), + ], +) +def test_check_functions_present(sql: str, engine: str, expected: bool) -> None: + """ + Check the `check_functions_present` method. + """ + functions = {"version", "query_to_xml"} + assert SQLScript(sql, engine).check_functions_present(functions) == expected + + +@pytest.mark.parametrize( + "kql, expected", + [ + ( + "StormEvents | take 10", + [ + (KQLTokenType.WORD, "StormEvents"), + (KQLTokenType.WHITESPACE, " "), + (KQLTokenType.OTHER, "|"), + (KQLTokenType.WHITESPACE, " "), + (KQLTokenType.WORD, "take"), + (KQLTokenType.WHITESPACE, " "), + (KQLTokenType.NUMBER, "10"), + ], + ), + ("'test'", [(KQLTokenType.STRING, "'test'")]), + ("```test```", [(KQLTokenType.STRING, "```test```")]), + ], +) +def test_tokenize_kql(kql: str, expected: list[tuple[KQLTokenType, str]]) -> None: + """ + Test the `tokenize_kql` function. + """ + assert tokenize_kql(kql) == expected + + +@pytest.mark.parametrize( + "sql, engine, expected", + [ + ("a = 1", "postgresql", False), + ("(SELECT * FROM table)", "postgresql", True), + ("SELECT * FROM table", "postgresql", False), + ("SELECT * FROM (SELECT 1)", "postgresql", True), + ("SELECT * FROM (SELECT 1) AS subquery", "postgresql", True), + ("WITH cte AS (SELECT 1) SELECT * FROM cte", "postgresql", True), + ("SELECT * FROM table WHERE EXISTS (SELECT 1)", "postgresql", True), + ("SELECT * FROM table WHERE NOT EXISTS (SELECT 1)", "postgresql", True), + ( + "SELECT * FROM table WHERE id IN (SELECT id FROM other_table)", + "postgresql", + True, + ), + ], +) +def test_has_subquery(sql: str, engine: str, expected: bool) -> None: + """ + Test the `has_subquery` method. + """ + assert SQLStatement(sql, engine).has_subquery() == expected