Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 43 additions & 34 deletions flask_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
from flask.signals import Namespace
from sqlalchemy import event, inspect, orm
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base
from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base, \
declared_attr
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.orm.session import Session as SessionBase

from ._compat import itervalues, string_types, xrange
from ._compat import iteritems, itervalues, string_types, xrange

__version__ = '3.0-dev'

Expand Down Expand Up @@ -547,63 +548,59 @@ def get_engine(self):
return rv


def _should_set_tablename(bases, d):
"""Check what values are set by a class and its bases to determine if a
tablename should be automatically generated.

The class and its bases are checked in order of precedence: the class
itself then each base in the order they were given at class definition.

Abstract classes do not generate a tablename, although they may have set
or inherited a tablename elsewhere.

If a class defines a tablename or table, a new one will not be generated.
Otherwise, if the class defines a primary key, a new name will be generated.
def _should_set_tablename(cls):
"""Traverse the model's MRO. If a primary key column is found before a
table or tablename, then a new tablename should be generated.

This supports:

* Joined table inheritance without explicitly naming sub-models.
* Single table inheritance.
* Inheriting from mixins or abstract models.

:param bases: base classes of new class
:param d: new class dict
:param cls: model to check
:return: True if tablename should be set
"""

if '__tablename__' in d or '__table__' in d or '__abstract__' in d:
return False

if any(v.primary_key for v in itervalues(d) if isinstance(v, sqlalchemy.Column)):
return True
for base in cls.__mro__:
d = base.__dict__

for base in bases:
if hasattr(base, '__tablename__') or hasattr(base, '__table__'):
if '__tablename__' in d or '__table__' in d:
return False

for name in dir(base):
attr = getattr(base, name)
for name, obj in iteritems(d):
if isinstance(obj, declared_attr):
obj = getattr(cls, name)

if isinstance(attr, sqlalchemy.Column) and attr.primary_key:
if isinstance(obj, sqlalchemy.Column) and obj.primary_key:
return True


class _BoundDeclarativeMeta(DeclarativeMeta):
def camel_to_snake_case(name):
def _join(match):
word = match.group()

if len(word) > 1:
return ('_%s_%s' % (word[:-1], word[-1])).lower()

return '_' + word.lower()

return _camelcase_re.sub(_join, name).lstrip('_')


class _BoundDeclarativeMeta(DeclarativeMeta):
def __new__(cls, name, bases, d):
if _should_set_tablename(bases, d):
def _join(match):
word = match.group()
if len(word) > 1:
return ('_%s_%s' % (word[:-1], word[-1])).lower()
return '_' + word.lower()
d['__tablename__'] = _camelcase_re.sub(_join, name).lstrip('_')
# if tablename is set explicitly, move it to the cache attribute so
# that future subclasses still have auto behavior
if '__tablename__' in d:
d['_cached_tablename'] = d.pop('__tablename__')

return DeclarativeMeta.__new__(cls, name, bases, d)

def __init__(self, name, bases, d):
bind_key = d.pop('__bind_key__', None) or getattr(self, '__bind_key__', None)
DeclarativeMeta.__init__(self, name, bases, d)

if bind_key is not None and hasattr(self, '__table__'):
self.__table__.info['bind_key'] = bind_key

Expand Down Expand Up @@ -639,6 +636,18 @@ class Model(object):
#: Equivalent to ``db.session.query(Model)`` unless :attr:`query_class` has been changed.
query = None

_cached_tablename = None

@declared_attr
def __tablename__(cls):
if (
'_cached_tablename' not in cls.__dict__ and
_should_set_tablename(cls)
):
cls._cached_tablename = camel_to_snake_case(cls.__name__)

return cls._cached_tablename


class SQLAlchemy(object):
"""This class is used to control the SQLAlchemy integration to one
Expand Down
55 changes: 49 additions & 6 deletions test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ class Base(db.Model):
class Duck(Base):
pass

self.assertFalse(hasattr(Base, '__tablename__'))
self.assertEqual(Base.__tablename__, 'base')
self.assertEqual(Duck.__tablename__, 'duck')

def test_complex_inheritance(self):
Expand All @@ -322,6 +322,53 @@ class RubberDuck(IdMixin, Duck):

self.assertEqual(RubberDuck.__tablename__, 'rubber_duck')

def test_manual_name(self):
app = flask.Flask(__name__)
db = fsa.SQLAlchemy(app)

class Duck(db.Model):
__tablename__ = 'DUCK'
id = db.Column(db.Integer, primary_key=True)

class Daffy(Duck):
id = db.Column(db.Integer, db.ForeignKey(Duck.id), primary_key=True)

self.assertEqual(Duck.__tablename__, 'DUCK')
self.assertEqual(Daffy.__tablename__, 'daffy')

def test_no_access_to_class_property(self):
app = flask.Flask(__name__)
db = fsa.SQLAlchemy(app)

class class_property(object):
def __init__(self, f):
self.f = f

def __get__(self, instance, owner):
return self.f(owner)

class Duck(db.Model):
id = db.Column(db.Integer, primary_key=True)

class ns(object):
accessed = False

# Since there's no id provided by the following model,
# _should_set_tablename will scan all attributes. If it's working
# properly, it won't access the class property, but will access the
# declared_attr.

class Witch(Duck):
@declared_attr
def is_duck(self):
ns.accessed = True

@class_property
def floats(self):
assert False

self.assertTrue(ns.accessed)


class PaginationTestCase(unittest.TestCase):
def test_basic_pagination(self):
Expand Down Expand Up @@ -486,13 +533,9 @@ def test_custom_query_class(self):
class CustomQueryClass(fsa.BaseQuery):
pass

class MyModelClass(object):
pass

app = flask.Flask(__name__)
app.config['TESTING'] = True
db = fsa.SQLAlchemy(app, query_class=CustomQueryClass,
model_class=MyModelClass)
db = fsa.SQLAlchemy(app, query_class=CustomQueryClass)

class Parent(db.Model):
id = db.Column(db.Integer, primary_key=True)
Expand Down