From d9c2b77ff8e5afa00ef75ed5ad8cc6dcad37239d Mon Sep 17 00:00:00 2001 From: Bogdan Kyryliuk Date: Mon, 26 Sep 2016 09:10:37 -0700 Subject: [PATCH 01/13] Implement import / export dashboard functionality. --- caravel/models.py | 109 +++++++++++++++++- .../templates/caravel/export_dashboards.html | 24 ++++ .../templates/caravel/import_dashboards.html | 13 +++ caravel/views.py | 75 +++++++++++- tests/core_tests.py | 33 ++++++ 5 files changed, 249 insertions(+), 5 deletions(-) create mode 100644 caravel/templates/caravel/export_dashboards.html create mode 100644 caravel/templates/caravel/import_dashboards.html diff --git a/caravel/models.py b/caravel/models.py index f160e038d960..6c9f7beeb08c 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -7,6 +7,7 @@ import functools import json import logging +import pickle import re import textwrap from collections import namedtuple @@ -18,6 +19,9 @@ import requests import sqlalchemy as sqla from sqlalchemy.engine.url import make_url +from sqlalchemy.orm import joinedload +from sqlalchemy.orm.session import make_transient + import sqlparse from dateutil.parser import parse @@ -41,6 +45,7 @@ from sqlalchemy.ext.compiler import compiles from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import backref, relationship +from sqlalchemy.orm.session import make_transient from sqlalchemy.sql import table, literal_column, text, column from sqlalchemy.sql.expression import ColumnClause, TextAsFrom from sqlalchemy_utils import EncryptedType @@ -283,6 +288,47 @@ def get_viz(self, url_params_multidict=None): slice_=self ) + @classmethod + def import_slice(cls, slice): + make_transient(slice) + slice.id = None + + session = db.session + existing_slice = ( + session.query(Slice) + .filter_by(slice_name=slice.slice_name) + .first() + ) + + # no overrides for the slices + # TODO: consider comparing the object and overriding if needed. + if existing_slice: + return existing_slice + + id_pos = slice.perm.find('(id') + if id_pos == -1: + logging.warning('slice {} has malformed perm {}'.format( + slice.slice_name, slice.perf)) + return None + + perm_prefix = slice.perm[:id_pos] + datasource_class = SourceRegistry.sources[ + slice.datasource_type] + datasources = session.query(datasource_class).all() + slice_datasource = [d for d in datasources if perm_prefix in d.perm][0] + slice.datasource_id = slice_datasource.id + # generate correct perms for the slice + slice.perm = '{}(id:{})'.format(perm_prefix, slice_datasource.id) + session.add(slice) + session.commit() + + created_slice = ( + session.query(Slice) + .filter_by(slice_name=slice.slice_name) + .first() + ) + return created_slice + def set_perm(mapper, connection, target): # noqa src_class = target.cls_model @@ -369,6 +415,68 @@ def json_data(self): } return json.dumps(d) + @classmethod + def import_dashboard(cls, dashboard_to_import): + session = db.session + # TODO: provide nice error message about failed imports + slices_to_attach = [] + for slice in dashboard_to_import.slices: + imported_slice = Slice.import_slice(slice) + if not imported_slice: + break + slices_to_attach.append(imported_slice) + logging.info( + slice.slice_name + ' belongs to the dashboard ' + + dashboard_to_import.dashboard_title) + # Skip dashboard creation if the slice wasn't imported + if not slices_to_attach: + logging.warning('Dashboard {} import failed'.format( + dashboard_to_import.dashboard_title)) + return + + existing_dashboard = ( + session.query(Dashboard) + .filter_by( + dashboard_title=dashboard_to_import.dashboard_title) + .first() + ) + + # Override the dashboard + if existing_dashboard: + # TODO: add nice notification about the overriden dashboard + session.delete(existing_dashboard) + session.commit() + + # session.add(dashboard_to_import) causes sqlachemy failures + # related to the attached users / slices. Creating new object + # allows to avoid conflicts in the sql alchemy state. + new_dash = Dashboard( + dashboard_title=dashboard_to_import.dashboard_title, + position_json=dashboard_to_import.position_json, + json_metadata=dashboard_to_import.json_metadata, + description=dashboard_to_import.description, + css=dashboard_to_import.css, + slug=dashboard_to_import.slug, + slices=slices_to_attach, + ) + session.add(new_dash) + session.commit() + + @classmethod + def export_dashboards(cls, dashboard_ids): + eager_dashboards = [] + for dashboard_id in dashboard_ids: + eager_dashboard = db.session.query(Dashboard).options( + joinedload('slices')).filter_by(id=dashboard_id).first() + db.session.expunge(eager_dashboard) + make_transient(eager_dashboard) + eager_dashboard.created_by_fk = None + eager_dashboard.changed_by_fk = None + eager_dashboard.owners = [] + eager_dashboards.append(eager_dashboard) + return pickle.dumps(eager_dashboards) + + class Queryable(object): @@ -405,7 +513,6 @@ def explore_url(self): else: return "/caravel/explore/{obj.type}/{obj.id}/".format(obj=self) - class Database(Model, AuditMixinNullable): """An ORM object that stores Database related information""" diff --git a/caravel/templates/caravel/export_dashboards.html b/caravel/templates/caravel/export_dashboards.html new file mode 100644 index 000000000000..4704b80794eb --- /dev/null +++ b/caravel/templates/caravel/export_dashboards.html @@ -0,0 +1,24 @@ +{% extends "caravel/basic.html" %} +{% block title %}{{ _("Export") }}{% endblock %} +{% block body %} + {% include "caravel/flash_wrapper.html" %} +
+ {{ _("Export the dashboards") }} +

+ {{ _("You are about to export dashboards: %(dashboard_names)s ", + dashboard_names=dashboard_names) }} +

+
+ + +
+
+{% endblock %} \ No newline at end of file diff --git a/caravel/templates/caravel/import_dashboards.html b/caravel/templates/caravel/import_dashboards.html new file mode 100644 index 000000000000..4f5082886094 --- /dev/null +++ b/caravel/templates/caravel/import_dashboards.html @@ -0,0 +1,13 @@ +{% extends "caravel/basic.html" %} +{% block title %}{{ _("Import") }}{% endblock %} +{% block body %} + {% include "caravel/flash_wrapper.html" %} +
+ Import the dashboards. +

Import the dashboards.

+
+

+ +

+
+{% endblock %} \ No newline at end of file diff --git a/caravel/views.py b/caravel/views.py index 212cf0709b7a..0d1d8afa0f15 100755 --- a/caravel/views.py +++ b/caravel/views.py @@ -5,6 +5,8 @@ import json import logging +import os +import pickle import re import sys import time @@ -15,8 +17,9 @@ import sqlalchemy as sqla from flask import ( - g, request, redirect, flash, Response, render_template, Markup) -from flask_appbuilder import ModelView, CompactCRUDMixin, BaseView, expose + g, request, make_response, redirect, flash, Response, render_template, + Markup, url_for, send_file) +from flask_appbuilder import ModelView, CompactCRUDMixin, BaseView, expose, filemanager from flask_appbuilder.actions import action from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.security.decorators import has_access, has_access_api @@ -26,6 +29,8 @@ from flask_appbuilder.models.sqla.filters import BaseFilter from sqlalchemy import create_engine +from werkzeug import secure_filename +from werkzeug.datastructures import ImmutableMultiDict from werkzeug.routing import BaseConverter from werkzeug.datastructures import ImmutableMultiDict from wtforms.validators import ValidationError @@ -867,15 +872,56 @@ def pre_update(self, obj): def pre_delete(self, obj): check_ownership(obj) + @action("mulexport", "Export", "Export dashboards?", "fa-database", single=True) + def mulexport(self, items): + # TODO: find a way to download file and redirect. + ids = ''.join('&id={}'.format(d.id) for d in items) + dashboard_names = ''.join( + '&dashboard_name={}'.format(d.dashboard_title) for d in items) + return redirect('/dashboardmodelview/export_dashboards_form?{}{}' + .format(ids[1:], dashboard_names)) + + @expose("/export_dashboards") + def download_file(self): + ids = request.args.getlist('id') + payload = models.Dashboard.export_dashboards(ids) + return Response( + payload, + status=200, + headers=generate_download_headers("pickle"), + mimetype="application/text") + + @expose("/export_dashboards_form") + def download_dashboards(self): + ids = request.args.getlist('id') + args = ''.join('&id={}'.format(id) for id in ids) + dashboard_names = request.args.getlist('dashboard_name') + concat_dashboard_names = ', '.join(dashboard_names) + download_url = ( + '/dashboardmodelview/export_dashboards?{}'.format(args[1:]) + ) + return self.render_template( + 'caravel/export_dashboards.html', + dashboard_names=concat_dashboard_names, + download_url=download_url, + dashboards_url='/dashboardmodelview/list' + ) + appbuilder.add_view( DashboardModelView, "Dashboards", - label=__("Dashboards"), + label=__("Dashboards list"), icon="fa-dashboard", - category="", + category='Dashboards', category_icon='',) +appbuilder.add_link( + 'Import Dashboards', + href='/caravel/import_dashboards', + icon="fa-flask", + category='Dashboards') + class DashboardModelViewAsync(DashboardModelView): # noqa list_columns = ['dashboard_link', 'creator', 'modified', 'dashboard_title'] @@ -1126,6 +1172,7 @@ def get_viz( return viz_obj @has_access +<<<<<<< f70d301f0d37fe6c4f94cd7d6026a72a56087d34 @expose("/slice//") def slice(self, slice_id): viz_obj = self.get_viz(slice_id) @@ -1149,6 +1196,26 @@ def explore_json(self, datasource_type, datasource_id): status=200, mimetype="application/json") + @expose("/import_dashboards", methods=['GET', 'POST']) + @log_this + def import_dashboards(self): + """Overrides the dashboards using pickled instances from the file.""" + if request.method == 'POST': + file = request.files['file'] + if file: + filename = secure_filename(file.filename) + filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) + file.save(filepath) + for d in pickle.load(open(filepath, 'rb')): + models.Dashboard.import_dashboard(d) + os.remove(filepath) + return redirect('/dashboardmodelview/list/') + return self.render_template('caravel/import_dashboards.html') + + @has_access + @expose("/explore////") + @expose("/explore///") + @expose("/datasource///") # Legacy url @log_this @has_access @expose("/explore///") diff --git a/tests/core_tests.py b/tests/core_tests.py index bfd9700aaed0..7fb32530b252 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -8,6 +8,7 @@ import doctest import json import io +import pickle import random import unittest @@ -452,6 +453,38 @@ def test_only_owners_can_save(self): db.session.commit() self.test_save_dash('alpha') + def test_export_dashboards(self): + births_dash = db.session.query(models.Dashboard).filter_by( + slug='births').first() + + # export 1 dashboard + export_dash_url_1 = ( + '/dashboardmodelview/export_dashboards?id={}' + .format(births_dash.id) + ) + resp_1 = self.client.get(export_dash_url_1) + exported_dashboards_1 = pickle.loads(resp_1.data) + self.assertEquals(1, len(exported_dashboards_1)) + self.assertEquals('births', exported_dashboards_1[0].slug) + self.assertEquals('Births', exported_dashboards_1[0].dashboard_title) + self.assertTrue(exported_dashboards_1[0].position_json) + self.assertEquals(10, len(exported_dashboards_1[0].slices)) + + # export 2 dashboards + world_health_dash = db.session.query(models.Dashboard).filter_by( + slug='world_health').first() + export_dash_url_2 = ( + '/dashboardmodelview/export_dashboards?id={}&id={}' + .format(births_dash.id, world_health_dash.id) + ) + resp_2 = self.client.get(export_dash_url_2) + exported_dashboards_2 = pickle.loads(resp_2.data) + self.assertEquals(2, len(exported_dashboards_2)) + self.assertEquals('births', exported_dashboards_2[0].slug) + self.assertEquals(10, len(exported_dashboards_2[0].slices)) + self.assertEquals('world_health', exported_dashboards_2[1].slug) + self.assertEquals(10, len(exported_dashboards_2[1].slices)) + if __name__ == '__main__': unittest.main() From 0b9a9f3ae1143447e6b23db4c231d69a78bc6ba7 Mon Sep 17 00:00:00 2001 From: Bogdan Kyryliuk Date: Tue, 27 Sep 2016 14:36:46 -0700 Subject: [PATCH 02/13] Address comments from discussion. --- caravel/__init__.py | 6 + caravel/models.py | 127 ++++++++---- .../templates/caravel/export_dashboards.html | 30 +-- caravel/views.py | 48 ++--- tests/core_tests.py | 188 +++++++++++++++++- 5 files changed, 299 insertions(+), 100 deletions(-) diff --git a/caravel/__init__.py b/caravel/__init__.py index 1fcfe43acb99..814b60ebceff 100644 --- a/caravel/__init__.py +++ b/caravel/__init__.py @@ -82,6 +82,12 @@ def checkout(dbapi_con, con_record, con_proxy): if app.config.get('ENABLE_PROXY_FIX'): app.wsgi_app = ProxyFix(app.wsgi_app) +if app.config.get('UPLOAD_FOLDER'): + try: + os.makedirs(app.config.get('UPLOAD_FOLDER')) + except OSError, e: + pass + class MyIndexView(IndexView): @expose('/') diff --git a/caravel/models.py b/caravel/models.py index 6c9f7beeb08c..5eb11a4d42e8 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -289,45 +289,64 @@ def get_viz(self, url_params_multidict=None): ) @classmethod - def import_slice(cls, slice): - make_transient(slice) - slice.id = None - + def import_slice(cls, slc, import_time=None): session = db.session - existing_slice = ( - session.query(Slice) - .filter_by(slice_name=slice.slice_name) - .first() - ) - - # no overrides for the slices - # TODO: consider comparing the object and overriding if needed. - if existing_slice: - return existing_slice - - id_pos = slice.perm.find('(id') + make_transient(slc) + slc.dashboards = [] + id_pos = slc.perm.find('(id') if slc.perm else -1 if id_pos == -1: logging.warning('slice {} has malformed perm {}'.format( - slice.slice_name, slice.perf)) + slc.slice_name, slc.perm)) return None - perm_prefix = slice.perm[:id_pos] - datasource_class = SourceRegistry.sources[ - slice.datasource_type] + perm_prefix = slc.perm[:id_pos] + datasource_class = SourceRegistry.sources[slc.datasource_type] datasources = session.query(datasource_class).all() - slice_datasource = [d for d in datasources if perm_prefix in d.perm][0] - slice.datasource_id = slice_datasource.id - # generate correct perms for the slice - slice.perm = '{}(id:{})'.format(perm_prefix, slice_datasource.id) - session.add(slice) - session.commit() + if not datasources: + logging.warning('Datasource {} was not found for the slice {}.' + .format(slc.datasource_id, slc.slice_name)) + return None - created_slice = ( + params_dict = {} + if slc.params: + params_dict = json.loads(slc.params) + if 'remote_id' not in params_dict: + params_dict['remote_id'] = slc.id + params_dict['import_time'] = import_time + slc.params = json.dumps(params_dict) + + # find if the slice was already imported + slc_to_override = ( session.query(Slice) - .filter_by(slice_name=slice.slice_name) + .filter(Slice.params.ilike('%"remote_id": {}%'.format(slc.id))) .first() ) - return created_slice + + slc.id = None + if slc_to_override: + slc_params = json.loads(slc_to_override.params) + if slc_params.get('import_time') == import_time: + # slice was imported in the current batch, no changes needed: + return slc_to_override.id + slc.id = slc_to_override.id + # delete old version of the slice + session.delete(slc_to_override) + session.commit() + + # add slice pointing to the right datasources + slc_dss = [d for d in datasources if perm_prefix in d.perm] + if not slc_dss: + logging.warning( + 'Datasources were not found for the slice {} with perm {} ' + .format(slc.slice_name, slc.perm)) + return None + slc.datasource_id = slc_dss[0].id + # generate correct perms for the slice + slc.perm = '{}(id:{})'.format(perm_prefix, slc.datasource_id) + session.add(slc) + logging.info('Imported the slice: {}'.format(slc.to_json())) + session.flush() + return slc.id def set_perm(mapper, connection, target): # noqa @@ -416,17 +435,26 @@ def json_data(self): return json.dumps(d) @classmethod - def import_dashboard(cls, dashboard_to_import): - session = db.session + def import_dashboard(cls, dashboard_to_import, import_time=None): + logging.info('Started import of the dashboard: {}' + .format(dashboard_to_import.to_json())) + make_transient(dashboard_to_import) # TODO: provide nice error message about failed imports slices_to_attach = [] - for slice in dashboard_to_import.slices: - imported_slice = Slice.import_slice(slice) - if not imported_slice: - break - slices_to_attach.append(imported_slice) + session = db.session + logging.info('Dashboard has {} slices' + .format(len(dashboard_to_import.slices))) + slices = copy(dashboard_to_import.slices) + for slc in slices: + logging.info('Importing slice {} from the dashboard: {}'.format( + slc.to_json(), dashboard_to_import.dashboard_title)) + slc_id = Slice.import_slice(slc, import_time=import_time) + if not slc_id: + continue + slices_to_attach.append( + session.query(Slice).filter_by(id=slc_id).first()) logging.info( - slice.slice_name + ' belongs to the dashboard ' + + slc.slice_name + ' belongs to the dashboard ' + dashboard_to_import.dashboard_title) # Skip dashboard creation if the slice wasn't imported if not slices_to_attach: @@ -434,16 +462,27 @@ def import_dashboard(cls, dashboard_to_import): dashboard_to_import.dashboard_title)) return + if dashboard_to_import.json_metadata: + json_metadata_dict = json.loads(dashboard_to_import.json_metadata) + else: + json_metadata_dict = {} + if 'remote_id' not in json_metadata_dict: + json_metadata_dict['remote_id'] = dashboard_to_import.id + json_metadata_dict['import_time'] = import_time + + # Override the dashboard existing_dashboard = ( session.query(Dashboard) - .filter_by( - dashboard_title=dashboard_to_import.dashboard_title) + .filter(Dashboard.json_metadata.ilike( + '%"remote_id": {}%'.format(dashboard_to_import.id))) .first() ) - - # Override the dashboard + dashboard_to_import.id = None if existing_dashboard: # TODO: add nice notification about the overriden dashboard + logging.info('Dashboard already exist and will be overridden: {}' + .format(existing_dashboard.to_json())) + session.delete(existing_dashboard) session.commit() @@ -453,7 +492,7 @@ def import_dashboard(cls, dashboard_to_import): new_dash = Dashboard( dashboard_title=dashboard_to_import.dashboard_title, position_json=dashboard_to_import.position_json, - json_metadata=dashboard_to_import.json_metadata, + json_metadata=json.dumps(json_metadata_dict), description=dashboard_to_import.description, css=dashboard_to_import.css, slug=dashboard_to_import.slug, @@ -461,6 +500,8 @@ def import_dashboard(cls, dashboard_to_import): ) session.add(new_dash) session.commit() + logging.info('Imported the dashboard: {}'.format(new_dash.to_json())) + return new_dash.id @classmethod def export_dashboards(cls, dashboard_ids): @@ -477,7 +518,6 @@ def export_dashboards(cls, dashboard_ids): return pickle.dumps(eager_dashboards) - class Queryable(object): """A common interface to objects that are queryable (tables and datasources)""" @@ -513,6 +553,7 @@ def explore_url(self): else: return "/caravel/explore/{obj.type}/{obj.id}/".format(obj=self) + class Database(Model, AuditMixinNullable): """An ORM object that stores Database related information""" diff --git a/caravel/templates/caravel/export_dashboards.html b/caravel/templates/caravel/export_dashboards.html index 4704b80794eb..83a0c840e09f 100644 --- a/caravel/templates/caravel/export_dashboards.html +++ b/caravel/templates/caravel/export_dashboards.html @@ -1,24 +1,6 @@ -{% extends "caravel/basic.html" %} -{% block title %}{{ _("Export") }}{% endblock %} -{% block body %} - {% include "caravel/flash_wrapper.html" %} -
- {{ _("Export the dashboards") }} -

