Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions flask_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from sqlalchemy.orm.session import Session as SessionBase
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
from sqlalchemy.schema import MetaData
from flask_sqlalchemy._compat import iteritems, itervalues, xrange, string_types

# the best timer function for the platform
Expand Down Expand Up @@ -57,10 +58,10 @@

def _make_table(db):
def _make_table(*args, **kwargs):
if len(args) > 1 and isinstance(args[1], db.Column):
args = (args[0], db.metadata) + args[1:]
info = kwargs.pop('info', None) or {}
info.setdefault('bind_key', None)
if len(args) > 1 and isinstance(args[1], db.Column):
args = (args[0], db.get_metadata(bind=info['bind_key'])) + args[1:]
kwargs['info'] = info
return sqlalchemy.Table(*args, **kwargs)
return _make_table
Expand Down Expand Up @@ -598,6 +599,10 @@ def _join(match):

def __init__(self, name, bases, d):
bind_key = d.pop('__bind_key__', None) or getattr(self, '__bind_key__', None)
if bind_key:
if bind_key not in self._metadata:
self._metadata[bind_key] = MetaData()
self.metadata = self._metadata[bind_key]
DeclarativeMeta.__init__(self, name, bases, d)
if bind_key is not None and hasattr(self, '__table__'):
self.__table__.info['bind_key'] = bind_key
Expand Down Expand Up @@ -730,6 +735,7 @@ def __init__(self, app=None, use_native_unicode=True, session_options=None,
self.use_native_unicode = use_native_unicode
self.Query = query_class
self.session = self.create_scoped_session(session_options)
model_class._metadata = {}
self.Model = self.make_declarative_base(model_class, metadata)
self._engine_lock = Lock()
self.app = app
Expand All @@ -744,6 +750,13 @@ def metadata(self):

return self.Model.metadata

def get_metadata(self, bind=None):
if not bind:
return self.metadata
if bind not in self.Model._metadata:
self.Model._metadata[bind] = MetaData()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use copy/deepcopy here for not throwing customized metadata away?

self.Model._metadata[bind] = copy(self.metadata)

return self.Model._metadata.get(bind)

def create_scoped_session(self, options=None):
"""Create a :class:`~sqlalchemy.orm.scoping.scoped_session`
on the factory from :meth:`create_session`.
Expand Down Expand Up @@ -943,7 +956,7 @@ def get_app(self, reference_app=None):
def get_tables_for_bind(self, bind=None):
"""Returns a list of all tables relevant for a bind."""
result = []
for table in itervalues(self.Model.metadata.tables):
for table in itervalues(self.get_metadata(bind=bind).tables):
if table.info.get('bind_key') == bind:
result.append(table)
return result
Expand Down Expand Up @@ -977,7 +990,7 @@ def _execute_for_all_tables(self, app, bind, operation, skip_tables=False):
if not skip_tables:
tables = self.get_tables_for_bind(bind)
extra['tables'] = tables
op = getattr(self.Model.metadata, operation)
op = getattr(self.get_metadata(bind=bind), operation)
op(bind=self.get_engine(app, bind), **extra)

def create_all(self, bind='__all__', app=None):
Expand Down
30 changes: 27 additions & 3 deletions test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,8 @@ class Baz(db.Model):
app.config['SQLALCHEMY_BINDS'][key])

# do the models have the correct engines?
self.assertEqual(db.metadata.tables['foo'].info['bind_key'], 'foo')
self.assertEqual(db.metadata.tables['bar'].info['bind_key'], 'bar')
self.assertEqual(db.get_metadata(bind='foo').tables['foo'].info['bind_key'], 'foo')
self.assertEqual(db.get_metadata(bind='bar').tables['bar'].info['bind_key'], 'bar')
self.assertEqual(db.metadata.tables['baz'].info.get('bind_key'), None)

# see the tables created in an engine
Expand All @@ -431,6 +431,30 @@ class Baz(db.Model):
Baz.__table__: db.get_engine(app, None)
})

def test_binds_with_same_table_name(self):
app = flask.Flask(__name__)
app.config['SQLALCHEMY_BINDS'] = {
'foo': 'sqlite://',
'bar': 'sqlite://'
}
db = fsa.SQLAlchemy(app)

class FooUser(db.Model):
__bind_key__ = 'foo'
__tablename__ = 'users'
__table_args__ = {"info": {"bind_key": "foo"}}
id = db.Column(db.Integer, primary_key=True)

class BarUser(db.Model):
__bind_key__ = 'bar'
__tablename__ = 'users'
id = db.Column(db.Integer, primary_key=True)

db.create_all()

self.assertEqual(db.get_metadata(bind='foo').tables['users'].info['bind_key'], 'foo')
self.assertEqual(db.get_metadata(bind='bar').tables['users'].info['bind_key'], 'bar')

def test_abstract_binds(self):
app = flask.Flask(__name__)
app.config['SQLALCHEMY_BINDS'] = {
Expand All @@ -448,7 +472,7 @@ class FooBoundModel(AbstractFooBoundModel):
db.create_all()

# does the model have the correct engines?
self.assertEqual(db.metadata.tables['foo_bound_model'].info['bind_key'], 'foo')
self.assertEqual(db.get_metadata(bind='foo').tables['foo_bound_model'].info['bind_key'], 'foo')

# see the tables created in an engine
metadata = db.MetaData()
Expand Down