diff --git a/requirements.txt b/requirements.txt index 9192868524e9..b2a406cd661c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -76,6 +76,7 @@ six==1.11.0 # via bleach, cryptography, flask-jwt-extended, flask- sqlalchemy-utils==0.33.11 sqlalchemy==1.3.5 sqlparse==0.2.4 +tabulate==0.8.3 urllib3==1.24.3 # via requests, selenium vine==1.1.4 # via amqp webencodings==0.5.1 # via bleach diff --git a/setup.py b/setup.py index 420375b249ee..527e03633376 100644 --- a/setup.py +++ b/setup.py @@ -105,6 +105,7 @@ def get_git_sha(): 'sqlalchemy>=1.3.5,<2.0', 'sqlalchemy-utils>=0.33.2', 'sqlparse', + 'tabulate>=0.8.3', 'wtforms-json', ], extras_require={ diff --git a/superset/cli.py b/superset/cli.py index 6691b0148f8f..b05471abaf2d 100755 --- a/superset/cli.py +++ b/superset/cli.py @@ -17,20 +17,32 @@ # under the License. # pylint: disable=C,R,W from datetime import datetime +import json import logging from subprocess import Popen -from sys import stdout +from sys import exit, stdout +import tarfile +import tempfile import click from colorama import Fore, Style from pathlib2 import Path +import requests import yaml from superset import ( app, appbuilder, data, db, security_manager, ) +from superset.data.helpers import ( + download_url_to_blob_url, get_examples_file_list, get_examples_uris, + list_examples_table, +) +from superset.exceptions import DashboardNotFoundException, ExampleNotFoundException from superset.utils import ( - core as utils, dashboard_import_export, dict_import_export) + core as utils, dashboard_import_export, dict_import_export, +) + +logging.getLogger('urllib3').setLevel(logging.WARNING) config = app.config celery_app = utils.get_celery_app(config) @@ -48,7 +60,7 @@ def make_shell_context(): @app.cli.command() def init(): """Inits the Superset application""" - utils.get_or_create_main_db() + utils.get_or_create_db_by_name(db_name='main') appbuilder.add_permissions(update_perms=True) security_manager.sync_role_definitions() @@ -124,10 +136,254 @@ def load_examples_run(load_test_data): @app.cli.command() @click.option('--load-test-data', '-t', is_flag=True, help='Load additional test data') def load_examples(load_test_data): - """Loads a set of Slices and Dashboards and a supporting dataset """ + """Loads a set of charts and dashboards and a supporting dataset""" load_examples_run(load_test_data) +def exclusive(ctx_params, exclusive_params, error_message): + """Provide exclusive option grouping""" + if sum([1 if ctx_params[p] else 0 for p in exclusive_params]) > 1: + raise click.UsageError(error_message) + + +@app.cli.group() +def examples(): + """Manages example dashboards/datasets""" + pass + + +@examples.command('export') +@click.option( + '--dashboard-id', '-i', default=None, type=int, + help='Specify a single dashboard id to export') +@click.option( + '--dashboard-title', '-t', default=None, + help='Specify a single dashboard title to export') +@click.option( + '--description', '-d', help='Description of new example', required=True) +@click.option( + '--example-title', '-e', help='Title for new example', required=False) +@click.option( + '--file-name', '-f', default='dashboard.tar.gz', + help='Specify export file name. Defaults to dashboard.tar.gz') +@click.option( + '--license', '-l', '_license', default='Apache 2.0', + help='License of the example dashboard') +@click.option( + '--url', '-u', default=None, help='URL of dataset home page') +def export_example(dashboard_id, dashboard_title, description, example_title, + file_name, _license, url): + """Export example dashboard/datasets tarball""" + if not (dashboard_id or dashboard_title): + raise click.UsageError('must supply --dashboard-id/-i or --dashboard-title/-t') + exclusive( + click.get_current_context().params, + ['dashboard_id', 'dashboard_title'], + 'options --dashboard-id/-i and --dashboard-title/-t mutually exclusive') + + # Export into a temporary directory and then tarball that directory + with tempfile.TemporaryDirectory() as tmp_dir_name: + + try: + data = dashboard_import_export.export_dashboards( + db.session, + dashboard_ids=[dashboard_id], + dashboard_titles=[dashboard_title], + export_data=True, + export_data_dir=tmp_dir_name, + description=description, + export_title=example_title or dashboard_title, + _license=_license, + url=url, + strip_database=True) + + dashboard_slug = dashboard_import_export.get_slug( + db.session, + dashboard_id=dashboard_id, + dashboard_title=dashboard_title) + + out_path = f'{tmp_dir_name}/dashboard.json' + + with open(out_path, 'w') as data_stream: + data_stream.write(data) + + with tarfile.open(file_name, 'w:gz') as tar: + tar.add(tmp_dir_name, arcname=f'{dashboard_slug}') + + click.secho(str(f'Exported example to {file_name}'), fg='blue') + + except DashboardNotFoundException as e: + click.secho(str(e), fg='red') + exit(1) + + +@examples.command('list') +@click.option( + '--examples-repo', '-r', + help='Full name of Github repository containing examples, ex: ' + + "'apache-superset/examples-data'", + default=None) +@click.option( + '--examples-tag', '-r', + help="Tag or branch of Github repository containing examples. Defaults to 'master'", + default='master') +@click.option( + '--full-fields', '-ff', is_flag=True, default=False, help='Print full length fields') +def _list_examples(examples_repo, examples_tag, full_fields): + """List example dashboards/datasets""" + click.echo( + list_examples_table(examples_repo, examples_tag=examples_tag, + full_fields=full_fields)) + + +@examples.command('import') +@click.option( + '--database-uri', '-d', help='Database URI to import example to', + default=config.get('SQLALCHEMY_EXAMPLES_URI')) +@click.option( + '--examples-repo', '-r', + help='Full name of Github repository containing examples, ex: ' + + "'apache-superset/examples-data'", + default=None) +@click.option( + '--examples-tag', '-r', + help="Tag or branch of Github repository containing examples. Defaults to 'master'", + default='master') +@click.option( + '--example-title', '-e', help='Title of example to import', required=True) +def import_example(example_title, examples_repo, examples_tag, database_uri): + """Import an example dashboard/dataset""" + + # First fetch the example information from Github + examples_repos = [(examples_repo, examples_tag)] \ + if examples_repo else config.get('EXAMPLE_REPOS_TAGS') + examples_repos_uris = [(r[0], r[1]) + get_examples_uris(r[0], r[1]) + for r in examples_repos] + examples_files = get_examples_file_list(examples_repos_uris) + + # Github authentication via a Personal Access Token for rate limit problems + headers = None + token = config.get('GITHUB_AUTH_TOKEN') + if token: + headers = {'Authorization': 'token %s' % token} + + import_example_json = None + import_data_info = None + for example_file in examples_files: + + metadata_download_url = example_file['metadata_file']['download_url'] + example_metadata_json = requests.get(metadata_download_url, + headers=headers).content + # Cheaply load json without generating objects + example_metadata = json.loads(example_metadata_json) + if example_metadata['description']['title'] == example_title: + import_example_json = example_metadata_json + import_data_info = example_file['data_files'] + logging.info( + f"Will import example '{example_title}' from {metadata_download_url}") + break + + if not import_example_json: + e = ExampleNotFoundException(f'Example {example_title} not found!') + click.secho(str(e), fg='red') + exit(1) + + # Parse data to get file download_urls -> blob_urls + example_metadata = json.loads(import_example_json, + object_hook=utils.decode_dashboards) + + # The given download url won't work for data files, need a blob url + data_blob_urls = {} + for ex_file in example_metadata['files']: + github_info = [t for t in import_data_info + if t['name'] == ex_file['file_name']][0] + blob_url = download_url_to_blob_url(github_info['download_url']) + data_blob_urls[github_info['name']] = blob_url + + try: + dashboard_import_export.import_dashboards( + db.session, + import_example_json, + is_example=True, + data_blob_urls=data_blob_urls, + database_uri=database_uri) + except Exception as e: + logging.error(f"Error importing example dashboard '{example_title}'!") + logging.exception(e) + + +@examples.command('remove') +@click.option( + '--example-title', '-e', help='Title of example to remove', required=True) +@click.option( + '--database-uri', '-d', help='Database URI to remove example from', + default=config.get('SQLALCHEMY_EXAMPLES_URI')) +@click.option( + '--examples-repo', '-r', + help='Full name of Github repository containing examples, ex: ' + + "'apache-superset/examples-data'", + default=None) +@click.option( + '--examples-tag', '-r', + help="Tag or branch of Github repository containing examples. Defaults to 'master'", + default='master') +def remove_example(example_title, database_uri, examples_repo, examples_tag): + """Remove an example dashboard/dataset""" + + # First fetch the example information from Github + examples_repos = [(examples_repo, examples_tag)] \ + if examples_repo else config.get('EXAMPLE_REPOS_TAGS') + examples_repos_uris = [(r[0], r[1]) + get_examples_uris(r[0], r[1]) + for r in examples_repos] + examples_files = get_examples_file_list(examples_repos_uris) + + # Github authentication via a Personal Access Token for rate limit problems + headers = None + token = config.get('GITHUB_AUTH_TOKEN') + if token: + headers = {'Authorization': 'token %s' % token} + + # temporary - substitute url provided + db_name = 'superset' + + import_example_data = None + for example_file in examples_files: + + metadata_download_url = example_file['metadata_file']['download_url'] + example_metadata_json = requests.get(metadata_download_url, + headers=headers).content + # Cheaply load json without generating objects + example_metadata = json.loads(example_metadata_json) + if example_metadata['description']['title'] == example_title: + import_example_data = json.loads(example_metadata_json) + logging.info( + f"Will remove example '{example_title}' from '{db_name}'") + break + + logging.info(import_example_data['files']) + + # Get the dashboard and associated records + dashboard_title = \ + import_example_data['dashboards'][0]['__Dashboard__']['dashboard_title'] + logging.info(f'Got dashboard title {dashboard_title} for removal...') + + utils.get_or_create_db_by_name(db_name='main') + session = db.session() + + try: + dashboard_import_export.remove_dashboard( + session, + import_example_data, + dashboard_title, + database_uri=database_uri, + ) + except DashboardNotFoundException as e: + logging.exception(e) + click.secho( + f'Example {example_title} associated dashboard {dashboard_title} not found!', + fg='red') + + @app.cli.command() @click.option('--datasource', '-d', help='Specify which datasource name to load, if ' 'omitted, all datasources will be refreshed') @@ -156,7 +412,7 @@ def refresh_druid(datasource, merge): @app.cli.command() @click.option( - '--path', '-p', + '--path', '-p', required=True, help='Path to a single JSON file or path containing multiple JSON files' 'files to import (*.json)') @click.option( @@ -177,7 +433,7 @@ def import_dashboards(path, recursive): try: with f.open() as data_stream: dashboard_import_export.import_dashboards( - db.session, data_stream) + db.session, data_stream.read()) except Exception as e: logging.error('Error when importing dashboard from file %s', f) logging.error(e) @@ -190,9 +446,31 @@ def import_dashboards(path, recursive): @click.option( '--print_stdout', '-p', is_flag=True, default=False, help='Print JSON to stdout') -def export_dashboards(print_stdout, dashboard_file): - """Export dashboards to JSON""" - data = dashboard_import_export.export_dashboards(db.session) +@click.option( + '--dashboard-ids', '-i', default=None, type=int, multiple=True, + help='Specify dashboard id to export') +@click.option( + '--dashboard-titles', '-t', default=None, multiple=True, + help='Specify dashboard title to export') +@click.option( + '--export-data', '-x', default=None, is_flag=True, + help="Export the dashboard's data tables as CSV files.") +@click.option( + '--export-data-dir', '-d', default=config.get('DASHBOARD_EXPORT_DIR'), + help="Specify export directory path. Defaults to '/tmp'") +def export_dashboards(print_stdout, dashboard_file, dashboard_ids, + dashboard_titles, export_data, export_data_dir): + """Export dashboards to JSON and optionally tables to CSV""" + try: + data = dashboard_import_export.export_dashboards( + db.session, + dashboard_ids=dashboard_ids, + dashboard_titles=dashboard_titles, + export_data=export_data, + export_data_dir=export_data_dir) + except DashboardNotFoundException as e: + click.secho(str(e), fg='red') + exit(1) if print_stdout or not dashboard_file: print(data) if dashboard_file: @@ -369,7 +647,7 @@ def load_test_users_run(): gamma_sqllab_role = security_manager.add_role('gamma_sqllab') for perm in security_manager.find_role('Gamma').permissions: security_manager.add_permission_role(gamma_sqllab_role, perm) - utils.get_or_create_main_db() + utils.get_or_create_db_by_name(db_name='main') db_perm = utils.get_main_database(security_manager.get_session).perm security_manager.add_permission_view_menu('database_access', db_perm) db_pvm = security_manager.find_permission_view_menu( diff --git a/superset/config.py b/superset/config.py index 68ce8b5bf96f..a56921727885 100644 --- a/superset/config.py +++ b/superset/config.py @@ -80,7 +80,12 @@ # SQLALCHEMY_DATABASE_URI = 'mysql://myapp@localhost/myapp' # SQLALCHEMY_DATABASE_URI = 'postgresql://root:password@localhost/myapp' -# In order to hook up a custom password store for all SQLACHEMY connections +# The SQLAlchemy connection string for incoming examples +SQLALCHEMY_EXAMPLES_URI = SQLALCHEMY_DATABASE_URI +# SQLALCHEMY_EXAMPLES_URI = 'mysql://myapp@localhost/examples' +# SQLALCHEMY_EXAMPLES_URI = 'postgresql://root:password@localhost/examples' + +# In order to hook up a custom password store for all SQLALCHEMY connections # implement a function that takes a single argument of type 'sqla.engine.url', # returns a password and set SQLALCHEMY_CUSTOM_PASSWORD_STORE. # @@ -601,6 +606,7 @@ class CeleryConfig(object): # Send user to a link where they can report bugs BUG_REPORT_URL = None + # Send user to a link where they can read more about Superset DOCUMENTATION_URL = None @@ -628,6 +634,17 @@ class CeleryConfig(object): 'force_https_permanent': False, } +# Default dashboard export directory +DASHBOARD_EXPORT_DIR = '/tmp' + +# Tuple format: Gitub repo full name, tag/branch +EXAMPLE_REPOS_TAGS = [ + ('apache-superset/examples-data', 'v0.0.6'), +] + +# Github Authorization Token - in case the examples commands exceed rate limits +GITHUB_AUTH_TOKEN = None + try: if CONFIG_PATH_ENV_VAR in os.environ: # Explicitly import config module that is not in pythonpath; useful diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index 2e8bc25adff0..a25212143fa7 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -24,11 +24,11 @@ from sqlalchemy.orm import foreign, relationship from superset.models.core import Slice -from superset.models.helpers import AuditMixinNullable, ImportMixin +from superset.models.helpers import AuditMixinNullable, ImportExportMixin from superset.utils import core as utils -class BaseDatasource(AuditMixinNullable, ImportMixin): +class BaseDatasource(AuditMixinNullable, ImportExportMixin): """A common interface to objects that are queryable (tables and datasources)""" @@ -341,7 +341,7 @@ def update_from_object(self, obj): obj.get('columns'), self.columns, self.column_class, 'column_name') -class BaseColumn(AuditMixinNullable, ImportMixin): +class BaseColumn(AuditMixinNullable, ImportExportMixin): """Interface for column""" __tablename__ = None # {connector_name}_column @@ -404,7 +404,7 @@ def data(self): return {s: getattr(self, s) for s in attrs if hasattr(self, s)} -class BaseMetric(AuditMixinNullable, ImportMixin): +class BaseMetric(AuditMixinNullable, ImportExportMixin): """Interface for Metrics""" diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index c71bc8061962..def02689da49 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -51,7 +51,7 @@ from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.exceptions import MetricPermException, SupersetException from superset.models.helpers import ( - AuditMixinNullable, ImportMixin, QueryResult, + AuditMixinNullable, ImportExportMixin, QueryResult, ) from superset.utils import core as utils, import_datasource from superset.utils.core import ( @@ -87,7 +87,7 @@ def __init__(self, name, post_aggregator): self.post_aggregator = post_aggregator -class DruidCluster(Model, AuditMixinNullable, ImportMixin): +class DruidCluster(Model, AuditMixinNullable, ImportExportMixin): """ORM object referencing the Druid clusters""" diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index a2e392b79fc3..511b2a9b9263 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -929,13 +929,14 @@ def fetch_metadata(self): db.session.commit() @classmethod - def import_obj(cls, i_datasource, import_time=None): + def import_obj(cls, i_datasource, import_time=None, substitute_db_name=None): """Imports the datasource from the object to the database. Metrics and columns and datasource will be overrided if exists. This function can be used to import/export dashboards between multiple superset instances. Audit metadata isn't copies over. """ + def lookup_sqlatable(table): return db.session.query(SqlaTable).join(Database).filter( SqlaTable.table_name == table.table_name, @@ -944,11 +945,13 @@ def lookup_sqlatable(table): ).first() def lookup_database(table): + db_name = table.params_dict['database_name'] return db.session.query(Database).filter_by( - database_name=table.params_dict['database_name']).one() + database_name=db_name).one() + return import_datasource.import_datasource( - db.session, i_datasource, lookup_database, lookup_sqlatable, - import_time) + db.session, i_datasource, lookup_database, + lookup_sqlatable, import_time) @classmethod def query_datasources_by_name( diff --git a/superset/data/bart_lines.py b/superset/data/bart_lines.py index f4e0b1f09cfc..b9ce8173d715 100644 --- a/superset/data/bart_lines.py +++ b/superset/data/bart_lines.py @@ -21,7 +21,7 @@ from sqlalchemy import String, Text from superset import db -from superset.utils.core import get_or_create_main_db +from superset.utils.core import get_or_create_db_by_name from .helpers import TBL, get_example_data @@ -50,7 +50,7 @@ def load_bart_lines(): if not tbl: tbl = TBL(table_name=tbl_name) tbl.description = 'BART lines' - tbl.database = get_or_create_main_db() + tbl.database = get_or_create_db_by_name(db_name='main') db.session.merge(tbl) db.session.commit() tbl.fetch_metadata() diff --git a/superset/data/birth_names.py b/superset/data/birth_names.py index 85de0196b1cb..dbb5140fd69d 100644 --- a/superset/data/birth_names.py +++ b/superset/data/birth_names.py @@ -23,7 +23,7 @@ from superset import db, security_manager from superset.connectors.sqla.models import SqlMetric, TableColumn -from superset.utils.core import get_or_create_main_db +from superset.utils.core import get_or_create_db_by_name from .helpers import ( config, Dash, @@ -61,7 +61,7 @@ def load_birth_names(): if not obj: obj = TBL(table_name='birth_names') obj.main_dttm_col = 'ds' - obj.database = get_or_create_main_db() + obj.database = get_or_create_db_by_name(db_name='main') obj.filter_select_enabled = True if not any(col.column_name == 'num_california' for col in obj.columns): diff --git a/superset/data/country_map.py b/superset/data/country_map.py index 303b85d0f8e7..a796f202b78d 100644 --- a/superset/data/country_map.py +++ b/superset/data/country_map.py @@ -68,7 +68,7 @@ def load_country_map_data(): if not obj: obj = TBL(table_name='birth_france_by_region') obj.main_dttm_col = 'dttm' - obj.database = utils.get_or_create_main_db() + obj.database = utils.get_or_create_db_by_name(db_name='main') if not any(col.metric_name == 'avg__2004' for col in obj.metrics): col = str(column('2004').compile(db.engine)) obj.metrics.append(SqlMetric( diff --git a/superset/data/energy.py b/superset/data/energy.py index 293dca556b02..fceaebb2f934 100644 --- a/superset/data/energy.py +++ b/superset/data/energy.py @@ -52,7 +52,7 @@ def load_energy(): if not tbl: tbl = TBL(table_name=tbl_name) tbl.description = 'Energy consumption' - tbl.database = utils.get_or_create_main_db() + tbl.database = utils.get_or_create_db_by_name(db_name='main') if not any(col.metric_name == 'sum__value' for col in tbl.metrics): col = str(column('value').compile(db.engine)) diff --git a/superset/data/flights.py b/superset/data/flights.py index 25112444fc29..556b8336fc7b 100644 --- a/superset/data/flights.py +++ b/superset/data/flights.py @@ -54,7 +54,7 @@ def load_flights(): if not tbl: tbl = TBL(table_name=tbl_name) tbl.description = 'Random set of flights in the US' - tbl.database = utils.get_or_create_main_db() + tbl.database = utils.get_or_create_db_by_name(db_name='main') db.session.merge(tbl) db.session.commit() tbl.fetch_metadata() diff --git a/superset/data/helpers.py b/superset/data/helpers.py index f876dc9105ad..1092827bf826 100644 --- a/superset/data/helpers.py +++ b/superset/data/helpers.py @@ -14,21 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + """Loads datasets, dashboards and slices in a new superset instance""" # pylint: disable=C,R,W +from datetime import datetime from io import BytesIO import json import os +import re import zlib import requests +from tabulate import tabulate from superset import app, db from superset.connectors.connector_registry import ConnectorRegistry +from superset.exceptions import BadGithubUrlConvertException from superset.models import core as models -BASE_URL = 'https://github.com/apache-superset/examples-data/blob/master/' - # Shortcuts DB = models.Database Slice = models.Slice @@ -68,10 +71,130 @@ def get_slice_json(defaults, **kwargs): return json.dumps(d, indent=4, sort_keys=True) +def get_examples_uris(repo_name, tag): + """Given a full Github repo name return the base urls to the contents and blog APIs""" + contents_uri = f'https://api.github.com/repos/{repo_name}/contents/?ref={tag}' + blob_uri = f'https://github.com/{repo_name}/blob/{tag}/' + return contents_uri, blob_uri + + +def download_url_to_blob_url(download_url): + """Get a download link for a large file in the examples-data repo + + Example input: + 'https://raw.githubusercontent.com/rjurney/examples-data/v0.0.4/world_health/wb_health_population.csv.gz' + Example output: + 'https://github.com/rjurney/examples-data/raw/v0.0.4/world_health/wb_health_population.csv.gz' + """ + + hits = re.search( + 'https://raw.githubusercontent.com/(.+?/.+?)/(.+?)/(.*)', download_url) + if len(hits.groups()) < 3: + raise BadGithubUrlConvertException(f'Bad input url: {download_url}') + + blob_url = f'https://github.com/{hits.group(1)}/raw/{hits.group(2)}/{hits.group(3)}' + return blob_url + + def get_example_data(filepath, is_gzip=True, make_bytes=False): - content = requests.get(f'{BASE_URL}{filepath}?raw=true').content + """Get the examples data for the legacy examples""" + examples_repos_uris = \ + [get_examples_uris(r[0], r[1]) for r in config.get('EXAMPLE_REPOS_TAGS')] + contents_uri, blob_uri = examples_repos_uris[0] + content = requests.get(f'{blob_uri}/{filepath}?raw=true').content if is_gzip: - content = zlib.decompress(content, zlib.MAX_WBITS|16) + content = zlib.decompress(content, zlib.MAX_WBITS | 16) if make_bytes: content = BytesIO(content) return content + + +def get_examples_file_list(examples_repos_uris, examples_tag='master'): + """Use the Github get contents API to list available examples""" + examples = [] + for (repo_name, repo_tag, contents_uri, blob_uri) in examples_repos_uris: + + # Github authentication via a Personal Access Token for rate limit problems + headers = None + token = config.get('GITHUB_AUTH_TOKEN') + if token: + headers = {'Authorization': 'token %s' % token} + + content = json.loads(requests.get(contents_uri, headers=headers).content) + dirs = [x for x in content if x['type'] == 'dir'] # examples are in sub-dirs + + for _dir in dirs: + link = _dir['_links']['self'] + sub_content = json.loads(requests.get(link, headers=headers).content) + dashboard_info = list(filter( + lambda x: x['name'] == 'dashboard.json', sub_content))[0] + data_files = list(filter( + lambda x: x['name'].endswith('.csv.gz'), sub_content)) + examples.append({ + 'repo_name': repo_name, + 'repo_tag': repo_tag, + 'metadata_file': dashboard_info, + 'data_files': data_files, + }) + + return examples + + +def list_examples_table(examples_repo, examples_tag='master', full_fields=True): + """Turn a list of available examples into a PrettyTable""" + + # Write a pretty table to stdout + headers = ['Title', 'Description', 'Size (MB)', 'Rows', + 'Files', 'URL', 'Created Date', 'Repository', 'Tag'] + + # Optionally replace the default examples repo with a specified one + examples_repos_uris = [(r[0], r[1]) + get_examples_uris(r[0], r[1]) + for r in config.get('EXAMPLE_REPOS_TAGS')] + + # Replace the configured repos with the examples repo specified + if examples_repo: + examples_repos_uris = [ + (examples_repo, + examples_tag) + + get_examples_uris(examples_repo, examples_tag), + ] + + file_info_list = get_examples_file_list(examples_repos_uris) + + def shorten(val, length): + result = val + if len(val) > length: + result = val[0:length] + '...' + return result + + def date_format(iso_date): + dt = datetime.strptime(iso_date, '%Y-%m-%dT%H:%M:%S.%f') + return dt.isoformat() + + rows = [] + for file_info in file_info_list: + + d = json.loads( + requests.get( + file_info['metadata_file']['download_url']).content)['description'] + + if not full_fields: + file_info['repo_name'] = shorten(file_info['repo_name'], 30) + file_info['repo_tag'] = shorten(file_info['repo_tag'], 20) + d['description'] = shorten(d['description'], 50) + d['url'] = shorten(d['url'], 30) + + row = [ + d['title'], + d['description'], + d['total_size_mb'], + d['total_rows'], + d['file_count'], + d['url'], + date_format(d['created_at']), + file_info['repo_name'], + file_info['repo_tag'], + ] + rows.append(row) + + return '\n' + tabulate(rows, headers=headers) + '\n' diff --git a/superset/data/long_lat.py b/superset/data/long_lat.py index 18f477cfa404..e1ae194d3758 100644 --- a/superset/data/long_lat.py +++ b/superset/data/long_lat.py @@ -79,7 +79,7 @@ def load_long_lat_data(): if not obj: obj = TBL(table_name='long_lat') obj.main_dttm_col = 'datetime' - obj.database = utils.get_or_create_main_db() + obj.database = utils.get_or_create_db_by_name(db_name='main') db.session.merge(obj) db.session.commit() obj.fetch_metadata() diff --git a/superset/data/multiformat_time_series.py b/superset/data/multiformat_time_series.py index 58ff7fbb0d32..98b78760eed2 100644 --- a/superset/data/multiformat_time_series.py +++ b/superset/data/multiformat_time_series.py @@ -61,7 +61,7 @@ def load_multiformat_time_series(): if not obj: obj = TBL(table_name='multiformat_time_series') obj.main_dttm_col = 'ds' - obj.database = utils.get_or_create_main_db() + obj.database = utils.get_or_create_db_by_name(db_name='main') dttm_and_expr_dict = { 'ds': [None, None], 'ds2': [None, None], diff --git a/superset/data/paris.py b/superset/data/paris.py index 2ed3f8eaea07..f5fcbd7892e4 100644 --- a/superset/data/paris.py +++ b/superset/data/paris.py @@ -48,7 +48,7 @@ def load_paris_iris_geojson(): if not tbl: tbl = TBL(table_name=tbl_name) tbl.description = 'Map of Paris' - tbl.database = utils.get_or_create_main_db() + tbl.database = utils.get_or_create_db_by_name(db_name='main') db.session.merge(tbl) db.session.commit() tbl.fetch_metadata() diff --git a/superset/data/random_time_series.py b/superset/data/random_time_series.py index ee7450a63405..4c149a3bf622 100644 --- a/superset/data/random_time_series.py +++ b/superset/data/random_time_series.py @@ -52,7 +52,7 @@ def load_random_time_series_data(): if not obj: obj = TBL(table_name='random_time_series') obj.main_dttm_col = 'ds' - obj.database = utils.get_or_create_main_db() + obj.database = utils.get_or_create_db_by_name(db_name='main') db.session.merge(obj) db.session.commit() obj.fetch_metadata() diff --git a/superset/data/sf_population_polygons.py b/superset/data/sf_population_polygons.py index 2248a48dafec..7ed5c8816d88 100644 --- a/superset/data/sf_population_polygons.py +++ b/superset/data/sf_population_polygons.py @@ -48,7 +48,7 @@ def load_sf_population_polygons(): if not tbl: tbl = TBL(table_name=tbl_name) tbl.description = 'Population density of San Francisco' - tbl.database = utils.get_or_create_main_db() + tbl.database = utils.get_or_create_db_by_name(db_name='main') db.session.merge(tbl) db.session.commit() tbl.fetch_metadata() diff --git a/superset/data/unicode_test_data.py b/superset/data/unicode_test_data.py index 03c00a7b07fc..e978c4df920e 100644 --- a/superset/data/unicode_test_data.py +++ b/superset/data/unicode_test_data.py @@ -64,7 +64,7 @@ def load_unicode_test_data(): if not obj: obj = TBL(table_name='unicode_test') obj.main_dttm_col = 'dttm' - obj.database = utils.get_or_create_main_db() + obj.database = utils.get_or_create_db_by_name(db_name='main') db.session.merge(obj) db.session.commit() obj.fetch_metadata() diff --git a/superset/data/world_bank.py b/superset/data/world_bank.py index d5370ecbb87e..ec394c170ad6 100644 --- a/superset/data/world_bank.py +++ b/superset/data/world_bank.py @@ -67,7 +67,7 @@ def load_world_bank_health_n_pop(): tbl = TBL(table_name=tbl_name) tbl.description = utils.readfile(os.path.join(DATA_FOLDER, 'countries.md')) tbl.main_dttm_col = 'year' - tbl.database = utils.get_or_create_main_db() + tbl.database = utils.get_or_create_db_by_name(db_name='main') tbl.filter_select_enabled = True metrics = [ diff --git a/superset/exceptions.py b/superset/exceptions.py index ae491fc8d731..928d785c4b35 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -54,3 +54,15 @@ class SupersetTemplateException(SupersetException): class SpatialException(SupersetException): pass + + +class DashboardNotFoundException(SupersetException): + pass + + +class ExampleNotFoundException(SupersetException): + pass + + +class BadGithubUrlConvertException(SupersetException): + pass diff --git a/superset/migrations/README b/superset/migrations/README index 98e4f9c44eff..2500aa1bcf72 100755 --- a/superset/migrations/README +++ b/superset/migrations/README @@ -1 +1 @@ -Generic single-database configuration. \ No newline at end of file +Generic single-database configuration. diff --git a/superset/migrations/versions/e5200a951e62_add_dashboards_uuid.py b/superset/migrations/versions/e5200a951e62_add_dashboards_uuid.py new file mode 100644 index 000000000000..b26cabe81c77 --- /dev/null +++ b/superset/migrations/versions/e5200a951e62_add_dashboards_uuid.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Adds uuid columns to all classes with ImportExportMixin: dashboards, datasources, dbs, slices, tables, dashboard_email_schedules, slice_email_schedules + +Revision ID: e5200a951e62 +Revises: e9df189e5c7e +Create Date: 2019-05-08 13:42:48.479145 + +""" +import uuid + +from alembic import op +from sqlalchemy import CHAR, Column, Integer +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy_utils.types.uuid import UUIDType + +from superset import db + +# revision identifiers, used by Alembic. +revision = 'e5200a951e62' +down_revision = 'd7c1a0d6f2da' + +Base = declarative_base() + + +def get_uuid(): + return str(uuid.uuid4()) + + +class Dashboard(Base): + __tablename__ = 'dashboards' + id = Column(Integer, primary_key=True) + uuid = Column(UUIDType(binary=False), default=get_uuid) + + +class Datasource(Base): + __tablename__ = 'datasources' + id = Column(Integer, primary_key=True) + uuid = Column(UUIDType(binary=False), default=get_uuid) + + +class Database(Base): + __tablename__ = 'dbs' + id = Column(Integer, primary_key=True) + uuid = Column(UUIDType(binary=False), default=get_uuid) + + +class DruidCluster(Base): + __tablename__ = 'clusters' + id = Column(Integer, primary_key=True) + uuid = Column(UUIDType(binary=False), default=get_uuid) + + +class DruidMetric(Base): + __tablename__ = 'metrics' + id = Column(Integer, primary_key=True) + uuid = Column(UUIDType(binary=False), default=get_uuid) + + +class Slice(Base): + __tablename__ = 'slices' + id = Column(Integer, primary_key=True) + uuid = Column(UUIDType(binary=False), default=get_uuid) + + +class SqlaTable(Base): + __tablename__ = 'tables' + id = Column(Integer, primary_key=True) + uuid = Column(UUIDType(binary=False), default=get_uuid) + + +class SqlMetric(Base): + __tablename__ = 'sql_metrics' + id = Column(Integer, primary_key=True) + uuid = Column(UUIDType(binary=False), default=get_uuid) + + +class TableColumn(Base): + __tablename__ = 'table_columns' + id = Column(Integer, primary_key=True) + uuid = Column(UUIDType(binary=False), default=get_uuid) + + +def upgrade(): + bind = op.get_bind() + session = db.Session(bind=bind) + db_type = session.bind.dialect.name + + def add_uuid_column(col_name, _type): + """Add a uuid column to a given table""" + with op.batch_alter_table(col_name) as batch_op: + batch_op.add_column(Column('uuid', UUIDType(binary=False), default=get_uuid)) + for s in session.query(_type): + s.uuid = get_uuid() + session.merge(s) + + if db_type != 'postgresql': + with op.batch_alter_table(col_name) as batch_op: + batch_op.alter_column('uuid', existing_type=CHAR(32), + new_column_name='uuid', nullable=False) + batch_op.create_unique_constraint('uq_uuid', ['uuid']) + + session.commit() + + add_uuid_column('dashboards', Dashboard) + add_uuid_column('datasources', Datasource) + add_uuid_column('dbs', Database) + add_uuid_column('clusters', DruidCluster) + add_uuid_column('metrics', DruidMetric) + add_uuid_column('slices', Slice) + add_uuid_column('sql_metrics', SqlMetric) + add_uuid_column('tables', SqlaTable) + add_uuid_column('table_columns', TableColumn) + + session.close() + + +def downgrade(): + with op.batch_alter_table('dashboards') as batch_op: + batch_op.drop_column('uuid') + + with op.batch_alter_table('datasources') as batch_op: + batch_op.drop_column('uuid') + + with op.batch_alter_table('dbs') as batch_op: + batch_op.drop_column('uuid') + + with op.batch_alter_table('clusters') as batch_op: + batch_op.drop_column('uuid') + + with op.batch_alter_table('metrics') as batch_op: + batch_op.drop_column('uuid') + + with op.batch_alter_table('slices') as batch_op: + batch_op.drop_column('uuid') + + with op.batch_alter_table('sql_metrics') as batch_op: + batch_op.drop_column('uuid') + + with op.batch_alter_table('tables') as batch_op: + batch_op.drop_column('uuid') + + with op.batch_alter_table('table_columns') as batch_op: + batch_op.drop_column('uuid') diff --git a/superset/models/core.py b/superset/models/core.py index b379af7caba2..30f6e5646ccc 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -22,6 +22,7 @@ import functools import json import logging +import os import textwrap from typing import List @@ -42,13 +43,14 @@ from sqlalchemy.orm.session import make_transient from sqlalchemy.pool import NullPool from sqlalchemy.schema import UniqueConstraint +from sqlalchemy.sql import select, text from sqlalchemy_utils import EncryptedType import sqlparse from superset import app, db, db_engine_specs, security_manager from superset.connectors.connector_registry import ConnectorRegistry from superset.legacy import update_time_range -from superset.models.helpers import AuditMixinNullable, ImportMixin +from superset.models.helpers import AuditMixinNullable, ImportExportMixin from superset.models.tags import ChartUpdater, DashboardUpdater, FavStarUpdater from superset.models.user_attributes import UserAttribute from superset.utils import ( @@ -144,7 +146,7 @@ class CssTemplate(Model, AuditMixinNullable): Column('slice_id', Integer, ForeignKey('slices.id'))) -class Slice(Model, AuditMixinNullable, ImportMixin): +class Slice(Model, AuditMixinNullable, ImportExportMixin): """A slice is essentially a report or a view on data""" @@ -350,11 +352,11 @@ def import_obj(cls, slc_to_import, slc_to_override, import_time=None): session = db.session make_transient(slc_to_import) slc_to_import.dashboards = [] - slc_to_import.alter_params( - remote_id=slc_to_import.id, import_time=import_time) + slc_to_import.alter_params(remote_id=slc_to_import.id, import_time=import_time) - slc_to_import = slc_to_import.copy() + slc_to_import = slc_to_import.copy() # Resets id to None params = slc_to_import.params_dict + slc_to_import.datasource_id = ConnectorRegistry.get_datasource_by_name( session, slc_to_import.datasource_type, params['datasource_name'], params['schema'], params['database_name']).id @@ -394,7 +396,7 @@ def url(self): ) -class Dashboard(Model, AuditMixinNullable, ImportMixin): +class Dashboard(Model, AuditMixinNullable, ImportExportMixin): """The dashboard object!""" @@ -551,15 +553,14 @@ def alter_positions(dashboard, old_to_new_slc_id_dict): new_timed_refresh_immune_slices = [] new_expanded_slices = {} i_params_dict = dashboard_to_import.params_dict - remote_id_slice_map = { - slc.params_dict['remote_id']: slc + uuid_slice_map = { + slc.uuid: slc for slc in session.query(Slice).all() - if 'remote_id' in slc.params_dict } for slc in slices: logging.info('Importing slice {} from the dashboard: {}'.format( slc.to_json(), dashboard_to_import.dashboard_title)) - remote_slc = remote_id_slice_map.get(slc.id) + remote_slc = uuid_slice_map.get(slc.uuid) new_slc_id = Slice.import_obj(slc, remote_slc, import_time=import_time) old_to_new_slc_id_dict[slc.id] = new_slc_id # update json metadata that deals with slice ids @@ -580,14 +581,15 @@ def alter_positions(dashboard, old_to_new_slc_id_dict): # override the dashboard existing_dashboard = None for dash in session.query(Dashboard).all(): - if ('remote_id' in dash.params_dict and - dash.params_dict['remote_id'] == - dashboard_to_import.id): + if ('uuid' in dash.params_dict and + dash.params_dict['uuid'] == + dashboard_to_import.uuid): existing_dashboard = dash dashboard_to_import.id = None alter_positions(dashboard_to_import, old_to_new_slc_id_dict) - dashboard_to_import.alter_params(import_time=import_time) + dashboard_to_import.alter_params(import_time=import_time, + uuid=dashboard_to_import.uuid) if new_expanded_slices: dashboard_to_import.alter_params( expanded_slices=new_expanded_slices) @@ -617,7 +619,8 @@ def alter_positions(dashboard, old_to_new_slc_id_dict): return copied_dash.id @classmethod - def export_dashboards(cls, dashboard_ids): + def export_dashboards(cls, dashboard_ids, export_data=False, + export_data_dir=None): copied_dashboards = [] datasource_ids = set() for dashboard_id in dashboard_ids: @@ -652,13 +655,66 @@ def export_dashboards(cls, dashboard_ids): make_transient(eager_datasource) eager_datasources.append(eager_datasource) - return json.dumps({ - 'dashboards': copied_dashboards, - 'datasources': eager_datasources, - }, cls=utils.DashboardEncoder, indent=4) + files = [] + total_file_size = 0 + total_file_rows = 0 + if export_data and export_data_dir: + + for data_table in eager_datasources: + engine = data_table.database.get_sqla_engine() + columns = [c.get_sqla_col() for c in data_table.columns] + + qry = ( + select(columns) + .select_from(text(data_table.name)) + ) + qry.compile(engine) + sql = '{}'.format( + qry.compile(engine), + ) + + df = pd.read_sql_query(sql=sql, con=engine) + row_count = len(df.index) + 1 # plus one for header + + file_name = f'{data_table.name}.csv.gz' + file_path = f'{export_data_dir}/{file_name}' + + if not os.path.exists(export_data_dir): + os.makedirs(export_data_dir) + df.to_csv(file_path, compression='gzip') + + file_size = os.path.getsize(file_path) + + table_record = { + 'file_name': file_name, + 'rows': row_count, + 'size': file_size, + 'table_name': data_table.name, + } + + total_file_rows += row_count + total_file_size += file_size + + files.append(table_record) + + # Partially fill out the bibliography + desc = { + 'total_size': total_file_size, + 'total_size_mb': round(total_file_size / (1024.0 * 1024.0), 2), + 'total_rows': total_file_rows, + 'file_count': len(files), + 'created_at': datetime.now().isoformat(), + } + + return { + 'description': desc, + 'dashboards': [o.export_to_json_serializable() for o in copied_dashboards], + 'datasources': [o.export_to_json_serializable() for o in eager_datasources], + 'files': files, + } -class Database(Model, AuditMixinNullable, ImportMixin): +class Database(Model, AuditMixinNullable, ImportExportMixin): """An ORM object that stores Database related information""" diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 78b438d9f8f1..6088d22ab470 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -20,6 +20,7 @@ import json import logging import re +import uuid from flask import escape, Markup from flask_appbuilder.models.decorators import renders @@ -29,9 +30,10 @@ from sqlalchemy import and_, or_, UniqueConstraint from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm.exc import MultipleResultsFound +from sqlalchemy_utils.types.uuid import UUIDType import yaml -from superset.utils.core import QueryStatus +from superset.utils.core import QueryStatus, SQLAJsonEncoder def json_to_dict(json_str): @@ -43,7 +45,11 @@ def json_to_dict(json_str): return {} -class ImportMixin(object): +def get_uuid(): + return str(uuid.uuid4()) + + +class ImportExportMixin(): export_parent = None # The name of the attribute # with the SQL Alchemy back reference @@ -56,6 +62,12 @@ class ImportMixin(object): # The names of the attributes # that are available for import and export + uuid = sa.Column(UUIDType(binary=False), unique=True, default=get_uuid) + + @classmethod + def export_fields_with_uuid(cls): + return list(cls.export_fields) + ['uuid'] + @classmethod def _parent_foreign_key_mappings(cls): """Get a mapping of foreign name to the local name of foreign keys""" @@ -86,8 +98,8 @@ def formatter(c): str(c.type), c.default.arg) if c.default else str(c.type)) schema = {c.name: formatter(c) for c in cls.__table__.columns - if (c.name in cls.export_fields and - c.name not in parent_excludes)} + if (c.name in cls.export_fields_with_uuid() and + c.name not in parent_excludes)} if recursive: for c in cls.export_children: child_class = cls.__mapper__.relationships[c].argument.class_ @@ -100,7 +112,7 @@ def import_from_dict(cls, session, dict_rep, parent=None, recursive=True, sync=[]): """Import obj from a dictionary""" parent_refs = cls._parent_foreign_key_mappings() - export_fields = set(cls.export_fields) | set(parent_refs.keys()) + export_fields = set(cls.export_fields_with_uuid()) | set(parent_refs.keys()) new_children = {c: dict_rep.get(c) for c in cls.export_children if c in dict_rep} unique_constrains = cls._unique_constrains() @@ -183,6 +195,10 @@ def import_from_dict(cls, session, dict_rep, parent=None, return obj + def export_to_json_serializable(self, recursive=True): + """Export obj to be serializable in json""" + return SQLAJsonEncoder.encode(self) + def export_to_dict(self, recursive=True, include_parent_ref=False, include_defaults=False): """Export obj to dictionary""" @@ -192,9 +208,10 @@ def export_to_dict(self, recursive=True, include_parent_ref=False, parent_ref = cls.__mapper__.relationships.get(cls.export_parent) if parent_ref: parent_excludes = {c.name for c in parent_ref.local_columns} + dict_rep = {c.name: getattr(self, c.name) for c in cls.__table__.columns - if (c.name in self.export_fields and + if (c.name in ImportExportMixin.export_fields_with_uuid() and c.name not in parent_excludes and (include_defaults or ( getattr(self, c.name) is not None and @@ -212,13 +229,13 @@ def export_to_dict(self, recursive=True, include_parent_ref=False, include_defaults=include_defaults, ) for child in getattr(self, c) ], - key=lambda k: sorted(k.items())) + key=lambda k: sorted(str(k.items()))) return dict_rep def override(self, obj): """Overrides the plain fields of the dashboard.""" - for field in obj.__class__.export_fields: + for field in obj.__class__.export_fields_with_uuid(): setattr(self, field, getattr(obj, field)) def copy(self): diff --git a/superset/models/schedules.py b/superset/models/schedules.py index fdd6636699f4..0dcae0b74ae8 100644 --- a/superset/models/schedules.py +++ b/superset/models/schedules.py @@ -27,7 +27,7 @@ from sqlalchemy.orm import relationship from superset import security_manager -from superset.models.helpers import AuditMixinNullable, ImportMixin +from superset.models.helpers import AuditMixinNullable, ImportExportMixin metadata = Model.metadata # pylint: disable=no-member @@ -77,7 +77,7 @@ def user(self): class DashboardEmailSchedule(Model, AuditMixinNullable, - ImportMixin, + ImportExportMixin, EmailSchedule): __tablename__ = 'dashboard_email_schedules' dashboard_id = Column(Integer, ForeignKey('dashboards.id')) @@ -90,7 +90,7 @@ class DashboardEmailSchedule(Model, class SliceEmailSchedule(Model, AuditMixinNullable, - ImportMixin, + ImportExportMixin, EmailSchedule): __tablename__ = 'slice_email_schedules' slice_id = Column(Integer, ForeignKey('slices.id')) diff --git a/superset/utils/core.py b/superset/utils/core.py index c2ad7a7d762f..6de47211b9b4 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -285,17 +285,31 @@ def decode_dashboards(o): return o -class DashboardEncoder(json.JSONEncoder): +class SQLAJsonEncoder(): + # pylint: disable=E0202 - def default(self, o): - try: - vals = { - k: v for k, v in o.__dict__.items() if k != '_sa_instance_state'} + @classmethod + def encode(cls, o): + if isinstance(o, uuid.UUID): + return str(o) + if isinstance(o, datetime): + return {'__datetime__': o.replace(microsecond=0).isoformat()} + if isinstance(o, list): + return [SQLAJsonEncoder.encode(i) for i in o] + if hasattr(o, '__dict__'): + vals = {} + if 'uuid' in o.__dict__.keys(): + vals['uuid'] = o.__dict__['uuid'] + for k, v in o.__dict__.items(): + if k == '_sa_instance_state': + continue + elif k.startswith('json') or k.endswith('json'): + vals[k] = v + else: + vals[k] = SQLAJsonEncoder.encode(v) return {'__{}__'.format(o.__class__.__name__): vals} - except Exception: - if type(o) == datetime: - return {'__datetime__': o.replace(microsecond=0).isoformat()} - return json.JSONEncoder.default(self, o) + else: + return o def parse_human_timedelta(s: str) -> timedelta: @@ -887,19 +901,23 @@ def user_label(user: User) -> Optional[str]: return None -def get_or_create_main_db(): +def get_or_create_db_by_name(db_name='main', database_uri=None): from superset import conf, db from superset.models import core as models - logging.info('Creating database reference') + logging.info(f'Creating database reference {db_name}') dbobj = get_main_database(db.session) if not dbobj: dbobj = models.Database( - database_name='main', + database_name=db_name, allow_csv_upload=True, expose_in_sqllab=True, ) - dbobj.set_sqlalchemy_uri(conf.get('SQLALCHEMY_DATABASE_URI')) + if db_name == 'examples': + database_uri = database_uri or conf.get('SQLALCHEMY_EXAMPLES_URI') + else: + database_uri = conf.get('SQLALCHEMY_DATABASE_URI') + dbobj.set_sqlalchemy_uri(database_uri) db.session.add(dbobj) db.session.commit() return dbobj @@ -914,6 +932,15 @@ def get_main_database(session): ) +def get_examples_database(session): + from superset.models import core as models + return ( + session.query(models.Database) + .filter_by(database_name='examples') + .first() + ) + + def is_adhoc_metric(metric) -> bool: return ( isinstance(metric, dict) and diff --git a/superset/utils/dashboard_import_export.py b/superset/utils/dashboard_import_export.py index 5f6782d9387e..7c64b82bd979 100644 --- a/superset/utils/dashboard_import_export.py +++ b/superset/utils/dashboard_import_export.py @@ -17,33 +17,211 @@ # pylint: disable=C,R,W import json import logging +import os +import re +import shutil +import tempfile import time -from superset.models.core import Dashboard -from superset.utils.core import decode_dashboards +from flask_appbuilder import Model +import pandas as pd +import requests +from sqlalchemy import Column, Integer +from sqlalchemy.engine.url import make_url +from sqlalchemy.orm.exc import NoResultFound +from superset import app, db +from superset.connectors.connector_registry import ConnectorRegistry +from superset.exceptions import DashboardNotFoundException, SupersetException +from superset.models.core import Dashboard, Database +from superset.utils.core import ( + decode_dashboards, get_or_create_db_by_name, +) -def import_dashboards(session, data_stream, import_time=None): + +def import_dashboard(session, data, import_time): + """Import a Dashboard from exported data""" + for dashboard in data['dashboards']: + Dashboard.import_obj( + dashboard, import_time=import_time) + + +def import_datasources(data, import_time, substitute_db_name=None): + """Import any data sources in Dashboard file""" + for table in data['datasources']: + if table.type == 'table': + type(table).import_obj(table, import_time=import_time, + substitute_db_name=substitute_db_name) + else: + raise SupersetException('Druid datasources not supported!') + + +def table_to_sql(path, table_name, engine): + """Take a file and load it into a table""" + logging.info(f'Import data from file {path} into table {table_name}') + + try: + df = pd.read_csv(path, parse_dates=True, infer_datetime_format=True, + compression='infer') + df.to_sql( + table_name, + engine.get_sqla_engine(), + if_exists='replace', + chunksize=500, + index=False) + except (pd.errors.ParserError, pd.errors.OutOfBoundsDatetime, + pd.errors.EmptyDataError) as e: + logging.exception(e) + raise SupersetException('Error reading table into database!') + + +def import_files_to_table(data, is_example=False, data_blob_urls=None): + """Import any files in this exported Dashboard""" + if isinstance(data, dict) and data.get('files', []): + engine = get_or_create_db_by_name(db_name='main') + + if is_example: + with tempfile.TemporaryDirectory() as tmpdir: + for file_info in data['files']: + + # Get the github info for the file + blob_file_path = f'{tmpdir}{os.path.sep}{file_info["file_name"]}' + blob_url = data_blob_urls[file_info['file_name']] + + response = requests.get(blob_url, stream=True) + with open(blob_file_path, 'wb') as out_file: + shutil.copyfileobj(response.raw, out_file) + del response + + table_to_sql(blob_file_path, file_info['table_name'], engine) + else: + for table in data['files']: + table_to_sql(table['file_name'], table['table_name'], engine) + + +def import_dashboards(session, data, is_example=False, data_blob_urls=None, + database_uri=None, import_time=None): """Imports dashboards from a stream to databases""" current_tt = int(time.time()) import_time = current_tt if import_time is None else import_time - data = json.loads(data_stream.read(), object_hook=decode_dashboards) - # TODO: import DRUID datasources - for table in data['datasources']: - type(table).import_obj(table, import_time=import_time) + + data = json.loads(data, object_hook=decode_dashboards) + + # substitute_db_name = get_db_name(database_uri) if database_uri else \ + # get_or_create_db_by_name(db_name='examples').database_name + substitute_db_name = get_db_name(database_uri) if database_uri else \ + get_or_create_db_by_name(db_name='main').database_name + + import_files_to_table(data, is_example=True, data_blob_urls=data_blob_urls) session.commit() - for dashboard in data['dashboards']: - Dashboard.import_obj( - dashboard, import_time=import_time) + + import_datasources(data, import_time, substitute_db_name=substitute_db_name) + import_dashboard(session, data, import_time) session.commit() -def export_dashboards(session): +def get_db_name(uri): + """Get the DB name from the URI string""" + db_name = make_url(uri).database + if uri.startswith('sqlite'): + db_name = re.match('(?s:.*)/(.+?).db$', db_name).group(1) + return db_name + + +def get_default_example_db(): + """Get the optional substitute database for example import""" + uri = app.config.get('SQLALCHEMY_EXAMPLES_URI') + db_name = get_db_name(uri) + + return db.session.query(Database).filter_by( + database_name=db_name).one() + + +def export_dashboards(session, dashboard_ids=None, dashboard_titles=None, + export_data=False, export_data_dir=None, description=None, + export_title=None, _license='Apache 2.0', url=None, + strip_database=False): """Returns all dashboards metadata as a json dump""" logging.info('Starting export') - dashboards = session.query(Dashboard) - dashboard_ids = [] - for dashboard in dashboards: - dashboard_ids.append(dashboard.id) - data = Dashboard.export_dashboards(dashboard_ids) - return data + export_dashboard_ids = [] + + session = db.session() if not session else session + query = session.query(Dashboard) + if dashboard_ids or dashboard_titles: + query = query.filter(Dashboard.id.in_(dashboard_ids) | + Dashboard.dashboard_title.in_(dashboard_titles)) + + export_dashboard_ids = [d.id for d in query.all()] + + data = {} + if not export_dashboard_ids: + logging.error('No dashboards found!') + raise DashboardNotFoundException('No dashboards found!') + else: + data = Dashboard.export_dashboards(export_dashboard_ids, + export_data, export_data_dir) + + if export_title: + data['description']['title'] = export_title + if description: + data['description']['description'] = description + if url: + data['description']['url'] = url + data['description']['license'] = _license + + export_json = json.dumps(data, indent=4, sort_keys=True) + + return export_json + + +def get_slug(session, dashboard_id=None, dashboard_title=None): + """Get the slug for the name of the directory inside the tarballed example""" + session = session or db.session() + query = session.query(Dashboard) + slug = None + if dashboard_id or dashboard_title: + query = query.filter((Dashboard.id == dashboard_id) | + (Dashboard.dashboard_title == dashboard_title)) + dashboard = query.first() + slug = getattr(dashboard, 'slug', None) + return slug + + +def remove_dashboard(session, import_example_data, dashboard_title, database_uri=None, + primary_key=Column('id', Integer, primary_key=True)): + """Remove a dashboard based on id or title""" + + session = db.session() if not session else session + logging.info(session.query(Dashboard).all()) + + try: + dashboard = session.query(Dashboard).filter( + Dashboard.dashboard_title == dashboard_title, + ).one() + + session.delete(dashboard) + session.commit() + except NoResultFound: + raise DashboardNotFoundException('Dashboard not found!') + + # Remove the associated table metadata + SqlaTable = ConnectorRegistry.sources['table'] + for f in import_example_data['files']: + t = session.query(SqlaTable).filter( + SqlaTable.table_name == f['table_name'], + ).one() + session.delete(t) + session.commit() + + # Now delete the physical data table + # examples_engine = get_or_create_db_by_name(db_name='examples') + examples_engine = get_or_create_db_by_name(db_name='main') + sqla_engine = examples_engine.get_sqla_engine() + + # Create a model class on the fly to do a cross-platform table drop + class DropTable(Model): + __tablename__ = f['table_name'] + id = primary_key + + table = DropTable() + table.__table__.drop(sqla_engine) diff --git a/superset/utils/dict_import_export.py b/superset/utils/dict_import_export.py index ca57c6b177d4..374a575684df 100644 --- a/superset/utils/dict_import_export.py +++ b/superset/utils/dict_import_export.py @@ -72,11 +72,6 @@ def import_from_dict(session, data, sync=[]): for database in data.get(DATABASES_KEY, []): Database.import_from_dict(session, database, sync=sync) - logging.info('Importing %d %s', - len(data.get(DRUID_CLUSTERS_KEY, [])), - DRUID_CLUSTERS_KEY) - for datasource in data.get(DRUID_CLUSTERS_KEY, []): - DruidCluster.import_from_dict(session, datasource, sync=sync) session.commit() else: logging.info('Supplied object is not a dictionary.') diff --git a/tests/cli_tests.py b/tests/cli_tests.py new file mode 100644 index 000000000000..7b7022b0de8f --- /dev/null +++ b/tests/cli_tests.py @@ -0,0 +1,292 @@ +import gzip +import json +import logging +import os +import struct +import tarfile +import tempfile + +import pandas as pd + +from superset import app, cli, db # noqa: F401 +from superset.connectors.connector_registry import ConnectorRegistry +from superset.models.core import Dashboard, Database, Slice +from superset.utils.core import get_or_create_db_by_name +from tests.base_tests import SupersetTestCase + +config = app.config + + +class SupersetCliTestCase(SupersetTestCase): + + @classmethod + def setUp(cls): + cls.runner = app.test_cli_runner() + + @classmethod + def get_uncompressed_size(cls, file_path): + """Last 4 bytes of a gzip file contain uncompressed size""" + with open(file_path, 'rb') as f: + f.seek(-4, 2) + return struct.unpack('I', f.read(4))[0] + + @classmethod + def gzip_file_line_count(cls, file_path): + """Get the line count of a gzip'd CSV file""" + with gzip.open(file_path, 'r') as f: + for i, l in enumerate(f): + pass + return i + 1 + + def test_version(self): + """Test `superset version`""" + version_result = self.runner.invoke(app.cli, ['version']) + # Version result should contain version string + self.assertTrue(config.get('VERSION_STRING') in version_result.output) + + def test_export_all_test_dashboards(self): + """Test `superset export_dashboards`""" + self.runner.invoke(app.cli, ['load_examples']) + result = self.runner.invoke(app.cli, ['export_dashboards']) + data = json.loads(result.output) + + # Should export at least all 5 test dashboards + self.assertGreaterEqual(len(data['dashboards']), 5) + + def test_export_dashboard_by_id(self): + """Test `superset export_dashboards -i 3`""" + self.runner.invoke(app.cli, ['load_examples']) + result = self.runner.invoke(app.cli, ['export_dashboards', '-i', '5']) + data = json.loads(result.output) + + # Should export 1 dashboard with matching id + ids = list(map(lambda d: d['__Dashboard__']['id'], data['dashboards'])) + self.assertEqual(len(ids), 1) + self.assertEqual(ids[0], 5) + + def test_export_dashboard_by_title(self): + """Test `superset export_dashboards -t World's Bank Data`""" + self.runner.invoke(app.cli, ['load_examples']) + result = self.runner.invoke( + app.cli, ['export_dashboards', '-t', "World's Bank Data"]) + data = json.loads(result.output) + + # Should export 1 dashboard with matching id + ids = list(map( + lambda d: d['__Dashboard__']['dashboard_title'], data['dashboards'])) + self.assertEqual(len(ids), 1) + self.assertEqual(ids[0], "World's Bank Data") + + def test_examples_menu(self): + """Test `superset examples` menu""" + result = self.runner.invoke(app.cli, ['examples']) + self.assertIn('import', result.output) + self.assertIn('list', result.output) + self.assertIn('remove', result.output) + self.assertIn('export', result.output) + + def test_examples_list(self): + """Test `superset examples list`""" + result = self.runner.invoke(app.cli, ['examples', 'list']) + + found = False + for i, line in enumerate(result.output.split('\n')): + # skip header + if i < 3: + continue + # Odd lines have data + if (i % 2) != 0: + row = line[1:-1] + parts = [i.strip() for i in row.split('|')] + if parts[0] == 'World Bank Health Nutrition and Population Stats': + found = True + + # Did we find the example in the list? + self.assertEqual(found, True) + + def test_examples_import(self): + """Test `superset examples import`""" + self.runner.invoke( + app.cli, + [ + 'examples', 'import', '-e', + 'World Bank Health Nutrition and Population Stats', + ], + ) + + # Did the dashboard get imported to the main DB? + dashboard = db.session.query(Dashboard).filter( + Dashboard.dashboard_title.in_(["World's Bank Data"])).one() + self.assertEqual(dashboard.dashboard_title, "World's Bank Data") + + # Temporary - substitute default + db_name = 'main' + + # Did the data table get imported? + SqlaTable = ConnectorRegistry.sources['table'] + table = ( + db.session.query(SqlaTable) + .join(Database) + .filter( + Database.database_name == db_name and + SqlaTable.table_name == 'wb_health_population') + ).one() + self.assertEqual(table.name, 'wb_health_population') + + # Did all rows get imported? + df = pd.read_sql('SELECT * FROM wb_health_population', + get_or_create_db_by_name(db_name='main').get_sqla_engine()) + self.assertEqual(len(df.index), 11770) + + def test_examples_import_duplicate(self): + """Test `superset examples import` when existing dashboard/diff uuid is present""" + # Load a pre-existing "World's Bank" Dashboard via `superset load_examples` + self.runner.invoke( + app.cli, + ['load_examples'], + ) + # Load the same dashboard but different uuids + self.runner.invoke( + app.cli, + [ + 'examples', 'import', '-e', + 'World Bank Health Nutrition and Population Stats', + ], + ) + + # Did the dashboard get imported to the main DB more than once? + dashboards = db.session.query(Dashboard).filter( + Dashboard.dashboard_title.in_(["World's Bank Data"])).all() + self.assertEqual(len(dashboards), 2) + + # Did the slices get imported to the main DB more than once? + slices = db.session.query(Slice).filter( + Slice.slice_name.in_(["World's Population"]), + ).all() + self.assertEqual(len(slices), 2) + + def test_examples_import_duplicate_uuid(self): + """Test `superset examples import` when existing dashboard/same uuid is present""" + # Load a pre-existing "World's Bank" Dashboard + self.runner.invoke( + app.cli, + [ + 'examples', 'import', '-e', + 'World Bank Health Nutrition and Population Stats', + ], + ) + # Load the same dashboard but different uuids + self.runner.invoke( + app.cli, + [ + 'examples', 'import', '-e', + 'World Bank Health Nutrition and Population Stats', + ], + ) + + # Did the dashboard get imported to the main DB just once? + dashboards = db.session.query(Dashboard).filter( + Dashboard.dashboard_title.in_(["World's Bank Data"])).all() + self.assertEqual(len(dashboards), 1) + + # Did the slices get imported just once? + slices = db.session.query(Slice).filter( + Slice.slice_name.in_(["World's Population"]), + ).all() + self.assertEqual(len(slices), 1) + + def test_examples_remove(self): + """Test `superset examples remove`""" + # First add the example... + self.runner.invoke( + app.cli, + [ + 'examples', 'import', '-e', + 'World Bank Health Nutrition and Population Stats', + ], + ) + + # Then remove the example... + self.runner.invoke( + app.cli, + [ + 'examples', 'remove', '-e', + 'World Bank Health Nutrition and Population Stats', + ], + ) + + # Is the dashboard still in the main db? + total = db.session.query(Dashboard).filter( + Dashboard.dashboard_title.in_(["World's Bank Data"])).count() + self.assertEqual(total, 0) + + # Is the data table gone? + db_name = 'main' + + # Did the data table get removed? + SqlaTable = ConnectorRegistry.sources['table'] + total = ( + db.session.query(SqlaTable) + .join(Database) + .filter( + Database.database_name == db_name and + SqlaTable.table_name == 'wb_health_population') + ).count() + self.assertEqual(total, 0) + + def test_examples_export(self): + """Test `superset examples export`""" + # self.runner.invoke(app.cli, ['load_examples']) + result = self.runner.invoke( + app.cli, + [ + 'examples', 'export', '-e', + 'World Bank Health Nutrition and Population Stats', '-t', + "World's Bank Data", '-d', + 'Health Nutrition and Population Statistics database provides key ' + + 'health, nutrition and population statistics gathered from a ' + + 'variety of international and national sources. Themes include ' + + 'global surgery, health financing, HIV/AIDS, immunization, ' + + 'infectious diseases, medical resources and usage, noncommunicable ' + + 'diseases, nutrition, population dynamics, reproductive health, ' + + 'universal health coverage, and water and sanitation.', + '-l', 'Apache 2.0', '-u', + 'https://datacatalog.worldbank.org/dataset/' + + 'health-nutrition-and-population-statistics', + ]) + logging.info(result.output) + + # Inspect the tarball + with tarfile.open('dashboard.tar.gz', 'r:gz') as tar: + + # Extract all exported files to a temporary directory + out_d = tempfile.TemporaryDirectory() + + tar.extractall(out_d.name) + world_health_path = f'{out_d.name}{os.path.sep}world_health{os.path.sep}' + + # Check the Dashboard metadata export + json_f = open(f'{world_health_path}/dashboard.json', 'r') + dashboard = json.loads(json_f.read()) + desc = dashboard['description'] + self.assertEqual( + desc['title'], 'World Bank Health Nutrition and Population Stats') + self.assertEqual( + desc['description'][0:30], 'Health Nutrition and Populatio') + + # Check the data export by writing out the tarball, getting the file size + # and comparing to the metadata size + data_file_path = f'{world_health_path}/wb_health_population.csv.gz' + + file_size = SupersetCliTestCase.get_uncompressed_size(data_file_path) + file_size = os.path.getsize(data_file_path) + self.assertEqual( + desc['total_size'], + file_size) + + # Check the data export row count against the example's description metadata + self.assertEqual( + desc['total_rows'], + SupersetCliTestCase.gzip_file_line_count(data_file_path)) + + out_d.cleanup() diff --git a/tests/dict_import_export_tests.py b/tests/dict_import_export_tests.py index f1f93fa64f89..dba6205841dd 100644 --- a/tests/dict_import_export_tests.py +++ b/tests/dict_import_export_tests.py @@ -21,9 +21,6 @@ import yaml from superset import db -from superset.connectors.druid.models import ( - DruidColumn, DruidDatasource, DruidMetric, -) from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.utils.core import get_main_database from .base_tests import SupersetTestCase @@ -45,9 +42,6 @@ def delete_imports(cls): for table in session.query(SqlaTable): if DBREF in table.params_dict: session.delete(table) - for datasource in session.query(DruidDatasource): - if DBREF in datasource.params_dict: - session.delete(datasource) session.commit() @classmethod @@ -87,36 +81,6 @@ def create_table( table.metrics.append(SqlMetric(metric_name=metric_name, expression='')) return table, dict_rep - def create_druid_datasource( - self, name, id=0, cols_names=[], metric_names=[]): - name = '{0}{1}'.format(NAME_PREFIX, name) - cluster_name = 'druid_test' - params = {DBREF: id, 'database_name': cluster_name} - dict_rep = { - 'cluster_name': cluster_name, - 'datasource_name': name, - 'id': id, - 'params': json.dumps(params), - 'columns': [{'column_name': c} for c in cols_names], - 'metrics': [{'metric_name': c, 'json': '{}'} for c in metric_names], - } - - datasource = DruidDatasource( - id=id, - datasource_name=name, - cluster_name=cluster_name, - params=json.dumps(params), - ) - for col_name in cols_names: - datasource.columns.append(DruidColumn(column_name=col_name)) - for metric_name in metric_names: - datasource.metrics.append(DruidMetric(metric_name=metric_name)) - return datasource, dict_rep - - def get_datasource(self, datasource_id): - return db.session.query(DruidDatasource).filter_by( - id=datasource_id).first() - def get_table_by_name(self, name): return db.session.query(SqlaTable).filter_by( table_name=name).first() @@ -257,106 +221,6 @@ def test_import_table_override_identical(self): self.yaml_compare(imported_copy_table.export_to_dict(), imported_table.export_to_dict()) - def test_import_druid_no_metadata(self): - datasource, dict_datasource = self.create_druid_datasource( - 'pure_druid', id=ID_PREFIX + 1) - imported_cluster = DruidDatasource.import_from_dict(db.session, - dict_datasource) - db.session.commit() - imported = self.get_datasource(imported_cluster.id) - self.assert_datasource_equals(datasource, imported) - - def test_import_druid_1_col_1_met(self): - datasource, dict_datasource = self.create_druid_datasource( - 'druid_1_col_1_met', id=ID_PREFIX + 2, - cols_names=['col1'], metric_names=['metric1']) - imported_cluster = DruidDatasource.import_from_dict(db.session, - dict_datasource) - db.session.commit() - imported = self.get_datasource(imported_cluster.id) - self.assert_datasource_equals(datasource, imported) - self.assertEquals( - {DBREF: ID_PREFIX + 2, 'database_name': 'druid_test'}, - json.loads(imported.params)) - - def test_import_druid_2_col_2_met(self): - datasource, dict_datasource = self.create_druid_datasource( - 'druid_2_col_2_met', id=ID_PREFIX + 3, cols_names=['c1', 'c2'], - metric_names=['m1', 'm2']) - imported_cluster = DruidDatasource.import_from_dict(db.session, - dict_datasource) - db.session.commit() - imported = self.get_datasource(imported_cluster.id) - self.assert_datasource_equals(datasource, imported) - - def test_import_druid_override_append(self): - datasource, dict_datasource = self.create_druid_datasource( - 'druid_override', id=ID_PREFIX + 3, cols_names=['col1'], - metric_names=['m1']) - imported_cluster = DruidDatasource.import_from_dict(db.session, - dict_datasource) - db.session.commit() - table_over, table_over_dict = self.create_druid_datasource( - 'druid_override', id=ID_PREFIX + 3, - cols_names=['new_col1', 'col2', 'col3'], - metric_names=['new_metric1']) - imported_over_cluster = DruidDatasource.import_from_dict( - db.session, - table_over_dict) - db.session.commit() - imported_over = self.get_datasource(imported_over_cluster.id) - self.assertEquals(imported_cluster.id, imported_over.id) - expected_datasource, _ = self.create_druid_datasource( - 'druid_override', id=ID_PREFIX + 3, - metric_names=['new_metric1', 'm1'], - cols_names=['col1', 'new_col1', 'col2', 'col3']) - self.assert_datasource_equals(expected_datasource, imported_over) - - def test_import_druid_override_sync(self): - datasource, dict_datasource = self.create_druid_datasource( - 'druid_override', id=ID_PREFIX + 3, cols_names=['col1'], - metric_names=['m1']) - imported_cluster = DruidDatasource.import_from_dict( - db.session, - dict_datasource) - db.session.commit() - table_over, table_over_dict = self.create_druid_datasource( - 'druid_override', id=ID_PREFIX + 3, - cols_names=['new_col1', 'col2', 'col3'], - metric_names=['new_metric1']) - imported_over_cluster = DruidDatasource.import_from_dict( - session=db.session, - dict_rep=table_over_dict, - sync=['metrics', 'columns']) # syncing metrics and columns - db.session.commit() - imported_over = self.get_datasource(imported_over_cluster.id) - self.assertEquals(imported_cluster.id, imported_over.id) - expected_datasource, _ = self.create_druid_datasource( - 'druid_override', id=ID_PREFIX + 3, - metric_names=['new_metric1'], - cols_names=['new_col1', 'col2', 'col3']) - self.assert_datasource_equals(expected_datasource, imported_over) - - def test_import_druid_override_identical(self): - datasource, dict_datasource = self.create_druid_datasource( - 'copy_cat', id=ID_PREFIX + 4, - cols_names=['new_col1', 'col2', 'col3'], - metric_names=['new_metric1']) - imported = DruidDatasource.import_from_dict(session=db.session, - dict_rep=dict_datasource) - db.session.commit() - copy_datasource, dict_cp_datasource = self.create_druid_datasource( - 'copy_cat', id=ID_PREFIX + 4, - cols_names=['new_col1', 'col2', 'col3'], - metric_names=['new_metric1']) - imported_copy = DruidDatasource.import_from_dict(db.session, - dict_cp_datasource) - db.session.commit() - - self.assertEquals(imported.id, imported_copy.id) - self.assert_datasource_equals( - copy_datasource, self.get_datasource(imported.id)) - if __name__ == '__main__': unittest.main()