diff --git a/flask_sqlalchemy/__init__.py b/flask_sqlalchemy/__init__.py index 7cfc032a..b534c193 100644 --- a/flask_sqlalchemy/__init__.py +++ b/flask_sqlalchemy/__init__.py @@ -25,11 +25,12 @@ from flask.signals import Namespace from sqlalchemy import event, inspect, orm from sqlalchemy.engine.url import make_url -from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base +from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base, \ + declared_attr from sqlalchemy.orm.exc import UnmappedClassError from sqlalchemy.orm.session import Session as SessionBase -from ._compat import itervalues, string_types, xrange +from ._compat import iteritems, itervalues, string_types, xrange __version__ = '3.0-dev' @@ -547,18 +548,9 @@ def get_engine(self): return rv -def _should_set_tablename(bases, d): - """Check what values are set by a class and its bases to determine if a - tablename should be automatically generated. - - The class and its bases are checked in order of precedence: the class - itself then each base in the order they were given at class definition. - - Abstract classes do not generate a tablename, although they may have set - or inherited a tablename elsewhere. - - If a class defines a tablename or table, a new one will not be generated. - Otherwise, if the class defines a primary key, a new name will be generated. +def _should_set_tablename(cls): + """Traverse the model's MRO. If a primary key column is found before a + table or tablename, then a new tablename should be generated. This supports: @@ -566,44 +558,49 @@ def _should_set_tablename(bases, d): * Single table inheritance. * Inheriting from mixins or abstract models. - :param bases: base classes of new class - :param d: new class dict + :param cls: model to check :return: True if tablename should be set """ - if '__tablename__' in d or '__table__' in d or '__abstract__' in d: - return False - - if any(v.primary_key for v in itervalues(d) if isinstance(v, sqlalchemy.Column)): - return True + for base in cls.__mro__: + d = base.__dict__ - for base in bases: - if hasattr(base, '__tablename__') or hasattr(base, '__table__'): + if '__tablename__' in d or '__table__' in d: return False - for name in dir(base): - attr = getattr(base, name) + for name, obj in iteritems(d): + if isinstance(obj, declared_attr): + obj = getattr(cls, name) - if isinstance(attr, sqlalchemy.Column) and attr.primary_key: + if isinstance(obj, sqlalchemy.Column) and obj.primary_key: return True -class _BoundDeclarativeMeta(DeclarativeMeta): +def camel_to_snake_case(name): + def _join(match): + word = match.group() + + if len(word) > 1: + return ('_%s_%s' % (word[:-1], word[-1])).lower() + + return '_' + word.lower() + return _camelcase_re.sub(_join, name).lstrip('_') + + +class _BoundDeclarativeMeta(DeclarativeMeta): def __new__(cls, name, bases, d): - if _should_set_tablename(bases, d): - def _join(match): - word = match.group() - if len(word) > 1: - return ('_%s_%s' % (word[:-1], word[-1])).lower() - return '_' + word.lower() - d['__tablename__'] = _camelcase_re.sub(_join, name).lstrip('_') + # if tablename is set explicitly, move it to the cache attribute so + # that future subclasses still have auto behavior + if '__tablename__' in d: + d['_cached_tablename'] = d.pop('__tablename__') return DeclarativeMeta.__new__(cls, name, bases, d) def __init__(self, name, bases, d): bind_key = d.pop('__bind_key__', None) or getattr(self, '__bind_key__', None) DeclarativeMeta.__init__(self, name, bases, d) + if bind_key is not None and hasattr(self, '__table__'): self.__table__.info['bind_key'] = bind_key @@ -639,6 +636,18 @@ class Model(object): #: Equivalent to ``db.session.query(Model)`` unless :attr:`query_class` has been changed. query = None + _cached_tablename = None + + @declared_attr + def __tablename__(cls): + if ( + '_cached_tablename' not in cls.__dict__ and + _should_set_tablename(cls) + ): + cls._cached_tablename = camel_to_snake_case(cls.__name__) + + return cls._cached_tablename + class SQLAlchemy(object): """This class is used to control the SQLAlchemy integration to one diff --git a/test_sqlalchemy.py b/test_sqlalchemy.py index 24a0072d..ffcd079a 100755 --- a/test_sqlalchemy.py +++ b/test_sqlalchemy.py @@ -300,7 +300,7 @@ class Base(db.Model): class Duck(Base): pass - self.assertFalse(hasattr(Base, '__tablename__')) + self.assertEqual(Base.__tablename__, 'base') self.assertEqual(Duck.__tablename__, 'duck') def test_complex_inheritance(self): @@ -322,6 +322,53 @@ class RubberDuck(IdMixin, Duck): self.assertEqual(RubberDuck.__tablename__, 'rubber_duck') + def test_manual_name(self): + app = flask.Flask(__name__) + db = fsa.SQLAlchemy(app) + + class Duck(db.Model): + __tablename__ = 'DUCK' + id = db.Column(db.Integer, primary_key=True) + + class Daffy(Duck): + id = db.Column(db.Integer, db.ForeignKey(Duck.id), primary_key=True) + + self.assertEqual(Duck.__tablename__, 'DUCK') + self.assertEqual(Daffy.__tablename__, 'daffy') + + def test_no_access_to_class_property(self): + app = flask.Flask(__name__) + db = fsa.SQLAlchemy(app) + + class class_property(object): + def __init__(self, f): + self.f = f + + def __get__(self, instance, owner): + return self.f(owner) + + class Duck(db.Model): + id = db.Column(db.Integer, primary_key=True) + + class ns(object): + accessed = False + + # Since there's no id provided by the following model, + # _should_set_tablename will scan all attributes. If it's working + # properly, it won't access the class property, but will access the + # declared_attr. + + class Witch(Duck): + @declared_attr + def is_duck(self): + ns.accessed = True + + @class_property + def floats(self): + assert False + + self.assertTrue(ns.accessed) + class PaginationTestCase(unittest.TestCase): def test_basic_pagination(self): @@ -486,13 +533,9 @@ def test_custom_query_class(self): class CustomQueryClass(fsa.BaseQuery): pass - class MyModelClass(object): - pass - app = flask.Flask(__name__) app.config['TESTING'] = True - db = fsa.SQLAlchemy(app, query_class=CustomQueryClass, - model_class=MyModelClass) + db = fsa.SQLAlchemy(app, query_class=CustomQueryClass) class Parent(db.Model): id = db.Column(db.Integer, primary_key=True)