diff --git a/caravel/__init__.py b/caravel/__init__.py index f06729d7a47c..1fcfe43acb99 100644 --- a/caravel/__init__.py +++ b/caravel/__init__.py @@ -14,7 +14,7 @@ from flask_appbuilder.baseviews import expose from flask_cache import Cache from flask_migrate import Migrate -from caravel import source_registry +from caravel.source_registry import SourceRegistry from werkzeug.contrib.fixers import ProxyFix @@ -96,7 +96,11 @@ def index(self): sm = appbuilder.sm -src_registry = source_registry.SourceRegistry() - get_session = appbuilder.get_session + +# Registering sources +module_datasource_map = app.config.get("DEFAULT_MODULE_DS_MAP") +module_datasource_map.update(app.config.get("ADDITIONAL_MODULE_DS_MAP")) +SourceRegistry.register_sources(module_datasource_map) + from caravel import views, config # noqa diff --git a/caravel/bin/caravel b/caravel/bin/caravel index bf09f4d61c47..a582c3e5ecb6 100755 --- a/caravel/bin/caravel +++ b/caravel/bin/caravel @@ -20,14 +20,6 @@ config = app.config manager = Manager(app) manager.add_command('db', MigrateCommand) -module_datasource_map = config.get("DEFAULT_MODULE_DS_MAP") -module_datasource_map.update(config.get("ADDITIONAL_MODULE_DS_MAP")) - -datasources = {} -for module in module_datasource_map: - datasources[module] = __import__(module, fromlist=module_datasource_map[module]) - -utils.register_sources(datasources, module_datasource_map, caravel.src_registry) @manager.option( diff --git a/caravel/models.py b/caravel/models.py index 38a457b98d4d..a04a4e96da16 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -26,7 +26,6 @@ from flask_appbuilder import Model from flask_appbuilder.models.mixins import AuditMixin from flask_appbuilder.models.decorators import renders -from flask_appbuilder.security.sqla.models import Role, PermissionView from flask_babel import lazy_gettext as _ from pydruid.client import PyDruid @@ -50,7 +49,8 @@ from werkzeug.datastructures import ImmutableMultiDict import caravel -from caravel import app, db, get_session, utils, sm, src_registry +from caravel import app, db, get_session, utils, sm +from caravel.source_registry import SourceRegistry from caravel.viz import viz_types from caravel.utils import flasher, MetricPermException, DimSelector @@ -172,7 +172,7 @@ def __repr__(self): @property def cls_model(self): - return src_registry.sources[self.datasource_type] + return SourceRegistry.sources[self.datasource_type] @property def datasource(self): @@ -2028,7 +2028,7 @@ class DatasourceAccessRequest(Model, AuditMixinNullable): @property def cls_model(self): - return src_registry.sources[self.datasource_type] + return SourceRegistry.sources[self.datasource_type] @property def username(self): diff --git a/caravel/source_registry.py b/caravel/source_registry.py index 5e253aa3d337..04b3f17c28cc 100644 --- a/caravel/source_registry.py +++ b/caravel/source_registry.py @@ -1,4 +1,3 @@ -from flask import flash class SourceRegistry(object): @@ -6,10 +5,10 @@ class SourceRegistry(object): sources = {} - def add_source(self, ds_type, cls_model): - if ds_type not in self.sources: - self.sources[ds_type] = cls_model - if self.sources[ds_type] is not cls_model: - raise Exception( - 'source type: {} is already associated with Model: {}'.format( - ds_type, self.sources[ds_type])) + @classmethod + def register_sources(cls, datasource_config): + for module_name, class_names in datasource_config.items(): + module_obj = __import__(module_name, fromlist=class_names) + for class_name in class_names: + source_class = getattr(module_obj, class_name) + cls.sources[source_class.type] = source_class diff --git a/caravel/utils.py b/caravel/utils.py index c80b94fc813d..15a168b99ada 100644 --- a/caravel/utils.py +++ b/caravel/utils.py @@ -426,14 +426,6 @@ def readfile(filepath): return content -def register_sources(datasources, module_datasource_map, registry): - for m in datasources: - datasource_list = module_datasource_map[m] - for ds in datasource_list: - ds_class = getattr(datasources[m], ds) - registry.add_source(ds_class.type, ds_class) - - def generic_find_constraint_name(table, columns, referenced, db): """Utility to find a constraint name in alembic migrations""" t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine) diff --git a/caravel/views.py b/caravel/views.py index 46037cb4699b..3fc8fd9ffe58 100755 --- a/caravel/views.py +++ b/caravel/views.py @@ -34,8 +34,9 @@ import caravel from caravel import ( appbuilder, cache, db, models, viz, utils, app, - sm, ascii_art, sql_lab, src_registry + sm, ascii_art, sql_lab ) +from caravel.source_registry import SourceRegistry from caravel.models import DatasourceAccessRequest as DAR config = app.config @@ -748,9 +749,9 @@ def add(self): if not widget: return redirect(self.get_redirect()) - sources = src_registry.sources + sources = SourceRegistry.sources for source in sources: - ds = db.session.query(src_registry.sources[source]).first() + ds = db.session.query(SourceRegistry.sources[source]).first() if ds is not None: url = "/{}/list/".format(ds.baselink) msg = _("Click on a {} link to create a Slice".format(source)) @@ -1054,7 +1055,7 @@ def approve(self): role_to_extend = request.args.get('role_to_extend') session = db.session - datasource_class = src_registry.sources[datasource_type] + datasource_class = SourceRegistry.sources[datasource_type] datasource = session.query(datasource_class).filter_by( id=datasource_id).first() @@ -1117,7 +1118,7 @@ def approve(self): @log_this def explore(self, datasource_type, datasource_id, slice_id=None): error_redirect = '/slicemodelview/list/' - datasource_class = src_registry.sources[datasource_type] + datasource_class = SourceRegistry.sources[datasource_type] datasources = db.session.query(datasource_class).all() datasources = sorted(datasources, key=lambda ds: ds.full_name) datasource = [ds for ds in datasources if int(datasource_id) == ds.id] diff --git a/tests/core_tests.py b/tests/core_tests.py index 1297cfb3088a..bb1ace23a8cd 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -17,7 +17,8 @@ from flask_appbuilder.security.sqla import models as ab_models import caravel -from caravel import app, db, models, utils, appbuilder, sm, src_registry +from caravel import app, db, models, utils, appbuilder, sm +from caravel.source_registry import SourceRegistry from caravel.models import DruidDatasource from .base_tests import CaravelTestCase @@ -243,7 +244,7 @@ def test_approve(self): self.login('admin') def prepare_request(ds_type, ds_name, role): - ds_class = src_registry.sources[ds_type] + ds_class = SourceRegistry.sources[ds_type] # TODO: generalize datasource names if ds_type == 'table': ds = session.query(ds_class).filter(