Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion superset/commands/dataset/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
from sqlalchemy.exc import SQLAlchemyError

from superset.commands.base import BaseCommand, CreateMixin
from superset.commands.dataset.exceptions import (
Expand All @@ -43,7 +44,16 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()

@transaction(on_error=partial(on_error, reraise=DatasetCreateFailedError))
@transaction(
on_error=partial(
on_error,
catches=(
SQLAlchemyError,
SupersetSecurityException,
),
reraise=DatasetCreateFailedError,
)
)
def run(self) -> Model:
self.validate()

Expand Down
17 changes: 15 additions & 2 deletions superset/commands/dataset/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Optional

from flask_babel import lazy_gettext as _
from marshmallow.validate import ValidationError

Expand All @@ -26,6 +28,7 @@
ImportFailedError,
UpdateFailedError,
)
from superset.exceptions import SupersetSecurityException
from superset.sql_parse import Table


Expand Down Expand Up @@ -159,11 +162,21 @@ class DatasetInvalidError(CommandInvalidError):
message = _("Dataset parameters are invalid.")


class DatasetCreateFailedError(CreateFailedError):
class DatasetSQLStatementErrorMixin(CommandException):
def __init__(
self,
ex: Optional[Exception] = None,
) -> None:
if isinstance(ex, SupersetSecurityException):
self.message = str(ex)
super().__init__()


class DatasetCreateFailedError(CreateFailedError, DatasetSQLStatementErrorMixin):
message = _("Dataset could not be created.")


class DatasetUpdateFailedError(UpdateFailedError):
class DatasetUpdateFailedError(UpdateFailedError, DatasetSQLStatementErrorMixin):
message = _("Dataset could not be updated.")


Expand Down
1 change: 1 addition & 0 deletions superset/commands/dataset/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
catches=(
SQLAlchemyError,
ValueError,
SupersetSecurityException,
),
reraise=DatasetUpdateFailedError,
)
Expand Down
2 changes: 1 addition & 1 deletion superset/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def on_error(
logger.exception(ex.exception)

if reraise:
raise reraise() from ex
raise reraise(ex) from ex
else:
raise ex

Expand Down
22 changes: 22 additions & 0 deletions tests/integration_tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,28 @@ def test_create_dataset_with_sql(self):
db.session.delete(model)
db.session.commit()

@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_create_dataset_with_sql_validate_select_sql(self):
"""
Dataset API: Test create dataset with invalid select sql statement
"""
if backend() == "sqlite":
return

energy_usage_ds = self.get_energy_usage_dataset()
self.login(username="admin")
table_data = {
"database": energy_usage_ds.database_id,
"table_name": "energy_usage_virtual",
"sql": "insert into energy_usage select * from energy_usage",
}
if schema := get_example_default_schema():
table_data["schema"] = schema
rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post")
assert rv.status_code == 422
data = json.loads(rv.data.decode("utf-8"))
assert data == {"message": "Only `SELECT` statements are allowed"}

@unittest.skip("test is failing stochastically")
def test_create_dataset_same_name_different_schema(self):
if backend() == "sqlite":
Expand Down