diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index d661aaf96173..9749d076750d 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1467,7 +1467,7 @@ def adhoc_metric_to_sqla( if not processed: try: - expression = self._process_sql_expression( + expression = self._process_select_expression( expression=expression, database_id=self.database_id, engine=self.database.backend, @@ -1502,7 +1502,7 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals """ label = utils.get_column_name(col) try: - expression = self._process_sql_expression( + expression = self._process_select_expression( expression=col["sqlExpression"], database_id=self.database_id, engine=self.database.backend, diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 5e7b5d9861bb..88230d7bc904 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -871,6 +871,40 @@ def _process_sql_expression( # pylint: disable=too-many-arguments raise QueryObjectValidationError(ex.message) from ex return expression + def _process_select_expression( + self, + expression: Optional[str], + database_id: int, + engine: str, + schema: str, + template_processor: Optional[BaseTemplateProcessor], + ) -> Optional[str]: + """ + Validate and process an adhoc expression used as a column or metric. + + This requires prefixing the expression with a dummy SELECT statement, so it can + be properly parsed and validated. + """ + if expression: + expression = f"SELECT {expression}" + + if processed := self._process_sql_expression( + expression=expression, + database_id=database_id, + engine=engine, + schema=schema, + template_processor=template_processor, + ): + prefix, expression = re.split( + r"SELECT\s+", + processed, + maxsplit=1, + flags=re.IGNORECASE, + ) + return expression.strip() + + return None + def _process_orderby_expression( self, expression: Optional[str], @@ -1200,7 +1234,7 @@ def adhoc_metric_to_sqla( expression = metric.get("sqlExpression") if not processed: - expression = self._process_sql_expression( + expression = self._process_select_expression( expression=metric["sqlExpression"], database_id=self.database_id, engine=self.database.backend, @@ -1888,12 +1922,12 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma template_processor=template_processor, ) else: - selected = validate_adhoc_subquery( - selected, - self.database, - self.catalog, - self.schema, - self.database.db_engine_spec.engine, + selected = self._process_select_expression( + expression=selected, + database_id=self.database_id, + engine=self.database.backend, + schema=self.schema, + template_processor=template_processor, ) outer = literal_column(f"({selected})") outer = self.make_sqla_column_compatible(outer, selected) @@ -1917,12 +1951,12 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma _sql = quote(selected) _column_label = selected - selected = validate_adhoc_subquery( - _sql, - self.database, - self.catalog, - self.schema, - self.database.db_engine_spec.engine, + selected = self._process_select_expression( + expression=_sql, + database_id=self.database_id, + engine=self.database.backend, + schema=self.schema, + template_processor=template_processor, ) select_exprs.append( @@ -2196,7 +2230,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma if extras: where = extras.get("where") if where: - where = self._process_sql_expression( + where = self._process_select_expression( expression=where, database_id=self.database_id, engine=self.database.backend, @@ -2206,7 +2240,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma where_clause_and += [self.text(where)] having = extras.get("having") if having: - having = self._process_sql_expression( + having = self._process_select_expression( expression=having, database_id=self.database_id, engine=self.database.backend, diff --git a/tests/unit_tests/models/helpers_test.py b/tests/unit_tests/models/helpers_test.py index 95862c29544a..8154e5013649 100644 --- a/tests/unit_tests/models/helpers_test.py +++ b/tests/unit_tests/models/helpers_test.py @@ -798,3 +798,330 @@ def test_process_orderby_expression_with_template_processor( assert call_args["template_processor"] is template_processor assert result == "processed_column DESC" + + +def test_process_select_expression_basic( + mocker: MockerFixture, + database: Database, +) -> None: + """ + Test basic SELECT expression processing. + """ + from superset.connectors.sqla.models import SqlaTable + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + ) + + # Mock _process_sql_expression to return a processed SELECT statement + mocker.patch.object( + table, + "_process_sql_expression", + return_value="SELECT COUNT(*)", + ) + + result = table._process_select_expression( + expression="COUNT(*)", + database_id=database.id, + engine="sqlite", + schema="", + template_processor=None, + ) + + assert result == "COUNT(*)" + + +def test_process_select_expression_with_case_insensitive_select( + mocker: MockerFixture, + database: Database, +) -> None: + """ + Test SELECT expression processing with case-insensitive matching. + """ + from superset.connectors.sqla.models import SqlaTable + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + ) + + # Mock with lowercase "select" + mocker.patch.object( + table, + "_process_sql_expression", + return_value="select column_name", + ) + + result = table._process_select_expression( + expression="column_name", + database_id=database.id, + engine="sqlite", + schema="", + template_processor=None, + ) + + assert result == "column_name" + + +def test_process_select_expression_complex( + mocker: MockerFixture, + database: Database, +) -> None: + """ + Test SELECT expression with complex expressions. + """ + from superset.connectors.sqla.models import SqlaTable + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + ) + + complex_select = "CASE WHEN status = 'active' THEN 1 ELSE 0 END" + mocker.patch.object( + table, + "_process_sql_expression", + return_value=f"SELECT {complex_select}", + ) + + result = table._process_select_expression( + expression=complex_select, + database_id=database.id, + engine="sqlite", + schema="", + template_processor=None, + ) + + assert result == complex_select + + +def test_process_select_expression_none( + mocker: MockerFixture, + database: Database, +) -> None: + """ + Test SELECT expression processing with None expression. + """ + from superset.connectors.sqla.models import SqlaTable + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + ) + + # Mock should return None when input is None + mocker.patch.object( + table, + "_process_sql_expression", + return_value=None, + ) + + result = table._process_select_expression( + expression=None, + database_id=database.id, + engine="sqlite", + schema="", + template_processor=None, + ) + + assert result is None + + +def test_process_select_expression_empty_string( + mocker: MockerFixture, + database: Database, +) -> None: + """ + Test SELECT expression processing with empty string. + """ + from superset.connectors.sqla.models import SqlaTable + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + ) + + # Mock should return None for empty string + mocker.patch.object( + table, + "_process_sql_expression", + return_value=None, + ) + + result = table._process_select_expression( + expression="", + database_id=database.id, + engine="sqlite", + schema="", + template_processor=None, + ) + + assert result is None + + +def test_process_select_expression_strips_whitespace( + mocker: MockerFixture, + database: Database, +) -> None: + """ + Test that SELECT expression processing strips leading/trailing whitespace. + """ + from superset.connectors.sqla.models import SqlaTable + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + ) + + # Mock with extra whitespace after SELECT + mocker.patch.object( + table, + "_process_sql_expression", + return_value="SELECT column_name ", + ) + + result = table._process_select_expression( + expression="column_name", + database_id=database.id, + engine="sqlite", + schema="", + template_processor=None, + ) + + assert result == "column_name" + + +def test_process_select_expression_with_template_processor( + mocker: MockerFixture, + database: Database, +) -> None: + """ + Test SELECT expression with template processor. + """ + from unittest.mock import Mock + + from superset.connectors.sqla.models import SqlaTable + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + ) + + # Create a mock template processor + template_processor = Mock() + + # Mock the _process_sql_expression to verify it receives the prefixed expression + mock_process = mocker.patch.object( + table, + "_process_sql_expression", + return_value="SELECT processed_expression", + ) + + result = table._process_select_expression( + expression="some_expression", + database_id=database.id, + engine="sqlite", + schema="", + template_processor=template_processor, + ) + + # Verify _process_sql_expression was called with SELECT prefix + mock_process.assert_called_once() + call_args = mock_process.call_args[1] + assert call_args["expression"] == "SELECT some_expression" + assert call_args["template_processor"] is template_processor + + assert result == "processed_expression" + + +def test_process_select_expression_distinct_column( + mocker: MockerFixture, + database: Database, +) -> None: + """ + Test SELECT expression with DISTINCT keyword (e.g., "distinct owners"). + + This test ensures that expressions like "distinct owners" used in adhoc + metrics or columns are properly parsed and validated. + """ + from superset.connectors.sqla.models import SqlaTable + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + ) + + # Mock _process_sql_expression to return a processed SELECT with DISTINCT + mocker.patch.object( + table, + "_process_sql_expression", + return_value="SELECT DISTINCT owners", + ) + + result = table._process_select_expression( + expression="distinct owners", + database_id=database.id, + engine="sqlite", + schema="", + template_processor=None, + ) + + assert result == "DISTINCT owners" + + +def test_process_select_expression_end_to_end(database: Database) -> None: + """ + End-to-end test that verifies the regex split works with real sqlglot processing. + + This test does NOT mock _process_sql_expression, allowing the full flow + through sqlglot parsing and validation to ensure the regex extraction works. + """ + from superset.connectors.sqla.models import SqlaTable + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + ) + + # Test various real-world expressions + test_cases = [ + # (input, expected_output) + ("COUNT(*)", "COUNT(*)"), + ("DISTINCT owners", "DISTINCT owners"), + ("column_name", "column_name"), + ( + "CASE WHEN status = 'active' THEN 1 ELSE 0 END", + "CASE WHEN status = 'active' THEN 1 ELSE 0 END", + ), + ("SUM(amount) / COUNT(*)", "SUM(amount) / COUNT(*)"), + ("UPPER(name)", "UPPER(name)"), + ] + + for expression, expected in test_cases: + result = table._process_select_expression( + expression=expression, + database_id=database.id, + engine="sqlite", + schema="", + template_processor=None, + ) + # sqlglot may normalize the SQL slightly, so we check the result exists + # and doesn't contain the SELECT prefix + assert result is not None, f"Failed to process: {expression}" + assert not result.upper().startswith("SELECT"), ( + f"Result still has SELECT prefix: {result}" + ) + # The result should contain the core expression (case-insensitive check) + assert expected.replace(" ", "").lower() in result.replace(" ", "").lower(), ( + f"Expected '{expected}' to be in result '{result}' for input '{expression}'" + )