diff --git a/superset/charts/api.py b/superset/charts/api.py index 41a6e779bddc..882e6596ac0c 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -17,12 +17,15 @@ import logging from flask import g, request, Response -from flask_appbuilder.api import expose, protect, safe +from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface +from flask_babel import ngettext +from superset.charts.commands.bulk_delete import BulkDeleteChartCommand from superset.charts.commands.create import CreateChartCommand from superset.charts.commands.delete import DeleteChartCommand from superset.charts.commands.exceptions import ( + ChartBulkDeleteFailedError, ChartCreateFailedError, ChartDeleteFailedError, ChartForbiddenError, @@ -32,7 +35,12 @@ ) from superset.charts.commands.update import UpdateChartCommand from superset.charts.filters import ChartFilter -from superset.charts.schemas import ChartPostSchema, ChartPutSchema +from superset.charts.schemas import ( + ChartPostSchema, + ChartPutSchema, + get_delete_ids_schema, +) +from superset.constants import RouteMethod from superset.models.slice import Slice from superset.views.base_api import BaseSupersetModelRestApi @@ -45,6 +53,11 @@ class ChartRestApi(BaseSupersetModelRestApi): resource_name = "chart" allow_browser_login = True + include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | { + RouteMethod.EXPORT, + RouteMethod.RELATED, + "bulk_delete", # not using RouteMethod since locally defined + } class_permission_name = "SliceModelView" show_columns = [ "slice_name", @@ -269,3 +282,61 @@ def delete(self, pk: int) -> Response: # pylint: disable=arguments-differ except ChartDeleteFailedError as e: logger.error(f"Error deleting model {self.__class__.__name__}: {e}") return self.response_422(message=str(e)) + + @expose("/", methods=["DELETE"]) + @protect() + @safe + @rison(get_delete_ids_schema) + def bulk_delete(self, **kwargs) -> Response: # pylint: disable=arguments-differ + """Delete bulk Charts + --- + delete: + description: >- + Deletes multiple Charts in a bulk operation + parameters: + - in: query + name: q + content: + application/json: + schema: + type: array + items: + type: integer + responses: + 200: + description: Charts bulk delete + content: + application/json: + schema: + type: object + properties: + message: + type: string + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + item_ids = kwargs["rison"] + try: + BulkDeleteChartCommand(g.user, item_ids).run() + return self.response( + 200, + message=ngettext( + f"Deleted %(num)d chart", + f"Deleted %(num)d charts", + num=len(item_ids), + ), + ) + except ChartNotFoundError: + return self.response_404() + except ChartForbiddenError: + return self.response_403() + except ChartBulkDeleteFailedError as e: + return self.response_422(message=str(e)) diff --git a/superset/charts/commands/bulk_delete.py b/superset/charts/commands/bulk_delete.py new file mode 100644 index 000000000000..f8ded64fd9fc --- /dev/null +++ b/superset/charts/commands/bulk_delete.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import List, Optional + +from flask_appbuilder.security.sqla.models import User + +from superset.charts.commands.exceptions import ( + ChartBulkDeleteFailedError, + ChartForbiddenError, + ChartNotFoundError, +) +from superset.charts.dao import ChartDAO +from superset.commands.base import BaseCommand +from superset.commands.exceptions import DeleteFailedError +from superset.exceptions import SupersetSecurityException +from superset.models.slice import Slice +from superset.views.base import check_ownership + +logger = logging.getLogger(__name__) + + +class BulkDeleteChartCommand(BaseCommand): + def __init__(self, user: User, model_ids: List[int]): + self._actor = user + self._model_ids = model_ids + self._models: Optional[List[Slice]] = None + + def run(self): + self.validate() + try: + ChartDAO.bulk_delete(self._models) + except DeleteFailedError as e: + logger.exception(e.exception) + raise ChartBulkDeleteFailedError() + + def validate(self) -> None: + # Validate/populate model exists + self._models = ChartDAO.find_by_ids(self._model_ids) + if not self._models or len(self._models) != len(self._model_ids): + raise ChartNotFoundError() + # Check ownership + for model in self._models: + try: + check_ownership(model) + except SupersetSecurityException: + raise ChartForbiddenError() diff --git a/superset/charts/commands/exceptions.py b/superset/charts/commands/exceptions.py index b8e3d81022f5..9e0f79df81dc 100644 --- a/superset/charts/commands/exceptions.py +++ b/superset/charts/commands/exceptions.py @@ -79,3 +79,7 @@ class ChartDeleteFailedError(DeleteFailedError): class ChartForbiddenError(ForbiddenError): message = _("Changing this chart is forbidden") + + +class ChartBulkDeleteFailedError(CreateFailedError): + message = _("Charts could not be deleted.") diff --git a/superset/charts/dao.py b/superset/charts/dao.py index 6732c96f6bab..ec56b5d0b314 100644 --- a/superset/charts/dao.py +++ b/superset/charts/dao.py @@ -15,9 +15,13 @@ # specific language governing permissions and limitations # under the License. import logging +from typing import List + +from sqlalchemy.exc import SQLAlchemyError from superset.charts.filters import ChartFilter from superset.dao.base import BaseDAO +from superset.extensions import db from superset.models.slice import Slice logger = logging.getLogger(__name__) @@ -26,3 +30,23 @@ class ChartDAO(BaseDAO): model_cls = Slice base_filter = ChartFilter + + @staticmethod + def bulk_delete(models: List[Slice], commit=True): + item_ids = [model.id for model in models] + # bulk delete, first delete related data + for model in models: + model.owners = [] + model.dashboards = [] + db.session.merge(model) + # bulk delete itself + try: + db.session.query(Slice).filter(Slice.id.in_(item_ids)).delete( + synchronize_session="fetch" + ) + if commit: + db.session.commit() + except SQLAlchemyError as e: + if commit: + db.session.rollback() + raise e diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index fd8b32094a52..9a965a43e598 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -21,6 +21,8 @@ from superset.exceptions import SupersetException from superset.utils import core as utils +get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}} + def validate_json(value): try: diff --git a/superset/dao/base.py b/superset/dao/base.py index 4f19efc4f8fb..f0133515cb21 100644 --- a/superset/dao/base.py +++ b/superset/dao/base.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, Optional +from typing import Dict, List, Optional from flask_appbuilder.models.filters import BaseFilter from flask_appbuilder.models.sqla import Model @@ -48,7 +48,7 @@ class BaseDAO: @classmethod def find_by_id(cls, model_id: int) -> Model: """ - Retrives a model by id, if defined applies `base_filter` + Find a model by id, if defined applies `base_filter` """ query = db.session.query(cls.model_cls) if cls.base_filter: @@ -58,6 +58,22 @@ def find_by_id(cls, model_id: int) -> Model: ).apply(query, None) return query.filter_by(id=model_id).one_or_none() + @classmethod + def find_by_ids(cls, model_ids: List[int]) -> List[Model]: + """ + Find a List of models by a list of ids, if defined applies `base_filter` + """ + id_col = getattr(cls.model_cls, "id", None) + if id_col is None: + return [] + query = db.session.query(cls.model_cls).filter(id_col.in_(model_ids)) + if cls.base_filter: + data_model = SQLAInterface(cls.model_cls, db.session) + query = cls.base_filter( # pylint: disable=not-callable + "id", data_model + ).apply(query, None) + return query.all() + @classmethod def create(cls, properties: Dict, commit=True) -> Optional[Model]: """ diff --git a/superset/dashboards/dao.py b/superset/dashboards/dao.py index e635750bec69..86684bbbd318 100644 --- a/superset/dashboards/dao.py +++ b/superset/dashboards/dao.py @@ -17,7 +17,6 @@ import logging from typing import List -from flask_appbuilder.models.sqla.interface import SQLAInterface from sqlalchemy.exc import SQLAlchemyError from superset.dao.base import BaseDAO @@ -32,13 +31,6 @@ class DashboardDAO(BaseDAO): model_cls = Dashboard base_filter = DashboardFilter - @staticmethod - def find_by_ids(model_ids: List[int]) -> List[Dashboard]: - query = db.session.query(Dashboard).filter(Dashboard.id.in_(model_ids)) - data_model = SQLAInterface(Dashboard, db.session) - query = DashboardFilter("id", data_model).apply(query, None) - return query.all() - @staticmethod def validate_slug_uniqueness(slug: str) -> bool: if not slug: diff --git a/tests/chart_api_tests.py b/tests/chart_api_tests.py index eef8419d76c0..0afda14f24db 100644 --- a/tests/chart_api_tests.py +++ b/tests/chart_api_tests.py @@ -19,6 +19,7 @@ from typing import List, Optional import prison +from sqlalchemy.sql import func from superset import db, security_manager from superset.connectors.connector_registry import ConnectorRegistry @@ -81,6 +82,40 @@ def test_delete_chart(self): model = db.session.query(Slice).get(chart_id) self.assertEqual(model, None) + def test_delete_bulk_charts(self): + """ + Chart API: Test delete bulk + """ + admin_id = self.get_user("admin").id + chart_count = 4 + chart_ids = list() + for chart_name_index in range(chart_count): + chart_ids.append( + self.insert_chart(f"title{chart_name_index}", [admin_id], 1).id + ) + self.login(username="admin") + argument = chart_ids + uri = f"api/v1/chart/?q={prison.dumps(argument)}" + rv = self.client.delete(uri) + self.assertEqual(rv.status_code, 200) + response = json.loads(rv.data.decode("utf-8")) + expected_response = {"message": f"Deleted {chart_count} charts"} + self.assertEqual(response, expected_response) + for chart_id in chart_ids: + model = db.session.query(Slice).get(chart_id) + self.assertEqual(model, None) + + def test_delete_bulk_chart_bad_request(self): + """ + Chart API: Test delete bulk bad request + """ + chart_ids = [1, "a"] + self.login(username="admin") + argument = chart_ids + uri = f"api/v1/chart/?q={prison.dumps(argument)}" + rv = self.client.delete(uri) + self.assertEqual(rv.status_code, 400) + def test_delete_not_found_chart(self): """ Chart API: Test not found delete @@ -91,6 +126,18 @@ def test_delete_not_found_chart(self): rv = self.client.delete(uri) self.assertEqual(rv.status_code, 404) + def test_delete_bulk_charts_not_found(self): + """ + Chart API: Test delete bulk not found + """ + max_id = db.session.query(func.max(Slice.id)).scalar() + chart_ids = [max_id + 1, max_id + 2] + self.login(username="admin") + argument = chart_ids + uri = f"api/v1/chart/?q={prison.dumps(argument)}" + rv = self.client.delete(uri) + self.assertEqual(rv.status_code, 404) + def test_delete_chart_admin_not_owned(self): """ Chart API: Test admin delete not owned @@ -105,6 +152,31 @@ def test_delete_chart_admin_not_owned(self): model = db.session.query(Slice).get(chart_id) self.assertEqual(model, None) + def test_delete_bulk_chart_admin_not_owned(self): + """ + Chart API: Test admin delete bulk not owned + """ + gamma_id = self.get_user("gamma").id + chart_count = 4 + chart_ids = list() + for chart_name_index in range(chart_count): + chart_ids.append( + self.insert_chart(f"title{chart_name_index}", [gamma_id], 1).id + ) + + self.login(username="admin") + argument = chart_ids + uri = f"api/v1/chart/?q={prison.dumps(argument)}" + rv = self.client.delete(uri) + response = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 200) + expected_response = {"message": f"Deleted {chart_count} charts"} + self.assertEqual(response, expected_response) + + for chart_id in chart_ids: + model = db.session.query(Slice).get(chart_id) + self.assertEqual(model, None) + def test_delete_chart_not_owned(self): """ Chart API: Test delete try not owned @@ -125,6 +197,53 @@ def test_delete_chart_not_owned(self): db.session.delete(user_alpha2) db.session.commit() + def test_delete_bulk_chart_not_owned(self): + """ + Chart API: Test delete bulk try not owned + """ + user_alpha1 = self.create_user( + "alpha1", "password", "Alpha", email="alpha1@superset.org" + ) + user_alpha2 = self.create_user( + "alpha2", "password", "Alpha", email="alpha2@superset.org" + ) + + chart_count = 4 + charts = list() + for chart_name_index in range(chart_count): + charts.append( + self.insert_chart(f"title{chart_name_index}", [user_alpha1.id], 1) + ) + + owned_chart = self.insert_chart("title_owned", [user_alpha2.id], 1) + + self.login(username="alpha2", password="password") + + # verify we can't delete not owned charts + arguments = [chart.id for chart in charts] + uri = f"api/v1/chart/?q={prison.dumps(arguments)}" + rv = self.client.delete(uri) + self.assertEqual(rv.status_code, 403) + response = json.loads(rv.data.decode("utf-8")) + expected_response = {"message": "Forbidden"} + self.assertEqual(response, expected_response) + + # # nothing is deleted in bulk with a list of owned and not owned charts + arguments = [chart.id for chart in charts] + [owned_chart.id] + uri = f"api/v1/chart/?q={prison.dumps(arguments)}" + rv = self.client.delete(uri) + self.assertEqual(rv.status_code, 403) + response = json.loads(rv.data.decode("utf-8")) + expected_response = {"message": "Forbidden"} + self.assertEqual(response, expected_response) + + for chart in charts: + db.session.delete(chart) + db.session.delete(owned_chart) + db.session.delete(user_alpha1) + db.session.delete(user_alpha2) + db.session.commit() + def test_create_chart(self): """ Chart API: Test create chart diff --git a/tests/dashboards/api_tests.py b/tests/dashboards/api_tests.py index 1c8a050f2a77..e00adcb44c22 100644 --- a/tests/dashboards/api_tests.py +++ b/tests/dashboards/api_tests.py @@ -379,7 +379,7 @@ def test_delete_bulk_dashboard_not_owned(self): expected_response = {"message": "Forbidden"} self.assertEqual(response, expected_response) - # nothing is delete in bulk with a list of owned and not owned dashboards + # nothing is deleted in bulk with a list of owned and not owned dashboards arguments = [dashboard.id for dashboard in dashboards] + [owned_dashboard.id] uri = f"api/v1/dashboard/?q={prison.dumps(arguments)}" rv = self.client.delete(uri)