- {{ _("You are about to export dashboards: %(dashboard_names)s ", - dashboard_names=dashboard_names) }} -

-
- - -
-
-{% endblock %} \ No newline at end of file + diff --git a/caravel/views.py b/caravel/views.py index 0d1d8afa0f15..623c76dc5a81 100755 --- a/caravel/views.py +++ b/caravel/views.py @@ -19,7 +19,7 @@ from flask import ( g, request, make_response, redirect, flash, Response, render_template, Markup, url_for, send_file) -from flask_appbuilder import ModelView, CompactCRUDMixin, BaseView, expose, filemanager +from flask_appbuilder import ModelView, CompactCRUDMixin, BaseView, expose from flask_appbuilder.actions import action from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.security.decorators import has_access, has_access_api @@ -872,38 +872,22 @@ def pre_update(self, obj): def pre_delete(self, obj): check_ownership(obj) - @action("mulexport", "Export", "Export dashboards?", "fa-database", single=True) + @action("mulexport", "Export", "Export dashboards?", "fa-database") def mulexport(self, items): - # TODO: find a way to download file and redirect. ids = ''.join('&id={}'.format(d.id) for d in items) - dashboard_names = ''.join( - '&dashboard_name={}'.format(d.dashboard_title) for d in items) - return redirect('/dashboardmodelview/export_dashboards_form?{}{}' - .format(ids[1:], dashboard_names)) - - @expose("/export_dashboards") - def download_file(self): - ids = request.args.getlist('id') - payload = models.Dashboard.export_dashboards(ids) - return Response( - payload, - status=200, - headers=generate_download_headers("pickle"), - mimetype="application/text") + return redirect( + '/dashboardmodelview/export_dashboards_form?{}'.format(ids[1:])) @expose("/export_dashboards_form") def download_dashboards(self): - ids = request.args.getlist('id') - args = ''.join('&id={}'.format(id) for id in ids) - dashboard_names = request.args.getlist('dashboard_name') - concat_dashboard_names = ', '.join(dashboard_names) - download_url = ( - '/dashboardmodelview/export_dashboards?{}'.format(args[1:]) - ) + if request.args.get('action') == 'go': + ids = request.args.getlist('id') + return Response( + models.Dashboard.export_dashboards(ids), + headers=generate_download_headers("pickle"), + mimetype="application/text") return self.render_template( 'caravel/export_dashboards.html', - dashboard_names=concat_dashboard_names, - download_url=download_url, dashboards_url='/dashboardmodelview/list' ) @@ -1201,13 +1185,15 @@ def explore_json(self, datasource_type, datasource_id): def import_dashboards(self): """Overrides the dashboards using pickled instances from the file.""" if request.method == 'POST': - file = request.files['file'] - if file: - filename = secure_filename(file.filename) + f = request.files['file'] + if f: + filename = secure_filename(f.filename) filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) - file.save(filepath) + f.save(filepath) + current_tt = int(time.time()) for d in pickle.load(open(filepath, 'rb')): - models.Dashboard.import_dashboard(d) + models.Dashboard.import_dashboard( + d, import_time=current_tt) os.remove(filepath) return redirect('/dashboardmodelview/list/') return self.render_template('caravel/import_dashboards.html') diff --git a/tests/core_tests.py b/tests/core_tests.py index 7fb32530b252..b6c471e7e71a 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -459,7 +459,7 @@ def test_export_dashboards(self): # export 1 dashboard export_dash_url_1 = ( - '/dashboardmodelview/export_dashboards?id={}' + '/dashboardmodelview/export_dashboards_form?id={}&action=go' .format(births_dash.id) ) resp_1 = self.client.get(export_dash_url_1) @@ -474,7 +474,7 @@ def test_export_dashboards(self): world_health_dash = db.session.query(models.Dashboard).filter_by( slug='world_health').first() export_dash_url_2 = ( - '/dashboardmodelview/export_dashboards?id={}&id={}' + '/dashboardmodelview/export_dashboards_form?id={}&id={}&action=go' .format(births_dash.id, world_health_dash.id) ) resp_2 = self.client.get(export_dash_url_2) @@ -485,6 +485,190 @@ def test_export_dashboards(self): self.assertEquals('world_health', exported_dashboards_2[1].slug) self.assertEquals(10, len(exported_dashboards_2[1].slices)) + def test_import_slice(self): + session = db.session + + def get_slice(name, ds_id=666, slc_id=1, + perm='[main].[wb_health_population](id:666)'): + return models.Slice( + slice_name=name, + datasource_type='table', + viz_type='bubble', + params='{"num_period_compare": "10"}', + # used to locate the datasource + perm=perm, + datasource_id=ds_id, + id=slc_id + ) + + def find_slice(slc_id): + return session.query(models.Slice).filter_by(id=slc_id).first() + + # Case 1. Import slice with different datasource id + slc_id_1 = models.Slice.import_slice( + get_slice('Import Me 1'), import_time=1989) + slc_1 = find_slice(slc_id_1) + table_id = db.session.query(models.SqlaTable).filter_by( + table_name='wb_health_population').first().id + self.assertEquals(table_id, slc_1.datasource_id) + self.assertEquals( + '[main].[wb_health_population](id:{})'.format(table_id), + slc_1.perm) + self.assertEquals('Import Me 1', slc_1.slice_name) + self.assertEquals('table', slc_1.datasource_type) + self.assertEquals('bubble', slc_1.viz_type) + self.assertEquals( + '{"num_period_compare": "10",' + ' "remote_id": 1, ' + '"import_time": 1989}', + slc_1.params) + + # Case 2. Import slice with same datasource id + slc_id_2 = models.Slice.import_slice(get_slice( + 'Import Me 2', ds_id=table_id, slc_id=2, + perm='[main].[wb_health_population](id:{})'.format(table_id))) + slc_2 = find_slice(slc_id_2) + self.assertEquals(table_id, slc_2.datasource_id) + self.assertEquals( + '[main].[wb_health_population](id:{})'.format(table_id), + slc_2.perm) + + # Case 3. Import slice with no permissions + slice_3 = get_slice('Import Me 3', slc_id=3) + slice_3.perm = None + slc_id_3 = models.Slice.import_slice(slice_3) + self.assertIsNone(slc_id_3) + + # Case 4. Import slice with non existent datasource + slc_id_4 = models.Slice.import_slice(get_slice( + 'Import Me 4', slc_id=4, perm='[main].[non_existent](id:999)')) + self.assertIsNone(slc_id_4) + + # Case 5. Slice already exists, no override + slc_id_5 = models.Slice.import_slice( + get_slice('Import Me 1'), import_time=1989) + slc_5 = find_slice(slc_id_5) + self.assertEquals(table_id, slc_5.datasource_id) + self.assertEquals( + '[main].[wb_health_population](id:{})'.format(table_id), + slc_5.perm) + self.assertEquals('Import Me 1', slc_5.slice_name) + self.assertEquals('table', slc_5.datasource_type) + self.assertEquals('bubble', slc_5.viz_type) + self.assertEquals( + '{"num_period_compare": "10",' + ' "remote_id": 1, ' + '"import_time": 1989}', + slc_5.params) + + # Case 6. Older version of slice already exists, override + slc_id_6 = models.Slice.import_slice( + get_slice('Import Me New 1'), import_time=1990) + slc_6 = find_slice(slc_id_6) + self.assertEquals(table_id, slc_6.datasource_id) + self.assertEquals( + '[main].[wb_health_population](id:{})'.format(table_id), + slc_5.perm) + self.assertEquals('Import Me New 1', slc_6.slice_name) + self.assertEquals('table', slc_6.datasource_type) + self.assertEquals('bubble', slc_6.viz_type) + self.assertEquals( + '{"num_period_compare": "10",' + ' "remote_id": 1, ' + '"import_time": 1990}', + slc_6.params) + + session.delete(slc_2) + session.delete(slc_6) + session.commit() + + def test_import_dashboard(self): + session = db.session + + def get_dashboard(title, id=0, slcs=[]): + return models.Dashboard( + id=id, + dashboard_title=title, + slices=slcs, + position_json='{"size_y": 2, "size_x": 2}', + slug='{}_imported'.format(title.lower()), + ) + + def find_dash(dash_id): + return session.query(models.Dashboard).filter_by( + id=dash_id).first() + + def get_slc(name, table, slc_id): + return models.Slice( + id=slc_id, + slice_name=name, + datasource_type='table', + perm='[main].[{}](id:{})'.format(table, slc_id) + ) + + # Case 1. Import empty dashboard + empty_dash = get_dashboard('empty_dashboard', id=1001) + imported_dash_id_1 = models.Dashboard.import_dashboard( + empty_dash, import_time=1989) + # Dashboard import fails if there are no slices + self.assertIsNone(find_dash(imported_dash_id_1)) + + # Case 2. Import dashboard with 1 new slice. + dash_with_1_slice = get_dashboard( + 'dash_with_1_slice', + slcs=[get_slc('health_slc', 'wb_health_population', 1001)], + id=1002) + imported_dash_id_2 = models.Dashboard.import_dashboard( + dash_with_1_slice, import_time=1990) + imported_dash_2 = find_dash(imported_dash_id_2) + self.assertEquals('dash_with_1_slice', imported_dash_2.dashboard_title) + self.assertEquals(1, len(imported_dash_2.slices)) + self.assertEquals('wb_health_population', + imported_dash_2.slices[0].datasource.table_name) + self.assertEquals('{"remote_id": 1002, "import_time": 1990}', + imported_dash_2.json_metadata) + + # Case 3. Import dashboard with 2 new slices. + dash_with_2_slices = get_dashboard( + 'dash_with_2_slices', + slcs=[get_slc('energy_slc', 'energy_usage', 1002), + get_slc('birth_slc', 'birth_names', 1003)], + id=1003) + imported_dash_id_3 = models.Dashboard.import_dashboard( + dash_with_2_slices, import_time=1991) + imported_dash_3 = find_dash(imported_dash_id_3) + self.assertEquals('dash_with_2_slices', imported_dash_3.dashboard_title) + self.assertEquals(2, len(imported_dash_3.slices)) + self.assertEquals('{"remote_id": 1003, "import_time": 1991}', + imported_dash_3.json_metadata) + imported_dash_slice_ids_3 = [s.id for s in imported_dash_3.slices] + + # Case 4. Override the dashboard and 2 slices + dash_with_2_slices_new = get_dashboard( + 'dash_with_2_slices_new', + slcs=[get_slc('energy_slc', 'energy_usage', 1002), + get_slc('birth_slc', 'birth_names', 1003)], + id=1003) + imported_dash_id_4 = models.Dashboard.import_dashboard( + dash_with_2_slices_new, import_time=1992) + imported_dash_4 = find_dash(imported_dash_id_4) + self.assertEquals('dash_with_2_slices_new', + imported_dash_4.dashboard_title) + self.assertEquals(2, len(imported_dash_4.slices)) + self.assertEquals( + imported_dash_slice_ids_3, [s.id for s in imported_dash_4.slices]) + self.assertEquals('{"remote_id": 1003, "import_time": 1992}', + imported_dash_4.json_metadata) + + # cleanup + slc_to_delete = set(imported_dash_2.slices) + slc_to_delete.update(imported_dash_4.slices) + for slc in slc_to_delete: + session.delete(slc) + session.delete(imported_dash_2) + session.delete(imported_dash_4) + session.commit() + if __name__ == '__main__': unittest.main() From 8d260617f5f0853d7b1e1720fbe47681faf6aa76 Mon Sep 17 00:00:00 2001 From: Bogdan Kyryliuk Date: Tue, 27 Sep 2016 14:55:51 -0700 Subject: [PATCH 03/13] Add function descriptions. --- caravel/models.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/caravel/models.py b/caravel/models.py index 5eb11a4d42e8..7a3532b1b671 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -290,6 +290,12 @@ def get_viz(self, url_params_multidict=None): @classmethod def import_slice(cls, slc, import_time=None): + """Inserts or overrides slc in the database. + + remote_id and import_time fields in params_dict are set to track the + slice origin and ensure correct overrides for multiple imports. + Slice.perm is used to find the datasources and connect them. + """ session = db.session make_transient(slc) slc.dashboards = [] @@ -436,6 +442,15 @@ def json_data(self): @classmethod def import_dashboard(cls, dashboard_to_import, import_time=None): + """Imports the dashboard from the object to the database. + + Once dashboard is imported, json_metadata field is extended and stores + remote_id and import_time. It helps to decide if the dashboard has to + be overridden or just copies over. Slices that belong to this dashboard + will be wired to existing tables using Slice.perm field to derive the + datasource name. This function can be used to import/export dashboards + between multiple caravel instances. Audit metadata isn't copies over. + """ logging.info('Started import of the dashboard: {}' .format(dashboard_to_import.to_json())) make_transient(dashboard_to_import) @@ -444,6 +459,8 @@ def import_dashboard(cls, dashboard_to_import, import_time=None): session = db.session logging.info('Dashboard has {} slices' .format(len(dashboard_to_import.slices))) + # copy slices object as Slice.import_slice will mutate the slice + # and will remove the existing dashboard - slice association slices = copy(dashboard_to_import.slices) for slc in slices: logging.info('Importing slice {} from the dashboard: {}'.format( From 3a3e0d66d3cd225a2a6e62a868dcae0524572ef9 Mon Sep 17 00:00:00 2001 From: Bogdan Kyryliuk Date: Tue, 27 Sep 2016 15:05:44 -0700 Subject: [PATCH 04/13] Minor fixes --- caravel/__init__.py | 2 +- caravel/templates/caravel/import_dashboards.html | 2 +- caravel/views.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/caravel/__init__.py b/caravel/__init__.py index 814b60ebceff..bb74091a714a 100644 --- a/caravel/__init__.py +++ b/caravel/__init__.py @@ -85,7 +85,7 @@ def checkout(dbapi_con, con_record, con_proxy): if app.config.get('UPLOAD_FOLDER'): try: os.makedirs(app.config.get('UPLOAD_FOLDER')) - except OSError, e: + except OSError: pass diff --git a/caravel/templates/caravel/import_dashboards.html b/caravel/templates/caravel/import_dashboards.html index 4f5082886094..af3c5e70fbf2 100644 --- a/caravel/templates/caravel/import_dashboards.html +++ b/caravel/templates/caravel/import_dashboards.html @@ -10,4 +10,4 @@

Import the dashboards.

-{% endblock %} \ No newline at end of file +{% endblock %} diff --git a/caravel/views.py b/caravel/views.py index 623c76dc5a81..b4fddc6e7e01 100755 --- a/caravel/views.py +++ b/caravel/views.py @@ -18,7 +18,7 @@ from flask import ( g, request, make_response, redirect, flash, Response, render_template, - Markup, url_for, send_file) + Markup, url_for) from flask_appbuilder import ModelView, CompactCRUDMixin, BaseView, expose from flask_appbuilder.actions import action from flask_appbuilder.models.sqla.interface import SQLAInterface @@ -895,7 +895,7 @@ def download_dashboards(self): appbuilder.add_view( DashboardModelView, "Dashboards", - label=__("Dashboards list"), + label=__("Dashboards List"), icon="fa-dashboard", category='Dashboards', category_icon='',) From 0035fe3d944238c4886146d350f230be8b15c8df Mon Sep 17 00:00:00 2001 From: Bogdan Kyryliuk Date: Tue, 27 Sep 2016 15:56:33 -0700 Subject: [PATCH 05/13] Fix tests for python 3. --- caravel/models.py | 3 +- .../templates/caravel/import_dashboards.html | 3 +- caravel/views.py | 2 +- tests/core_tests.py | 37 +++++++++---------- 4 files changed, 22 insertions(+), 23 deletions(-) diff --git a/caravel/models.py b/caravel/models.py index 7a3532b1b671..8d82fbd47c7c 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -299,6 +299,7 @@ def import_slice(cls, slc, import_time=None): session = db.session make_transient(slc) slc.dashboards = [] + # perms are named [{db_name}].[{datasource_name}](id:{id}) id_pos = slc.perm.find('(id') if slc.perm else -1 if id_pos == -1: logging.warning('slice {} has malformed perm {}'.format( @@ -516,7 +517,7 @@ def import_dashboard(cls, dashboard_to_import, import_time=None): slices=slices_to_attach, ) session.add(new_dash) - session.commit() + session.flush() logging.info('Imported the dashboard: {}'.format(new_dash.to_json())) return new_dash.id diff --git a/caravel/templates/caravel/import_dashboards.html b/caravel/templates/caravel/import_dashboards.html index af3c5e70fbf2..b36901d2906f 100644 --- a/caravel/templates/caravel/import_dashboards.html +++ b/caravel/templates/caravel/import_dashboards.html @@ -5,9 +5,10 @@
Import the dashboards.

Import the dashboards.

-
+

+

