diff --git a/superset/commands/dataset/create.py b/superset/commands/dataset/create.py index a2d81e548bfb..81c3343b505a 100644 --- a/superset/commands/dataset/create.py +++ b/superset/commands/dataset/create.py @@ -20,6 +20,7 @@ from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError +from sqlalchemy.exc import SQLAlchemyError from superset.commands.base import BaseCommand, CreateMixin from superset.commands.dataset.exceptions import ( @@ -43,7 +44,16 @@ class CreateDatasetCommand(CreateMixin, BaseCommand): def __init__(self, data: dict[str, Any]): self._properties = data.copy() - @transaction(on_error=partial(on_error, reraise=DatasetCreateFailedError)) + @transaction( + on_error=partial( + on_error, + catches=( + SQLAlchemyError, + SupersetSecurityException, + ), + reraise=DatasetCreateFailedError, + ) + ) def run(self) -> Model: self.validate() diff --git a/superset/commands/dataset/exceptions.py b/superset/commands/dataset/exceptions.py index 83b5436c233a..bb20d711d07b 100644 --- a/superset/commands/dataset/exceptions.py +++ b/superset/commands/dataset/exceptions.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional + from flask_babel import lazy_gettext as _ from marshmallow.validate import ValidationError @@ -26,6 +28,7 @@ ImportFailedError, UpdateFailedError, ) +from superset.exceptions import SupersetSecurityException from superset.sql_parse import Table @@ -159,11 +162,21 @@ class DatasetInvalidError(CommandInvalidError): message = _("Dataset parameters are invalid.") -class DatasetCreateFailedError(CreateFailedError): +class DatasetSQLStatementErrorMixin(CommandException): + def __init__( + self, + ex: Optional[Exception] = None, + ) -> None: + if isinstance(ex, SupersetSecurityException): + self.message = str(ex) + super().__init__() + + +class DatasetCreateFailedError(CreateFailedError, DatasetSQLStatementErrorMixin): message = _("Dataset could not be created.") -class DatasetUpdateFailedError(UpdateFailedError): +class DatasetUpdateFailedError(UpdateFailedError, DatasetSQLStatementErrorMixin): message = _("Dataset could not be updated.") diff --git a/superset/commands/dataset/update.py b/superset/commands/dataset/update.py index 14d1c5ef4470..1337a4e94624 100644 --- a/superset/commands/dataset/update.py +++ b/superset/commands/dataset/update.py @@ -67,6 +67,7 @@ def __init__( catches=( SQLAlchemyError, ValueError, + SupersetSecurityException, ), reraise=DatasetUpdateFailedError, ) diff --git a/superset/utils/decorators.py b/superset/utils/decorators.py index 844a8f063c1b..9e94a16de5d9 100644 --- a/superset/utils/decorators.py +++ b/superset/utils/decorators.py @@ -231,7 +231,7 @@ def on_error( logger.exception(ex.exception) if reraise: - raise reraise() from ex + raise reraise(ex) from ex else: raise ex diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 37de6e87c27a..e58db4c71f91 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -754,6 +754,28 @@ def test_create_dataset_with_sql(self): db.session.delete(model) db.session.commit() + @pytest.mark.usefixtures("load_energy_table_with_slice") + def test_create_dataset_with_sql_validate_select_sql(self): + """ + Dataset API: Test create dataset with invalid select sql statement + """ + if backend() == "sqlite": + return + + energy_usage_ds = self.get_energy_usage_dataset() + self.login(username="admin") + table_data = { + "database": energy_usage_ds.database_id, + "table_name": "energy_usage_virtual", + "sql": "insert into energy_usage select * from energy_usage", + } + if schema := get_example_default_schema(): + table_data["schema"] = schema + rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post") + assert rv.status_code == 422 + data = json.loads(rv.data.decode("utf-8")) + assert data == {"message": "Only `SELECT` statements are allowed"} + @unittest.skip("test is failing stochastically") def test_create_dataset_same_name_different_schema(self): if backend() == "sqlite":