diff --git a/superset/datasets/api.py b/superset/datasets/api.py index 0a2f1fe7e20e..cab625b16995 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -16,8 +16,9 @@ # under the License. import logging +import yaml from flask import g, request, Response -from flask_appbuilder.api import expose, protect, safe +from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from superset.connectors.sqla.models import SqlaTable @@ -35,8 +36,12 @@ ) from superset.datasets.commands.refresh import RefreshDatasetCommand from superset.datasets.commands.update import UpdateDatasetCommand -from superset.datasets.schemas import DatasetPostSchema, DatasetPutSchema -from superset.views.base import DatasourceFilter +from superset.datasets.schemas import ( + DatasetPostSchema, + DatasetPutSchema, + get_export_ids_schema, +) +from superset.views.base import DatasourceFilter, generate_download_headers from superset.views.base_api import BaseSupersetModelRestApi from superset.views.database.filters import DatabaseFilter @@ -51,10 +56,11 @@ class DatasetRestApi(BaseSupersetModelRestApi): allow_browser_login = True class_permission_name = "TableModelView" - include_route_methods = ( - RouteMethod.REST_MODEL_VIEW_CRUD_SET | {RouteMethod.RELATED} | {"refresh"} - ) - + include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | { + RouteMethod.EXPORT, + RouteMethod.RELATED, + "refresh", + } list_columns = [ "database_name", "changed_by_name", @@ -278,6 +284,58 @@ def delete(self, pk: int) -> Response: # pylint: disable=arguments-differ logger.error(f"Error deleting model {self.__class__.__name__}: {e}") return self.response_422(message=str(e)) + @expose("/export/", methods=["GET"]) + @protect() + @safe + @rison(get_export_ids_schema) + def export(self, **kwargs): + """Export dashboards + --- + get: + description: >- + Exports multiple datasets and downloads them as YAML files + parameters: + - in: query + name: q + content: + application/json: + schema: + type: array + items: + type: integer + responses: + 200: + description: Dataset export + content: + text/plain: + schema: + type: string + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 404: + $ref: '#/components/responses/404' + 500: + $ref: '#/components/responses/500' + """ + requested_ids = kwargs["rison"] + query = self.datamodel.session.query(SqlaTable).filter( + SqlaTable.id.in_(requested_ids) + ) + query = self._base_filters.apply_all(query) + items = query.all() + ids = [item.id for item in items] + if len(ids) != len(requested_ids): + return self.response_404() + + data = [t.export_to_dict() for t in items] + return Response( + yaml.safe_dump(data), + headers=generate_download_headers("yaml"), + mimetype="application/text", + ) + @expose("//refresh", methods=["PUT"]) @protect() @safe diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py index b5b74e56d63f..0c1e67586f7a 100644 --- a/superset/datasets/schemas.py +++ b/superset/datasets/schemas.py @@ -20,6 +20,8 @@ from marshmallow import fields, Schema, ValidationError from marshmallow.validate import Length +get_export_ids_schema = {"type": "array", "items": {"type": "integer"}} + def validate_python_date_format(value): regex = re.compile( diff --git a/tests/dataset_api_tests.py b/tests/dataset_api_tests.py index 48b3fbde36e2..77e056680c8a 100644 --- a/tests/dataset_api_tests.py +++ b/tests/dataset_api_tests.py @@ -20,6 +20,7 @@ from unittest.mock import patch import prison +import yaml from sqlalchemy.sql import func from superset import db, security_manager @@ -31,6 +32,8 @@ ) from superset.models.core import Database from superset.utils.core import get_example_database +from superset.utils.dict_import_export import export_to_dict +from superset.views.base import generate_download_headers from tests.base_tests import SupersetTestCase @@ -680,3 +683,63 @@ def test_dataset_item_refresh_not_owned(self): db.session.delete(dataset) db.session.commit() + + def test_export_dataset(self): + """ + Dataset API: Test export dataset + :return: + """ + birth_names_dataset = self.get_birth_names_dataset() + + argument = [birth_names_dataset.id] + uri = f"api/v1/dataset/export/?q={prison.dumps(argument)}" + + self.login(username="admin") + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 200) + self.assertEqual( + rv.headers["Content-Disposition"], + generate_download_headers("yaml")["Content-Disposition"], + ) + + cli_export = export_to_dict( + session=db.session, + recursive=True, + back_references=False, + include_defaults=False, + ) + cli_export_tables = cli_export["databases"][0]["tables"] + expected_response = [] + for export_table in cli_export_tables: + if export_table["table_name"] == "birth_names": + expected_response = export_table + break + ui_export = yaml.safe_load(rv.data.decode("utf-8")) + self.assertEqual(ui_export[0], expected_response) + + def test_export_dataset_not_found(self): + """ + Dataset API: Test export dataset not found + :return: + """ + max_id = db.session.query(func.max(SqlaTable.id)).scalar() + # Just one does not exist and we get 404 + argument = [max_id + 1, 1] + uri = f"api/v1/dataset/export/?q={prison.dumps(argument)}" + self.login(username="admin") + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 404) + + def test_export_dataset_gamma(self): + """ + Dataset API: Test export dataset has gamma + :return: + """ + birth_names_dataset = self.get_birth_names_dataset() + + argument = [birth_names_dataset.id] + uri = f"api/v1/dataset/export/?q={prison.dumps(argument)}" + + self.login(username="gamma") + rv = self.client.get(uri) + self.assertEqual(rv.status_code, 401)