diff --git a/flask_sqlalchemy/__init__.py b/flask_sqlalchemy/__init__.py index eb39ab07..8b54365a 100644 --- a/flask_sqlalchemy/__init__.py +++ b/flask_sqlalchemy/__init__.py @@ -30,7 +30,7 @@ from sqlalchemy.orm.exc import UnmappedClassError from sqlalchemy.orm.session import Session as SessionBase -from ._compat import iteritems, itervalues, string_types, xrange +from ._compat import itervalues, string_types, xrange __version__ = '2.2.1' @@ -551,31 +551,39 @@ def get_engine(self): def _should_set_tablename(cls): - """Traverse the model's MRO. If a primary key column is found before a - table or tablename, then a new tablename should be generated. - - This supports: - - * Joined table inheritance without explicitly naming sub-models. - * Single table inheritance. - * Inheriting from mixins or abstract models. - - :param cls: model to check - :return: True if tablename should be set + """Determine whether ``__tablename__`` should be automatically generated + for a model. + + * If no class in the MRO sets a name, one should be generated. + * If a declared attr is found, it should be used instead. + * If a name is found, it should be used if the class is a mixin, otherwise + one should be generated. + * Abstract models should not have one generated. + + Later, :meth:`._BoundDeclarativeMeta.__table_cls__` will determine if the + model looks like single or joined-table inheritance. If no primary key is + found, the name will be unset. """ + if ( + cls.__dict__.get('__abstract__', False) + or not any(isinstance(b, DeclarativeMeta) for b in cls.__mro__[1:]) + ): + return False for base in cls.__mro__: - d = base.__dict__ + if '__tablename__' not in base.__dict__: + continue - if '__tablename__' in d or '__table__' in d: + if isinstance(base.__dict__['__tablename__'], declared_attr): return False - for name, obj in iteritems(d): - if isinstance(obj, declared_attr): - obj = getattr(cls, name) + return not ( + base is cls + or base.__dict__.get('__abstract__', False) + or not isinstance(base, DeclarativeMeta) + ) - if isinstance(obj, sqlalchemy.Column) and obj.primary_key: - return True + return True def camel_to_snake_case(name): @@ -591,20 +599,36 @@ def _join(match): class _BoundDeclarativeMeta(DeclarativeMeta): - def __new__(cls, name, bases, d): - # if tablename is set explicitly, move it to the cache attribute so - # that future subclasses still have auto behavior - if '__tablename__' in d: - d['_cached_tablename'] = d.pop('__tablename__') + def __init__(cls, name, bases, d): + if _should_set_tablename(cls): + cls.__tablename__ = camel_to_snake_case(cls.__name__) + + bind_key = ( + d.pop('__bind_key__', None) + or getattr(cls, '__bind_key__', None) + ) - return DeclarativeMeta.__new__(cls, name, bases, d) + super(_BoundDeclarativeMeta, cls).__init__(name, bases, d) - def __init__(self, name, bases, d): - bind_key = d.pop('__bind_key__', None) or getattr(self, '__bind_key__', None) - DeclarativeMeta.__init__(self, name, bases, d) + if bind_key is not None and hasattr(cls, '__table__'): + cls.__table__.info['bind_key'] = bind_key - if bind_key is not None and hasattr(self, '__table__'): - self.__table__.info['bind_key'] = bind_key + def __table_cls__(cls, *args, **kwargs): + """This is called by SQLAlchemy during mapper setup. It determines the + final table object that the model will use. + + If no primary key is found, that indicates single-table inheritance, + so no table will be created and ``__tablename__`` will be unset. + """ + for arg in args: + if ( + (isinstance(arg, sqlalchemy.Column) and arg.primary_key) + or isinstance(arg, sqlalchemy.PrimaryKeyConstraint) + ): + return sqlalchemy.Table(*args, **kwargs) + + if '__tablename__' in cls.__dict__: + del cls.__tablename__ def get_state(app): @@ -638,18 +662,6 @@ class Model(object): #: Equivalent to ``db.session.query(Model)`` unless :attr:`query_class` has been changed. query = None - _cached_tablename = None - - @declared_attr - def __tablename__(cls): - if ( - '_cached_tablename' not in cls.__dict__ and - _should_set_tablename(cls) - ): - cls._cached_tablename = camel_to_snake_case(cls.__name__) - - return cls._cached_tablename - class SQLAlchemy(object): """This class is used to control the SQLAlchemy integration to one diff --git a/tests/test_table_name.py b/tests/test_table_name.py index f1af0a55..8a7d011e 100644 --- a/tests/test_table_name.py +++ b/tests/test_table_name.py @@ -1,3 +1,5 @@ +import inspect + from sqlalchemy.ext.declarative import declared_attr @@ -25,6 +27,7 @@ class Duck(db.Model): class Mallard(Duck): pass + assert '__tablename__' not in Mallard.__dict__ assert Mallard.__tablename__ == 'duck' @@ -39,8 +42,10 @@ class Donald(Duck): assert Donald.__tablename__ == 'donald' -def test_mixin_name(db): - """Primary key provided by mixin should still allow model to set tablename.""" +def test_mixin_id(db): + """Primary key provided by mixin should still allow model to set + tablename. + """ class Base(object): id = db.Column(db.Integer, primary_key=True) @@ -51,8 +56,33 @@ class Duck(Base, db.Model): assert Duck.__tablename__ == 'duck' +def test_mixin_attr(db): + """A declared attr tablename will be used down multiple levels of + inheritance. + """ + class Mixin(object): + @declared_attr + def __tablename__(cls): + return cls.__name__.upper() + + class Bird(Mixin, db.Model): + id = db.Column(db.Integer, primary_key=True) + + class Duck(Bird): + # object reference + id = db.Column(db.ForeignKey(Bird.id), primary_key=True) + + class Mallard(Duck): + # string reference + id = db.Column(db.ForeignKey('DUCK.id'), primary_key=True) + + assert Bird.__tablename__ == 'BIRD' + assert Duck.__tablename__ == 'DUCK' + assert Mallard.__tablename__ == 'MALLARD' + + def test_abstract_name(db): - """Abstract model should not set a name. Subclass should set a name.""" + """Abstract model should not set a name. Subclass should set a name.""" class Base(db.Model): __abstract__ = True id = db.Column(db.Integer, primary_key=True) @@ -60,19 +90,23 @@ class Base(db.Model): class Duck(Base): pass - assert Base.__tablename__ == 'base' + assert '__tablename__' not in Base.__dict__ assert Duck.__tablename__ == 'duck' def test_complex_inheritance(db): - """Joined table inheritance, but the new primary key is provided by a mixin, not directly on the class.""" + """Joined table inheritance, but the new primary key is provided by a + mixin, not directly on the class. + """ class Duck(db.Model): id = db.Column(db.Integer, primary_key=True) class IdMixin(object): @declared_attr def id(cls): - return db.Column(db.Integer, db.ForeignKey(Duck.id), primary_key=True) + return db.Column( + db.Integer, db.ForeignKey(Duck.id), primary_key=True + ) class RubberDuck(IdMixin, Duck): pass @@ -81,18 +115,55 @@ class RubberDuck(IdMixin, Duck): def test_manual_name(db): + """Setting a manual name prevents generation for the immediate model. A + name is generated for joined but not single-table inheritance. + """ class Duck(db.Model): __tablename__ = 'DUCK' id = db.Column(db.Integer, primary_key=True) + type = db.Column(db.String) + + __mapper_args__ = { + 'polymorphic_on': type + } class Daffy(Duck): id = db.Column(db.Integer, db.ForeignKey(Duck.id), primary_key=True) + __mapper_args__ = { + 'polymorphic_identity': 'Warner' + } + + class Donald(Duck): + __mapper_args__ = { + 'polymorphic_identity': 'Disney' + } + assert Duck.__tablename__ == 'DUCK' assert Daffy.__tablename__ == 'daffy' + assert '__tablename__' not in Donald.__dict__ + assert Donald.__tablename__ == 'DUCK' + # polymorphic condition for single-table query + assert 'WHERE "DUCK".type' in str(Donald.query) + + +def test_primary_constraint(db): + """Primary key will be picked up from table args.""" + class Duck(db.Model): + id = db.Column(db.Integer) + + __table_args__ = ( + db.PrimaryKeyConstraint(id), + ) + + assert Duck.__table__ is not None + assert Duck.__tablename__ == 'duck' def test_no_access_to_class_property(db): + """Ensure the implementation doesn't access class properties or declared + attrs while inspecting the unmapped model. + """ class class_property(object): def __init__(self, f): self.f = f @@ -106,14 +177,13 @@ class Duck(db.Model): class ns(object): accessed = False - # Since there's no id provided by the following model, - # _should_set_tablename will scan all attributes. If it's working - # properly, it won't access the class property, but will access the - # declared_attr. - class Witch(Duck): @declared_attr def is_duck(self): + # declared attrs will be accessed during mapper configuration, + # but make sure they're not accessed before that + info = inspect.getouterframes(inspect.currentframe())[2] + assert info[3] != '_should_set_tablename' ns.accessed = True @class_property