diff --git a/flask_sqlalchemy/__init__.py b/flask_sqlalchemy/__init__.py index f5d3eaf5..cfa9af3c 100644 --- a/flask_sqlalchemy/__init__.py +++ b/flask_sqlalchemy/__init__.py @@ -26,6 +26,7 @@ from sqlalchemy.orm.session import Session as SessionBase from sqlalchemy.engine.url import make_url from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta +from sqlalchemy.schema import MetaData from flask_sqlalchemy._compat import iteritems, itervalues, xrange, string_types # the best timer function for the platform @@ -57,10 +58,10 @@ def _make_table(db): def _make_table(*args, **kwargs): - if len(args) > 1 and isinstance(args[1], db.Column): - args = (args[0], db.metadata) + args[1:] info = kwargs.pop('info', None) or {} info.setdefault('bind_key', None) + if len(args) > 1 and isinstance(args[1], db.Column): + args = (args[0], db.get_metadata(bind=info['bind_key'])) + args[1:] kwargs['info'] = info return sqlalchemy.Table(*args, **kwargs) return _make_table @@ -598,6 +599,10 @@ def _join(match): def __init__(self, name, bases, d): bind_key = d.pop('__bind_key__', None) or getattr(self, '__bind_key__', None) + if bind_key: + if bind_key not in self._metadata: + self._metadata[bind_key] = MetaData() + self.metadata = self._metadata[bind_key] DeclarativeMeta.__init__(self, name, bases, d) if bind_key is not None and hasattr(self, '__table__'): self.__table__.info['bind_key'] = bind_key @@ -730,6 +735,7 @@ def __init__(self, app=None, use_native_unicode=True, session_options=None, self.use_native_unicode = use_native_unicode self.Query = query_class self.session = self.create_scoped_session(session_options) + model_class._metadata = {} self.Model = self.make_declarative_base(model_class, metadata) self._engine_lock = Lock() self.app = app @@ -744,6 +750,13 @@ def metadata(self): return self.Model.metadata + def get_metadata(self, bind=None): + if not bind: + return self.metadata + if bind not in self.Model._metadata: + self.Model._metadata[bind] = MetaData() + return self.Model._metadata.get(bind) + def create_scoped_session(self, options=None): """Create a :class:`~sqlalchemy.orm.scoping.scoped_session` on the factory from :meth:`create_session`. @@ -943,7 +956,7 @@ def get_app(self, reference_app=None): def get_tables_for_bind(self, bind=None): """Returns a list of all tables relevant for a bind.""" result = [] - for table in itervalues(self.Model.metadata.tables): + for table in itervalues(self.get_metadata(bind=bind).tables): if table.info.get('bind_key') == bind: result.append(table) return result @@ -977,7 +990,7 @@ def _execute_for_all_tables(self, app, bind, operation, skip_tables=False): if not skip_tables: tables = self.get_tables_for_bind(bind) extra['tables'] = tables - op = getattr(self.Model.metadata, operation) + op = getattr(self.get_metadata(bind=bind), operation) op(bind=self.get_engine(app, bind), **extra) def create_all(self, bind='__all__', app=None): diff --git a/test_sqlalchemy.py b/test_sqlalchemy.py index 24a0072d..a93bae27 100755 --- a/test_sqlalchemy.py +++ b/test_sqlalchemy.py @@ -404,8 +404,8 @@ class Baz(db.Model): app.config['SQLALCHEMY_BINDS'][key]) # do the models have the correct engines? - self.assertEqual(db.metadata.tables['foo'].info['bind_key'], 'foo') - self.assertEqual(db.metadata.tables['bar'].info['bind_key'], 'bar') + self.assertEqual(db.get_metadata(bind='foo').tables['foo'].info['bind_key'], 'foo') + self.assertEqual(db.get_metadata(bind='bar').tables['bar'].info['bind_key'], 'bar') self.assertEqual(db.metadata.tables['baz'].info.get('bind_key'), None) # see the tables created in an engine @@ -431,6 +431,30 @@ class Baz(db.Model): Baz.__table__: db.get_engine(app, None) }) + def test_binds_with_same_table_name(self): + app = flask.Flask(__name__) + app.config['SQLALCHEMY_BINDS'] = { + 'foo': 'sqlite://', + 'bar': 'sqlite://' + } + db = fsa.SQLAlchemy(app) + + class FooUser(db.Model): + __bind_key__ = 'foo' + __tablename__ = 'users' + __table_args__ = {"info": {"bind_key": "foo"}} + id = db.Column(db.Integer, primary_key=True) + + class BarUser(db.Model): + __bind_key__ = 'bar' + __tablename__ = 'users' + id = db.Column(db.Integer, primary_key=True) + + db.create_all() + + self.assertEqual(db.get_metadata(bind='foo').tables['users'].info['bind_key'], 'foo') + self.assertEqual(db.get_metadata(bind='bar').tables['users'].info['bind_key'], 'bar') + def test_abstract_binds(self): app = flask.Flask(__name__) app.config['SQLALCHEMY_BINDS'] = { @@ -448,7 +472,7 @@ class FooBoundModel(AbstractFooBoundModel): db.create_all() # does the model have the correct engines? - self.assertEqual(db.metadata.tables['foo_bound_model'].info['bind_key'], 'foo') + self.assertEqual(db.get_metadata(bind='foo').tables['foo_bound_model'].info['bind_key'], 'foo') # see the tables created in an engine metadata = db.MetaData()