diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 502c8ac82b49..004b2dd53dc2 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -65,7 +65,6 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import DisallowedSQLFunction, OAuth2Error, OAuth2RedirectError from superset.sql.parse import BaseSQLStatement, SQLScript, Table -from superset.sql_parse import ParsedQuery from superset.superset_typing import ( OAuth2ClientConfig, OAuth2State, @@ -1221,8 +1220,8 @@ def get_limit_from_sql(cls, sql: str) -> int | None: :param sql: SQL query :return: Value of limit clause in query """ - parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine) - return parsed_query.limit + script = SQLScript(sql, engine=cls.engine) + return script.statements[-1].get_limit_value() @classmethod def set_or_update_query_limit(cls, sql: str, limit: int) -> str: @@ -2088,14 +2087,6 @@ def update_params_from_encrypted_extra( # pylint: disable=invalid-name logger.error(ex, exc_info=True) raise - @classmethod - def is_select_query(cls, parsed_query: ParsedQuery) -> bool: - """ - Determine if the statement should be considered as SELECT statement. - Some query dialects do not contain "SELECT" word in queries (eg. Kusto) - """ - return parsed_query.is_select() - @classmethod def get_column_spec( # pylint: disable=unused-argument cls, diff --git a/superset/db_engine_specs/kusto.py b/superset/db_engine_specs/kusto.py index 59c3b1f2313a..a48c83182a5b 100644 --- a/superset/db_engine_specs/kusto.py +++ b/superset/db_engine_specs/kusto.py @@ -28,7 +28,6 @@ SupersetDBAPIOperationalError, SupersetDBAPIProgrammingError, ) -from superset.sql_parse import ParsedQuery from superset.utils.core import GenericDataType @@ -155,10 +154,6 @@ def convert_dttm( return None - @classmethod - def is_select_query(cls, parsed_query: ParsedQuery) -> bool: - return not parsed_query.sql.startswith(".") - @classmethod def parse_sql(cls, sql: str) -> list[str]: """ diff --git a/superset/sql_lab.py b/superset/sql_lab.py index c157896fcc6e..f4ca12de9533 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -53,7 +53,7 @@ from superset.models.core import Database from superset.models.sql_lab import Query from superset.result_set import SupersetResultSet -from superset.sql.parse import SQLStatement, Table +from superset.sql.parse import SQLScript, SQLStatement, Table from superset.sql_parse import ( CtasMethod, insert_rls_as_subquery, @@ -263,6 +263,7 @@ def execute_sql_statement( # pylint: disable=too-many-statements, too-many-loca ) raise SupersetErrorsException(errors) + original_sql = sql if apply_ctas: if not query.tmp_table_name: start_dttm = datetime.fromtimestamp(query.start_time) @@ -277,7 +278,7 @@ def execute_sql_statement( # pylint: disable=too-many-statements, too-many-loca query.select_as_cta_used = True # Do not apply limit to the CTA queries when SQLLAB_CTAS_NO_LIMIT is set to true - if db_engine_spec.is_select_query(parsed_query) and not ( + if not SQLScript(original_sql, db_engine_spec.engine).has_mutation() and not ( query.select_as_cta_used and SQLLAB_CTAS_NO_LIMIT ): if SQL_MAX_ROW and (not query.limit or query.limit > SQL_MAX_ROW): @@ -553,7 +554,7 @@ def execute_sql_statements( # noqa: C901 # Commit the connection so CTA queries will create the table and any DML. should_commit = ( - not db_engine_spec.is_select_query(parsed_query) # check if query is DML + SQLScript(rendered_query, db_engine_spec.engine).has_mutation() or apply_ctas ) if should_commit: diff --git a/tests/unit_tests/db_engine_specs/test_kusto.py b/tests/unit_tests/db_engine_specs/test_kusto.py index e8759f38cf4f..d3a49f86e9c2 100644 --- a/tests/unit_tests/db_engine_specs/test_kusto.py +++ b/tests/unit_tests/db_engine_specs/test_kusto.py @@ -23,7 +23,6 @@ from superset.db_engine_specs.kusto import KustoKqlEngineSpec from superset.sql.parse import SQLScript -from superset.sql_parse import ParsedQuery from tests.unit_tests.db_engine_specs.utils import assert_convert_dttm from tests.unit_tests.fixtures.common import dttm # noqa: F401 @@ -53,26 +52,6 @@ def test_sql_has_mutation(sql: str, expected: bool) -> None: ) -@pytest.mark.parametrize( - "kql,expected", - [ - ("tbl | limit 100", True), - ("let foo = 1; tbl | where bar == foo", True), - (".show tables", False), - ], -) -def test_kql_is_select_query(kql: str, expected: bool) -> None: - """ - Make sure that KQL dialect consider only statements that do not start with "." (dot) - as a SELECT statements - """ - - from superset.db_engine_specs.kusto import KustoKqlEngineSpec - - parsed_query = ParsedQuery(kql) - assert KustoKqlEngineSpec.is_select_query(parsed_query) == expected - - @pytest.mark.parametrize( "kql,expected", [ diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index 21c9a95247bd..89d7dbcb06d0 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -52,7 +52,6 @@ def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None: database.apply_limit_to_sql.return_value = "SELECT 42 AS answer LIMIT 2" database.mutate_sql_based_on_config.return_value = "SELECT 42 AS answer LIMIT 2" db_engine_spec = database.db_engine_spec - db_engine_spec.is_select_query.return_value = True db_engine_spec.fetch_data.return_value = [(42,)] cursor = mocker.MagicMock() @@ -95,7 +94,6 @@ def test_execute_sql_statement_with_rls( database.apply_limit_to_sql.return_value = sql_statement_with_rls_and_limit database.mutate_sql_based_on_config.return_value = sql_statement_with_rls_and_limit db_engine_spec = database.db_engine_spec - db_engine_spec.is_select_query.return_value = True db_engine_spec.fetch_data.return_value = [(42,)] cursor = mocker.MagicMock() @@ -140,7 +138,6 @@ def test_execute_sql_statement_exceeds_payload_limit(mocker: MockerFixture) -> N query = mocker.MagicMock() query.limit = 1 query.database = mocker.MagicMock() - query.database.db_engine_spec.is_select_query.return_value = True query.database.cache_timeout = 100 query.status = "RUNNING" query.select_as_cta = False @@ -193,7 +190,6 @@ def test_execute_sql_statement_within_payload_limit(mocker: MockerFixture) -> No query = mocker.MagicMock() query.limit = 1 query.database = mocker.MagicMock() - query.database.db_engine_spec.is_select_query.return_value = True query.database.cache_timeout = 100 query.status = "RUNNING" query.select_as_cta = False