{% endblock %} diff --git a/caravel/views.py b/caravel/views.py index b4fddc6e7e01..c08ac741156a 100755 --- a/caravel/views.py +++ b/caravel/views.py @@ -903,7 +903,7 @@ def download_dashboards(self): appbuilder.add_link( 'Import Dashboards', href='/caravel/import_dashboards', - icon="fa-flask", + icon="fa-cloud-upload", category='Dashboards') diff --git a/tests/core_tests.py b/tests/core_tests.py index b6c471e7e71a..aeb8557e5db7 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -518,10 +518,8 @@ def find_slice(slc_id): self.assertEquals('table', slc_1.datasource_type) self.assertEquals('bubble', slc_1.viz_type) self.assertEquals( - '{"num_period_compare": "10",' - ' "remote_id": 1, ' - '"import_time": 1989}', - slc_1.params) + {"num_period_compare": "10", "remote_id": 1, "import_time": 1989}, + json.loads(slc_1.params)) # Case 2. Import slice with same datasource id slc_id_2 = models.Slice.import_slice(get_slice( @@ -556,10 +554,8 @@ def find_slice(slc_id): self.assertEquals('table', slc_5.datasource_type) self.assertEquals('bubble', slc_5.viz_type) self.assertEquals( - '{"num_period_compare": "10",' - ' "remote_id": 1, ' - '"import_time": 1989}', - slc_5.params) + {"num_period_compare": "10", "remote_id": 1, "import_time": 1989}, + json.loads(slc_5.params)) # Case 6. Older version of slice already exists, override slc_id_6 = models.Slice.import_slice( @@ -573,10 +569,8 @@ def find_slice(slc_id): self.assertEquals('table', slc_6.datasource_type) self.assertEquals('bubble', slc_6.viz_type) self.assertEquals( - '{"num_period_compare": "10",' - ' "remote_id": 1, ' - '"import_time": 1990}', - slc_6.params) + {"num_period_compare": "10", "remote_id": 1, "import_time": 1990}, + json.loads(slc_6.params)) session.delete(slc_2) session.delete(slc_6) @@ -625,8 +619,9 @@ def get_slc(name, table, slc_id): self.assertEquals(1, len(imported_dash_2.slices)) self.assertEquals('wb_health_population', imported_dash_2.slices[0].datasource.table_name) - self.assertEquals('{"remote_id": 1002, "import_time": 1990}', - imported_dash_2.json_metadata) + self.assertEquals( + {"remote_id": 1002, "import_time": 1990}, + json.loads(imported_dash_2.json_metadata)) # Case 3. Import dashboard with 2 new slices. dash_with_2_slices = get_dashboard( @@ -639,9 +634,10 @@ def get_slc(name, table, slc_id): imported_dash_3 = find_dash(imported_dash_id_3) self.assertEquals('dash_with_2_slices', imported_dash_3.dashboard_title) self.assertEquals(2, len(imported_dash_3.slices)) - self.assertEquals('{"remote_id": 1003, "import_time": 1991}', - imported_dash_3.json_metadata) - imported_dash_slice_ids_3 = [s.id for s in imported_dash_3.slices] + self.assertEquals( + {"remote_id": 1003, "import_time": 1991}, + json.loads(imported_dash_3.json_metadata)) + imported_dash_slice_ids_3 = {s.id for s in imported_dash_3.slices} # Case 4. Override the dashboard and 2 slices dash_with_2_slices_new = get_dashboard( @@ -656,9 +652,10 @@ def get_slc(name, table, slc_id): imported_dash_4.dashboard_title) self.assertEquals(2, len(imported_dash_4.slices)) self.assertEquals( - imported_dash_slice_ids_3, [s.id for s in imported_dash_4.slices]) - self.assertEquals('{"remote_id": 1003, "import_time": 1992}', - imported_dash_4.json_metadata) + imported_dash_slice_ids_3, {s.id for s in imported_dash_4.slices}) + self.assertEquals( + {"remote_id": 1003, "import_time": 1992}, + json.loads(imported_dash_4.json_metadata)) # cleanup slc_to_delete = set(imported_dash_2.slices) From 3e506bd744b7479f21956856081531677bd94785 Mon Sep 17 00:00:00 2001 From: Bogdan Kyryliuk Date: Tue, 4 Oct 2016 11:25:33 -0700 Subject: [PATCH 06/13] Export datasources. --- caravel/models.py | 206 +++++++++++++++--------------- caravel/source_registry.py | 27 ++++ caravel/views.py | 10 +- tests/core_tests.py | 215 ------------------------------- tests/import_export_tests.py | 240 +++++++++++++++++++++++++++++++++++ 5 files changed, 376 insertions(+), 322 deletions(-) create mode 100644 tests/import_export_tests.py diff --git a/caravel/models.py b/caravel/models.py index 8d82fbd47c7c..9d88a36f933f 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -20,7 +20,6 @@ import sqlalchemy as sqla from sqlalchemy.engine.url import make_url from sqlalchemy.orm import joinedload -from sqlalchemy.orm.session import make_transient import sqlparse from dateutil.parser import parse @@ -288,6 +287,24 @@ def get_viz(self, url_params_multidict=None): slice_=self ) + def alter_params(self, **kwargs): + params_dict = {} + if self.params: + params_dict = json.loads(self.params) + for key, value in kwargs.iteritems(): + params_dict[key] = value + self.params = json.dumps(params_dict) + + def override(self, slice): + """Overrides the plain fields of the dashboard.""" + for field in Slice.export_fields(): + setattr(self, field, getattr(slice, field)) + + @classmethod + def export_fields(cls): + return ('slice_name', 'datasource_type', 'datasource_name', 'viz_type', + 'params', 'cache_timeout') + @classmethod def import_slice(cls, slc, import_time=None): """Inserts or overrides slc in the database. @@ -299,61 +316,28 @@ def import_slice(cls, slc, import_time=None): session = db.session make_transient(slc) slc.dashboards = [] - # perms are named [{db_name}].[{datasource_name}](id:{id}) - id_pos = slc.perm.find('(id') if slc.perm else -1 - if id_pos == -1: - logging.warning('slice {} has malformed perm {}'.format( - slc.slice_name, slc.perm)) - return None - - perm_prefix = slc.perm[:id_pos] - datasource_class = SourceRegistry.sources[slc.datasource_type] - datasources = session.query(datasource_class).all() - if not datasources: - logging.warning('Datasource {} was not found for the slice {}.' - .format(slc.datasource_id, slc.slice_name)) - return None - - params_dict = {} - if slc.params: - params_dict = json.loads(slc.params) - if 'remote_id' not in params_dict: - params_dict['remote_id'] = slc.id - params_dict['import_time'] = import_time - slc.params = json.dumps(params_dict) + slc.alter_params(remote_id=slc.id, import_time=import_time) # find if the slice was already imported slc_to_override = ( session.query(Slice) - .filter(Slice.params.ilike('%"remote_id": {}%'.format(slc.id))) + .filter(Slice.params.ilike('%"remote_id": {},%'.format(slc.id))) .first() ) slc.id = None + params = json.loads(slc.params) + slc.datasource_id = SourceRegistry.get_datasource_by_name( + session, slc.datasource_type, params['datasource_name'], + params['schema'], params['database_name']).id if slc_to_override: - slc_params = json.loads(slc_to_override.params) - if slc_params.get('import_time') == import_time: - # slice was imported in the current batch, no changes needed: - return slc_to_override.id - slc.id = slc_to_override.id - # delete old version of the slice - session.delete(slc_to_override) - session.commit() - - # add slice pointing to the right datasources - slc_dss = [d for d in datasources if perm_prefix in d.perm] - if not slc_dss: - logging.warning( - 'Datasources were not found for the slice {} with perm {} ' - .format(slc.slice_name, slc.perm)) - return None - slc.datasource_id = slc_dss[0].id - # generate correct perms for the slice - slc.perm = '{}(id:{})'.format(perm_prefix, slc.datasource_id) - session.add(slc) - logging.info('Imported the slice: {}'.format(slc.to_json())) - session.flush() - return slc.id + slc_to_override.override(slc) + session.flush() + return slc_to_override.id + else: + session.add(slc) + session.flush() + return slc.id def set_perm(mapper, connection, target): # noqa @@ -441,6 +425,14 @@ def json_data(self): } return json.dumps(d) + def alter_json_metadata(self, **kwargs): + json_metadata_dict = {} + if self.json_metadata: + json_metadata_dict = json.loads(self.json_metadata) + for key, value in kwargs.iteritems(): + json_metadata_dict[key] = value + self.json_metadata = json.dumps(json_metadata_dict) + @classmethod def import_dashboard(cls, dashboard_to_import, import_time=None): """Imports the dashboard from the object to the database. @@ -454,86 +446,86 @@ def import_dashboard(cls, dashboard_to_import, import_time=None): """ logging.info('Started import of the dashboard: {}' .format(dashboard_to_import.to_json())) - make_transient(dashboard_to_import) + # make_transient(dashboard_to_import) # TODO: provide nice error message about failed imports - slices_to_attach = [] session = db.session logging.info('Dashboard has {} slices' .format(len(dashboard_to_import.slices))) # copy slices object as Slice.import_slice will mutate the slice # and will remove the existing dashboard - slice association slices = copy(dashboard_to_import.slices) + slice_ids = set() for slc in slices: logging.info('Importing slice {} from the dashboard: {}'.format( slc.to_json(), dashboard_to_import.dashboard_title)) - slc_id = Slice.import_slice(slc, import_time=import_time) - if not slc_id: - continue - slices_to_attach.append( - session.query(Slice).filter_by(id=slc_id).first()) - logging.info( - slc.slice_name + ' belongs to the dashboard ' + - dashboard_to_import.dashboard_title) - # Skip dashboard creation if the slice wasn't imported - if not slices_to_attach: - logging.warning('Dashboard {} import failed'.format( - dashboard_to_import.dashboard_title)) - return + slice_ids.add(Slice.import_slice(slc, import_time=import_time)) - if dashboard_to_import.json_metadata: - json_metadata_dict = json.loads(dashboard_to_import.json_metadata) - else: - json_metadata_dict = {} - if 'remote_id' not in json_metadata_dict: - json_metadata_dict['remote_id'] = dashboard_to_import.id - json_metadata_dict['import_time'] = import_time + dashboard_to_import.alter_json_metadata( + remote_id=dashboard_to_import.id, import_time=import_time) + dashboard_to_import.id = None + dashboard_to_import.slices = session.query(Slice).filter( + Slice.id.in_(slice_ids)).all() # Override the dashboard existing_dashboard = ( session.query(Dashboard) .filter(Dashboard.json_metadata.ilike( - '%"remote_id": {}%'.format(dashboard_to_import.id))) + '%"remote_id": {},%'.format(dashboard_to_import.id))) .first() ) - dashboard_to_import.id = None + if existing_dashboard: - # TODO: add nice notification about the overriden dashboard - logging.info('Dashboard already exist and will be overridden: {}' - .format(existing_dashboard.to_json())) - - session.delete(existing_dashboard) - session.commit() - - # session.add(dashboard_to_import) causes sqlachemy failures - # related to the attached users / slices. Creating new object - # allows to avoid conflicts in the sql alchemy state. - new_dash = Dashboard( - dashboard_title=dashboard_to_import.dashboard_title, - position_json=dashboard_to_import.position_json, - json_metadata=json.dumps(json_metadata_dict), - description=dashboard_to_import.description, - css=dashboard_to_import.css, - slug=dashboard_to_import.slug, - slices=slices_to_attach, - ) - session.add(new_dash) - session.flush() - logging.info('Imported the dashboard: {}'.format(new_dash.to_json())) - return new_dash.id + existing_dashboard.override(dashboard_to_import) + session.flush() + return existing_dashboard.id + else: + # session.add(dashboard_to_import) causes sqlachemy failures + # related to the attached users / slices. Creating new object + # allows to avoid conflicts in the sql alchemy state. + copied_dash = dashboard_to_import.copy() + session.add(copied_dash) + session.flush() + return copied_dash.id + + def override(self, dashboard): + """Overrides the plain fields of the dashboard.""" + for field in Dashboard.export_fields(): + setattr(self, field, getattr(dashboard, field)) + + def copy(self): + """Creates a copy of the dashboard without relationships.""" + dashboard = Dashboard() + dashboard.override(self) + return dashboard + + @classmethod + def export_fields(cls): + return ('dashboard_title', 'position_json', 'json_metadata', + 'description', 'css', 'slug', 'slices') @classmethod def export_dashboards(cls, dashboard_ids): eager_dashboards = [] + datasource_set = set() for dashboard_id in dashboard_ids: eager_dashboard = db.session.query(Dashboard).options( joinedload('slices')).filter_by(id=dashboard_id).first() - db.session.expunge(eager_dashboard) - make_transient(eager_dashboard) - eager_dashboard.created_by_fk = None - eager_dashboard.changed_by_fk = None - eager_dashboard.owners = [] - eager_dashboards.append(eager_dashboard) - return pickle.dumps(eager_dashboards) + for slc in eager_dashboard.slices: + if slc.datasource not in datasource_set: + datasource_set.add(slc.datasource) + # add extra params for the import + slc.alter_params( + remote_id=slc.id, + datasource_name=slc.datasource.name, + schema=slc.datasource.name, + database_name=slc.datasource.database.database_name, + ) + + eager_dashboards.append(eager_dashboard.copy()) + return pickle.dumps({ + 'dashboards': eager_dashboards, + 'datasources': list(datasource_set), + }) class Queryable(object): @@ -599,6 +591,10 @@ class Database(Model, AuditMixinNullable): def __repr__(self): return self.database_name + @property + def name(self): + return self.database_name + @property def backend(self): url = make_url(self.sqlalchemy_uri_decrypted) @@ -1400,6 +1396,10 @@ def refresh_datasources(self): def perm(self): return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self) + @property + def name(self): + return self.cluster_name + class DruidDatasource(Model, AuditMixinNullable, Queryable): @@ -1428,6 +1428,10 @@ class DruidDatasource(Model, AuditMixinNullable, Queryable): offset = Column(Integer, default=0) cache_timeout = Column(Integer) + @property + def database(self): + return self.cluster + @property def metrics_combo(self): return sorted( diff --git a/caravel/source_registry.py b/caravel/source_registry.py index dc1a170d7c68..d3a1efdefa37 100644 --- a/caravel/source_registry.py +++ b/caravel/source_registry.py @@ -1,3 +1,4 @@ +from sqlalchemy.orm import joinedload class SourceRegistry(object): @@ -20,3 +21,29 @@ def get_datasource(cls, datasource_type, datasource_id, session): .filter_by(id=datasource_id) .one() ) + + @classmethod + def get_datasource_by_name(cls, session, datasource_type, datasource_name, + schema, database_name): + datasource_class = SourceRegistry.sources[datasource_type] + datasources = session.query(datasource_class).all() + db_ds = [d for d in datasources if d.database.name == database_name and + d.name == datasource_name and schema == schema] + return db_ds[0] + + + @classmethod + def get_eager_datasource(cls, session, datasource_type, datasource_id): + """Returns datasource with columns and metrics.""" + datasource_class = SourceRegistry.sources[datasource_type] + if datasource_type == 'table': + return ( + session.query(datasource_class) + .options(joinedload('table_columns') + .joinedload('sql_metrics')) + .filter_by(id=datasource_id) + .one() + ) + # TODO: support druid datasources. + return session.query(datasource_class).filter_by( + id=datasource_id).first() diff --git a/caravel/views.py b/caravel/views.py index c08ac741156a..61fbdd4b5d89 100755 --- a/caravel/views.py +++ b/caravel/views.py @@ -1083,9 +1083,8 @@ def approve(self): role_to_extend = request.args.get('role_to_extend') session = db.session - datasource_class = SourceRegistry.sources[datasource_type] - datasource = session.query(datasource_class).filter_by( - id=datasource_id).first() + datasource = SourceRegistry.get_datasource( + datasource_type, datasource_id, session) if not datasource: flash(DATASOURCE_MISSING_ERR, "alert") @@ -1156,7 +1155,6 @@ def get_viz( return viz_obj @has_access -<<<<<<< f70d301f0d37fe6c4f94cd7d6026a72a56087d34 @expose("/slice//") def slice(self, slice_id): viz_obj = self.get_viz(slice_id) @@ -1191,9 +1189,9 @@ def import_dashboards(self): filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) f.save(filepath) current_tt = int(time.time()) - for d in pickle.load(open(filepath, 'rb')): + for data in pickle.load(open(filepath, 'rb')): models.Dashboard.import_dashboard( - d, import_time=current_tt) + data['dashboards'], import_time=current_tt) os.remove(filepath) return redirect('/dashboardmodelview/list/') return self.render_template('caravel/import_dashboards.html') diff --git a/tests/core_tests.py b/tests/core_tests.py index aeb8557e5db7..bd09bfd8ca22 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -8,7 +8,6 @@ import doctest import json import io -import pickle import random import unittest @@ -453,219 +452,5 @@ def test_only_owners_can_save(self): db.session.commit() self.test_save_dash('alpha') - def test_export_dashboards(self): - births_dash = db.session.query(models.Dashboard).filter_by( - slug='births').first() - - # export 1 dashboard - export_dash_url_1 = ( - '/dashboardmodelview/export_dashboards_form?id={}&action=go' - .format(births_dash.id) - ) - resp_1 = self.client.get(export_dash_url_1) - exported_dashboards_1 = pickle.loads(resp_1.data) - self.assertEquals(1, len(exported_dashboards_1)) - self.assertEquals('births', exported_dashboards_1[0].slug) - self.assertEquals('Births', exported_dashboards_1[0].dashboard_title) - self.assertTrue(exported_dashboards_1[0].position_json) - self.assertEquals(10, len(exported_dashboards_1[0].slices)) - - # export 2 dashboards - world_health_dash = db.session.query(models.Dashboard).filter_by( - slug='world_health').first() - export_dash_url_2 = ( - '/dashboardmodelview/export_dashboards_form?id={}&id={}&action=go' - .format(births_dash.id, world_health_dash.id) - ) - resp_2 = self.client.get(export_dash_url_2) - exported_dashboards_2 = pickle.loads(resp_2.data) - self.assertEquals(2, len(exported_dashboards_2)) - self.assertEquals('births', exported_dashboards_2[0].slug) - self.assertEquals(10, len(exported_dashboards_2[0].slices)) - self.assertEquals('world_health', exported_dashboards_2[1].slug) - self.assertEquals(10, len(exported_dashboards_2[1].slices)) - - def test_import_slice(self): - session = db.session - - def get_slice(name, ds_id=666, slc_id=1, - perm='[main].[wb_health_population](id:666)'): - return models.Slice( - slice_name=name, - datasource_type='table', - viz_type='bubble', - params='{"num_period_compare": "10"}', - # used to locate the datasource - perm=perm, - datasource_id=ds_id, - id=slc_id - ) - - def find_slice(slc_id): - return session.query(models.Slice).filter_by(id=slc_id).first() - - # Case 1. Import slice with different datasource id - slc_id_1 = models.Slice.import_slice( - get_slice('Import Me 1'), import_time=1989) - slc_1 = find_slice(slc_id_1) - table_id = db.session.query(models.SqlaTable).filter_by( - table_name='wb_health_population').first().id - self.assertEquals(table_id, slc_1.datasource_id) - self.assertEquals( - '[main].[wb_health_population](id:{})'.format(table_id), - slc_1.perm) - self.assertEquals('Import Me 1', slc_1.slice_name) - self.assertEquals('table', slc_1.datasource_type) - self.assertEquals('bubble', slc_1.viz_type) - self.assertEquals( - {"num_period_compare": "10", "remote_id": 1, "import_time": 1989}, - json.loads(slc_1.params)) - - # Case 2. Import slice with same datasource id - slc_id_2 = models.Slice.import_slice(get_slice( - 'Import Me 2', ds_id=table_id, slc_id=2, - perm='[main].[wb_health_population](id:{})'.format(table_id))) - slc_2 = find_slice(slc_id_2) - self.assertEquals(table_id, slc_2.datasource_id) - self.assertEquals( - '[main].[wb_health_population](id:{})'.format(table_id), - slc_2.perm) - - # Case 3. Import slice with no permissions - slice_3 = get_slice('Import Me 3', slc_id=3) - slice_3.perm = None - slc_id_3 = models.Slice.import_slice(slice_3) - self.assertIsNone(slc_id_3) - - # Case 4. Import slice with non existent datasource - slc_id_4 = models.Slice.import_slice(get_slice( - 'Import Me 4', slc_id=4, perm='[main].[non_existent](id:999)')) - self.assertIsNone(slc_id_4) - - # Case 5. Slice already exists, no override - slc_id_5 = models.Slice.import_slice( - get_slice('Import Me 1'), import_time=1989) - slc_5 = find_slice(slc_id_5) - self.assertEquals(table_id, slc_5.datasource_id) - self.assertEquals( - '[main].[wb_health_population](id:{})'.format(table_id), - slc_5.perm) - self.assertEquals('Import Me 1', slc_5.slice_name) - self.assertEquals('table', slc_5.datasource_type) - self.assertEquals('bubble', slc_5.viz_type) - self.assertEquals( - {"num_period_compare": "10", "remote_id": 1, "import_time": 1989}, - json.loads(slc_5.params)) - - # Case 6. Older version of slice already exists, override - slc_id_6 = models.Slice.import_slice( - get_slice('Import Me New 1'), import_time=1990) - slc_6 = find_slice(slc_id_6) - self.assertEquals(table_id, slc_6.datasource_id) - self.assertEquals( - '[main].[wb_health_population](id:{})'.format(table_id), - slc_5.perm) - self.assertEquals('Import Me New 1', slc_6.slice_name) - self.assertEquals('table', slc_6.datasource_type) - self.assertEquals('bubble', slc_6.viz_type) - self.assertEquals( - {"num_period_compare": "10", "remote_id": 1, "import_time": 1990}, - json.loads(slc_6.params)) - - session.delete(slc_2) - session.delete(slc_6) - session.commit() - - def test_import_dashboard(self): - session = db.session - - def get_dashboard(title, id=0, slcs=[]): - return models.Dashboard( - id=id, - dashboard_title=title, - slices=slcs, - position_json='{"size_y": 2, "size_x": 2}', - slug='{}_imported'.format(title.lower()), - ) - - def find_dash(dash_id): - return session.query(models.Dashboard).filter_by( - id=dash_id).first() - - def get_slc(name, table, slc_id): - return models.Slice( - id=slc_id, - slice_name=name, - datasource_type='table', - perm='[main].[{}](id:{})'.format(table, slc_id) - ) - - # Case 1. Import empty dashboard - empty_dash = get_dashboard('empty_dashboard', id=1001) - imported_dash_id_1 = models.Dashboard.import_dashboard( - empty_dash, import_time=1989) - # Dashboard import fails if there are no slices - self.assertIsNone(find_dash(imported_dash_id_1)) - - # Case 2. Import dashboard with 1 new slice. - dash_with_1_slice = get_dashboard( - 'dash_with_1_slice', - slcs=[get_slc('health_slc', 'wb_health_population', 1001)], - id=1002) - imported_dash_id_2 = models.Dashboard.import_dashboard( - dash_with_1_slice, import_time=1990) - imported_dash_2 = find_dash(imported_dash_id_2) - self.assertEquals('dash_with_1_slice', imported_dash_2.dashboard_title) - self.assertEquals(1, len(imported_dash_2.slices)) - self.assertEquals('wb_health_population', - imported_dash_2.slices[0].datasource.table_name) - self.assertEquals( - {"remote_id": 1002, "import_time": 1990}, - json.loads(imported_dash_2.json_metadata)) - - # Case 3. Import dashboard with 2 new slices. - dash_with_2_slices = get_dashboard( - 'dash_with_2_slices', - slcs=[get_slc('energy_slc', 'energy_usage', 1002), - get_slc('birth_slc', 'birth_names', 1003)], - id=1003) - imported_dash_id_3 = models.Dashboard.import_dashboard( - dash_with_2_slices, import_time=1991) - imported_dash_3 = find_dash(imported_dash_id_3) - self.assertEquals('dash_with_2_slices', imported_dash_3.dashboard_title) - self.assertEquals(2, len(imported_dash_3.slices)) - self.assertEquals( - {"remote_id": 1003, "import_time": 1991}, - json.loads(imported_dash_3.json_metadata)) - imported_dash_slice_ids_3 = {s.id for s in imported_dash_3.slices} - - # Case 4. Override the dashboard and 2 slices - dash_with_2_slices_new = get_dashboard( - 'dash_with_2_slices_new', - slcs=[get_slc('energy_slc', 'energy_usage', 1002), - get_slc('birth_slc', 'birth_names', 1003)], - id=1003) - imported_dash_id_4 = models.Dashboard.import_dashboard( - dash_with_2_slices_new, import_time=1992) - imported_dash_4 = find_dash(imported_dash_id_4) - self.assertEquals('dash_with_2_slices_new', - imported_dash_4.dashboard_title) - self.assertEquals(2, len(imported_dash_4.slices)) - self.assertEquals( - imported_dash_slice_ids_3, {s.id for s in imported_dash_4.slices}) - self.assertEquals( - {"remote_id": 1003, "import_time": 1992}, - json.loads(imported_dash_4.json_metadata)) - - # cleanup - slc_to_delete = set(imported_dash_2.slices) - slc_to_delete.update(imported_dash_4.slices) - for slc in slc_to_delete: - session.delete(slc) - session.delete(imported_dash_2) - session.delete(imported_dash_4) - session.commit() - - if __name__ == '__main__': unittest.main() diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py new file mode 100644 index 000000000000..558479a2adc6 --- /dev/null +++ b/tests/import_export_tests.py @@ -0,0 +1,240 @@ +"""Unit tests for Caravel""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import json +import pickle +import unittest + +from caravel import db, models + +from .base_tests import CaravelTestCase + + +class ImportExportTests(CaravelTestCase): + """Testing export import functionality for dashboards""" + + examples_loaded = False + + def __init__(self, *args, **kwargs): + super(ImportExportTests, self).__init__(*args, **kwargs) + + def test_export_dashboards(self): + births_dash = db.session.query(models.Dashboard).filter_by( + slug='births').first() + + # export 1 dashboard + export_dash_url_1 = ( + '/dashboardmodelview/export_dashboards_form?id={}&action=go' + .format(births_dash.id) + ) + resp_1 = self.client.get(export_dash_url_1) + exported_dashboards_1 = pickle.loads(resp_1.data)['dashboards'] + self.assertEquals(1, len(exported_dashboards_1)) + self.assertEquals('births', exported_dashboards_1[0].slug) + self.assertEquals('Births', exported_dashboards_1[0].dashboard_title) + self.assertTrue(exported_dashboards_1[0].position_json) + self.assertEquals(9, len(exported_dashboards_1[0].slices)) + + # export 2 dashboards + world_health_dash = db.session.query(models.Dashboard).filter_by( + slug='world_health').first() + export_dash_url_2 = ( + '/dashboardmodelview/export_dashboards_form?id={}&id={}&action=go' + .format(births_dash.id, world_health_dash.id) + ) + resp_2 = self.client.get(export_dash_url_2) + exported_dashboards_2 = pickle.loads(resp_2.data)['dashboards'] + self.assertEquals(2, len(exported_dashboards_2)) + self.assertEquals('births', exported_dashboards_2[0].slug) + self.assertEquals(9, len(exported_dashboards_2[0].slices)) + self.assertEquals('world_health', exported_dashboards_2[1].slug) + self.assertEquals(10, len(exported_dashboards_2[1].slices)) + + def test_import_slice(self): + session = db.session + + def get_slice(name, ds_id=666, slc_id=1, + db_name='main', + table_name='wb_health_population'): + params = { + 'num_period_compare': '10', + 'remote_id': slc_id, + 'datasource_name': table_name, + 'database_name': db_name, + 'schema': '', + } + return models.Slice( + slice_name=name, + datasource_type='table', + viz_type='bubble', + params=json.dumps(params), + datasource_id=ds_id, + id=slc_id + ) + + def find_slice(slc_id): + return session.query(models.Slice).filter_by(id=slc_id).first() + + # Case 1. Import slice with different datasource id + slc_id_1 = models.Slice.import_slice( + get_slice('Import Me 1'), import_time=1989) + slc_1 = find_slice(slc_id_1) + table_id = db.session.query(models.SqlaTable).filter_by( + table_name='wb_health_population').first().id + self.assertEquals(table_id, slc_1.datasource_id) + self.assertEquals( + '[main].[wb_health_population](id:{})'.format(table_id), + slc_1.perm) + self.assertEquals('Import Me 1', slc_1.slice_name) + self.assertEquals('table', slc_1.datasource_type) + self.assertEquals('bubble', slc_1.viz_type) + self.assertEquals( + {'num_period_compare': '10', + 'datasource_name': 'wb_health_population', + 'database_name': 'main', + 'remote_id': 1, + 'schema': '', + 'import_time': 1989,}, + json.loads(slc_1.params)) + + # Case 2. Import slice with same datasource id + slc_id_2 = models.Slice.import_slice(get_slice( + 'Import Me 2', ds_id=table_id, slc_id=2)) + slc_2 = find_slice(slc_id_2) + self.assertEquals(table_id, slc_2.datasource_id) + self.assertEquals( + '[main].[wb_health_population](id:{})'.format(table_id), + slc_2.perm) + + # Case 3. Import slice with non existent datasource + with self.assertRaises(IndexError): + models.Slice.import_slice(get_slice( + 'Import Me 3', slc_id=3, table_name='non_existent')) + + # Case 4. Older version of slice already exists, override + slc_id_4 = models.Slice.import_slice( + get_slice('Import Me New 1'), import_time=1990) + slc_4 = find_slice(slc_id_4) + self.assertEquals(table_id, slc_4.datasource_id) + self.assertEquals( + '[main].[wb_health_population](id:{})'.format(table_id), + slc_4.perm) + self.assertEquals('Import Me New 1', slc_4.slice_name) + self.assertEquals('table', slc_4.datasource_type) + self.assertEquals('bubble', slc_4.viz_type) + self.assertEquals( + {'num_period_compare': '10', + 'datasource_name': 'wb_health_population', + 'database_name': 'main', + 'remote_id': 1, + 'schema': '', + 'import_time': 1990 + }, + json.loads(slc_4.params)) + + session.delete(slc_2) + session.delete(slc_4) + session.commit() + + def test_import_dashboard(self): + session = db.session + + def get_dashboard(title, id=0, slcs=[]): + return models.Dashboard( + id=id, + dashboard_title=title, + slices=slcs, + position_json='{"size_y": 2, "size_x": 2}', + slug='{}_imported'.format(title.lower()), + ) + + def find_dash(dash_id): + return session.query(models.Dashboard).filter_by( + id=dash_id).first() + + def get_slc(name, table_name, slc_id): + params = { + 'remote_id': slc_id, + 'datasource_name': table_name, + 'database_name': 'main', + 'schema': '', + } + return models.Slice( + id=slc_id, + slice_name=name, + datasource_type='table', + params=json.dumps(params) + ) + + # Case 1. Import empty dashboard can be imported + empty_dash = get_dashboard('empty_dashboard', id=1001) + imported_dash_id_1 = models.Dashboard.import_dashboard( + empty_dash, import_time=1989) + # Dashboard import fails if there are no slices + self.assertEquals( + 'empty_dashboard', find_dash(imported_dash_id_1).dashboard_title) + + # Case 2. Import dashboard with 1 new slice. + dash_with_1_slice = get_dashboard( + 'dash_with_1_slice', + slcs=[get_slc('health_slc', 'wb_health_population', 1001)], + id=1002) + imported_dash_id_2 = models.Dashboard.import_dashboard( + dash_with_1_slice, import_time=1990) + imported_dash_2 = find_dash(imported_dash_id_2) + self.assertEquals('dash_with_1_slice', imported_dash_2.dashboard_title) + self.assertEquals(1, len(imported_dash_2.slices)) + self.assertEquals('wb_health_population', + imported_dash_2.slices[0].datasource.table_name) + self.assertEquals( + {"remote_id": 1002, "import_time": 1990}, + json.loads(imported_dash_2.json_metadata)) + + # Case 3. Import dashboard with 2 new slices. + dash_with_2_slices = get_dashboard( + 'dash_with_2_slices', + slcs=[get_slc('energy_slc', 'energy_usage', 1002), + get_slc('birth_slc', 'birth_names', 1003)], + id=1003) + imported_dash_id_3 = models.Dashboard.import_dashboard( + dash_with_2_slices, import_time=1991) + imported_dash_3 = find_dash(imported_dash_id_3) + self.assertEquals('dash_with_2_slices', imported_dash_3.dashboard_title) + self.assertEquals(2, len(imported_dash_3.slices)) + self.assertEquals( + {"remote_id": 1003, "import_time": 1991}, + json.loads(imported_dash_3.json_metadata)) + imported_dash_slice_ids_3 = {s.id for s in imported_dash_3.slices} + + # Case 4. Override the dashboard and 2 slices + dash_with_2_slices_new = get_dashboard( + 'dash_with_2_slices_new', + slcs=[get_slc('energy_slc', 'energy_usage', 1002), + get_slc('birth_slc', 'birth_names', 1003)], + id=1003) + imported_dash_id_4 = models.Dashboard.import_dashboard( + dash_with_2_slices_new, import_time=1992) + imported_dash_4 = find_dash(imported_dash_id_4) + self.assertEquals('dash_with_2_slices_new', + imported_dash_4.dashboard_title) + self.assertEquals(2, len(imported_dash_4.slices)) + self.assertEquals( + imported_dash_slice_ids_3, {s.id for s in imported_dash_4.slices}) + self.assertEquals( + {"remote_id": 1003, "import_time": 1992}, + json.loads(imported_dash_4.json_metadata)) + + # cleanup + slc_to_delete = set(imported_dash_2.slices) + slc_to_delete.update(imported_dash_4.slices) + for slc in slc_to_delete: + session.delete(slc) + session.delete(imported_dash_2) + session.delete(imported_dash_4) + session.commit() + +if __name__ == '__main__': + unittest.main() From 6a8ee226e20bd30e7901fc4a9e2d55dc15d5a05d Mon Sep 17 00:00:00 2001 From: Bogdan Kyryliuk Date: Wed, 5 Oct 2016 15:32:29 -0700 Subject: [PATCH 07/13] Implement tables import. --- caravel/data/__init__.py | 4 +- caravel/migrations/versions/b46fa1b0b39e_.py | 28 ++ caravel/models.py | 268 ++++++++++++++----- caravel/source_registry.py | 10 +- caravel/views.py | 10 +- tests/import_export_tests.py | 204 ++++++++++++-- 6 files changed, 420 insertions(+), 104 deletions(-) create mode 100644 caravel/migrations/versions/b46fa1b0b39e_.py diff --git a/caravel/data/__init__.py b/caravel/data/__init__.py index 974ce2e898ba..0f9e2a308efa 100644 --- a/caravel/data/__init__.py +++ b/caravel/data/__init__.py @@ -1058,7 +1058,7 @@ def load_multiformat_time_series_data(): 'string0': ['%Y-%m-%d %H:%M:%S.%f', None], 'string3': ['%Y/%m/%d%H:%M:%S.%f', None], } - for col in obj.table_columns: + for col in obj.columns: dttm_and_expr = dttm_and_expr_dict[col.column_name] col.python_date_format = dttm_and_expr[0] col.dbatabase_expr = dttm_and_expr[1] @@ -1069,7 +1069,7 @@ def load_multiformat_time_series_data(): tbl = obj print("Creating some slices") - for i, col in enumerate(tbl.table_columns): + for i, col in enumerate(tbl.columns): slice_data = { "granularity_sqla": col.column_name, "datasource_id": "8", diff --git a/caravel/migrations/versions/b46fa1b0b39e_.py b/caravel/migrations/versions/b46fa1b0b39e_.py new file mode 100644 index 000000000000..372fdbed8a52 --- /dev/null +++ b/caravel/migrations/versions/b46fa1b0b39e_.py @@ -0,0 +1,28 @@ +"""Add json_metadata to the tables table. + +Revision ID: b46fa1b0b39e +Revises: ef8843b41dac +Create Date: 2016-10-05 11:30:31.748238 + +""" + +# revision identifiers, used by Alembic. +revision = 'b46fa1b0b39e' +down_revision = 'ef8843b41dac' + +from alembic import op +import logging +import sqlalchemy as sa + + +def upgrade(): + op.add_column('tables', + sa.Column('json_metadata', sa.Text(), nullable=True)) + + +def downgrade(): + try: + op.drop_column('tables', 'json_metadata') + except Exception as e: + logging.warning(str(e)) + diff --git a/caravel/models.py b/caravel/models.py index 9d88a36f933f..4a68cc908eb3 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -19,7 +19,7 @@ import requests import sqlalchemy as sqla from sqlalchemy.engine.url import make_url -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import subqueryload import sqlparse from dateutil.parser import parse @@ -74,6 +74,19 @@ def __init__(self, name, field_names, function): self.name = name +class ImportMixin(object): + def override(self, obj): + """Overrides the plain fields of the dashboard.""" + for field in obj.__class__.export_fields: + setattr(self, field, getattr(obj, field)) + + def copy(self): + """Creates a copy of the dashboard without relationships.""" + new_obj = self.__class__() + new_obj.override(self) + return new_obj + + class AuditMixinNullable(AuditMixin): """Altering the AuditMixin to use nullable fields @@ -153,7 +166,7 @@ class CssTemplate(Model, AuditMixinNullable): ) -class Slice(Model, AuditMixinNullable): +class Slice(Model, AuditMixinNullable, ImportMixin): """A slice is essentially a report or a view on data""" @@ -170,6 +183,9 @@ class Slice(Model, AuditMixinNullable): perm = Column(String(2000)) owners = relationship("User", secondary=slice_user) + export_fields = ('slice_name', 'datasource_type', 'datasource_name', + 'viz_type', 'params', 'cache_timeout') + def __repr__(self): return self.slice_name @@ -291,22 +307,12 @@ def alter_params(self, **kwargs): params_dict = {} if self.params: params_dict = json.loads(self.params) - for key, value in kwargs.iteritems(): - params_dict[key] = value + for key in kwargs: + params_dict[key] = kwargs[key] self.params = json.dumps(params_dict) - def override(self, slice): - """Overrides the plain fields of the dashboard.""" - for field in Slice.export_fields(): - setattr(self, field, getattr(slice, field)) - @classmethod - def export_fields(cls): - return ('slice_name', 'datasource_type', 'datasource_name', 'viz_type', - 'params', 'cache_timeout') - - @classmethod - def import_slice(cls, slc, import_time=None): + def import_obj(cls, slc, import_time=None): """Inserts or overrides slc in the database. remote_id and import_time fields in params_dict are set to track the @@ -365,7 +371,7 @@ def set_perm(mapper, connection, target): # noqa ) -class Dashboard(Model, AuditMixinNullable): +class Dashboard(Model, AuditMixinNullable, ImportMixin): """The dashboard object!""" @@ -381,6 +387,9 @@ class Dashboard(Model, AuditMixinNullable): 'Slice', secondary=dashboard_slices, backref='dashboards') owners = relationship("User", secondary=dashboard_user) + export_fields = ('dashboard_title', 'position_json', 'json_metadata', + 'description', 'css', 'slug', 'slices') + def __repr__(self): return self.dashboard_title @@ -429,25 +438,23 @@ def alter_json_metadata(self, **kwargs): json_metadata_dict = {} if self.json_metadata: json_metadata_dict = json.loads(self.json_metadata) - for key, value in kwargs.iteritems(): - json_metadata_dict[key] = value + for key in kwargs: + json_metadata_dict[key] = kwargs[key] self.json_metadata = json.dumps(json_metadata_dict) @classmethod - def import_dashboard(cls, dashboard_to_import, import_time=None): + def import_obj(cls, dashboard_to_import, import_time=None): """Imports the dashboard from the object to the database. - Once dashboard is imported, json_metadata field is extended and stores - remote_id and import_time. It helps to decide if the dashboard has to - be overridden or just copies over. Slices that belong to this dashboard - will be wired to existing tables using Slice.perm field to derive the - datasource name. This function can be used to import/export dashboards - between multiple caravel instances. Audit metadata isn't copies over. + Once dashboard is imported, json_metadata field is extended and stores + remote_id and import_time. It helps to decide if the dashboard has to + be overridden or just copies over. Slices that belong to this + dashboard will be wired to existing tables. This function can be used + to import/export dashboards between multiple caravel instances. + Audit metadata isn't copies over. """ logging.info('Started import of the dashboard: {}' .format(dashboard_to_import.to_json())) - # make_transient(dashboard_to_import) - # TODO: provide nice error message about failed imports session = db.session logging.info('Dashboard has {} slices' .format(len(dashboard_to_import.slices))) @@ -458,15 +465,9 @@ def import_dashboard(cls, dashboard_to_import, import_time=None): for slc in slices: logging.info('Importing slice {} from the dashboard: {}'.format( slc.to_json(), dashboard_to_import.dashboard_title)) - slice_ids.add(Slice.import_slice(slc, import_time=import_time)) - - dashboard_to_import.alter_json_metadata( - remote_id=dashboard_to_import.id, import_time=import_time) - dashboard_to_import.id = None - dashboard_to_import.slices = session.query(Slice).filter( - Slice.id.in_(slice_ids)).all() + slice_ids.add(Slice.import_obj(slc, import_time=import_time)) - # Override the dashboard + # override the dashboard existing_dashboard = ( session.query(Dashboard) .filter(Dashboard.json_metadata.ilike( @@ -474,8 +475,13 @@ def import_dashboard(cls, dashboard_to_import, import_time=None): .first() ) + dashboard_to_import.id = None + dashboard_to_import.alter_json_metadata(import_time=import_time) + new_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all() + if existing_dashboard: existing_dashboard.override(dashboard_to_import) + existing_dashboard.slices = new_slices session.flush() return existing_dashboard.id else: @@ -483,36 +489,26 @@ def import_dashboard(cls, dashboard_to_import, import_time=None): # related to the attached users / slices. Creating new object # allows to avoid conflicts in the sql alchemy state. copied_dash = dashboard_to_import.copy() + copied_dash.slices = new_slices session.add(copied_dash) session.flush() return copied_dash.id - def override(self, dashboard): - """Overrides the plain fields of the dashboard.""" - for field in Dashboard.export_fields(): - setattr(self, field, getattr(dashboard, field)) - - def copy(self): - """Creates a copy of the dashboard without relationships.""" - dashboard = Dashboard() - dashboard.override(self) - return dashboard - - @classmethod - def export_fields(cls): - return ('dashboard_title', 'position_json', 'json_metadata', - 'description', 'css', 'slug', 'slices') - @classmethod def export_dashboards(cls, dashboard_ids): - eager_dashboards = [] - datasource_set = set() + copied_dashboards = [] + datasource_ids = set() for dashboard_id in dashboard_ids: - eager_dashboard = db.session.query(Dashboard).options( - joinedload('slices')).filter_by(id=dashboard_id).first() - for slc in eager_dashboard.slices: - if slc.datasource not in datasource_set: - datasource_set.add(slc.datasource) + # make sure that dashboard_id is an integer + dashboard_id = int(dashboard_id) + copied_dashboard = ( + db.session.query(Dashboard) + .options(subqueryload(Dashboard.slices)) + .filter_by(id=dashboard_id).first() + ) + make_transient(copied_dashboard) + for slc in copied_dashboard.slices: + datasource_ids.add((slc.datasource_id, slc.datasource_type)) # add extra params for the import slc.alter_params( remote_id=slc.id, @@ -520,11 +516,23 @@ def export_dashboards(cls, dashboard_ids): schema=slc.datasource.name, database_name=slc.datasource.database.database_name, ) + copied_dashboard.alter_json_metadata(remote_id=dashboard_id) + copied_dashboards.append(copied_dashboard) + + eager_datasources = [] + for dashboard_id, dashboard_type in datasource_ids: + eager_datasource = SourceRegistry.get_eager_datasource( + db.session, dashboard_type, dashboard_id) + eager_datasource.alter_json_metadata( + remote_id=eager_datasource.id, + database_name=eager_datasource.database.database_name, + ) + make_transient(eager_datasource) + eager_datasources.append(eager_datasource) - eager_dashboards.append(eager_dashboard.copy()) return pickle.dumps({ - 'dashboards': eager_dashboards, - 'datasources': list(datasource_set), + 'dashboards': copied_dashboards, + 'datasources': eager_datasources, }) @@ -827,7 +835,7 @@ def perm(self): "[{obj.database_name}].(id:{obj.id})").format(obj=self) -class SqlaTable(Model, Queryable, AuditMixinNullable): +class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin): """An ORM object for SqlAlchemy table references""" @@ -851,9 +859,13 @@ class SqlaTable(Model, Queryable, AuditMixinNullable): cache_timeout = Column(Integer) schema = Column(String(255)) sql = Column(Text) - table_columns = relationship("TableColumn", back_populates="table") + json_metadata = Column(Text) baselink = "tablemodelview" + export_fields = ( + 'table_name', 'main_dttm_col', 'description', 'default_endpoint', + 'database_id', 'is_featured', 'offset', 'cache_timeout', 'schema', + 'sql', 'json_metadata') __table_args__ = ( sqla.UniqueConstraint( @@ -935,7 +947,7 @@ def time_column_grains(self): } def get_col(self, col_name): - columns = self.table_columns + columns = self.columns for col in columns: if col_name == col.column_name: return col @@ -1224,8 +1236,81 @@ def fetch_metadata(self): if not self.main_dttm_col: self.main_dttm_col = any_date_col + def alter_json_metadata(self, **kwargs): + json_metadata_dict = {} + if self.json_metadata: + json_metadata_dict = json.loads(self.json_metadata) + for key in kwargs: + json_metadata_dict[key] = kwargs[key] + self.json_metadata = json.dumps(json_metadata_dict) + + @property + def dict_metadata(self): + if self.json_metadata: + return json.loads(self.json_metadata) + return {} + + @classmethod + def import_obj(cls, datasource_to_import, import_time=None): + """Imports the datasource from the object to the database. + + Metrics and columns and datasource will be overrided if exists. + This function can be used to import/export dashboards between multiple + caravel instances. Audit metadata isn't copies over. + """ + session = db.session + make_transient(datasource_to_import) + logging.info('Started import of the datasource: {}' + .format(datasource_to_import.to_json())) + + datasource_to_import.id = None + database_name = datasource_to_import.dict_metadata['database_name'] + datasource_to_import.database_id = session.query(Database).filter_by( + database_name=database_name).one().id + datasource_to_import.alter_json_metadata(import_time=import_time) + + # override the datasource + datasource = ( + session.query(SqlaTable).join(Database) + .filter( + SqlaTable.table_name == datasource_to_import.table_name, + SqlaTable.schema == datasource_to_import.schema, + Database.id == datasource_to_import.database_id, + ) + .first() + ) + + if datasource: + datasource.override(datasource_to_import) + session.flush() + else: + datasource = datasource_to_import.copy() + session.add(datasource) + session.flush() + + for m in datasource_to_import.metrics: + new_m = m.copy() + new_m.table_id = datasource.id + logging.info('Importing metric {} from the datasource: {}'.format( + new_m.to_json(), datasource_to_import.full_name)) + imported_m = SqlMetric.import_obj(new_m) + if imported_m not in datasource.metrics: + datasource.metrics.append(imported_m) -class SqlMetric(Model, AuditMixinNullable): + for c in datasource_to_import.columns: + new_c = c.copy() + new_c.table_id = datasource.id + logging.info('Importing column {} from the datasource: {}'.format( + new_c.to_json(), datasource_to_import.full_name)) + imported_c = TableColumn.import_obj(new_c) + if imported_c not in datasource.columns: + datasource.columns.append(imported_c) + db.session.flush() + + return datasource.id + + +class SqlMetric(Model, AuditMixinNullable, ImportMixin): """ORM object for metrics, each table can have multiple metrics""" @@ -1244,6 +1329,10 @@ class SqlMetric(Model, AuditMixinNullable): is_restricted = Column(Boolean, default=False, nullable=True) d3format = Column(String(128)) + export_fields = ( + 'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression', + 'description', 'is_restricted', 'd3format') + @property def sqla_col(self): name = self.metric_name @@ -1256,8 +1345,28 @@ def perm(self): ).format(obj=self, parent_name=self.table.full_name) if self.table else None + @classmethod + def import_obj(cls, metric_to_import): + session = db.session + make_transient(metric_to_import) + metric_to_import.id = None + + # find if the column was already imported + existing_metric = session.query(SqlMetric).filter( + SqlMetric.table_id == metric_to_import.table_id, + SqlMetric.metric_name == metric_to_import.metric_name).first() + metric_to_import.table = None + if existing_metric: + existing_metric.override(metric_to_import) + session.flush() + return existing_metric + + session.add(metric_to_import) + session.flush() + return metric_to_import -class TableColumn(Model, AuditMixinNullable): + +class TableColumn(Model, AuditMixinNullable, ImportMixin): """ORM object for table columns, each table can have multiple columns""" @@ -1287,6 +1396,12 @@ class TableColumn(Model, AuditMixinNullable): num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG') date_types = ('DATE', 'TIME') str_types = ('VARCHAR', 'STRING', 'CHAR') + export_fields = ( + 'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active', + 'type', 'groupby', 'count_distinct', 'sum', 'max', 'min', + 'filterable', 'expression', 'description', 'python_date_format', + 'database_expression' + ) def __repr__(self): return self.column_name @@ -1312,6 +1427,27 @@ def sqla_col(self): col = literal_column(self.expression).label(name) return col + @classmethod + def import_obj(cls, column_to_import): + session = db.session + make_transient(column_to_import) + column_to_import.id = None + column_to_import.table = None + + # find if the column was already imported + existing_column = session.query(TableColumn).filter( + TableColumn.table_id == column_to_import.table_id, + TableColumn.column_name == column_to_import.column_name).first() + column_to_import.table = None + if existing_column: + existing_column.override(column_to_import) + session.flush() + return existing_column + + session.add(column_to_import) + session.flush() + return column_to_import + def dttm_sql_literal(self, dttm): """Convert datetime object to string diff --git a/caravel/source_registry.py b/caravel/source_registry.py index d3a1efdefa37..3b1457bb5200 100644 --- a/caravel/source_registry.py +++ b/caravel/source_registry.py @@ -1,4 +1,5 @@ -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import subqueryload + class SourceRegistry(object): @@ -31,7 +32,6 @@ def get_datasource_by_name(cls, session, datasource_type, datasource_name, d.name == datasource_name and schema == schema] return db_ds[0] - @classmethod def get_eager_datasource(cls, session, datasource_type, datasource_id): """Returns datasource with columns and metrics.""" @@ -39,8 +39,10 @@ def get_eager_datasource(cls, session, datasource_type, datasource_id): if datasource_type == 'table': return ( session.query(datasource_class) - .options(joinedload('table_columns') - .joinedload('sql_metrics')) + .options( + subqueryload(datasource_class.columns), + subqueryload(datasource_class.metrics) + ) .filter_by(id=datasource_id) .one() ) diff --git a/caravel/views.py b/caravel/views.py index 61fbdd4b5d89..99e928bfb543 100755 --- a/caravel/views.py +++ b/caravel/views.py @@ -1189,10 +1189,14 @@ def import_dashboards(self): filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) f.save(filepath) current_tt = int(time.time()) - for data in pickle.load(open(filepath, 'rb')): - models.Dashboard.import_dashboard( - data['dashboards'], import_time=current_tt) + data = pickle.load(open(filepath, 'rb')) + for table in data['datasources']: + models.SqlaTable.import_obj(table, import_time=current_tt) + for dashboard in data['dashboards']: + models.Dashboard.import_obj( + dashboard, import_time=current_tt) os.remove(filepath) + db.session.commit() return redirect('/dashboardmodelview/list/') return self.render_template('caravel/import_dashboards.html') diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index 558479a2adc6..5e981928d522 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -35,15 +35,27 @@ def test_export_dashboards(self): self.assertEquals(1, len(exported_dashboards_1)) self.assertEquals('births', exported_dashboards_1[0].slug) self.assertEquals('Births', exported_dashboards_1[0].dashboard_title) + self.assertEquals( + births_dash.id, + json.loads(exported_dashboards_1[0].json_metadata)['remote_id'] + ) + self.assertTrue(exported_dashboards_1[0].position_json) self.assertEquals(9, len(exported_dashboards_1[0].slices)) + exported_tables_1 = pickle.loads(resp_1.data)['datasources'] + self.assertEquals(1, len(exported_tables_1)) + self.assertEquals( + 1, len([c for c in [t.columns for t in exported_tables_1]])) + self.assertEquals( + 1, len([m for m in [t.metrics for t in exported_tables_1]])) + # export 2 dashboards world_health_dash = db.session.query(models.Dashboard).filter_by( slug='world_health').first() export_dash_url_2 = ( '/dashboardmodelview/export_dashboards_form?id={}&id={}&action=go' - .format(births_dash.id, world_health_dash.id) + .format(births_dash.id, world_health_dash.id) ) resp_2 = self.client.get(export_dash_url_2) exported_dashboards_2 = pickle.loads(resp_2.data)['dashboards'] @@ -51,12 +63,26 @@ def test_export_dashboards(self): self.assertEquals('births', exported_dashboards_2[0].slug) self.assertEquals(9, len(exported_dashboards_2[0].slices)) self.assertEquals('world_health', exported_dashboards_2[1].slug) + self.assertEquals( + world_health_dash.id, + json.loads(exported_dashboards_2[1].json_metadata)['remote_id'] + ) + self.assertEquals(10, len(exported_dashboards_2[1].slices)) + exported_tables_2 = pickle.loads(resp_2.data)['datasources'] + self.assertEquals(2, len(exported_tables_2)) + table_names = [t.table_name for t in exported_tables_2] + self.assertEquals(len(set(table_names)), len(table_names)) + self.assertEquals( + 2, len([c for c in [t.columns for t in exported_tables_2]])) + self.assertEquals( + 2, len([m for m in [t.metrics for t in exported_tables_2]])) + def test_import_slice(self): session = db.session - def get_slice(name, ds_id=666, slc_id=1, + def create_slice(name, ds_id=666, slc_id=1, db_name='main', table_name='wb_health_population'): params = { @@ -75,13 +101,13 @@ def get_slice(name, ds_id=666, slc_id=1, id=slc_id ) - def find_slice(slc_id): + def get_slice(slc_id): return session.query(models.Slice).filter_by(id=slc_id).first() # Case 1. Import slice with different datasource id - slc_id_1 = models.Slice.import_slice( - get_slice('Import Me 1'), import_time=1989) - slc_1 = find_slice(slc_id_1) + slc_id_1 = models.Slice.import_obj( + create_slice('Import Me 1'), import_time=1989) + slc_1 = get_slice(slc_id_1) table_id = db.session.query(models.SqlaTable).filter_by( table_name='wb_health_population').first().id self.assertEquals(table_id, slc_1.datasource_id) @@ -101,9 +127,9 @@ def find_slice(slc_id): json.loads(slc_1.params)) # Case 2. Import slice with same datasource id - slc_id_2 = models.Slice.import_slice(get_slice( + slc_id_2 = models.Slice.import_obj(create_slice( 'Import Me 2', ds_id=table_id, slc_id=2)) - slc_2 = find_slice(slc_id_2) + slc_2 = get_slice(slc_id_2) self.assertEquals(table_id, slc_2.datasource_id) self.assertEquals( '[main].[wb_health_population](id:{})'.format(table_id), @@ -111,13 +137,13 @@ def find_slice(slc_id): # Case 3. Import slice with non existent datasource with self.assertRaises(IndexError): - models.Slice.import_slice(get_slice( + models.Slice.import_obj(create_slice( 'Import Me 3', slc_id=3, table_name='non_existent')) # Case 4. Older version of slice already exists, override - slc_id_4 = models.Slice.import_slice( - get_slice('Import Me New 1'), import_time=1990) - slc_4 = find_slice(slc_id_4) + slc_id_4 = models.Slice.import_obj( + create_slice('Import Me New 1'), import_time=1990) + slc_4 = get_slice(slc_id_4) self.assertEquals(table_id, slc_4.datasource_id) self.assertEquals( '[main].[wb_health_population](id:{})'.format(table_id), @@ -134,6 +160,8 @@ def find_slice(slc_id): 'import_time': 1990 }, json.loads(slc_4.params)) + # override doesn't change the id + self.assertEquals(slc_id_1, slc_id_4) session.delete(slc_2) session.delete(slc_4) @@ -142,22 +170,24 @@ def find_slice(slc_id): def test_import_dashboard(self): session = db.session - def get_dashboard(title, id=0, slcs=[]): + def create_dashboard(title, id=0, slcs=[]): + json_metadata = {'remote_id': id} return models.Dashboard( id=id, dashboard_title=title, slices=slcs, position_json='{"size_y": 2, "size_x": 2}', slug='{}_imported'.format(title.lower()), + json_metadata=json.dumps(json_metadata) ) - def find_dash(dash_id): + def get_dash(dash_id): return session.query(models.Dashboard).filter_by( id=dash_id).first() def get_slc(name, table_name, slc_id): params = { - 'remote_id': slc_id, + 'remote_id': "'{}'".format(slc_id), 'datasource_name': table_name, 'database_name': 'main', 'schema': '', @@ -170,21 +200,20 @@ def get_slc(name, table_name, slc_id): ) # Case 1. Import empty dashboard can be imported - empty_dash = get_dashboard('empty_dashboard', id=1001) - imported_dash_id_1 = models.Dashboard.import_dashboard( + empty_dash = create_dashboard('empty_dashboard', id=1001) + imported_dash_id_1 = models.Dashboard.import_obj( empty_dash, import_time=1989) - # Dashboard import fails if there are no slices - self.assertEquals( - 'empty_dashboard', find_dash(imported_dash_id_1).dashboard_title) + imported_dash_1 = get_dash(imported_dash_id_1) + self.assertEquals('empty_dashboard', imported_dash_1.dashboard_title) # Case 2. Import dashboard with 1 new slice. - dash_with_1_slice = get_dashboard( + dash_with_1_slice = create_dashboard( 'dash_with_1_slice', slcs=[get_slc('health_slc', 'wb_health_population', 1001)], id=1002) - imported_dash_id_2 = models.Dashboard.import_dashboard( + imported_dash_id_2 = models.Dashboard.import_obj( dash_with_1_slice, import_time=1990) - imported_dash_2 = find_dash(imported_dash_id_2) + imported_dash_2 = get_dash(imported_dash_id_2) self.assertEquals('dash_with_1_slice', imported_dash_2.dashboard_title) self.assertEquals(1, len(imported_dash_2.slices)) self.assertEquals('wb_health_population', @@ -194,14 +223,14 @@ def get_slc(name, table_name, slc_id): json.loads(imported_dash_2.json_metadata)) # Case 3. Import dashboard with 2 new slices. - dash_with_2_slices = get_dashboard( + dash_with_2_slices = create_dashboard( 'dash_with_2_slices', slcs=[get_slc('energy_slc', 'energy_usage', 1002), get_slc('birth_slc', 'birth_names', 1003)], id=1003) - imported_dash_id_3 = models.Dashboard.import_dashboard( + imported_dash_id_3 = models.Dashboard.import_obj( dash_with_2_slices, import_time=1991) - imported_dash_3 = find_dash(imported_dash_id_3) + imported_dash_3 = get_dash(imported_dash_id_3) self.assertEquals('dash_with_2_slices', imported_dash_3.dashboard_title) self.assertEquals(2, len(imported_dash_3.slices)) self.assertEquals( @@ -210,14 +239,14 @@ def get_slc(name, table_name, slc_id): imported_dash_slice_ids_3 = {s.id for s in imported_dash_3.slices} # Case 4. Override the dashboard and 2 slices - dash_with_2_slices_new = get_dashboard( + dash_with_2_slices_new = create_dashboard( 'dash_with_2_slices_new', slcs=[get_slc('energy_slc', 'energy_usage', 1002), get_slc('birth_slc', 'birth_names', 1003)], id=1003) - imported_dash_id_4 = models.Dashboard.import_dashboard( + imported_dash_id_4 = models.Dashboard.import_obj( dash_with_2_slices_new, import_time=1992) - imported_dash_4 = find_dash(imported_dash_id_4) + imported_dash_4 = get_dash(imported_dash_id_4) self.assertEquals('dash_with_2_slices_new', imported_dash_4.dashboard_title) self.assertEquals(2, len(imported_dash_4.slices)) @@ -226,15 +255,132 @@ def get_slc(name, table_name, slc_id): self.assertEquals( {"remote_id": 1003, "import_time": 1992}, json.loads(imported_dash_4.json_metadata)) + # override doesn't change the id + self.assertEquals(imported_dash_id_3, imported_dash_id_4) # cleanup slc_to_delete = set(imported_dash_2.slices) slc_to_delete.update(imported_dash_4.slices) for slc in slc_to_delete: session.delete(slc) + + session.delete(imported_dash_1) session.delete(imported_dash_2) session.delete(imported_dash_4) session.commit() + def test_import_tables(self): + session = db.session + + def create_table( + name, schema='', id=0, cols_names=[], metric_names=[]): + json_metadata = {'remote_id': id, 'database_name': 'main'} + table = models.SqlaTable( + id=id, + schema=schema, + table_name=name, + json_metadata=json.dumps(json_metadata) + ) + for col_name in cols_names: + table.columns.append( + models.TableColumn(column_name=col_name)) + for metric_name in metric_names: + table.metrics.append(models.SqlMetric(metric_name=metric_name)) + return table + + def get_table(table_id): + return session.query(models.SqlaTable).filter_by( + id=table_id).first() + + # Case 1. Import table with no metrics and columns + pure_table = create_table('pure_table', id=1001) + imported_table_1_id = models.SqlaTable.import_obj( + pure_table, import_time=1989) + imported_table_1 = get_table(imported_table_1_id) + self.assertEquals('pure_table', imported_table_1.table_name) + + # Case 2. Import table with 1 column and metric + table_2 = create_table( + 'table_2', id=1002, cols_names=["col1"], metric_names=["metric1"]) + imported_table_id_2 = models.SqlaTable.import_obj( + table_2, import_time=1990) + imported_table_2 = get_table(imported_table_id_2) + self.assertEquals('table_2', imported_table_2.table_name) + self.assertEquals(1, len(imported_table_2.columns)) + self.assertEquals('col1', imported_table_2.columns[0].column_name) + self.assertEquals(1, len(imported_table_2.metrics)) + self.assertEquals('metric1', imported_table_2.metrics[0].metric_name) + self.assertEquals( + {'remote_id': 1002, 'import_time': 1990, 'database_name': 'main'}, + json.loads(imported_table_2.json_metadata)) + + # Case 3. Import table with 2 metrics and 2 columns + table_3 = create_table( + 'table_3', id=1003, cols_names=['col1', 'col2'], + metric_names=['metric1', 'metric2']) + imported_table_id_3 = models.SqlaTable.import_obj( + table_3, import_time=1991) + + imported_table_3 = get_table(imported_table_id_3) + self.assertEquals('table_3', imported_table_3.table_name) + self.assertEquals(2, len(imported_table_3.columns)) + self.assertEquals( + {'col1', 'col2'}, + set([c.column_name for c in imported_table_3.columns])) + self.assertEquals(2, len(imported_table_3.metrics)) + self.assertEquals( + {'metric1', 'metric2'}, + set([m.metric_name for m in imported_table_3.metrics])) + imported_table_3_id = imported_table_3.id + + # Case 4. Override table with different metrics and columns + table_4 = create_table( + 'table_3', id=1003, cols_names=['new_col1', 'col2', 'col3'], + metric_names=['new_metric1']) + imported_table_id_4 = models.SqlaTable.import_obj( + table_4, import_time=1992) + + imported_table_4 = get_table(imported_table_id_4) + self.assertEquals('table_3', imported_table_4.table_name) + self.assertEquals(4, len(imported_table_4.columns)) + # metrics and columns are never deleted, only appended + self.assertEquals( + {'col1', 'new_col1', 'col2', 'col3'}, + set([c.column_name for c in imported_table_4.columns])) + self.assertEquals(3, len(imported_table_4.metrics)) + self.assertEquals( + {'new_metric1', 'metric1', 'metric2'}, + set([m.metric_name for m in imported_table_4.metrics])) + # override doesn't change the id + self.assertEquals(imported_table_3_id, imported_table_4.id) + + # Case 5. Override with the identical table + table_5 = create_table( + 'table_3', id=1003, cols_names=['new_col1', 'col2', 'col3'], + metric_names=['new_metric1']) + imported_table_id_5 = models.SqlaTable.import_obj( + table_5, import_time=1993) + + imported_table_5 = get_table(imported_table_id_5) + self.assertEquals('table_3', imported_table_5.table_name) + self.assertEquals(4, len(imported_table_5.columns)) + self.assertEquals( + {'col1', 'new_col1', 'col2', 'col3'}, + set([c.column_name for c in imported_table_5.columns])) + self.assertEquals(3, len(imported_table_5.metrics)) + self.assertEquals( + {'new_metric1', 'metric1', 'metric2'}, + set([m.metric_name for m in imported_table_5.metrics])) + + # cleanup + # imported_table_3 and 4 are overriden + for table in [imported_table_1, imported_table_2, imported_table_5]: + for metric in table.metrics: + session.delete(metric) + for column in table.columns: + session.delete(column) + session.delete(table) + session.commit() + if __name__ == '__main__': unittest.main() From 045dd1c97f5e9f2c174945ad12140e3a7f804ebf Mon Sep 17 00:00:00 2001 From: Bogdan Kyryliuk Date: Fri, 7 Oct 2016 11:01:43 -0700 Subject: [PATCH 08/13] Json.loads does not support trailing commas. --- caravel/models.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/caravel/models.py b/caravel/models.py index 4a68cc908eb3..16e4403a8439 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -275,6 +275,13 @@ def slice_link(self): name = escape(self.slice_name) return Markup('{name}'.format(**locals())) + @property + def params_dict(self): + if self.params: + return json.loads(self.params) + else: + return {} + def get_viz(self, url_params_multidict=None): """Creates :py:class:viz.BaseViz object from the url_params_multidict. @@ -325,11 +332,11 @@ def import_obj(cls, slc, import_time=None): slc.alter_params(remote_id=slc.id, import_time=import_time) # find if the slice was already imported - slc_to_override = ( - session.query(Slice) - .filter(Slice.params.ilike('%"remote_id": {},%'.format(slc.id))) - .first() - ) + slc_to_override = None + for s in session.query(Slice).all(): + if ('remote_id' in s.params_dict and + s.params_dict['remote_id'] == slc.id): + slc_to_override = s slc.id = None params = json.loads(slc.params) @@ -468,12 +475,12 @@ def import_obj(cls, dashboard_to_import, import_time=None): slice_ids.add(Slice.import_obj(slc, import_time=import_time)) # override the dashboard - existing_dashboard = ( - session.query(Dashboard) - .filter(Dashboard.json_metadata.ilike( - '%"remote_id": {},%'.format(dashboard_to_import.id))) - .first() - ) + existing_dashboard = None + for dash in session.query(Dashboard).all(): + if ('remote_id' in dash.metadata_dejson and + dash.metadata_dejson['remote_id'] == + dashboard_to_import.id): + existing_dashboard = dash dashboard_to_import.id = None dashboard_to_import.alter_json_metadata(import_time=import_time) From 0cb12310ce9674be60fba363178acea948999e27 Mon Sep 17 00:00:00 2001 From: Bogdan Kyryliuk Date: Fri, 7 Oct 2016 16:32:53 -0700 Subject: [PATCH 09/13] Improve alter_dict func --- caravel/models.py | 33 +++++++++++++++------------------ caravel/source_registry.py | 1 - 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/caravel/models.py b/caravel/models.py index 16e4403a8439..f49b369c8707 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -311,12 +311,9 @@ def get_viz(self, url_params_multidict=None): ) def alter_params(self, **kwargs): - params_dict = {} - if self.params: - params_dict = json.loads(self.params) - for key in kwargs: - params_dict[key] = kwargs[key] - self.params = json.dumps(params_dict) + d = self.params_dict + d.update(kwargs) + self.params = json.dumps(d) @classmethod def import_obj(cls, slc, import_time=None): @@ -441,13 +438,16 @@ def json_data(self): } return json.dumps(d) + @property + def dict_metadata(self): + if not self.json_metadata: + return {} + return json.loads(self.json_metadata) + def alter_json_metadata(self, **kwargs): - json_metadata_dict = {} - if self.json_metadata: - json_metadata_dict = json.loads(self.json_metadata) - for key in kwargs: - json_metadata_dict[key] = kwargs[key] - self.json_metadata = json.dumps(json_metadata_dict) + d = self.dict_metadata + d.update(kwargs) + self.json_metadata = json.dumps(d) @classmethod def import_obj(cls, dashboard_to_import, import_time=None): @@ -1244,12 +1244,9 @@ def fetch_metadata(self): self.main_dttm_col = any_date_col def alter_json_metadata(self, **kwargs): - json_metadata_dict = {} - if self.json_metadata: - json_metadata_dict = json.loads(self.json_metadata) - for key in kwargs: - json_metadata_dict[key] = kwargs[key] - self.json_metadata = json.dumps(json_metadata_dict) + d = self.dict_metadata + d.update(kwargs) + self.json_metadata = json.dumps(d) @property def dict_metadata(self): diff --git a/caravel/source_registry.py b/caravel/source_registry.py index 3b1457bb5200..6176c9c0c96b 100644 --- a/caravel/source_registry.py +++ b/caravel/source_registry.py @@ -1,7 +1,6 @@ from sqlalchemy.orm import subqueryload - class SourceRegistry(object): """ Central Registry for all available datasource engines""" From 06c18b1a6aa33bb831de69ece4480e53a62b19fa Mon Sep 17 00:00:00 2001 From: Bogdan Kyryliuk Date: Mon, 10 Oct 2016 11:29:56 -0700 Subject: [PATCH 10/13] Resolve comments. --- caravel/models.py | 33 ++++++++++---------- caravel/views.py | 58 +++++++++++++++++------------------- tests/import_export_tests.py | 18 +++++------ 3 files changed, 52 insertions(+), 57 deletions(-) diff --git a/caravel/models.py b/caravel/models.py index f49b369c8707..7bf8f78d1834 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -316,7 +316,7 @@ def alter_params(self, **kwargs): self.params = json.dumps(d) @classmethod - def import_obj(cls, slc, import_time=None): + def import_obj(cls, slc_to_import, import_time=None): """Inserts or overrides slc in the database. remote_id and import_time fields in params_dict are set to track the @@ -324,30 +324,31 @@ def import_obj(cls, slc, import_time=None): Slice.perm is used to find the datasources and connect them. """ session = db.session - make_transient(slc) - slc.dashboards = [] - slc.alter_params(remote_id=slc.id, import_time=import_time) + make_transient(slc_to_import) + slc_to_import.dashboards = [] + slc_to_import.alter_params( + remote_id=slc_to_import.id, import_time=import_time) # find if the slice was already imported slc_to_override = None - for s in session.query(Slice).all(): - if ('remote_id' in s.params_dict and - s.params_dict['remote_id'] == slc.id): - slc_to_override = s - - slc.id = None - params = json.loads(slc.params) - slc.datasource_id = SourceRegistry.get_datasource_by_name( - session, slc.datasource_type, params['datasource_name'], + for slc in session.query(Slice).all(): + if ('remote_id' in slc.params_dict and + slc.params_dict['remote_id'] == slc_to_import.id): + slc_to_override = slc + + slc_to_import.id = None + params = slc_to_import.params_dict + slc_to_import.datasource_id = SourceRegistry.get_datasource_by_name( + session, slc_to_import.datasource_type, params['datasource_name'], params['schema'], params['database_name']).id if slc_to_override: - slc_to_override.override(slc) + slc_to_override.override(slc_to_import) session.flush() return slc_to_override.id else: - session.add(slc) + session.add(slc_to_import) session.flush() - return slc.id + return slc_to_import.id def set_perm(mapper, connection, target): # noqa diff --git a/caravel/views.py b/caravel/views.py index 99e928bfb543..9d6c22bb9ef3 100755 --- a/caravel/views.py +++ b/caravel/views.py @@ -32,7 +32,6 @@ from werkzeug import secure_filename from werkzeug.datastructures import ImmutableMultiDict from werkzeug.routing import BaseConverter -from werkzeug.datastructures import ImmutableMultiDict from wtforms.validators import ValidationError import caravel @@ -538,6 +537,16 @@ def pre_update(self, db): self.pre_add(db) +appbuilder.add_link( + 'Import Dashboards', + label=__("Import Dashboards"), + href='/caravel/import_dashboards', + icon="fa-cloud-upload", + category='Manage', + category_label=__("Manage"), + category_icon='fa-wrench',) + + appbuilder.add_view( DatabaseView, "Databases", @@ -663,7 +672,6 @@ class AccessRequestsModelView(CaravelModelView, DeleteMixin): category_label=__("Security"), icon='fa-table',) - appbuilder.add_separator("Sources") @@ -895,18 +903,11 @@ def download_dashboards(self): appbuilder.add_view( DashboardModelView, "Dashboards", - label=__("Dashboards List"), + label=__("Dashboards"), icon="fa-dashboard", - category='Dashboards', + category='', category_icon='',) -appbuilder.add_link( - 'Import Dashboards', - href='/caravel/import_dashboards', - icon="fa-cloud-upload", - category='Dashboards') - - class DashboardModelViewAsync(DashboardModelView): # noqa list_columns = ['dashboard_link', 'creator', 'modified', 'dashboard_title'] label_columns = { @@ -1182,28 +1183,23 @@ def explore_json(self, datasource_type, datasource_id): @log_this def import_dashboards(self): """Overrides the dashboards using pickled instances from the file.""" - if request.method == 'POST': - f = request.files['file'] - if f: - filename = secure_filename(f.filename) - filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) - f.save(filepath) - current_tt = int(time.time()) - data = pickle.load(open(filepath, 'rb')) - for table in data['datasources']: - models.SqlaTable.import_obj(table, import_time=current_tt) - for dashboard in data['dashboards']: - models.Dashboard.import_obj( - dashboard, import_time=current_tt) - os.remove(filepath) - db.session.commit() - return redirect('/dashboardmodelview/list/') + f = request.files.get('file') + if request.method == 'POST' and f: + filename = secure_filename(f.filename) + filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) + f.save(filepath) + current_tt = int(time.time()) + data = pickle.load(open(filepath, 'rb')) + for table in data['datasources']: + models.SqlaTable.import_obj(table, import_time=current_tt) + for dashboard in data['dashboards']: + models.Dashboard.import_obj( + dashboard, import_time=current_tt) + os.remove(filepath) + db.session.commit() + return redirect('/dashboardmodelview/list/') return self.render_template('caravel/import_dashboards.html') - @has_access - @expose("/explore////") - @expose("/explore///") - @expose("/datasource///") # Legacy url @log_this @has_access @expose("/explore///") diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index 5e981928d522..03f275a8a5fc 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -16,8 +16,6 @@ class ImportExportTests(CaravelTestCase): """Testing export import functionality for dashboards""" - examples_loaded = False - def __init__(self, *args, **kwargs): super(ImportExportTests, self).__init__(*args, **kwargs) @@ -58,26 +56,26 @@ def test_export_dashboards(self): .format(births_dash.id, world_health_dash.id) ) resp_2 = self.client.get(export_dash_url_2) - exported_dashboards_2 = pickle.loads(resp_2.data)['dashboards'] + exported_dashboards_2 = sorted(pickle.loads(resp_2.data)['dashboards']) self.assertEquals(2, len(exported_dashboards_2)) - self.assertEquals('births', exported_dashboards_2[0].slug) self.assertEquals(9, len(exported_dashboards_2[0].slices)) + self.assertEquals('births', exported_dashboards_2[0].slug) + + self.assertEquals(10, len(exported_dashboards_2[1].slices)) self.assertEquals('world_health', exported_dashboards_2[1].slug) self.assertEquals( world_health_dash.id, json.loads(exported_dashboards_2[1].json_metadata)['remote_id'] ) - self.assertEquals(10, len(exported_dashboards_2[1].slices)) - - exported_tables_2 = pickle.loads(resp_2.data)['datasources'] + exported_tables_2 = sorted(pickle.loads(resp_2.data)['datasources']) self.assertEquals(2, len(exported_tables_2)) table_names = [t.table_name for t in exported_tables_2] self.assertEquals(len(set(table_names)), len(table_names)) self.assertEquals( - 2, len([c for c in [t.columns for t in exported_tables_2]])) + 14, len([c for c in t.columns for t in exported_tables_2])) self.assertEquals( - 2, len([m for m in [t.metrics for t in exported_tables_2]])) + 8, len([m for m in t.metrics for t in exported_tables_2])) def test_import_slice(self): session = db.session @@ -190,7 +188,7 @@ def get_slc(name, table_name, slc_id): 'remote_id': "'{}'".format(slc_id), 'datasource_name': table_name, 'database_name': 'main', - 'schema': '', + 'schema': 'test', } return models.Slice( id=slc_id, From c78a38a395693b153c2617cbc86d254284ca35e2 Mon Sep 17 00:00:00 2001 From: Bogdan Kyryliuk Date: Mon, 10 Oct 2016 22:52:13 -0700 Subject: [PATCH 11/13] Refactor tests --- caravel/models.py | 1 + tests/import_export_tests.py | 652 +++++++++++++++++------------------ 2 files changed, 319 insertions(+), 334 deletions(-) diff --git a/caravel/models.py b/caravel/models.py index 7bf8f78d1834..114d8d37678b 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -347,6 +347,7 @@ def import_obj(cls, slc_to_import, import_time=None): return slc_to_override.id else: session.add(slc_to_import) + logging.info('Final slice: {}'.format(slc_to_import.to_json())) session.flush() return slc_to_import.id diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index 03f275a8a5fc..979be0ba660b 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -4,6 +4,8 @@ from __future__ import print_function from __future__ import unicode_literals +from sqlalchemy.orm.session import make_transient + import json import pickle import unittest @@ -19,366 +21,348 @@ class ImportExportTests(CaravelTestCase): def __init__(self, *args, **kwargs): super(ImportExportTests, self).__init__(*args, **kwargs) - def test_export_dashboards(self): - births_dash = db.session.query(models.Dashboard).filter_by( - slug='births').first() + @classmethod + def delete_imports(cls): + # Imported data clean up + session = db.session + for slc in session.query(models.Slice): + if 'remote_id' in slc.params_dict: + session.delete(slc) + for dash in session.query(models.Dashboard): + if 'remote_id' in dash.metadata_dejson: + session.delete(dash) + for table in session.query(models.SqlaTable): + if 'remote_id' in table.dict_metadata: + session.delete(table) + session.commit() - # export 1 dashboard - export_dash_url_1 = ( - '/dashboardmodelview/export_dashboards_form?id={}&action=go' - .format(births_dash.id) - ) - resp_1 = self.client.get(export_dash_url_1) - exported_dashboards_1 = pickle.loads(resp_1.data)['dashboards'] - self.assertEquals(1, len(exported_dashboards_1)) - self.assertEquals('births', exported_dashboards_1[0].slug) - self.assertEquals('Births', exported_dashboards_1[0].dashboard_title) - self.assertEquals( - births_dash.id, - json.loads(exported_dashboards_1[0].json_metadata)['remote_id'] + @classmethod + def setUpClass(cls): + cls.delete_imports() + + @classmethod + def tearDownClass(cls): + cls.delete_imports() + + def create_slice(self, name, ds_id=None, id=None, db_name='main', + table_name='wb_health_population'): + params = { + 'num_period_compare': '10', + 'remote_id': id, + 'datasource_name': table_name, + 'database_name': db_name, + 'schema': '', + } + + if table_name and not ds_id: + table = self.get_table_by_name(table_name) + if table: + ds_id = table.id + + return models.Slice( + slice_name=name, + datasource_type='table', + viz_type='bubble', + params=json.dumps(params), + datasource_id=ds_id, + id=id ) - self.assertTrue(exported_dashboards_1[0].position_json) - self.assertEquals(9, len(exported_dashboards_1[0].slices)) + def create_dashboard(self, title, id=0, slcs=[]): + json_metadata = {'remote_id': id} + return models.Dashboard( + id=id, + dashboard_title=title, + slices=slcs, + position_json='{"size_y": 2, "size_x": 2}', + slug='{}_imported'.format(title.lower()), + json_metadata=json.dumps(json_metadata) + ) - exported_tables_1 = pickle.loads(resp_1.data)['datasources'] - self.assertEquals(1, len(exported_tables_1)) + def create_table(self, name, schema='', id=0, cols_names=[], metric_names=[]): + json_metadata = {'remote_id': id, 'database_name': 'main'} + table = models.SqlaTable( + id=id, + schema=schema, + table_name=name, + json_metadata=json.dumps(json_metadata) + ) + for col_name in cols_names: + table.columns.append( + models.TableColumn(column_name=col_name)) + for metric_name in metric_names: + table.metrics.append(models.SqlMetric(metric_name=metric_name)) + return table + + def get_slice(self, slc_id): + return db.session.query(models.Slice).filter_by(id=slc_id).first() + + def get_dash(self, dash_id): + return db.session.query(models.Dashboard).filter_by( + id=dash_id).first() + + def get_dash_by_slug(self, dash_slug): + return db.session.query(models.Dashboard).filter_by( + slug=dash_slug).first() + + def get_table(self, table_id): + return db.session.query(models.SqlaTable).filter_by( + id=table_id).first() + + def get_table_by_name(self, name): + return db.session.query(models.SqlaTable).filter_by( + table_name=name).first() + + def assert_dash_equals(self, expected_dash, actual_dash): + self.assertEquals(expected_dash.slug, actual_dash.slug) self.assertEquals( - 1, len([c for c in [t.columns for t in exported_tables_1]])) + expected_dash.dashboard_title, actual_dash.dashboard_title) self.assertEquals( - 1, len([m for m in [t.metrics for t in exported_tables_1]])) - - # export 2 dashboards - world_health_dash = db.session.query(models.Dashboard).filter_by( - slug='world_health').first() - export_dash_url_2 = ( - '/dashboardmodelview/export_dashboards_form?id={}&id={}&action=go' - .format(births_dash.id, world_health_dash.id) - ) - resp_2 = self.client.get(export_dash_url_2) - exported_dashboards_2 = sorted(pickle.loads(resp_2.data)['dashboards']) - self.assertEquals(2, len(exported_dashboards_2)) - self.assertEquals(9, len(exported_dashboards_2[0].slices)) - self.assertEquals('births', exported_dashboards_2[0].slug) - - self.assertEquals(10, len(exported_dashboards_2[1].slices)) - self.assertEquals('world_health', exported_dashboards_2[1].slug) + expected_dash.position_json, actual_dash.position_json) self.assertEquals( - world_health_dash.id, - json.loads(exported_dashboards_2[1].json_metadata)['remote_id'] - ) - - exported_tables_2 = sorted(pickle.loads(resp_2.data)['datasources']) - self.assertEquals(2, len(exported_tables_2)) - table_names = [t.table_name for t in exported_tables_2] - self.assertEquals(len(set(table_names)), len(table_names)) + len(expected_dash.slices), len(actual_dash.slices)) + expected_slices = sorted( + expected_dash.slices, key=lambda s: s.slice_name) + actual_slices = sorted( + actual_dash.slices, key=lambda s: s.slice_name) + for e_slc, a_slc in zip(expected_slices, actual_slices): + self.assert_slice_equals(e_slc, a_slc) + + def assert_table_equals(self, expected_ds, actual_ds): + self.assertEquals(expected_ds.table_name, actual_ds.table_name) + self.assertEquals(expected_ds.main_dttm_col, actual_ds.main_dttm_col) + self.assertEquals(expected_ds.schema, actual_ds.schema) + self.assertEquals(len(expected_ds.metrics), len(actual_ds.metrics)) + self.assertEquals(len(expected_ds.columns), len(actual_ds.columns)) self.assertEquals( - 14, len([c for c in t.columns for t in exported_tables_2])) + set([c.column_name for c in expected_ds.columns]), + set([c.column_name for c in actual_ds.columns])) self.assertEquals( - 8, len([m for m in t.metrics for t in exported_tables_2])) + set([m.metric_name for m in expected_ds.metrics]), + set([m.metric_name for m in actual_ds.metrics])) - def test_import_slice(self): - session = db.session - - def create_slice(name, ds_id=666, slc_id=1, - db_name='main', - table_name='wb_health_population'): - params = { - 'num_period_compare': '10', - 'remote_id': slc_id, - 'datasource_name': table_name, - 'database_name': db_name, - 'schema': '', - } - return models.Slice( - slice_name=name, - datasource_type='table', - viz_type='bubble', - params=json.dumps(params), - datasource_id=ds_id, - id=slc_id - ) - - def get_slice(slc_id): - return session.query(models.Slice).filter_by(id=slc_id).first() - - # Case 1. Import slice with different datasource id - slc_id_1 = models.Slice.import_obj( - create_slice('Import Me 1'), import_time=1989) - slc_1 = get_slice(slc_id_1) - table_id = db.session.query(models.SqlaTable).filter_by( - table_name='wb_health_population').first().id - self.assertEquals(table_id, slc_1.datasource_id) + def assert_slice_equals(self, expected_slc, actual_slc): + self.assertEquals(actual_slc.datasource.perm, actual_slc.perm) + self.assertEquals(expected_slc.slice_name, actual_slc.slice_name) self.assertEquals( - '[main].[wb_health_population](id:{})'.format(table_id), - slc_1.perm) - self.assertEquals('Import Me 1', slc_1.slice_name) - self.assertEquals('table', slc_1.datasource_type) - self.assertEquals('bubble', slc_1.viz_type) + expected_slc.datasource_type, actual_slc.datasource_type) + self.assertEquals(expected_slc.viz_type, actual_slc.viz_type) self.assertEquals( - {'num_period_compare': '10', - 'datasource_name': 'wb_health_population', - 'database_name': 'main', - 'remote_id': 1, - 'schema': '', - 'import_time': 1989,}, - json.loads(slc_1.params)) - - # Case 2. Import slice with same datasource id - slc_id_2 = models.Slice.import_obj(create_slice( - 'Import Me 2', ds_id=table_id, slc_id=2)) - slc_2 = get_slice(slc_id_2) - self.assertEquals(table_id, slc_2.datasource_id) - self.assertEquals( - '[main].[wb_health_population](id:{})'.format(table_id), - slc_2.perm) + json.loads(expected_slc.params), json.loads(actual_slc.params)) - # Case 3. Import slice with non existent datasource - with self.assertRaises(IndexError): - models.Slice.import_obj(create_slice( - 'Import Me 3', slc_id=3, table_name='non_existent')) - - # Case 4. Older version of slice already exists, override - slc_id_4 = models.Slice.import_obj( - create_slice('Import Me New 1'), import_time=1990) - slc_4 = get_slice(slc_id_4) - self.assertEquals(table_id, slc_4.datasource_id) + def test_export_1_dashboard(self): + birth_dash = self.get_dash_by_slug('births') + export_dash_url = ( + '/dashboardmodelview/export_dashboards_form?id={}&action=go' + .format(birth_dash.id) + ) + resp = self.client.get(export_dash_url) + exported_dashboards = pickle.loads(resp.data)['dashboards'] + self.assert_dash_equals(birth_dash, exported_dashboards[0]) self.assertEquals( - '[main].[wb_health_population](id:{})'.format(table_id), - slc_4.perm) - self.assertEquals('Import Me New 1', slc_4.slice_name) - self.assertEquals('table', slc_4.datasource_type) - self.assertEquals('bubble', slc_4.viz_type) + birth_dash.id, + json.loads(exported_dashboards[0].json_metadata)['remote_id']) + + exported_tables = pickle.loads(resp.data)['datasources'] + self.assertEquals(1, len(exported_tables)) + self.assert_table_equals( + self.get_table_by_name('birth_names'), exported_tables[0]) + + def test_export_2_dashboards(self): + birth_dash = self.get_dash_by_slug('births') + world_health_dash = self.get_dash_by_slug('world_health') + export_dash_url = ( + '/dashboardmodelview/export_dashboards_form?id={}&id={}&action=go' + .format(birth_dash.id, world_health_dash.id)) + resp = self.client.get(export_dash_url) + exported_dashboards = sorted(pickle.loads(resp.data)['dashboards'], + key=lambda d: d.dashboard_title) + self.assertEquals(2, len(exported_dashboards)) + self.assert_dash_equals(birth_dash, exported_dashboards[0]) self.assertEquals( - {'num_period_compare': '10', - 'datasource_name': 'wb_health_population', - 'database_name': 'main', - 'remote_id': 1, - 'schema': '', - 'import_time': 1990 - }, - json.loads(slc_4.params)) - # override doesn't change the id - self.assertEquals(slc_id_1, slc_id_4) - - session.delete(slc_2) - session.delete(slc_4) - session.commit() + birth_dash.id, + json.loads(exported_dashboards[0].json_metadata)['remote_id'] + ) - def test_import_dashboard(self): - session = db.session + self.assert_dash_equals(world_health_dash, exported_dashboards[1]) + self.assertEquals( + world_health_dash.id, + json.loads(exported_dashboards[1].json_metadata)['remote_id'] + ) - def create_dashboard(title, id=0, slcs=[]): - json_metadata = {'remote_id': id} - return models.Dashboard( - id=id, - dashboard_title=title, - slices=slcs, - position_json='{"size_y": 2, "size_x": 2}', - slug='{}_imported'.format(title.lower()), - json_metadata=json.dumps(json_metadata) - ) - - def get_dash(dash_id): - return session.query(models.Dashboard).filter_by( - id=dash_id).first() - - def get_slc(name, table_name, slc_id): - params = { - 'remote_id': "'{}'".format(slc_id), - 'datasource_name': table_name, - 'database_name': 'main', - 'schema': 'test', - } - return models.Slice( - id=slc_id, - slice_name=name, - datasource_type='table', - params=json.dumps(params) - ) - - # Case 1. Import empty dashboard can be imported - empty_dash = create_dashboard('empty_dashboard', id=1001) - imported_dash_id_1 = models.Dashboard.import_obj( + exported_tables = sorted( + pickle.loads(resp.data)['datasources'], key=lambda t: t.table_name) + self.assertEquals(2, len(exported_tables)) + self.assert_table_equals( + self.get_table_by_name('birth_names'), exported_tables[0]) + self.assert_table_equals( + self.get_table_by_name('wb_health_population'), exported_tables[1]) + + def test_import_1_slice(self): + expected_slice = self.create_slice('Import Me', id=10001); + slc_id = models.Slice.import_obj(expected_slice, import_time=1989) + self.assert_slice_equals(expected_slice, self.get_slice(slc_id)) + + table_id = self.get_table_by_name('wb_health_population').id + self.assertEquals(table_id, self.get_slice(slc_id).datasource_id) + + def test_import_2_slices_for_same_table(self): + table_id = self.get_table_by_name('wb_health_population').id + # table_id != 666, import func will have to find the table + slc_1 = self.create_slice('Import Me 1', ds_id=666, id=10002) + slc_id_1 = models.Slice.import_obj(slc_1) + slc_2 = self.create_slice('Import Me 2', ds_id=666, id=10003) + slc_id_2 = models.Slice.import_obj(slc_2) + + imported_slc_1 = self.get_slice(slc_id_1) + imported_slc_2 = self.get_slice(slc_id_2) + self.assertEquals(table_id, imported_slc_1.datasource_id) + self.assert_slice_equals(slc_1, imported_slc_1) + + self.assertEquals(table_id, imported_slc_2.datasource_id) + self.assert_slice_equals(slc_2, imported_slc_2) + + def test_import_slices_for_non_existent_table(self): + with self.assertRaises(IndexError): + models.Slice.import_obj(self.create_slice( + 'Import Me 3', id=10004, table_name='non_existent')) + + def test_import_slices_override(self): + slc = self.create_slice('Import Me New', id=10005) + slc_1_id = models.Slice.import_obj(slc, import_time=1990) + slc.slice_name = 'Import Me New' + slc_2_id = models.Slice.import_obj( + self.create_slice('Import Me New', id=10005), import_time=1990) + self.assertEquals(slc_1_id, slc_2_id) + imported_slc = self.get_slice(slc_2_id) + self.assert_slice_equals(slc, imported_slc) + + def test_import_empty_dashboard(self): + empty_dash = self.create_dashboard('empty_dashboard', id=10001) + imported_dash_id = models.Dashboard.import_obj( empty_dash, import_time=1989) - imported_dash_1 = get_dash(imported_dash_id_1) - self.assertEquals('empty_dashboard', imported_dash_1.dashboard_title) - - # Case 2. Import dashboard with 1 new slice. - dash_with_1_slice = create_dashboard( - 'dash_with_1_slice', - slcs=[get_slc('health_slc', 'wb_health_population', 1001)], - id=1002) - imported_dash_id_2 = models.Dashboard.import_obj( + imported_dash = self.get_dash(imported_dash_id) + self.assert_dash_equals(empty_dash, imported_dash) + + def test_import_dashboard_1_slice(self): + slc = self.create_slice('health_slc', id=10006) + dash_with_1_slice = self.create_dashboard( + 'dash_with_1_slice', slcs=[slc], id=10002) + imported_dash_id = models.Dashboard.import_obj( dash_with_1_slice, import_time=1990) - imported_dash_2 = get_dash(imported_dash_id_2) - self.assertEquals('dash_with_1_slice', imported_dash_2.dashboard_title) - self.assertEquals(1, len(imported_dash_2.slices)) - self.assertEquals('wb_health_population', - imported_dash_2.slices[0].datasource.table_name) - self.assertEquals( - {"remote_id": 1002, "import_time": 1990}, - json.loads(imported_dash_2.json_metadata)) - - # Case 3. Import dashboard with 2 new slices. - dash_with_2_slices = create_dashboard( - 'dash_with_2_slices', - slcs=[get_slc('energy_slc', 'energy_usage', 1002), - get_slc('birth_slc', 'birth_names', 1003)], - id=1003) - imported_dash_id_3 = models.Dashboard.import_obj( + imported_dash = self.get_dash(imported_dash_id) + + expected_dash = self.create_dashboard( + 'dash_with_1_slice', slcs=[slc], id=10002) + make_transient(expected_dash) + self.assert_dash_equals(expected_dash, imported_dash) + self.assertEquals({"remote_id": 10002, "import_time": 1990}, + json.loads(imported_dash.json_metadata)) + + def test_import_dashboard_2_slices(self): + e_slc = self.create_slice('e_slc', id=10007, table_name='energy_usage') + b_slc = self.create_slice('b_slc', id=10008, table_name='birth_names') + dash_with_2_slices = self.create_dashboard( + 'dash_with_2_slices', slcs=[e_slc, b_slc], id=10003) + imported_dash_id = models.Dashboard.import_obj( dash_with_2_slices, import_time=1991) - imported_dash_3 = get_dash(imported_dash_id_3) - self.assertEquals('dash_with_2_slices', imported_dash_3.dashboard_title) - self.assertEquals(2, len(imported_dash_3.slices)) - self.assertEquals( - {"remote_id": 1003, "import_time": 1991}, - json.loads(imported_dash_3.json_metadata)) - imported_dash_slice_ids_3 = {s.id for s in imported_dash_3.slices} - - # Case 4. Override the dashboard and 2 slices - dash_with_2_slices_new = create_dashboard( - 'dash_with_2_slices_new', - slcs=[get_slc('energy_slc', 'energy_usage', 1002), - get_slc('birth_slc', 'birth_names', 1003)], - id=1003) - imported_dash_id_4 = models.Dashboard.import_obj( - dash_with_2_slices_new, import_time=1992) - imported_dash_4 = get_dash(imported_dash_id_4) - self.assertEquals('dash_with_2_slices_new', - imported_dash_4.dashboard_title) - self.assertEquals(2, len(imported_dash_4.slices)) - self.assertEquals( - imported_dash_slice_ids_3, {s.id for s in imported_dash_4.slices}) - self.assertEquals( - {"remote_id": 1003, "import_time": 1992}, - json.loads(imported_dash_4.json_metadata)) - # override doesn't change the id - self.assertEquals(imported_dash_id_3, imported_dash_id_4) - - # cleanup - slc_to_delete = set(imported_dash_2.slices) - slc_to_delete.update(imported_dash_4.slices) - for slc in slc_to_delete: - session.delete(slc) - - session.delete(imported_dash_1) - session.delete(imported_dash_2) - session.delete(imported_dash_4) - session.commit() - - def test_import_tables(self): - session = db.session + imported_dash = self.get_dash(imported_dash_id) + + expected_dash = self.create_dashboard( + 'dash_with_2_slices', slcs=[e_slc, b_slc], id=10003) + make_transient(expected_dash) + self.assert_dash_equals(imported_dash, expected_dash) + self.assertEquals({"remote_id": 10003, "import_time": 1991}, + json.loads(imported_dash.json_metadata)) + + def test_import_override_dashboard_2_slices(self): + e_slc = self.create_slice('e_slc', id=10009, table_name='energy_usage') + b_slc = self.create_slice('b_slc', id=10010, table_name='birth_names') + dash_to_import = self.create_dashboard( + 'override_dashboard', slcs=[e_slc, b_slc], id=10004) + imported_dash_id_1 = models.Dashboard.import_obj( + dash_to_import, import_time=1992) + + # create new instances of the slices + e_slc = self.create_slice( + 'e_slc', id=10009, table_name='energy_usage') + b_slc = self.create_slice( + 'b_slc', id=10010, table_name='birth_names') + c_slc = self.create_slice('c_slc', id=10011, table_name='birth_names') + dash_to_import_override = self.create_dashboard( + 'override_dashboard_new', slcs=[e_slc, b_slc, c_slc], id=10004) + imported_dash_id_2 = models.Dashboard.import_obj( + dash_to_import_override, import_time=1992) - def create_table( - name, schema='', id=0, cols_names=[], metric_names=[]): - json_metadata = {'remote_id': id, 'database_name': 'main'} - table = models.SqlaTable( - id=id, - schema=schema, - table_name=name, - json_metadata=json.dumps(json_metadata) - ) - for col_name in cols_names: - table.columns.append( - models.TableColumn(column_name=col_name)) - for metric_name in metric_names: - table.metrics.append(models.SqlMetric(metric_name=metric_name)) - return table - - def get_table(table_id): - return session.query(models.SqlaTable).filter_by( - id=table_id).first() - - # Case 1. Import table with no metrics and columns - pure_table = create_table('pure_table', id=1001) - imported_table_1_id = models.SqlaTable.import_obj( - pure_table, import_time=1989) - imported_table_1 = get_table(imported_table_1_id) - self.assertEquals('pure_table', imported_table_1.table_name) - - # Case 2. Import table with 1 column and metric - table_2 = create_table( - 'table_2', id=1002, cols_names=["col1"], metric_names=["metric1"]) - imported_table_id_2 = models.SqlaTable.import_obj( - table_2, import_time=1990) - imported_table_2 = get_table(imported_table_id_2) - self.assertEquals('table_2', imported_table_2.table_name) - self.assertEquals(1, len(imported_table_2.columns)) - self.assertEquals('col1', imported_table_2.columns[0].column_name) - self.assertEquals(1, len(imported_table_2.metrics)) - self.assertEquals('metric1', imported_table_2.metrics[0].metric_name) - self.assertEquals( - {'remote_id': 1002, 'import_time': 1990, 'database_name': 'main'}, - json.loads(imported_table_2.json_metadata)) - - # Case 3. Import table with 2 metrics and 2 columns - table_3 = create_table( - 'table_3', id=1003, cols_names=['col1', 'col2'], - metric_names=['metric1', 'metric2']) - imported_table_id_3 = models.SqlaTable.import_obj( - table_3, import_time=1991) - - imported_table_3 = get_table(imported_table_id_3) - self.assertEquals('table_3', imported_table_3.table_name) - self.assertEquals(2, len(imported_table_3.columns)) - self.assertEquals( - {'col1', 'col2'}, - set([c.column_name for c in imported_table_3.columns])) - self.assertEquals(2, len(imported_table_3.metrics)) + # override doesn't change the id + self.assertEquals(imported_dash_id_1, imported_dash_id_2) + expected_dash = self.create_dashboard( + 'override_dashboard_new', slcs=[e_slc, b_slc, c_slc], id=10004) + make_transient(expected_dash) + imported_dash = self.get_dash(imported_dash_id_2) + self.assert_dash_equals(expected_dash, imported_dash) + self.assertEquals({"remote_id": 10004, "import_time": 1992}, + json.loads(imported_dash.json_metadata)) + + def test_import_table_no_metadata(self): + table = self.create_table('pure_table', id=10001) + imported_t_id = models.SqlaTable.import_obj(table, import_time=1989) + imported_table = self.get_table(imported_t_id) + self.assert_table_equals(table, imported_table) + + def test_import_table_1_col_1_met(self): + table = self.create_table( + 'table_1_col_1_met', id=10002, + cols_names=["col1"], metric_names=["metric1"]) + imported_t_id = models.SqlaTable.import_obj(table, import_time=1990) + imported_table = self.get_table(imported_t_id) + self.assert_table_equals(table, imported_table) self.assertEquals( - {'metric1', 'metric2'}, - set([m.metric_name for m in imported_table_3.metrics])) - imported_table_3_id = imported_table_3.id - - # Case 4. Override table with different metrics and columns - table_4 = create_table( - 'table_3', id=1003, cols_names=['new_col1', 'col2', 'col3'], + {'remote_id': 10002, 'import_time': 1990, 'database_name': 'main'}, + json.loads(imported_table.json_metadata)) + + def test_import_table_2_col_2_met(self): + table = self.create_table( + 'table_2_col_2_met', id=10003, cols_names=['c1', 'c2'], + metric_names=['m1', 'm2']) + imported_t_id = models.SqlaTable.import_obj(table, import_time=1991) + + imported_table = self.get_table(imported_t_id) + self.assert_table_equals(table, imported_table) + + def test_import_table_override(self): + table = self.create_table( + 'table_override', id=10003, cols_names=['col1'], + metric_names=['m1']) + imported_t_id = models.SqlaTable.import_obj(table, import_time=1991) + + table_over = self.create_table( + 'table_override', id=10003, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) - imported_table_id_4 = models.SqlaTable.import_obj( - table_4, import_time=1992) - - imported_table_4 = get_table(imported_table_id_4) - self.assertEquals('table_3', imported_table_4.table_name) - self.assertEquals(4, len(imported_table_4.columns)) - # metrics and columns are never deleted, only appended - self.assertEquals( - {'col1', 'new_col1', 'col2', 'col3'}, - set([c.column_name for c in imported_table_4.columns])) - self.assertEquals(3, len(imported_table_4.metrics)) - self.assertEquals( - {'new_metric1', 'metric1', 'metric2'}, - set([m.metric_name for m in imported_table_4.metrics])) - # override doesn't change the id - self.assertEquals(imported_table_3_id, imported_table_4.id) + imported_table_over_id = models.SqlaTable.import_obj( + table_over, import_time=1992) + + imported_table_over = self.get_table(imported_table_over_id) + self.assertEquals(imported_t_id, imported_table_over.id) + expected_table = self.create_table( + 'table_override', id=10003, metric_names=['new_metric1', 'm1'], + cols_names=['col1', 'new_col1', 'col2', 'col3']) + self.assert_table_equals(expected_table, imported_table_over) + + def test_import_table_override_idential(self): + table = self.create_table( + 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], + metric_names=['new_metric1']) + imported_t_id = models.SqlaTable.import_obj(table, import_time=1993) - # Case 5. Override with the identical table - table_5 = create_table( - 'table_3', id=1003, cols_names=['new_col1', 'col2', 'col3'], + copy_table = self.create_table( + 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) - imported_table_id_5 = models.SqlaTable.import_obj( - table_5, import_time=1993) + imported_t_id_copy = models.SqlaTable.import_obj( + copy_table, import_time=1994) - imported_table_5 = get_table(imported_table_id_5) - self.assertEquals('table_3', imported_table_5.table_name) - self.assertEquals(4, len(imported_table_5.columns)) - self.assertEquals( - {'col1', 'new_col1', 'col2', 'col3'}, - set([c.column_name for c in imported_table_5.columns])) - self.assertEquals(3, len(imported_table_5.metrics)) - self.assertEquals( - {'new_metric1', 'metric1', 'metric2'}, - set([m.metric_name for m in imported_table_5.metrics])) - - # cleanup - # imported_table_3 and 4 are overriden - for table in [imported_table_1, imported_table_2, imported_table_5]: - for metric in table.metrics: - session.delete(metric) - for column in table.columns: - session.delete(column) - session.delete(table) - session.commit() + self.assertEquals(imported_t_id, imported_t_id_copy) + self.assert_table_equals(copy_table, self.get_table(imported_t_id)) if __name__ == '__main__': unittest.main() From 784d8853b0fa203143b745d2a97b4d006a27021a Mon Sep 17 00:00:00 2001 From: Bogdan Kyryliuk Date: Tue, 11 Oct 2016 09:35:58 -0700 Subject: [PATCH 12/13] Move params_dict and alter_params to the ImportMixin --- ...y => b46fa1b0b39e_add_params_to_tables.py} | 4 +- caravel/models.py | 75 +++++++------------ caravel/views.py | 3 +- tests/import_export_tests.py | 10 +-- 4 files changed, 36 insertions(+), 56 deletions(-) rename caravel/migrations/versions/{b46fa1b0b39e_.py => b46fa1b0b39e_add_params_to_tables.py} (77%) diff --git a/caravel/migrations/versions/b46fa1b0b39e_.py b/caravel/migrations/versions/b46fa1b0b39e_add_params_to_tables.py similarity index 77% rename from caravel/migrations/versions/b46fa1b0b39e_.py rename to caravel/migrations/versions/b46fa1b0b39e_add_params_to_tables.py index 372fdbed8a52..9d02ec5b4b10 100644 --- a/caravel/migrations/versions/b46fa1b0b39e_.py +++ b/caravel/migrations/versions/b46fa1b0b39e_add_params_to_tables.py @@ -17,12 +17,12 @@ def upgrade(): op.add_column('tables', - sa.Column('json_metadata', sa.Text(), nullable=True)) + sa.Column('params', sa.Text(), nullable=True)) def downgrade(): try: - op.drop_column('tables', 'json_metadata') + op.drop_column('tables', 'params') except Exception as e: logging.warning(str(e)) diff --git a/caravel/models.py b/caravel/models.py index 114d8d37678b..706fc49a3841 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -86,6 +86,18 @@ def copy(self): new_obj.override(self) return new_obj + def alter_params(self, **kwargs): + d = self.params_dict + d.update(kwargs) + self.params = json.dumps(d) + + @property + def params_dict(self): + if self.params: + return json.loads(self.params) + else: + return {} + class AuditMixinNullable(AuditMixin): @@ -275,13 +287,6 @@ def slice_link(self): name = escape(self.slice_name) return Markup('{name}'.format(**locals())) - @property - def params_dict(self): - if self.params: - return json.loads(self.params) - else: - return {} - def get_viz(self, url_params_multidict=None): """Creates :py:class:viz.BaseViz object from the url_params_multidict. @@ -310,11 +315,6 @@ def get_viz(self, url_params_multidict=None): slice_=self ) - def alter_params(self, **kwargs): - d = self.params_dict - d.update(kwargs) - self.params = json.dumps(d) - @classmethod def import_obj(cls, slc_to_import, import_time=None): """Inserts or overrides slc in the database. @@ -411,13 +411,6 @@ def url(self): def datasources(self): return {slc.datasource for slc in self.slices} - @property - def metadata_dejson(self): - if self.json_metadata: - return json.loads(self.json_metadata) - else: - return {} - @property def sqla_metadata(self): metadata = MetaData(bind=self.get_sqla_engine()) @@ -432,7 +425,7 @@ def dashboard_link(self): def json_data(self): d = { 'id': self.id, - 'metadata': self.metadata_dejson, + 'metadata': self.params_dict, 'dashboard_title': self.dashboard_title, 'slug': self.slug, 'slices': [slc.data for slc in self.slices], @@ -441,15 +434,12 @@ def json_data(self): return json.dumps(d) @property - def dict_metadata(self): - if not self.json_metadata: - return {} - return json.loads(self.json_metadata) + def params(self): + return self.json_metadata - def alter_json_metadata(self, **kwargs): - d = self.dict_metadata - d.update(kwargs) - self.json_metadata = json.dumps(d) + @params.setter + def params(self, value): + self.json_metadata = value @classmethod def import_obj(cls, dashboard_to_import, import_time=None): @@ -479,13 +469,13 @@ def import_obj(cls, dashboard_to_import, import_time=None): # override the dashboard existing_dashboard = None for dash in session.query(Dashboard).all(): - if ('remote_id' in dash.metadata_dejson and - dash.metadata_dejson['remote_id'] == + if ('remote_id' in dash.params_dict and + dash.params_dict['remote_id'] == dashboard_to_import.id): existing_dashboard = dash dashboard_to_import.id = None - dashboard_to_import.alter_json_metadata(import_time=import_time) + dashboard_to_import.alter_params(import_time=import_time) new_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all() if existing_dashboard: @@ -525,14 +515,14 @@ def export_dashboards(cls, dashboard_ids): schema=slc.datasource.name, database_name=slc.datasource.database.database_name, ) - copied_dashboard.alter_json_metadata(remote_id=dashboard_id) + copied_dashboard.alter_params(remote_id=dashboard_id) copied_dashboards.append(copied_dashboard) eager_datasources = [] for dashboard_id, dashboard_type in datasource_ids: eager_datasource = SourceRegistry.get_eager_datasource( db.session, dashboard_type, dashboard_id) - eager_datasource.alter_json_metadata( + eager_datasource.alter_params( remote_id=eager_datasource.id, database_name=eager_datasource.database.database_name, ) @@ -868,13 +858,13 @@ class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin): cache_timeout = Column(Integer) schema = Column(String(255)) sql = Column(Text) - json_metadata = Column(Text) + params = Column(Text) baselink = "tablemodelview" export_fields = ( 'table_name', 'main_dttm_col', 'description', 'default_endpoint', 'database_id', 'is_featured', 'offset', 'cache_timeout', 'schema', - 'sql', 'json_metadata') + 'sql', 'params') __table_args__ = ( sqla.UniqueConstraint( @@ -1245,17 +1235,6 @@ def fetch_metadata(self): if not self.main_dttm_col: self.main_dttm_col = any_date_col - def alter_json_metadata(self, **kwargs): - d = self.dict_metadata - d.update(kwargs) - self.json_metadata = json.dumps(d) - - @property - def dict_metadata(self): - if self.json_metadata: - return json.loads(self.json_metadata) - return {} - @classmethod def import_obj(cls, datasource_to_import, import_time=None): """Imports the datasource from the object to the database. @@ -1270,10 +1249,10 @@ def import_obj(cls, datasource_to_import, import_time=None): .format(datasource_to_import.to_json())) datasource_to_import.id = None - database_name = datasource_to_import.dict_metadata['database_name'] + database_name = datasource_to_import.params_dict['database_name'] datasource_to_import.database_id = session.query(Database).filter_by( database_name=database_name).one().id - datasource_to_import.alter_json_metadata(import_time=import_time) + datasource_to_import.alter_params(import_time=import_time) # override the datasource datasource = ( diff --git a/caravel/views.py b/caravel/views.py index 9d6c22bb9ef3..ea3fed604191 100755 --- a/caravel/views.py +++ b/caravel/views.py @@ -908,6 +908,7 @@ def download_dashboards(self): category='', category_icon='',) + class DashboardModelViewAsync(DashboardModelView): # noqa list_columns = ['dashboard_link', 'creator', 'modified', 'dashboard_title'] label_columns = { @@ -1529,7 +1530,7 @@ def save_dash(self, dashboard_id): dash.slices = [o for o in dash.slices if o.id in slice_ids] positions = sorted(data['positions'], key=lambda x: int(x['slice_id'])) dash.position_json = json.dumps(positions, indent=4, sort_keys=True) - md = dash.metadata_dejson + md = dash.params_dict if 'filter_immune_slices' not in md: md['filter_immune_slices'] = [] if 'filter_immune_slice_fields' not in md: diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index 979be0ba660b..454dbbb73b53 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -29,10 +29,10 @@ def delete_imports(cls): if 'remote_id' in slc.params_dict: session.delete(slc) for dash in session.query(models.Dashboard): - if 'remote_id' in dash.metadata_dejson: + if 'remote_id' in dash.params_dict: session.delete(dash) for table in session.query(models.SqlaTable): - if 'remote_id' in table.dict_metadata: + if 'remote_id' in table.params_dict: session.delete(table) session.commit() @@ -80,12 +80,12 @@ def create_dashboard(self, title, id=0, slcs=[]): ) def create_table(self, name, schema='', id=0, cols_names=[], metric_names=[]): - json_metadata = {'remote_id': id, 'database_name': 'main'} + params = {'remote_id': id, 'database_name': 'main'} table = models.SqlaTable( id=id, schema=schema, table_name=name, - json_metadata=json.dumps(json_metadata) + params=json.dumps(params) ) for col_name in cols_names: table.columns.append( @@ -319,7 +319,7 @@ def test_import_table_1_col_1_met(self): self.assert_table_equals(table, imported_table) self.assertEquals( {'remote_id': 10002, 'import_time': 1990, 'database_name': 'main'}, - json.loads(imported_table.json_metadata)) + json.loads(imported_table.params)) def test_import_table_2_col_2_met(self): table = self.create_table( From 813b3a39a0c2d60811e28eafcbddffdbc6d140fd Mon Sep 17 00:00:00 2001 From: Bogdan Kyryliuk Date: Tue, 11 Oct 2016 16:20:48 -0700 Subject: [PATCH 13/13] Fix flask menues. --- caravel/templates/caravel/import_dashboards.html | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/caravel/templates/caravel/import_dashboards.html b/caravel/templates/caravel/import_dashboards.html index b36901d2906f..8365bbe0f7be 100644 --- a/caravel/templates/caravel/import_dashboards.html +++ b/caravel/templates/caravel/import_dashboards.html @@ -1,4 +1,13 @@ {% extends "caravel/basic.html" %} + +# TODO: move the libs required by flask into the common.js from welcome.js. +{% block head_js %} + {{ super() }} + {% with filename="welcome" %} + {% include "caravel/partials/_script_tag.html" %} + {% endwith %} +{% endblock %} + {% block title %}{{ _("Import") }}{% endblock %} {% block body %} {% include "caravel/flash_wrapper.html" %}