From 08380962c30112f123220a18475de34fdfe11029 Mon Sep 17 00:00:00 2001 From: David Baumgold Date: Fri, 29 Jan 2016 13:26:15 -0500 Subject: [PATCH] Support binds on abstract models --- flask_sqlalchemy/__init__.py | 4 ++-- test_sqlalchemy.py | 42 ++++++++++++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/flask_sqlalchemy/__init__.py b/flask_sqlalchemy/__init__.py index f0fb1a10..6ed10539 100644 --- a/flask_sqlalchemy/__init__.py +++ b/flask_sqlalchemy/__init__.py @@ -599,9 +599,9 @@ def _join(match): return DeclarativeMeta.__new__(cls, name, bases, d) def __init__(self, name, bases, d): - bind_key = d.pop('__bind_key__', None) + bind_key = d.pop('__bind_key__', None) or getattr(self, '__bind_key__', None) DeclarativeMeta.__init__(self, name, bases, d) - if bind_key is not None: + if bind_key is not None and hasattr(self, '__table__'): self.__table__.info['bind_key'] = bind_key diff --git a/test_sqlalchemy.py b/test_sqlalchemy.py index c2f9b6ad..52ee56ed 100644 --- a/test_sqlalchemy.py +++ b/test_sqlalchemy.py @@ -1,6 +1,8 @@ from __future__ import with_statement import atexit +import tempfile +import os import unittest from datetime import datetime import flask @@ -384,12 +386,10 @@ def index(): class BindsTestCase(unittest.TestCase): def test_basic_binds(self): - import tempfile _, db1 = tempfile.mkstemp() _, db2 = tempfile.mkstemp() def _remove_files(): - import os try: os.remove(db1) os.remove(db2) @@ -456,6 +456,44 @@ class Baz(db.Model): Baz.__table__: db.get_engine(app, None) }) + def test_abstract_binds(self): + _, db1 = tempfile.mkstemp() + _, db2 = tempfile.mkstemp() + + def _remove_files(): + try: + os.remove(db1) + os.remove(db2) + except IOError: + pass + atexit.register(_remove_files) + + app = flask.Flask(__name__) + app.config['SQLALCHEMY_ENGINE'] = 'sqlite://' + app.config['SQLALCHEMY_BINDS'] = { + 'foo': 'sqlite:///' + db1, + 'bar': 'sqlite:///' + db2 + } + db = sqlalchemy.SQLAlchemy(app) + + class AbstractFooBoundModel(db.Model): + __abstract__ = True + __bind_key__ = 'foo' + + class FooBoundModel(AbstractFooBoundModel): + id = db.Column(db.Integer, primary_key=True) + + db.create_all() + + # does the model have the correct engines? + self.assertEqual(db.metadata.tables['foo_bound_model'].info['bind_key'], 'foo') + + # see the tables created in an engine + metadata = db.MetaData() + metadata.reflect(bind=db.get_engine(app, 'foo')) + self.assertEqual(len(metadata.tables), 1) + self.assertTrue('foo_bound_model' in metadata.tables) + class DefaultQueryClassTestCase(unittest.TestCase):