diff --git a/flask_sqlalchemy/__init__.py b/flask_sqlalchemy/__init__.py index 8e0a8f2d..966a8d7d 100644 --- a/flask_sqlalchemy/__init__.py +++ b/flask_sqlalchemy/__init__.py @@ -68,34 +68,34 @@ def _make_table(*args, **kwargs): return _make_table -def _set_default_query_class(d): +def _set_default_query_class(d, cls): if 'query_class' not in d: - d['query_class'] = BaseQuery + d['query_class'] = cls -def _wrap_with_default_query_class(fn): +def _wrap_with_default_query_class(fn, cls): @functools.wraps(fn) def newfn(*args, **kwargs): - _set_default_query_class(kwargs) + _set_default_query_class(kwargs, cls) if "backref" in kwargs: backref = kwargs['backref'] if isinstance(backref, string_types): backref = (backref, {}) - _set_default_query_class(backref[1]) + _set_default_query_class(backref[1], cls) return fn(*args, **kwargs) return newfn -def _include_sqlalchemy(obj): +def _include_sqlalchemy(obj, cls): for module in sqlalchemy, sqlalchemy.orm: for key in module.__all__: if not hasattr(obj, key): setattr(obj, key, getattr(module, key)) # Note: obj.Table does not attempt to be a SQLAlchemy Table class. obj.Table = _make_table(obj) - obj.relationship = _wrap_with_default_query_class(obj.relationship) - obj.relation = _wrap_with_default_query_class(obj.relation) - obj.dynamic_loader = _wrap_with_default_query_class(obj.dynamic_loader) + obj.relationship = _wrap_with_default_query_class(obj.relationship, cls) + obj.relation = _wrap_with_default_query_class(obj.relation, cls) + obj.dynamic_loader = _wrap_with_default_query_class(obj.dynamic_loader, cls) obj.event = event @@ -730,19 +730,21 @@ class User(db.Model): naming conventions among other, non-trivial things. """ - def __init__(self, app=None, use_native_unicode=True, session_options=None, metadata=None): + def __init__(self, app=None, use_native_unicode=True, session_options=None, + metadata=None, query_class=BaseQuery, model_class=Model): if session_options is None: session_options = {} session_options.setdefault('scopefunc', connection_stack.__ident_func__) + session_options.setdefault('query_cls', query_class) self.use_native_unicode = use_native_unicode self.session = self.create_scoped_session(session_options) - self.Model = self.make_declarative_base(metadata) - self.Query = BaseQuery + self.Query = query_class + self.Model = self.make_declarative_base(model_class, metadata) self._engine_lock = Lock() self.app = app - _include_sqlalchemy(self) + _include_sqlalchemy(self, query_class) if app is not None: self.init_app(app) @@ -770,11 +772,15 @@ def create_session(self, options): """ return SignallingSession(self, **options) - def make_declarative_base(self, metadata=None): + def make_declarative_base(self, model, metadata=None): """Creates the declarative base.""" - base = declarative_base(cls=Model, name='Model', + base = declarative_base(cls=model, name='Model', metadata=metadata, metaclass=_BoundDeclarativeMeta) + + if not getattr(base, 'query_class', None): + base.query_class = self.Query + base.query = _QueryProperty(self) return base diff --git a/test_sqlalchemy.py b/test_sqlalchemy.py index 22f07fd1..775a43c5 100644 --- a/test_sqlalchemy.py +++ b/test_sqlalchemy.py @@ -467,18 +467,87 @@ def test_default_query_class(self): class Parent(db.Model): id = db.Column(db.Integer, primary_key=True) - children = db.relationship("Child", backref = "parents", lazy='dynamic') + children = db.relationship("Child", backref = "parent", lazy='dynamic') + class Child(db.Model): id = db.Column(db.Integer, primary_key=True) parent_id = db.Column(db.Integer, db.ForeignKey('parent.id')) + p = Parent() c = Child() c.parent = p + self.assertEqual(type(Parent.query), sqlalchemy.BaseQuery) self.assertEqual(type(Child.query), sqlalchemy.BaseQuery) self.assertTrue(isinstance(p.children, sqlalchemy.BaseQuery)) - #self.assertTrue(isinstance(c.parents, sqlalchemy.BaseQuery)) + self.assertTrue(isinstance(db.session.query(Parent), sqlalchemy.BaseQuery)) + + +class CustomQueryClassTestCase(unittest.TestCase): + + def test_custom_query_class(self): + class CustomQueryClass(sqlalchemy.BaseQuery): + pass + + class MyModelClass(object): + pass + + app = flask.Flask(__name__) + app.config['SQLALCHEMY_ENGINE'] = 'sqlite://' + app.config['TESTING'] = True + db = sqlalchemy.SQLAlchemy(app, query_class=CustomQueryClass, + model_class=MyModelClass) + + class Parent(db.Model): + id = db.Column(db.Integer, primary_key=True) + children = db.relationship("Child", backref = "parent", lazy='dynamic') + + class Child(db.Model): + id = db.Column(db.Integer, primary_key=True) + parent_id = db.Column(db.Integer, db.ForeignKey('parent.id')) + + p = Parent() + c = Child() + c.parent = p + + self.assertEqual(type(Parent.query), CustomQueryClass) + self.assertEqual(type(Child.query), CustomQueryClass) + self.assertTrue(isinstance(p.children, CustomQueryClass)) + self.assertEqual(db.Query, CustomQueryClass) + self.assertEqual(db.Model.query_class, CustomQueryClass) + self.assertTrue(isinstance(db.session.query(Parent), CustomQueryClass)) + + + def test_dont_override_model_default(self): + class CustomQueryClass(sqlalchemy.BaseQuery): + pass + + app = flask.Flask(__name__) + app.config['SQLALCHEMY_ENGINE'] = 'sqlite://' + app.config['TESTING'] = True + db = sqlalchemy.SQLAlchemy(app, query_class=CustomQueryClass) + + class SomeModel(db.Model): + id = db.Column(db.Integer, primary_key=True) + + self.assertEqual(type(SomeModel.query), sqlalchemy.BaseQuery) + + +class CustomModelClassTestCase(unittest.TestCase): + + def test_custom_query_class(self): + class CustomModelClass(sqlalchemy.Model): + pass + + app = flask.Flask(__name__) + app.config['SQLALCHEMY_ENGINE'] = 'sqlite://' + app.config['TESTING'] = True + db = sqlalchemy.SQLAlchemy(app, model_class=CustomModelClass) + + class SomeModel(db.Model): + id = db.Column(db.Integer, primary_key=True) + self.assertTrue(isinstance(SomeModel(), CustomModelClass)) class SQLAlchemyIncludesTestCase(unittest.TestCase):