diff --git a/caravel/__init__.py b/caravel/__init__.py index 1fcfe43acb99..bb74091a714a 100644 --- a/caravel/__init__.py +++ b/caravel/__init__.py @@ -82,6 +82,12 @@ def checkout(dbapi_con, con_record, con_proxy): if app.config.get('ENABLE_PROXY_FIX'): app.wsgi_app = ProxyFix(app.wsgi_app) +if app.config.get('UPLOAD_FOLDER'): + try: + os.makedirs(app.config.get('UPLOAD_FOLDER')) + except OSError: + pass + class MyIndexView(IndexView): @expose('/') diff --git a/caravel/data/__init__.py b/caravel/data/__init__.py index 974ce2e898ba..0f9e2a308efa 100644 --- a/caravel/data/__init__.py +++ b/caravel/data/__init__.py @@ -1058,7 +1058,7 @@ def load_multiformat_time_series_data(): 'string0': ['%Y-%m-%d %H:%M:%S.%f', None], 'string3': ['%Y/%m/%d%H:%M:%S.%f', None], } - for col in obj.table_columns: + for col in obj.columns: dttm_and_expr = dttm_and_expr_dict[col.column_name] col.python_date_format = dttm_and_expr[0] col.dbatabase_expr = dttm_and_expr[1] @@ -1069,7 +1069,7 @@ def load_multiformat_time_series_data(): tbl = obj print("Creating some slices") - for i, col in enumerate(tbl.table_columns): + for i, col in enumerate(tbl.columns): slice_data = { "granularity_sqla": col.column_name, "datasource_id": "8", diff --git a/caravel/migrations/versions/b46fa1b0b39e_add_params_to_tables.py b/caravel/migrations/versions/b46fa1b0b39e_add_params_to_tables.py new file mode 100644 index 000000000000..9d02ec5b4b10 --- /dev/null +++ b/caravel/migrations/versions/b46fa1b0b39e_add_params_to_tables.py @@ -0,0 +1,28 @@ +"""Add json_metadata to the tables table. + +Revision ID: b46fa1b0b39e +Revises: ef8843b41dac +Create Date: 2016-10-05 11:30:31.748238 + +""" + +# revision identifiers, used by Alembic. +revision = 'b46fa1b0b39e' +down_revision = 'ef8843b41dac' + +from alembic import op +import logging +import sqlalchemy as sa + + +def upgrade(): + op.add_column('tables', + sa.Column('params', sa.Text(), nullable=True)) + + +def downgrade(): + try: + op.drop_column('tables', 'params') + except Exception as e: + logging.warning(str(e)) + diff --git a/caravel/models.py b/caravel/models.py index f160e038d960..706fc49a3841 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -7,6 +7,7 @@ import functools import json import logging +import pickle import re import textwrap from collections import namedtuple @@ -18,6 +19,8 @@ import requests import sqlalchemy as sqla from sqlalchemy.engine.url import make_url +from sqlalchemy.orm import subqueryload + import sqlparse from dateutil.parser import parse @@ -41,6 +44,7 @@ from sqlalchemy.ext.compiler import compiles from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import backref, relationship +from sqlalchemy.orm.session import make_transient from sqlalchemy.sql import table, literal_column, text, column from sqlalchemy.sql.expression import ColumnClause, TextAsFrom from sqlalchemy_utils import EncryptedType @@ -70,6 +74,31 @@ def __init__(self, name, field_names, function): self.name = name +class ImportMixin(object): + def override(self, obj): + """Overrides the plain fields of the dashboard.""" + for field in obj.__class__.export_fields: + setattr(self, field, getattr(obj, field)) + + def copy(self): + """Creates a copy of the dashboard without relationships.""" + new_obj = self.__class__() + new_obj.override(self) + return new_obj + + def alter_params(self, **kwargs): + d = self.params_dict + d.update(kwargs) + self.params = json.dumps(d) + + @property + def params_dict(self): + if self.params: + return json.loads(self.params) + else: + return {} + + class AuditMixinNullable(AuditMixin): """Altering the AuditMixin to use nullable fields @@ -149,7 +178,7 @@ class CssTemplate(Model, AuditMixinNullable): ) -class Slice(Model, AuditMixinNullable): +class Slice(Model, AuditMixinNullable, ImportMixin): """A slice is essentially a report or a view on data""" @@ -166,6 +195,9 @@ class Slice(Model, AuditMixinNullable): perm = Column(String(2000)) owners = relationship("User", secondary=slice_user) + export_fields = ('slice_name', 'datasource_type', 'datasource_name', + 'viz_type', 'params', 'cache_timeout') + def __repr__(self): return self.slice_name @@ -283,6 +315,42 @@ def get_viz(self, url_params_multidict=None): slice_=self ) + @classmethod + def import_obj(cls, slc_to_import, import_time=None): + """Inserts or overrides slc in the database. + + remote_id and import_time fields in params_dict are set to track the + slice origin and ensure correct overrides for multiple imports. + Slice.perm is used to find the datasources and connect them. + """ + session = db.session + make_transient(slc_to_import) + slc_to_import.dashboards = [] + slc_to_import.alter_params( + remote_id=slc_to_import.id, import_time=import_time) + + # find if the slice was already imported + slc_to_override = None + for slc in session.query(Slice).all(): + if ('remote_id' in slc.params_dict and + slc.params_dict['remote_id'] == slc_to_import.id): + slc_to_override = slc + + slc_to_import.id = None + params = slc_to_import.params_dict + slc_to_import.datasource_id = SourceRegistry.get_datasource_by_name( + session, slc_to_import.datasource_type, params['datasource_name'], + params['schema'], params['database_name']).id + if slc_to_override: + slc_to_override.override(slc_to_import) + session.flush() + return slc_to_override.id + else: + session.add(slc_to_import) + logging.info('Final slice: {}'.format(slc_to_import.to_json())) + session.flush() + return slc_to_import.id + def set_perm(mapper, connection, target): # noqa src_class = target.cls_model @@ -309,7 +377,7 @@ def set_perm(mapper, connection, target): # noqa ) -class Dashboard(Model, AuditMixinNullable): +class Dashboard(Model, AuditMixinNullable, ImportMixin): """The dashboard object!""" @@ -325,6 +393,9 @@ class Dashboard(Model, AuditMixinNullable): 'Slice', secondary=dashboard_slices, backref='dashboards') owners = relationship("User", secondary=dashboard_user) + export_fields = ('dashboard_title', 'position_json', 'json_metadata', + 'description', 'css', 'slug', 'slices') + def __repr__(self): return self.dashboard_title @@ -340,13 +411,6 @@ def url(self): def datasources(self): return {slc.datasource for slc in self.slices} - @property - def metadata_dejson(self): - if self.json_metadata: - return json.loads(self.json_metadata) - else: - return {} - @property def sqla_metadata(self): metadata = MetaData(bind=self.get_sqla_engine()) @@ -361,7 +425,7 @@ def dashboard_link(self): def json_data(self): d = { 'id': self.id, - 'metadata': self.metadata_dejson, + 'metadata': self.params_dict, 'dashboard_title': self.dashboard_title, 'slug': self.slug, 'slices': [slc.data for slc in self.slices], @@ -369,6 +433,107 @@ def json_data(self): } return json.dumps(d) + @property + def params(self): + return self.json_metadata + + @params.setter + def params(self, value): + self.json_metadata = value + + @classmethod + def import_obj(cls, dashboard_to_import, import_time=None): + """Imports the dashboard from the object to the database. + + Once dashboard is imported, json_metadata field is extended and stores + remote_id and import_time. It helps to decide if the dashboard has to + be overridden or just copies over. Slices that belong to this + dashboard will be wired to existing tables. This function can be used + to import/export dashboards between multiple caravel instances. + Audit metadata isn't copies over. + """ + logging.info('Started import of the dashboard: {}' + .format(dashboard_to_import.to_json())) + session = db.session + logging.info('Dashboard has {} slices' + .format(len(dashboard_to_import.slices))) + # copy slices object as Slice.import_slice will mutate the slice + # and will remove the existing dashboard - slice association + slices = copy(dashboard_to_import.slices) + slice_ids = set() + for slc in slices: + logging.info('Importing slice {} from the dashboard: {}'.format( + slc.to_json(), dashboard_to_import.dashboard_title)) + slice_ids.add(Slice.import_obj(slc, import_time=import_time)) + + # override the dashboard + existing_dashboard = None + for dash in session.query(Dashboard).all(): + if ('remote_id' in dash.params_dict and + dash.params_dict['remote_id'] == + dashboard_to_import.id): + existing_dashboard = dash + + dashboard_to_import.id = None + dashboard_to_import.alter_params(import_time=import_time) + new_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all() + + if existing_dashboard: + existing_dashboard.override(dashboard_to_import) + existing_dashboard.slices = new_slices + session.flush() + return existing_dashboard.id + else: + # session.add(dashboard_to_import) causes sqlachemy failures + # related to the attached users / slices. Creating new object + # allows to avoid conflicts in the sql alchemy state. + copied_dash = dashboard_to_import.copy() + copied_dash.slices = new_slices + session.add(copied_dash) + session.flush() + return copied_dash.id + + @classmethod + def export_dashboards(cls, dashboard_ids): + copied_dashboards = [] + datasource_ids = set() + for dashboard_id in dashboard_ids: + # make sure that dashboard_id is an integer + dashboard_id = int(dashboard_id) + copied_dashboard = ( + db.session.query(Dashboard) + .options(subqueryload(Dashboard.slices)) + .filter_by(id=dashboard_id).first() + ) + make_transient(copied_dashboard) + for slc in copied_dashboard.slices: + datasource_ids.add((slc.datasource_id, slc.datasource_type)) + # add extra params for the import + slc.alter_params( + remote_id=slc.id, + datasource_name=slc.datasource.name, + schema=slc.datasource.name, + database_name=slc.datasource.database.database_name, + ) + copied_dashboard.alter_params(remote_id=dashboard_id) + copied_dashboards.append(copied_dashboard) + + eager_datasources = [] + for dashboard_id, dashboard_type in datasource_ids: + eager_datasource = SourceRegistry.get_eager_datasource( + db.session, dashboard_type, dashboard_id) + eager_datasource.alter_params( + remote_id=eager_datasource.id, + database_name=eager_datasource.database.database_name, + ) + make_transient(eager_datasource) + eager_datasources.append(eager_datasource) + + return pickle.dumps({ + 'dashboards': copied_dashboards, + 'datasources': eager_datasources, + }) + class Queryable(object): @@ -433,6 +598,10 @@ class Database(Model, AuditMixinNullable): def __repr__(self): return self.database_name + @property + def name(self): + return self.database_name + @property def backend(self): url = make_url(self.sqlalchemy_uri_decrypted) @@ -665,7 +834,7 @@ def perm(self): "[{obj.database_name}].(id:{obj.id})").format(obj=self) -class SqlaTable(Model, Queryable, AuditMixinNullable): +class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin): """An ORM object for SqlAlchemy table references""" @@ -689,9 +858,13 @@ class SqlaTable(Model, Queryable, AuditMixinNullable): cache_timeout = Column(Integer) schema = Column(String(255)) sql = Column(Text) - table_columns = relationship("TableColumn", back_populates="table") + params = Column(Text) baselink = "tablemodelview" + export_fields = ( + 'table_name', 'main_dttm_col', 'description', 'default_endpoint', + 'database_id', 'is_featured', 'offset', 'cache_timeout', 'schema', + 'sql', 'params') __table_args__ = ( sqla.UniqueConstraint( @@ -773,7 +946,7 @@ def time_column_grains(self): } def get_col(self, col_name): - columns = self.table_columns + columns = self.columns for col in columns: if col_name == col.column_name: return col @@ -1062,8 +1235,67 @@ def fetch_metadata(self): if not self.main_dttm_col: self.main_dttm_col = any_date_col + @classmethod + def import_obj(cls, datasource_to_import, import_time=None): + """Imports the datasource from the object to the database. + + Metrics and columns and datasource will be overrided if exists. + This function can be used to import/export dashboards between multiple + caravel instances. Audit metadata isn't copies over. + """ + session = db.session + make_transient(datasource_to_import) + logging.info('Started import of the datasource: {}' + .format(datasource_to_import.to_json())) + + datasource_to_import.id = None + database_name = datasource_to_import.params_dict['database_name'] + datasource_to_import.database_id = session.query(Database).filter_by( + database_name=database_name).one().id + datasource_to_import.alter_params(import_time=import_time) + + # override the datasource + datasource = ( + session.query(SqlaTable).join(Database) + .filter( + SqlaTable.table_name == datasource_to_import.table_name, + SqlaTable.schema == datasource_to_import.schema, + Database.id == datasource_to_import.database_id, + ) + .first() + ) + + if datasource: + datasource.override(datasource_to_import) + session.flush() + else: + datasource = datasource_to_import.copy() + session.add(datasource) + session.flush() -class SqlMetric(Model, AuditMixinNullable): + for m in datasource_to_import.metrics: + new_m = m.copy() + new_m.table_id = datasource.id + logging.info('Importing metric {} from the datasource: {}'.format( + new_m.to_json(), datasource_to_import.full_name)) + imported_m = SqlMetric.import_obj(new_m) + if imported_m not in datasource.metrics: + datasource.metrics.append(imported_m) + + for c in datasource_to_import.columns: + new_c = c.copy() + new_c.table_id = datasource.id + logging.info('Importing column {} from the datasource: {}'.format( + new_c.to_json(), datasource_to_import.full_name)) + imported_c = TableColumn.import_obj(new_c) + if imported_c not in datasource.columns: + datasource.columns.append(imported_c) + db.session.flush() + + return datasource.id + + +class SqlMetric(Model, AuditMixinNullable, ImportMixin): """ORM object for metrics, each table can have multiple metrics""" @@ -1082,6 +1314,10 @@ class SqlMetric(Model, AuditMixinNullable): is_restricted = Column(Boolean, default=False, nullable=True) d3format = Column(String(128)) + export_fields = ( + 'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression', + 'description', 'is_restricted', 'd3format') + @property def sqla_col(self): name = self.metric_name @@ -1094,8 +1330,28 @@ def perm(self): ).format(obj=self, parent_name=self.table.full_name) if self.table else None + @classmethod + def import_obj(cls, metric_to_import): + session = db.session + make_transient(metric_to_import) + metric_to_import.id = None + + # find if the column was already imported + existing_metric = session.query(SqlMetric).filter( + SqlMetric.table_id == metric_to_import.table_id, + SqlMetric.metric_name == metric_to_import.metric_name).first() + metric_to_import.table = None + if existing_metric: + existing_metric.override(metric_to_import) + session.flush() + return existing_metric + + session.add(metric_to_import) + session.flush() + return metric_to_import + -class TableColumn(Model, AuditMixinNullable): +class TableColumn(Model, AuditMixinNullable, ImportMixin): """ORM object for table columns, each table can have multiple columns""" @@ -1125,6 +1381,12 @@ class TableColumn(Model, AuditMixinNullable): num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG') date_types = ('DATE', 'TIME') str_types = ('VARCHAR', 'STRING', 'CHAR') + export_fields = ( + 'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active', + 'type', 'groupby', 'count_distinct', 'sum', 'max', 'min', + 'filterable', 'expression', 'description', 'python_date_format', + 'database_expression' + ) def __repr__(self): return self.column_name @@ -1150,6 +1412,27 @@ def sqla_col(self): col = literal_column(self.expression).label(name) return col + @classmethod + def import_obj(cls, column_to_import): + session = db.session + make_transient(column_to_import) + column_to_import.id = None + column_to_import.table = None + + # find if the column was already imported + existing_column = session.query(TableColumn).filter( + TableColumn.table_id == column_to_import.table_id, + TableColumn.column_name == column_to_import.column_name).first() + column_to_import.table = None + if existing_column: + existing_column.override(column_to_import) + session.flush() + return existing_column + + session.add(column_to_import) + session.flush() + return column_to_import + def dttm_sql_literal(self, dttm): """Convert datetime object to string @@ -1234,6 +1517,10 @@ def refresh_datasources(self): def perm(self): return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self) + @property + def name(self): + return self.cluster_name + class DruidDatasource(Model, AuditMixinNullable, Queryable): @@ -1262,6 +1549,10 @@ class DruidDatasource(Model, AuditMixinNullable, Queryable): offset = Column(Integer, default=0) cache_timeout = Column(Integer) + @property + def database(self): + return self.cluster + @property def metrics_combo(self): return sorted( diff --git a/caravel/source_registry.py b/caravel/source_registry.py index dc1a170d7c68..6176c9c0c96b 100644 --- a/caravel/source_registry.py +++ b/caravel/source_registry.py @@ -1,3 +1,4 @@ +from sqlalchemy.orm import subqueryload class SourceRegistry(object): @@ -20,3 +21,30 @@ def get_datasource(cls, datasource_type, datasource_id, session): .filter_by(id=datasource_id) .one() ) + + @classmethod + def get_datasource_by_name(cls, session, datasource_type, datasource_name, + schema, database_name): + datasource_class = SourceRegistry.sources[datasource_type] + datasources = session.query(datasource_class).all() + db_ds = [d for d in datasources if d.database.name == database_name and + d.name == datasource_name and schema == schema] + return db_ds[0] + + @classmethod + def get_eager_datasource(cls, session, datasource_type, datasource_id): + """Returns datasource with columns and metrics.""" + datasource_class = SourceRegistry.sources[datasource_type] + if datasource_type == 'table': + return ( + session.query(datasource_class) + .options( + subqueryload(datasource_class.columns), + subqueryload(datasource_class.metrics) + ) + .filter_by(id=datasource_id) + .one() + ) + # TODO: support druid datasources. + return session.query(datasource_class).filter_by( + id=datasource_id).first() diff --git a/caravel/templates/caravel/export_dashboards.html b/caravel/templates/caravel/export_dashboards.html new file mode 100644 index 000000000000..83a0c840e09f --- /dev/null +++ b/caravel/templates/caravel/export_dashboards.html @@ -0,0 +1,6 @@ + diff --git a/caravel/templates/caravel/import_dashboards.html b/caravel/templates/caravel/import_dashboards.html new file mode 100644 index 000000000000..8365bbe0f7be --- /dev/null +++ b/caravel/templates/caravel/import_dashboards.html @@ -0,0 +1,23 @@ +{% extends "caravel/basic.html" %} + +# TODO: move the libs required by flask into the common.js from welcome.js. +{% block head_js %} + {{ super() }} + {% with filename="welcome" %} + {% include "caravel/partials/_script_tag.html" %} + {% endwith %} +{% endblock %} + +{% block title %}{{ _("Import") }}{% endblock %} +{% block body %} + {% include "caravel/flash_wrapper.html" %} +