From 5599861ef61ac7c0f9139c525c61ab4ac0cef5b4 Mon Sep 17 00:00:00 2001 From: Nick Whyte Date: Sat, 4 Jun 2016 22:52:25 +1000 Subject: [PATCH] Patch to allow use of a raw SQLAlchemy declarative base. --- docs/models.rst | 68 ++++++++ flask_sqlalchemy/__init__.py | 56 +++++-- test_sqlalchemy.py | 291 +++++++++++++++++++++++++++++++++++ 3 files changed, 404 insertions(+), 11 deletions(-) diff --git a/docs/models.rst b/docs/models.rst index e6fc9e1d..8b80b36d 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -143,3 +143,71 @@ Here we configured `Page.tags` to be a list of tags once loaded because we don't expect too many tags per page. The list of pages per tag (`Tag.pages`) however is a dynamic backref. As mentioned above this means that you will get a query object back you can use to fire a select yourself. + + +External Declarative Bases +-------------------------- + +Flask-SQLAlchemy allows you to provide your own declarative base if you +feel the need to do so. Doing so can allow you to cut down circular +imports, allow you to use the app factory pattern, or even share your +SQLAlchemy model across different python application using SQLAlchemy, +without the requirement of running a different set of models for use with +Flask-SQLAlchemy. + +We declare our model just as you would above, however, using SQLAlchemy +constructs, rather than accessing classes via ``db.*``. Once defined, call +call :meth:`SQLAlchemy.register_base` to register your delcarative base with +Flask-SQLAlchemy. + +A minimal example:: + + from sqlalchemy import Column, Integer, String + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + + class User(Base): + __tablename__ = 'user' + + id = Column(Integer, primary_key=True) + username = Column(String(80), unique=True) + email = Column(String(255), unique=True) + + def __repr__(self): + return '' % self.username + + from flask import Flask + from flask_sqlalchemy import SQLAlchemy + + app = Flask(__name__) + app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' + db = SQLAlchemy(app) + db.register_base(Base) + + db.create_all() + + @app.before_first_request + def insert_user(): + # We can create new objects the normal way + user = User(id=1, username='foo', email='foo@bar.com') + db.session.add(user) + db.session.commit() + + @app.route('/') + def index(user_id): + # We can query the model two ways: + user = db.session.query(User).get_or_404(user_id) + + # Or we can using the model's query property + user = User.query.get_or_404(user_id) + + return "Hello, {}".format(user.username) + + if __name__ == '__main__': + app.run(debug=True) + +If you're using binds, you'll need to use your own session that knows how +to handle them, or use the same ``Model.__bind__`` system and register the +extra engines with ``SQLALCHEMY_BINDS`` + diff --git a/flask_sqlalchemy/__init__.py b/flask_sqlalchemy/__init__.py index f5d3eaf5..8954c329 100644 --- a/flask_sqlalchemy/__init__.py +++ b/flask_sqlalchemy/__init__.py @@ -734,13 +734,17 @@ def __init__(self, app=None, use_native_unicode=True, session_options=None, self._engine_lock = Lock() self.app = app _include_sqlalchemy(self, query_class) + self.external_bases = [] if app is not None: self.init_app(app) @property def metadata(self): - """The metadata associated with ``db.Model``.""" + """The metadata associated with ``db.Model``. + Access to raw SQLA models added using register_base should + be referenced directly using it's own declarative base. + """ return self.Model.metadata @@ -943,9 +947,11 @@ def get_app(self, reference_app=None): def get_tables_for_bind(self, bind=None): """Returns a list of all tables relevant for a bind.""" result = [] - for table in itervalues(self.Model.metadata.tables): - if table.info.get('bind_key') == bind: - result.append(table) + for Base in self.bases: + for table in itervalues(Base.metadata.tables): + if table.info.get('bind_key') == bind: + result.append(table) + return result def get_binds(self, app=None): @@ -972,13 +978,14 @@ def _execute_for_all_tables(self, app, bind, operation, skip_tables=False): else: binds = bind - for bind in binds: - extra = {} - if not skip_tables: - tables = self.get_tables_for_bind(bind) - extra['tables'] = tables - op = getattr(self.Model.metadata, operation) - op(bind=self.get_engine(app, bind), **extra) + for Base in self.bases: + for bind in binds: + extra = {} + if not skip_tables: + tables = self.get_tables_for_bind(bind) + extra['tables'] = tables + op = getattr(Base.metadata, operation) + op(bind=self.get_engine(app, bind), **extra) def create_all(self, bind='__all__', app=None): """Creates all tables. @@ -1017,6 +1024,33 @@ def __repr__(self): app and app.config['SQLALCHEMY_DATABASE_URI'] or None ) + @property + def bases(self): + return [self.Model] + self.external_bases + + def register_base(self, Base): + """Register an external raw SQLAlchemy declarative base. + Allows usage of the base with our session management and + adds convenience query property using self.Query by default.""" + + self.external_bases.append(Base) + for c in Base._decl_class_registry.values(): + if isinstance(c, type): + if not hasattr(c, 'query') and not hasattr(c, 'query_class'): + c.query_class = self.Query + if not hasattr(c, 'query'): + c.query = _QueryProperty(self) + + # for name in dir(c): + # attr = getattr(c, name) + # if type(attr) == orm.attributes.InstrumentedAttribute: + # if hasattr(attr.prop, 'query_class'): + # attr.prop.query_class = self.Query + + # if hasattr(c , 'rel_dynamic'): + # c.rel_dynamic.prop.query_class = self.Query + + class FSADeprecationWarning(DeprecationWarning): pass diff --git a/test_sqlalchemy.py b/test_sqlalchemy.py index 24a0072d..e39126c4 100755 --- a/test_sqlalchemy.py +++ b/test_sqlalchemy.py @@ -734,6 +734,295 @@ def test_listen_to_session_event(self): db = fsa.SQLAlchemy(app) sa.event.listen(db.session, 'after_commit', lambda session: None) +class RawSQLADeclarativeBaseTestCase(unittest.TestCase): + + def sqla_raw_declarative_base(self): + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy import Column, String, Integer, ForeignKey + from sqlalchemy.orm import relationship + + Base = declarative_base() + + class Bar(Base): + __tablename__ = 'bar' + id = Column(Integer, primary_key=True) + parent_id = Column(Integer, ForeignKey('foo.id')) + + + class Foo(Base): + __tablename__ = 'foo' + id = Column(Integer, primary_key=True) + string = Column(String(255)) + children = relationship("Bar", lazy='dynamic') + + + return Base, Foo, Bar + + def setUp(self): + app = flask.Flask(__name__) + app.config['SQLALCHEMY_ENGINE'] = 'sqlite://' + app.config['TESTING'] = True + self.Base, self.Foo, self.Bar = self.sqla_raw_declarative_base() + db = fsa.SQLAlchemy(app) + + + db.register_base(self.Base) + db.create_all() + + self.db = db + self.app = app + + def tearDown(self): + self.db.drop_all() + + + def test_register_base_success(self): + + self.assertTrue(self.db.engine.dialect.has_table( + self.db.engine.connect(), 'foo')) + self.assertTrue(self.db.engine.dialect.has_table( + self.db.engine.connect(), 'bar')) + self.assertFalse(self.db.engine.dialect.has_table( + self.db.engine.connect(), 'faketable')) + + def test_drop_all(self): + # Make sure the tables were originally created so we can compare + # the fact that they have been dropped. + self.assertTrue(self.db.engine.dialect.has_table( + self.db.engine.connect(), 'foo')) + self.assertTrue(self.db.engine.dialect.has_table( + self.db.engine.connect(), 'bar')) + + self.db.drop_all() + self.assertFalse(self.db.engine.dialect.has_table( + self.db.engine.connect(), 'foo')) + self.assertFalse(self.db.engine.dialect.has_table( + self.db.engine.connect(), 'bar')) + + + def test_query_insert(self): + + self.assertEqual(len(self.Foo.query.all()), 0) + + foo = self.Foo(string='Foo') + + self.db.session.add(foo) + self.db.session.commit() + + self.assertEqual(len(self.db.session.query(self.Foo).all()), 1) + self.assertEqual(self.db.session.query(self.Foo).count(), 1) + + first_foo = self.db.session.query(self.Foo).first() + self.assertEqual(first_foo.string, 'Foo') + + def test_query_property(self): + + self.assertEqual(len(self.Foo.query.all()), 0) + foo = self.Foo(string='Foo') + + self.db.session.add(foo) + self.db.session.commit() + + self.assertEqual(len(self.Foo.query.all()), 1) + self.assertEqual(self.Foo.query.count(), 1) + + first_foo = self.Foo.query.first() + self.assertEqual(first_foo.string, 'Foo') + + + def test_default_query_class(self): + # Also test children. + p = self.Foo() + c = self.Bar() + c.parent = p + + self.assertEqual(type(self.Foo.query), fsa.BaseQuery) + self.assertEqual(type(self.Bar.query), fsa.BaseQuery) + + # Unable to override SQLA's relationship constructor to use our + # own query class for relationships, since we cannot inspect the + # relationship. If we can get enough info about how the original + # relationship property was constructed, we could reconstruct using + # a wrapped relationship property. Disabling this test for now. + + # self.assertTrue(isinstance(p.children, sqlalchemy.BaseQuery)) + + + +class RawSQLAMultipleDeclarativeBaseTestCase(unittest.TestCase): + + def sqla_raw_declarative_base(self): + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy import Column, String, Integer, ForeignKey + from sqlalchemy.orm import relationship + + models = dict() + + Base_A = declarative_base() + Base_B = declarative_base() + Base_C = declarative_base() + + class Bar_A(Base_A): + __tablename__ = 'bar_A' + id = Column(Integer, primary_key=True) + parent_id = Column(Integer, ForeignKey('foo_A.id')) + + + class Foo_A(Base_A): + __tablename__ = 'foo_A' + id = Column(Integer, primary_key=True) + string = Column(String(255)) + children = relationship("Bar_A", lazy='dynamic') + + models['A'] = dict( + Foo=Foo_A, + Bar=Bar_A, + Base=Base_A + ) + + class Bar_B(Base_B): + __tablename__ = 'bar_B' + id = Column(Integer, primary_key=True) + parent_id = Column(Integer, ForeignKey('foo_B.id')) + + + class Foo_B(Base_B): + __tablename__ = 'foo_B' + id = Column(Integer, primary_key=True) + string = Column(String(255)) + children = relationship("Bar_B", lazy='dynamic') + + models['B'] = dict( + Foo=Foo_B, + Bar=Bar_B, + Base=Base_B + ) + + class Bar_C(Base_C): + __tablename__ = 'bar_C' + id = Column(Integer, primary_key=True) + parent_id = Column(Integer, ForeignKey('foo_C.id')) + + + class Foo_C(Base_C): + __tablename__ = 'foo_C' + id = Column(Integer, primary_key=True) + string = Column(String(255)) + children = relationship("Bar_C", lazy='dynamic') + + models['C'] = dict( + Foo=Foo_C, + Bar=Bar_C, + Base=Base_C + ) + + + return models + + def setUp(self): + app = flask.Flask(__name__) + app.config['SQLALCHEMY_ENGINE'] = 'sqlite://' + app.config['TESTING'] = True + self.model_suffixes = ['A','B','C'] + + self.Models = self.sqla_raw_declarative_base() + db = fsa.SQLAlchemy(app) + + for base_group in self.Models.values(): + db.register_base(base_group['Base']) + db.create_all() + + self.db = db + self.app = app + + + def tearDown(self): + self.db.drop_all() + + def test_register_base_success(self): + for suffix in self.model_suffixes: + self.assertTrue(self.db.engine.dialect.has_table( + self.db.engine.connect(), + 'foo_{0}'.format(suffix))) + self.assertTrue(self.db.engine.dialect.has_table( + self.db.engine.connect(), + 'bar_{0}'.format(suffix))) + + self.assertFalse(self.db.engine.dialect.has_table( + self.db.engine.connect(), + 'faketable')) + + + def test_query_insert(self): + for suffix in self.model_suffixes: + self.assertEqual(len(self.db.session.query( + self.Models[suffix]['Foo']).all()), 0) + + foo = self.Models[suffix]['Foo'](string='Foo_{0}'.format(suffix)) + self.db.session.add(foo) + self.db.session.commit() + + self.assertEqual(len(self.db.session.query( + self.Models[suffix]['Foo']).all()), 1) + self.assertEqual(self.db.session.query( + self.Models[suffix]['Foo']).count(), 1) + + first_foo = self.db.session.query( + self.Models[suffix]['Foo']).first() + self.assertEqual(first_foo.string, 'Foo_{0}'.format(suffix)) + + + def test_query_property(self): + for suffix in self.model_suffixes: + self.assertEqual(len(self.Models[suffix]['Foo'].query.all()), 0) + + foo = self.Models[suffix]['Foo'](string='Foo') + self.db.session.add(foo) + self.db.session.commit() + + self.assertEqual(len(self.Models[suffix]['Foo'].query.all()), 1) + self.assertEqual(self.Models[suffix]['Foo'].query.count(), 1) + + first_foo = self.Models[suffix]['Foo'].query.first() + self.assertEqual(first_foo.string, 'Foo') + + + def test_default_query_class(self): + # Also test children. + for suffix in self.model_suffixes: + p = self.Models[suffix]['Foo']() + c = self.Models[suffix]['Bar']() + c.parent = p + + self.assertEqual( + type(self.Models[suffix]['Foo'].query), + fsa.BaseQuery) + self.assertEqual( + type(self.Models[suffix]['Bar'].query), + fsa.BaseQuery) + + # Unable to override SQLA's relationship constructor to use our + # own query class for relationships, since we cannot inspect the + # relationship. If we can get enough info about how the original + # relationship property was constructed, we could reconstruct using + # a wrapped relationship property. Disabling this test for now. + + # self.assertTrue(isinstance(p.children, sqlalchemy.BaseQuery)) + + def test_drop_all(self): + # Make sure they exist before drop, so we can compare the result. + for suffix in self.model_suffixes: + self.assertTrue(self.db.engine.dialect.has_table( + self.db.engine.connect(), 'foo_{0}'.format(suffix))) + self.assertTrue(self.db.engine.dialect.has_table( + self.db.engine.connect(), 'bar_{0}'.format(suffix))) + + self.db.drop_all() + for suffix in self.model_suffixes: + self.assertFalse(self.db.engine.dialect.has_table( + self.db.engine.connect(), 'foo_{0}'.format(suffix))) + self.assertFalse(self.db.engine.dialect.has_table( + self.db.engine.connect(), 'bar_{0}'.format(suffix))) def suite(): suite = unittest.TestSuite() @@ -752,6 +1041,8 @@ def suite(): suite.addTest(unittest.makeSuite(SessionScopingTestCase)) suite.addTest(unittest.makeSuite(CommitOnTeardownTestCase)) suite.addTest(unittest.makeSuite(CustomModelClassTestCase)) + suite.addTest(unittest.makeSuite(RawSQLADeclarativeBaseTestCase)) + suite.addTest(unittest.makeSuite(RawSQLAMultipleDeclarativeBaseTestCase)) if flask.signals_available: suite.addTest(unittest.makeSuite(SignallingTestCase))