From 82e994aa4fdeaddbad1eec60049b1b6753498e52 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Sat, 7 Jan 2017 14:01:54 -0800 Subject: [PATCH] Revert "Druid dashboard import/export. " --- superset/import_util.py | 74 -- .../versions/1296d28ec131_druid_exports.py | 22 - superset/models.py | 868 +++++++++--------- superset/source_registry.py | 25 +- superset/views.py | 8 +- tests/import_export_tests.py | 151 +-- 6 files changed, 463 insertions(+), 685 deletions(-) delete mode 100644 superset/import_util.py delete mode 100644 superset/migrations/versions/1296d28ec131_druid_exports.py diff --git a/superset/import_util.py b/superset/import_util.py deleted file mode 100644 index e71623d57926..000000000000 --- a/superset/import_util.py +++ /dev/null @@ -1,74 +0,0 @@ -import logging -from sqlalchemy.orm.session import make_transient - - -def import_datasource( - session, - i_datasource, - lookup_database, - lookup_datasource, - import_time): - """Imports the datasource from the object to the database. - - Metrics and columns and datasource will be overrided if exists. - This function can be used to import/export dashboards between multiple - superset instances. Audit metadata isn't copies over. - """ - make_transient(i_datasource) - logging.info('Started import of the datasource: {}'.format( - i_datasource.to_json())) - - i_datasource.id = None - i_datasource.database_id = lookup_database(i_datasource).id - i_datasource.alter_params(import_time=import_time) - - # override the datasource - datasource = lookup_datasource(i_datasource) - - if datasource: - datasource.override(i_datasource) - session.flush() - else: - datasource = i_datasource.copy() - session.add(datasource) - session.flush() - - for m in i_datasource.metrics: - new_m = m.copy() - new_m.table_id = datasource.id - logging.info('Importing metric {} from the datasource: {}'.format( - new_m.to_json(), i_datasource.full_name)) - imported_m = i_datasource.metric_cls.import_obj(new_m) - if (imported_m.metric_name not in - [m.metric_name for m in datasource.metrics]): - datasource.metrics.append(imported_m) - - for c in i_datasource.columns: - new_c = c.copy() - new_c.table_id = datasource.id - logging.info('Importing column {} from the datasource: {}'.format( - new_c.to_json(), i_datasource.full_name)) - imported_c = i_datasource.column_cls.import_obj(new_c) - if (imported_c.column_name not in - [c.column_name for c in datasource.columns]): - datasource.columns.append(imported_c) - session.flush() - return datasource.id - - -def import_simple_obj(session, i_obj, lookup_obj): - make_transient(i_obj) - i_obj.id = None - i_obj.table = None - - # find if the column was already imported - existing_column = lookup_obj(i_obj) - i_obj.table = None - if existing_column: - existing_column.override(i_obj) - session.flush() - return existing_column - - session.add(i_obj) - session.flush() - return i_obj diff --git a/superset/migrations/versions/1296d28ec131_druid_exports.py b/superset/migrations/versions/1296d28ec131_druid_exports.py deleted file mode 100644 index 38d9d750f963..000000000000 --- a/superset/migrations/versions/1296d28ec131_druid_exports.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Adds params to the datasource (druid) table - -Revision ID: 1296d28ec131 -Revises: e46f2d27a08e -Create Date: 2016-12-06 17:40:40.389652 - -""" - -# revision identifiers, used by Alembic. -revision = '1296d28ec131' -down_revision = '6414e83d82b7' - -from alembic import op -import sqlalchemy as sa - - -def upgrade(): - op.add_column('datasources', sa.Column('params', sa.String(length=1000), nullable=True)) - - -def downgrade(): - op.drop_column('datasources', 'params') diff --git a/superset/models.py b/superset/models.py index df0ada3b952a..cef46a4c57a8 100644 --- a/superset/models.py +++ b/superset/models.py @@ -52,9 +52,7 @@ from werkzeug.datastructures import ImmutableMultiDict -from superset import ( - app, db, db_engine_specs, get_session, utils, sm, import_util -) +from superset import app, db, db_engine_specs, get_session, utils, sm from superset.source_registry import SourceRegistry from superset.viz import viz_types from superset.jinja_context import get_template_processor @@ -97,13 +95,6 @@ def set_perm(mapper, connection, target): # noqa ) -def set_related_perm(mapper, connection, target): # noqa - src_class = target.cls_model - id_ = target.datasource_id - ds = db.session.query(src_class).filter_by(id=int(id_)).first() - target.perm = ds.perm - - class JavascriptPostAggregator(Postaggregator): def __init__(self, name, field_names, function): self.post_aggregator = { @@ -135,9 +126,7 @@ def alter_params(self, **kwargs): @property def params_dict(self): if self.params: - params = re.sub(",[ \t\r\n]+}", "}", self.params) - params = re.sub(",[ \t\r\n]+\]", "]", params) - return json.loads(params) + return json.loads(self.params) else: return {} @@ -402,11 +391,18 @@ def import_obj(cls, slc_to_import, import_time=None): slc_to_override.override(slc_to_import) session.flush() return slc_to_override.id - session.add(slc_to_import) - logging.info('Final slice: {}'.format(slc_to_import.to_json())) - session.flush() - return slc_to_import.id + else: + session.add(slc_to_import) + logging.info('Final slice: {}'.format(slc_to_import.to_json())) + session.flush() + return slc_to_import.id + +def set_related_perm(mapper, connection, target): # noqa + src_class = target.cls_model + id_ = target.datasource_id + ds = db.session.query(src_class).filter_by(id=int(id_)).first() + target.perm = ds.perm sqla.event.listen(Slice, 'before_insert', set_related_perm) sqla.event.listen(Slice, 'before_update', set_related_perm) @@ -619,7 +615,7 @@ def export_dashboards(cls, dashboard_ids): remote_id=slc.id, datasource_name=slc.datasource.name, schema=slc.datasource.name, - database_name=slc.datasource.database.name, + database_name=slc.datasource.database.database_name, ) copied_dashboard.alter_params(remote_id=dashboard_id) copied_dashboards.append(copied_dashboard) @@ -630,7 +626,7 @@ def export_dashboards(cls, dashboard_ids): db.session, dashboard_type, dashboard_id) eager_datasource.alter_params( remote_id=eager_datasource.id, - database_name=eager_datasource.database.name, + database_name=eager_datasource.database.database_name, ) make_transient(eager_datasource) eager_datasources.append(eager_datasource) @@ -902,168 +898,6 @@ def get_perm(self): sqla.event.listen(Database, 'after_update', set_perm) -class TableColumn(Model, AuditMixinNullable, ImportMixin): - - """ORM object for table columns, each table can have multiple columns""" - - __tablename__ = 'table_columns' - id = Column(Integer, primary_key=True) - table_id = Column(Integer, ForeignKey('tables.id')) - table = relationship( - 'SqlaTable', - backref=backref('columns', cascade='all, delete-orphan'), - foreign_keys=[table_id]) - column_name = Column(String(255)) - verbose_name = Column(String(1024)) - is_dttm = Column(Boolean, default=False) - is_active = Column(Boolean, default=True) - type = Column(String(32), default='') - groupby = Column(Boolean, default=False) - count_distinct = Column(Boolean, default=False) - sum = Column(Boolean, default=False) - avg = Column(Boolean, default=False) - max = Column(Boolean, default=False) - min = Column(Boolean, default=False) - filterable = Column(Boolean, default=False) - expression = Column(Text, default='') - description = Column(Text, default='') - python_date_format = Column(String(255)) - database_expression = Column(String(255)) - - num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG') - date_types = ('DATE', 'TIME') - str_types = ('VARCHAR', 'STRING', 'CHAR') - export_fields = ( - 'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active', - 'type', 'groupby', 'count_distinct', 'sum', 'avg', 'max', 'min', - 'filterable', 'expression', 'description', 'python_date_format', - 'database_expression' - ) - - def __repr__(self): - return self.column_name - - @property - def isnum(self): - return any([t in self.type.upper() for t in self.num_types]) - - @property - def is_time(self): - return any([t in self.type.upper() for t in self.date_types]) - - @property - def is_string(self): - return any([t in self.type.upper() for t in self.str_types]) - - @property - def sqla_col(self): - name = self.column_name - if not self.expression: - col = column(self.column_name).label(name) - else: - col = literal_column(self.expression).label(name) - return col - - def get_time_filter(self, start_dttm, end_dttm): - col = self.sqla_col.label('__time') - return and_( - col >= text(self.dttm_sql_literal(start_dttm)), - col <= text(self.dttm_sql_literal(end_dttm)), - ) - - def get_timestamp_expression(self, time_grain): - """Getting the time component of the query""" - expr = self.expression or self.column_name - if not self.expression and not time_grain: - return column(expr, type_=DateTime).label(DTTM_ALIAS) - if time_grain: - pdf = self.python_date_format - if pdf in ('epoch_s', 'epoch_ms'): - # if epoch, translate to DATE using db specific conf - db_spec = self.table.database.db_engine_spec - if pdf == 'epoch_s': - expr = db_spec.epoch_to_dttm().format(col=expr) - elif pdf == 'epoch_ms': - expr = db_spec.epoch_ms_to_dttm().format(col=expr) - grain = self.table.database.grains_dict().get(time_grain, '{col}') - expr = grain.function.format(col=expr) - return literal_column(expr, type_=DateTime).label(DTTM_ALIAS) - - @classmethod - def import_obj(cls, i_column): - def lookup_obj(lookup_column): - return db.session.query(TableColumn).filter( - TableColumn.table_id == lookup_column.table_id, - TableColumn.column_name == lookup_column.column_name).first() - return import_util.import_simple_obj(db.session, i_column, lookup_obj) - - def dttm_sql_literal(self, dttm): - """Convert datetime object to a SQL expression string - - If database_expression is empty, the internal dttm - will be parsed as the string with the pattern that - the user inputted (python_date_format) - If database_expression is not empty, the internal dttm - will be parsed as the sql sentence for the database to convert - """ - - tf = self.python_date_format or '%Y-%m-%d %H:%M:%S.%f' - if self.database_expression: - return self.database_expression.format(dttm.strftime('%Y-%m-%d %H:%M:%S')) - elif tf == 'epoch_s': - return str((dttm - datetime(1970, 1, 1)).total_seconds()) - elif tf == 'epoch_ms': - return str((dttm - datetime(1970, 1, 1)).total_seconds() * 1000.0) - else: - s = self.table.database.db_engine_spec.convert_dttm( - self.type, dttm) - return s or "'{}'".format(dttm.strftime(tf)) - - -class SqlMetric(Model, AuditMixinNullable, ImportMixin): - - """ORM object for metrics, each table can have multiple metrics""" - - __tablename__ = 'sql_metrics' - id = Column(Integer, primary_key=True) - metric_name = Column(String(512)) - verbose_name = Column(String(1024)) - metric_type = Column(String(32)) - table_id = Column(Integer, ForeignKey('tables.id')) - table = relationship( - 'SqlaTable', - backref=backref('metrics', cascade='all, delete-orphan'), - foreign_keys=[table_id]) - expression = Column(Text) - description = Column(Text) - is_restricted = Column(Boolean, default=False, nullable=True) - d3format = Column(String(128)) - - export_fields = ( - 'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression', - 'description', 'is_restricted', 'd3format') - - @property - def sqla_col(self): - name = self.metric_name - return literal_column(self.expression).label(name) - - @property - def perm(self): - return ( - "{parent_name}.[{obj.metric_name}](id:{obj.id})" - ).format(obj=self, - parent_name=self.table.full_name) if self.table else None - - @classmethod - def import_obj(cls, i_metric): - def lookup_obj(lookup_metric): - return db.session.query(SqlMetric).filter( - SqlMetric.table_id == lookup_metric.table_id, - SqlMetric.metric_name == lookup_metric.metric_name).first() - return import_util.import_simple_obj(db.session, i_metric, lookup_obj) - - class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin): """An ORM object for SqlAlchemy table references""" @@ -1093,8 +927,6 @@ class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin): perm = Column(String(1000)) baselink = "tablemodelview" - column_cls = TableColumn - metric_cls = SqlMetric export_fields = ( 'table_name', 'main_dttm_col', 'description', 'default_endpoint', 'database_id', 'is_featured', 'offset', 'cache_timeout', 'schema', @@ -1533,108 +1365,140 @@ def fetch_metadata(self): self.main_dttm_col = any_date_col @classmethod - def import_obj(cls, i_datasource, import_time=None): + def import_obj(cls, datasource_to_import, import_time=None): """Imports the datasource from the object to the database. Metrics and columns and datasource will be overrided if exists. This function can be used to import/export dashboards between multiple superset instances. Audit metadata isn't copies over. """ - def lookup_sqlatable(table): - return db.session.query(SqlaTable).join(Database).filter( - SqlaTable.table_name == table.table_name, - SqlaTable.schema == table.schema, - Database.id == table.database_id, - ).first() + session = db.session + make_transient(datasource_to_import) + logging.info('Started import of the datasource: {}' + .format(datasource_to_import.to_json())) + + datasource_to_import.id = None + database_name = datasource_to_import.params_dict['database_name'] + datasource_to_import.database_id = session.query(Database).filter_by( + database_name=database_name).one().id + datasource_to_import.alter_params(import_time=import_time) + + # override the datasource + datasource = ( + session.query(SqlaTable).join(Database) + .filter( + SqlaTable.table_name == datasource_to_import.table_name, + SqlaTable.schema == datasource_to_import.schema, + Database.id == datasource_to_import.database_id, + ) + .first() + ) + + if datasource: + datasource.override(datasource_to_import) + session.flush() + else: + datasource = datasource_to_import.copy() + session.add(datasource) + session.flush() - def lookup_database(table): - return db.session.query(Database).filter_by( - database_name=table.params_dict['database_name']).one() - return import_util.import_datasource( - db.session, i_datasource, lookup_database, lookup_sqlatable, - import_time) + for m in datasource_to_import.metrics: + new_m = m.copy() + new_m.table_id = datasource.id + logging.info('Importing metric {} from the datasource: {}'.format( + new_m.to_json(), datasource_to_import.full_name)) + imported_m = SqlMetric.import_obj(new_m) + if imported_m not in datasource.metrics: + datasource.metrics.append(imported_m) + + for c in datasource_to_import.columns: + new_c = c.copy() + new_c.table_id = datasource.id + logging.info('Importing column {} from the datasource: {}'.format( + new_c.to_json(), datasource_to_import.full_name)) + imported_c = TableColumn.import_obj(new_c) + if imported_c not in datasource.columns: + datasource.columns.append(imported_c) + db.session.flush() + + return datasource.id sqla.event.listen(SqlaTable, 'after_insert', set_perm) sqla.event.listen(SqlaTable, 'after_update', set_perm) -class DruidCluster(Model, AuditMixinNullable): - - """ORM object referencing the Druid clusters""" +class SqlMetric(Model, AuditMixinNullable, ImportMixin): - __tablename__ = 'clusters' - type = "druid" + """ORM object for metrics, each table can have multiple metrics""" + __tablename__ = 'sql_metrics' id = Column(Integer, primary_key=True) - cluster_name = Column(String(250), unique=True) - coordinator_host = Column(String(255)) - coordinator_port = Column(Integer) - coordinator_endpoint = Column( - String(255), default='druid/coordinator/v1/metadata') - broker_host = Column(String(255)) - broker_port = Column(Integer) - broker_endpoint = Column(String(255), default='druid/v2') - metadata_last_refreshed = Column(DateTime) - cache_timeout = Column(Integer) + metric_name = Column(String(512)) + verbose_name = Column(String(1024)) + metric_type = Column(String(32)) + table_id = Column(Integer, ForeignKey('tables.id')) + table = relationship( + 'SqlaTable', + backref=backref('metrics', cascade='all, delete-orphan'), + foreign_keys=[table_id]) + expression = Column(Text) + description = Column(Text) + is_restricted = Column(Boolean, default=False, nullable=True) + d3format = Column(String(128)) - def __repr__(self): - return self.cluster_name + export_fields = ( + 'metric_name', 'verbose_name', 'metric_type', 'table_id', 'expression', + 'description', 'is_restricted', 'd3format') - def get_pydruid_client(self): - cli = PyDruid( - "http://{0}:{1}/".format(self.broker_host, self.broker_port), - self.broker_endpoint) - return cli + @property + def sqla_col(self): + name = self.metric_name + return literal_column(self.expression).label(name) - def get_datasources(self): - endpoint = ( - "http://{obj.coordinator_host}:{obj.coordinator_port}/" - "{obj.coordinator_endpoint}/datasources" - ).format(obj=self) + @property + def perm(self): + return ( + "{parent_name}.[{obj.metric_name}](id:{obj.id})" + ).format(obj=self, + parent_name=self.table.full_name) if self.table else None - return json.loads(requests.get(endpoint).text) + @classmethod + def import_obj(cls, metric_to_import): + session = db.session + make_transient(metric_to_import) + metric_to_import.id = None + + # find if the column was already imported + existing_metric = session.query(SqlMetric).filter( + SqlMetric.table_id == metric_to_import.table_id, + SqlMetric.metric_name == metric_to_import.metric_name).first() + metric_to_import.table = None + if existing_metric: + existing_metric.override(metric_to_import) + session.flush() + return existing_metric - def get_druid_version(self): - endpoint = ( - "http://{obj.coordinator_host}:{obj.coordinator_port}/status" - ).format(obj=self) - return json.loads(requests.get(endpoint).text)['version'] + session.add(metric_to_import) + session.flush() + return metric_to_import - def refresh_datasources(self, datasource_name=None, merge_flag=False): - """Refresh metadata of all datasources in the cluster - If ``datasource_name`` is specified, only that datasource is updated - """ - self.druid_version = self.get_druid_version() - for datasource in self.get_datasources(): - if datasource not in config.get('DRUID_DATA_SOURCE_BLACKLIST'): - if not datasource_name or datasource_name == datasource: - DruidDatasource.sync_to_db(datasource, self, merge_flag) - - @property - def perm(self): - return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self) - - @property - def name(self): - return self.cluster_name +class TableColumn(Model, AuditMixinNullable, ImportMixin): -class DruidColumn(Model, AuditMixinNullable, ImportMixin): - """ORM model for storing Druid datasource column metadata""" + """ORM object for table columns, each table can have multiple columns""" - __tablename__ = 'columns' + __tablename__ = 'table_columns' id = Column(Integer, primary_key=True) - datasource_name = Column( - String(255), - ForeignKey('datasources.datasource_name')) - # Setting enable_typechecks=False disables polymorphic inheritance. - datasource = relationship( - 'DruidDatasource', + table_id = Column(Integer, ForeignKey('tables.id')) + table = relationship( + 'SqlaTable', backref=backref('columns', cascade='all, delete-orphan'), - enable_typechecks=False) + foreign_keys=[table_id]) column_name = Column(String(255)) + verbose_name = Column(String(1024)) + is_dttm = Column(Boolean, default=False) is_active = Column(Boolean, default=True) - type = Column(String(32)) + type = Column(String(32), default='') groupby = Column(Boolean, default=False) count_distinct = Column(Boolean, default=False) sum = Column(Boolean, default=False) @@ -1642,13 +1506,19 @@ class DruidColumn(Model, AuditMixinNullable, ImportMixin): max = Column(Boolean, default=False) min = Column(Boolean, default=False) filterable = Column(Boolean, default=False) - description = Column(Text) - dimension_spec_json = Column(Text) + expression = Column(Text, default='') + description = Column(Text, default='') + python_date_format = Column(String(255)) + database_expression = Column(String(255)) + num_types = ('DOUBLE', 'FLOAT', 'INT', 'BIGINT', 'LONG') + date_types = ('DATE', 'TIME') + str_types = ('VARCHAR', 'STRING', 'CHAR') export_fields = ( - 'datasource_name', 'column_name', 'is_active', 'type', 'groupby', - 'count_distinct', 'sum', 'avg', 'max', 'min', 'filterable', - 'description', 'dimension_spec_json' + 'table_id', 'column_name', 'verbose_name', 'is_dttm', 'is_active', + 'type', 'groupby', 'count_distinct', 'sum', 'avg', 'max', 'min', + 'filterable', 'expression', 'description', 'python_date_format', + 'database_expression' ) def __repr__(self): @@ -1656,142 +1526,135 @@ def __repr__(self): @property def isnum(self): - return self.type in ('LONG', 'DOUBLE', 'FLOAT', 'INT') + return any([t in self.type.upper() for t in self.num_types]) @property - def dimension_spec(self): - if self.dimension_spec_json: - return json.loads(self.dimension_spec_json) + def is_time(self): + return any([t in self.type.upper() for t in self.date_types]) - def generate_metrics(self): - """Generate metrics based on the column metadata""" - M = DruidMetric # noqa - metrics = [] - metrics.append(DruidMetric( - metric_name='count', - verbose_name='COUNT(*)', - metric_type='count', - json=json.dumps({'type': 'count', 'name': 'count'}) - )) - # Somehow we need to reassign this for UDAFs - if self.type in ('DOUBLE', 'FLOAT'): - corrected_type = 'DOUBLE' - else: - corrected_type = self.type + @property + def is_string(self): + return any([t in self.type.upper() for t in self.str_types]) - if self.sum and self.isnum: - mt = corrected_type.lower() + 'Sum' - name = 'sum__' + self.column_name - metrics.append(DruidMetric( - metric_name=name, - metric_type='sum', - verbose_name='SUM({})'.format(self.column_name), - json=json.dumps({ - 'type': mt, 'name': name, 'fieldName': self.column_name}) - )) + @property + def sqla_col(self): + name = self.column_name + if not self.expression: + col = column(self.column_name).label(name) + else: + col = literal_column(self.expression).label(name) + return col - if self.avg and self.isnum: - mt = corrected_type.lower() + 'Avg' - name = 'avg__' + self.column_name - metrics.append(DruidMetric( - metric_name=name, - metric_type='avg', - verbose_name='AVG({})'.format(self.column_name), - json=json.dumps({ - 'type': mt, 'name': name, 'fieldName': self.column_name}) - )) + def get_time_filter(self, start_dttm, end_dttm): + col = self.sqla_col.label('__time') + return and_( + col >= text(self.dttm_sql_literal(start_dttm)), + col <= text(self.dttm_sql_literal(end_dttm)), + ) - if self.min and self.isnum: - mt = corrected_type.lower() + 'Min' - name = 'min__' + self.column_name - metrics.append(DruidMetric( - metric_name=name, - metric_type='min', - verbose_name='MIN({})'.format(self.column_name), - json=json.dumps({ - 'type': mt, 'name': name, 'fieldName': self.column_name}) - )) - if self.max and self.isnum: - mt = corrected_type.lower() + 'Max' - name = 'max__' + self.column_name - metrics.append(DruidMetric( - metric_name=name, - metric_type='max', - verbose_name='MAX({})'.format(self.column_name), - json=json.dumps({ - 'type': mt, 'name': name, 'fieldName': self.column_name}) - )) - if self.count_distinct: - name = 'count_distinct__' + self.column_name - if self.type == 'hyperUnique' or self.type == 'thetaSketch': - metrics.append(DruidMetric( - metric_name=name, - verbose_name='COUNT(DISTINCT {})'.format(self.column_name), - metric_type=self.type, - json=json.dumps({ - 'type': self.type, - 'name': name, - 'fieldName': self.column_name - }) - )) - else: - mt = 'count_distinct' - metrics.append(DruidMetric( - metric_name=name, - verbose_name='COUNT(DISTINCT {})'.format(self.column_name), - metric_type='count_distinct', - json=json.dumps({ - 'type': 'cardinality', - 'name': name, - 'fieldNames': [self.column_name]}) - )) - session = get_session() - new_metrics = [] - for metric in metrics: - m = ( - session.query(M) - .filter(M.metric_name == metric.metric_name) - .filter(M.datasource_name == self.datasource_name) - .filter(DruidCluster.cluster_name == self.datasource.cluster_name) - .first() - ) - metric.datasource_name = self.datasource_name - if not m: - new_metrics.append(metric) - session.add(metric) - session.flush() + def get_timestamp_expression(self, time_grain): + """Getting the time component of the query""" + expr = self.expression or self.column_name + if not self.expression and not time_grain: + return column(expr, type_=DateTime).label(DTTM_ALIAS) + if time_grain: + pdf = self.python_date_format + if pdf in ('epoch_s', 'epoch_ms'): + # if epoch, translate to DATE using db specific conf + db_spec = self.table.database.db_engine_spec + if pdf == 'epoch_s': + expr = db_spec.epoch_to_dttm().format(col=expr) + elif pdf == 'epoch_ms': + expr = db_spec.epoch_ms_to_dttm().format(col=expr) + grain = self.table.database.grains_dict().get(time_grain, '{col}') + expr = grain.function.format(col=expr) + return literal_column(expr, type_=DateTime).label(DTTM_ALIAS) @classmethod - def import_obj(cls, i_column): - def lookup_obj(lookup_column): - return db.session.query(DruidColumn).filter( - DruidColumn.datasource_name == lookup_column.datasource_name, - DruidColumn.column_name == lookup_column.column_name).first() + def import_obj(cls, column_to_import): + session = db.session + make_transient(column_to_import) + column_to_import.id = None + column_to_import.table = None + + # find if the column was already imported + existing_column = session.query(TableColumn).filter( + TableColumn.table_id == column_to_import.table_id, + TableColumn.column_name == column_to_import.column_name).first() + column_to_import.table = None + if existing_column: + existing_column.override(column_to_import) + session.flush() + return existing_column - return import_util.import_simple_obj(db.session, i_column, lookup_obj) + session.add(column_to_import) + session.flush() + return column_to_import + def dttm_sql_literal(self, dttm): + """Convert datetime object to a SQL expression string -class DruidMetric(Model, AuditMixinNullable, ImportMixin): + If database_expression is empty, the internal dttm + will be parsed as the string with the pattern that + the user inputted (python_date_format) + If database_expression is not empty, the internal dttm + will be parsed as the sql sentence for the database to convert + """ - """ORM object referencing Druid metrics for a datasource""" + tf = self.python_date_format or '%Y-%m-%d %H:%M:%S.%f' + if self.database_expression: + return self.database_expression.format(dttm.strftime('%Y-%m-%d %H:%M:%S')) + elif tf == 'epoch_s': + return str((dttm - datetime(1970, 1, 1)).total_seconds()) + elif tf == 'epoch_ms': + return str((dttm - datetime(1970, 1, 1)).total_seconds() * 1000.0) + else: + s = self.table.database.db_engine_spec.convert_dttm( + self.type, dttm) + return s or "'{}'".format(dttm.strftime(tf)) + + +class DruidCluster(Model, AuditMixinNullable): + + """ORM object referencing the Druid clusters""" + + __tablename__ = 'clusters' + type = "druid" - __tablename__ = 'metrics' id = Column(Integer, primary_key=True) - metric_name = Column(String(512)) - verbose_name = Column(String(1024)) - metric_type = Column(String(32)) - datasource_name = Column( - String(255), - ForeignKey('datasources.datasource_name')) - # Setting enable_typechecks=False disables polymorphic inheritance. - datasource = relationship( - 'DruidDatasource', - backref=backref('metrics', cascade='all, delete-orphan'), - enable_typechecks=False) - json = Column(Text) - description = Column(Text) - is_restricted = Column(Boolean, default=False, nullable=True) - d3format = Column(String(128)) + cluster_name = Column(String(250), unique=True) + coordinator_host = Column(String(255)) + coordinator_port = Column(Integer) + coordinator_endpoint = Column( + String(255), default='druid/coordinator/v1/metadata') + broker_host = Column(String(255)) + broker_port = Column(Integer) + broker_endpoint = Column(String(255), default='druid/v2') + metadata_last_refreshed = Column(DateTime) + cache_timeout = Column(Integer) + + def __repr__(self): + return self.cluster_name + + def get_pydruid_client(self): + cli = PyDruid( + "http://{0}:{1}/".format(self.broker_host, self.broker_port), + self.broker_endpoint) + return cli + + def get_datasources(self): + endpoint = ( + "http://{obj.coordinator_host}:{obj.coordinator_port}/" + "{obj.coordinator_endpoint}/datasources" + ).format(obj=self) + + return json.loads(requests.get(endpoint).text) + + def get_druid_version(self): + endpoint = ( + "http://{obj.coordinator_host}:{obj.coordinator_port}/status" + ).format(obj=self) + return json.loads(requests.get(endpoint).text)['version'] def refresh_datasources(self, datasource_name=None, merge_flag=False): """Refresh metadata of all datasources in the cluster @@ -1803,37 +1666,17 @@ def refresh_datasources(self, datasource_name=None, merge_flag=False): if datasource not in config.get('DRUID_DATA_SOURCE_BLACKLIST'): if not datasource_name or datasource_name == datasource: DruidDatasource.sync_to_db(datasource, self, merge_flag) - export_fields = ( - 'metric_name', 'verbose_name', 'metric_type', 'datasource_name', - 'json', 'description', 'is_restricted', 'd3format' - ) - - @property - def json_obj(self): - try: - obj = json.loads(self.json) - except Exception: - obj = {} - return obj @property def perm(self): - return ( - "{parent_name}.[{obj.metric_name}](id:{obj.id})" - ).format(obj=self, - parent_name=self.datasource.full_name - ) if self.datasource else None + return "[{obj.cluster_name}].(id:{obj.id})".format(obj=self) - @classmethod - def import_obj(cls, i_metric): - def lookup_obj(lookup_metric): - return db.session.query(DruidMetric).filter( - DruidMetric.datasource_name == lookup_metric.datasource_name, - DruidMetric.metric_name == lookup_metric.metric_name).first() - return import_util.import_simple_obj(db.session, i_metric, lookup_obj) + @property + def name(self): + return self.cluster_name -class DruidDatasource(Model, AuditMixinNullable, Queryable, ImportMixin): +class DruidDatasource(Model, AuditMixinNullable, Queryable): """ORM object referencing Druid datasources (tables)""" @@ -1860,17 +1703,8 @@ class DruidDatasource(Model, AuditMixinNullable, Queryable, ImportMixin): 'DruidCluster', backref='datasources', foreign_keys=[cluster_name]) offset = Column(Integer, default=0) cache_timeout = Column(Integer) - params = Column(String(1000)) perm = Column(String(1000)) - metric_cls = DruidMetric - column_cls = DruidColumn - - export_fields = ( - 'datasource_name', 'is_hidden', 'description', 'default_endpoint', - 'cluster_name', 'is_featured', 'offset', 'cache_timeout', 'params' - ) - @property def database(self): return self.cluster @@ -1948,27 +1782,6 @@ def get_metric_obj(self, metric_name): if m.metric_name == metric_name ][0] - @classmethod - def import_obj(cls, i_datasource, import_time=None): - """Imports the datasource from the object to the database. - - Metrics and columns and datasource will be overrided if exists. - This function can be used to import/export dashboards between multiple - superset instances. Audit metadata isn't copies over. - """ - def lookup_datasource(d): - return db.session.query(DruidDatasource).join(DruidCluster).filter( - DruidDatasource.datasource_name == d.datasource_name, - DruidCluster.cluster_name == d.cluster_name, - ).first() - - def lookup_cluster(d): - return db.session.query(DruidCluster).filter_by( - cluster_name=d.cluster_name).one() - return import_util.import_datasource( - db.session, i_datasource, lookup_cluster, lookup_datasource, - import_time) - @staticmethod def version_higher(v1, v2): """is v1 higher than v2 @@ -2602,6 +2415,183 @@ def wrapper(*args, **kwargs): return wrapper +class DruidMetric(Model, AuditMixinNullable): + + """ORM object referencing Druid metrics for a datasource""" + + __tablename__ = 'metrics' + id = Column(Integer, primary_key=True) + metric_name = Column(String(512)) + verbose_name = Column(String(1024)) + metric_type = Column(String(32)) + datasource_name = Column( + String(255), + ForeignKey('datasources.datasource_name')) + # Setting enable_typechecks=False disables polymorphic inheritance. + datasource = relationship( + 'DruidDatasource', + backref=backref('metrics', cascade='all, delete-orphan'), + enable_typechecks=False) + json = Column(Text) + description = Column(Text) + is_restricted = Column(Boolean, default=False, nullable=True) + d3format = Column(String(128)) + + @property + def json_obj(self): + try: + obj = json.loads(self.json) + except Exception: + obj = {} + return obj + + @property + def perm(self): + return ( + "{parent_name}.[{obj.metric_name}](id:{obj.id})" + ).format(obj=self, + parent_name=self.datasource.full_name + ) if self.datasource else None + + +class DruidColumn(Model, AuditMixinNullable): + + """ORM model for storing Druid datasource column metadata""" + + __tablename__ = 'columns' + id = Column(Integer, primary_key=True) + datasource_name = Column( + String(255), + ForeignKey('datasources.datasource_name')) + # Setting enable_typechecks=False disables polymorphic inheritance. + datasource = relationship( + 'DruidDatasource', + backref=backref('columns', cascade='all, delete-orphan'), + enable_typechecks=False) + column_name = Column(String(255)) + is_active = Column(Boolean, default=True) + type = Column(String(32)) + groupby = Column(Boolean, default=False) + count_distinct = Column(Boolean, default=False) + sum = Column(Boolean, default=False) + avg = Column(Boolean, default=False) + max = Column(Boolean, default=False) + min = Column(Boolean, default=False) + filterable = Column(Boolean, default=False) + description = Column(Text) + dimension_spec_json = Column(Text) + + def __repr__(self): + return self.column_name + + @property + def isnum(self): + return self.type in ('LONG', 'DOUBLE', 'FLOAT', 'INT') + + @property + def dimension_spec(self): + if self.dimension_spec_json: + return json.loads(self.dimension_spec_json) + + def generate_metrics(self): + """Generate metrics based on the column metadata""" + M = DruidMetric # noqa + metrics = [] + metrics.append(DruidMetric( + metric_name='count', + verbose_name='COUNT(*)', + metric_type='count', + json=json.dumps({'type': 'count', 'name': 'count'}) + )) + # Somehow we need to reassign this for UDAFs + if self.type in ('DOUBLE', 'FLOAT'): + corrected_type = 'DOUBLE' + else: + corrected_type = self.type + + if self.sum and self.isnum: + mt = corrected_type.lower() + 'Sum' + name = 'sum__' + self.column_name + metrics.append(DruidMetric( + metric_name=name, + metric_type='sum', + verbose_name='SUM({})'.format(self.column_name), + json=json.dumps({ + 'type': mt, 'name': name, 'fieldName': self.column_name}) + )) + + if self.avg and self.isnum: + mt = corrected_type.lower() + 'Avg' + name = 'avg__' + self.column_name + metrics.append(DruidMetric( + metric_name=name, + metric_type='avg', + verbose_name='AVG({})'.format(self.column_name), + json=json.dumps({ + 'type': mt, 'name': name, 'fieldName': self.column_name}) + )) + + if self.min and self.isnum: + mt = corrected_type.lower() + 'Min' + name = 'min__' + self.column_name + metrics.append(DruidMetric( + metric_name=name, + metric_type='min', + verbose_name='MIN({})'.format(self.column_name), + json=json.dumps({ + 'type': mt, 'name': name, 'fieldName': self.column_name}) + )) + if self.max and self.isnum: + mt = corrected_type.lower() + 'Max' + name = 'max__' + self.column_name + metrics.append(DruidMetric( + metric_name=name, + metric_type='max', + verbose_name='MAX({})'.format(self.column_name), + json=json.dumps({ + 'type': mt, 'name': name, 'fieldName': self.column_name}) + )) + if self.count_distinct: + name = 'count_distinct__' + self.column_name + if self.type == 'hyperUnique' or self.type == 'thetaSketch': + metrics.append(DruidMetric( + metric_name=name, + verbose_name='COUNT(DISTINCT {})'.format(self.column_name), + metric_type=self.type, + json=json.dumps({ + 'type': self.type, + 'name': name, + 'fieldName': self.column_name + }) + )) + else: + mt = 'count_distinct' + metrics.append(DruidMetric( + metric_name=name, + verbose_name='COUNT(DISTINCT {})'.format(self.column_name), + metric_type='count_distinct', + json=json.dumps({ + 'type': 'cardinality', + 'name': name, + 'fieldNames': [self.column_name]}) + )) + session = get_session() + new_metrics = [] + for metric in metrics: + m = ( + session.query(M) + .filter(M.metric_name == metric.metric_name) + .filter(M.datasource_name == self.datasource_name) + .filter(DruidCluster.cluster_name == self.datasource.cluster_name) + .first() + ) + metric.datasource_name = self.datasource_name + if not m: + new_metrics.append(metric) + session.add(metric) + session.flush() + + class FavStar(Model): __tablename__ = 'favstar' diff --git a/superset/source_registry.py b/superset/source_registry.py index 012d15105347..2c72157ebf0b 100644 --- a/superset/source_registry.py +++ b/superset/source_registry.py @@ -36,10 +36,7 @@ def get_datasource_by_name(cls, session, datasource_type, datasource_name, schema, database_name): datasource_class = SourceRegistry.sources[datasource_type] datasources = session.query(datasource_class).all() - - # Filter datasoures that don't have database. - db_ds = [d for d in datasources if d.database and - d.database.name == database_name and + db_ds = [d for d in datasources if d.database.name == database_name and d.name == datasource_name and schema == schema] return db_ds[0] @@ -68,12 +65,16 @@ def query_datasources_by_name( def get_eager_datasource(cls, session, datasource_type, datasource_id): """Returns datasource with columns and metrics.""" datasource_class = SourceRegistry.sources[datasource_type] - return ( - session.query(datasource_class) - .options( - subqueryload(datasource_class.columns), - subqueryload(datasource_class.metrics) + if datasource_type == 'table': + return ( + session.query(datasource_class) + .options( + subqueryload(datasource_class.columns), + subqueryload(datasource_class.metrics) + ) + .filter_by(id=datasource_id) + .one() ) - .filter_by(id=datasource_id) - .one() - ) + # TODO: support druid datasources. + return session.query(datasource_class).filter_by( + id=datasource_id).first() diff --git a/superset/views.py b/superset/views.py index f5ddf7ff77d2..6cb206bfff7f 100755 --- a/superset/views.py +++ b/superset/views.py @@ -1418,14 +1418,8 @@ def import_dashboards(self): if request.method == 'POST' and f: current_tt = int(time.time()) data = pickle.load(f) - # TODO: import DRUID datasources for table in data['datasources']: - if table.type == 'table': - models.SqlaTable.import_obj(table, import_time=current_tt) - else: - models.DruidDatasource.import_obj( - table, import_time=current_tt) - db.session.commit() + models.SqlaTable.import_obj(table, import_time=current_tt) for dashboard in data['dashboards']: models.Dashboard.import_obj( dashboard, import_time=current_tt) diff --git a/tests/import_export_tests.py b/tests/import_export_tests.py index 6c7d79753729..3201ce9d3478 100644 --- a/tests/import_export_tests.py +++ b/tests/import_export_tests.py @@ -34,9 +34,6 @@ def delete_imports(cls): for table in session.query(models.SqlaTable): if 'remote_id' in table.params_dict: session.delete(table) - for datasource in session.query(models.DruidDatasource): - if 'remote_id' in datasource.params_dict: - session.delete(datasource) session.commit() @classmethod @@ -55,11 +52,6 @@ def create_slice(self, name, ds_id=None, id=None, db_name='main', 'datasource_name': table_name, 'database_name': db_name, 'schema': '', - # Test for trailing commas - "metrics": [ - "sum__signup_attempt_email", - "sum__signup_attempt_facebook", - ], } if table_name and not ds_id: @@ -87,8 +79,7 @@ def create_dashboard(self, title, id=0, slcs=[]): json_metadata=json.dumps(json_metadata) ) - def create_table( - self, name, schema='', id=0, cols_names=[], metric_names=[]): + def create_table(self, name, schema='', id=0, cols_names=[], metric_names=[]): params = {'remote_id': id, 'database_name': 'main'} table = models.SqlaTable( id=id, @@ -103,23 +94,6 @@ def create_table( table.metrics.append(models.SqlMetric(metric_name=metric_name)) return table - def create_druid_datasource( - self, name, id=0, cols_names=[], metric_names=[]): - params = {'remote_id': id, 'database_name': 'druid_test'} - datasource = models.DruidDatasource( - id=id, - datasource_name=name, - cluster_name='druid_test', - params=json.dumps(params) - ) - for col_name in cols_names: - datasource.columns.append( - models.DruidColumn(column_name=col_name)) - for metric_name in metric_names: - datasource.metrics.append(models.DruidMetric( - metric_name=metric_name)) - return datasource - def get_slice(self, slc_id): return db.session.query(models.Slice).filter_by(id=slc_id).first() @@ -139,10 +113,6 @@ def get_table(self, table_id): return db.session.query(models.SqlaTable).filter_by( id=table_id).first() - def get_datasource(self, datasource_id): - return db.session.query(models.DruidDatasource).filter_by( - id=datasource_id).first() - def get_table_by_name(self, name): return db.session.query(models.SqlaTable).filter_by( table_name=name).first() @@ -177,19 +147,6 @@ def assert_table_equals(self, expected_ds, actual_ds): set([m.metric_name for m in expected_ds.metrics]), set([m.metric_name for m in actual_ds.metrics])) - def assert_datasource_equals(self, expected_ds, actual_ds): - self.assertEquals( - expected_ds.datasource_name, actual_ds.datasource_name) - self.assertEquals(expected_ds.main_dttm_col, actual_ds.main_dttm_col) - self.assertEquals(len(expected_ds.metrics), len(actual_ds.metrics)) - self.assertEquals(len(expected_ds.columns), len(actual_ds.columns)) - self.assertEquals( - set([c.column_name for c in expected_ds.columns]), - set([c.column_name for c in actual_ds.columns])) - self.assertEquals( - set([m.metric_name for m in expected_ds.metrics]), - set([m.metric_name for m in actual_ds.metrics])) - def assert_slice_equals(self, expected_slc, actual_slc): self.assertEquals(expected_slc.slice_name, actual_slc.slice_name) self.assertEquals( @@ -396,131 +353,63 @@ def test_import_override_dashboard_2_slices(self): def test_import_table_no_metadata(self): table = self.create_table('pure_table', id=10001) - imported_id = models.SqlaTable.import_obj(table, import_time=1989) - imported = self.get_table(imported_id) - self.assert_table_equals(table, imported) + imported_t_id = models.SqlaTable.import_obj(table, import_time=1989) + imported_table = self.get_table(imported_t_id) + self.assert_table_equals(table, imported_table) def test_import_table_1_col_1_met(self): table = self.create_table( 'table_1_col_1_met', id=10002, cols_names=["col1"], metric_names=["metric1"]) - imported_id = models.SqlaTable.import_obj(table, import_time=1990) - imported = self.get_table(imported_id) - self.assert_table_equals(table, imported) + imported_t_id = models.SqlaTable.import_obj(table, import_time=1990) + imported_table = self.get_table(imported_t_id) + self.assert_table_equals(table, imported_table) self.assertEquals( {'remote_id': 10002, 'import_time': 1990, 'database_name': 'main'}, - json.loads(imported.params)) + json.loads(imported_table.params)) def test_import_table_2_col_2_met(self): table = self.create_table( 'table_2_col_2_met', id=10003, cols_names=['c1', 'c2'], metric_names=['m1', 'm2']) - imported_id = models.SqlaTable.import_obj(table, import_time=1991) + imported_t_id = models.SqlaTable.import_obj(table, import_time=1991) - imported = self.get_table(imported_id) - self.assert_table_equals(table, imported) + imported_table = self.get_table(imported_t_id) + self.assert_table_equals(table, imported_table) def test_import_table_override(self): table = self.create_table( 'table_override', id=10003, cols_names=['col1'], metric_names=['m1']) - imported_id = models.SqlaTable.import_obj(table, import_time=1991) + imported_t_id = models.SqlaTable.import_obj(table, import_time=1991) table_over = self.create_table( 'table_override', id=10003, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) - imported_over_id = models.SqlaTable.import_obj( + imported_table_over_id = models.SqlaTable.import_obj( table_over, import_time=1992) - imported_over = self.get_table(imported_over_id) - self.assertEquals(imported_id, imported_over.id) + imported_table_over = self.get_table(imported_table_over_id) + self.assertEquals(imported_t_id, imported_table_over.id) expected_table = self.create_table( 'table_override', id=10003, metric_names=['new_metric1', 'm1'], cols_names=['col1', 'new_col1', 'col2', 'col3']) - self.assert_table_equals(expected_table, imported_over) + self.assert_table_equals(expected_table, imported_table_over) def test_import_table_override_idential(self): table = self.create_table( 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) - imported_id = models.SqlaTable.import_obj(table, import_time=1993) + imported_t_id = models.SqlaTable.import_obj(table, import_time=1993) copy_table = self.create_table( 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], metric_names=['new_metric1']) - imported_id_copy = models.SqlaTable.import_obj( + imported_t_id_copy = models.SqlaTable.import_obj( copy_table, import_time=1994) - self.assertEquals(imported_id, imported_id_copy) - self.assert_table_equals(copy_table, self.get_table(imported_id)) - - def test_import_druid_no_metadata(self): - datasource = self.create_druid_datasource('pure_druid', id=10001) - imported_id = models.DruidDatasource.import_obj( - datasource, import_time=1989) - imported = self.get_datasource(imported_id) - self.assert_datasource_equals(datasource, imported) - - def test_import_druid_1_col_1_met(self): - datasource = self.create_druid_datasource( - 'druid_1_col_1_met', id=10002, - cols_names=["col1"], metric_names=["metric1"]) - imported_id = models.DruidDatasource.import_obj( - datasource, import_time=1990) - imported = self.get_datasource(imported_id) - self.assert_datasource_equals(datasource, imported) - self.assertEquals( - {'remote_id': 10002, 'import_time': 1990, - 'database_name': 'druid_test'}, - json.loads(imported.params)) - - def test_import_druid_2_col_2_met(self): - datasource = self.create_druid_datasource( - 'druid_2_col_2_met', id=10003, cols_names=['c1', 'c2'], - metric_names=['m1', 'm2']) - imported_id = models.DruidDatasource.import_obj( - datasource, import_time=1991) - imported = self.get_datasource(imported_id) - self.assert_datasource_equals(datasource, imported) - - def test_import_druid_override(self): - datasource = self.create_druid_datasource( - 'druid_override', id=10003, cols_names=['col1'], - metric_names=['m1']) - imported_id = models.DruidDatasource.import_obj( - datasource, import_time=1991) - - table_over = self.create_druid_datasource( - 'druid_override', id=10003, - cols_names=['new_col1', 'col2', 'col3'], - metric_names=['new_metric1']) - imported_over_id = models.DruidDatasource.import_obj( - table_over, import_time=1992) - - imported_over = self.get_datasource(imported_over_id) - self.assertEquals(imported_id, imported_over.id) - expected_datasource = self.create_druid_datasource( - 'druid_override', id=10003, metric_names=['new_metric1', 'm1'], - cols_names=['col1', 'new_col1', 'col2', 'col3']) - self.assert_datasource_equals(expected_datasource, imported_over) - - def test_import_druid_override_idential(self): - datasource = self.create_druid_datasource( - 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], - metric_names=['new_metric1']) - imported_id = models.DruidDatasource.import_obj( - datasource, import_time=1993) - - copy_datasource = self.create_druid_datasource( - 'copy_cat', id=10004, cols_names=['new_col1', 'col2', 'col3'], - metric_names=['new_metric1']) - imported_id_copy = models.DruidDatasource.import_obj( - copy_datasource, import_time=1994) - - self.assertEquals(imported_id, imported_id_copy) - self.assert_datasource_equals( - copy_datasource, self.get_datasource(imported_id)) - + self.assertEquals(imported_t_id, imported_t_id_copy) + self.assert_table_equals(copy_table, self.get_table(imported_t_id)) if __name__ == '__main__': unittest.main()