Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions flaskext/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,25 @@ class User(db.Model):
emulates `Table` behavior but is not a class. `db.Table` exposes the
`Table` interface, but is a function which allows omission of metadata.

To use a different base class for your models, create a subclass of
:class:`Model` and provide it as the `modelclass` keyword argument to this
function::

from flaskext.sqlalchemy import Model

class MyBaseModel(Model):
def print_hello(self):
print 'Hello'

app = Flask(__name__)
db = SQLAlchemy(app)

class User(db.Model):
name = db.Column(db.String(80))

user = User(name='John')
user.print_hello() # prints 'Hello'

You may also define your own SessionExtension instances as well when
defining your SQLAlchemy class instance. You may pass your custom instances
to the `session_extensions` keyword. This can be either a single
Expand Down Expand Up @@ -591,10 +610,16 @@ class User(db.Model):
.. versionadded:: 0.16
`scopefunc` is now accepted on `session_options`. It allows specifying
a custom function which will define the SQLAlchemy session's scoping.

.. versionadded:: 0.16
`modelclass` is used as the base class for the declarative base, to
allow for user-specified base classes.

"""

def __init__(self, app=None, use_native_unicode=True,
session_extensions=None, session_options=None):
session_extensions=None, session_options=None,
modelclass=Model):
self.use_native_unicode = use_native_unicode
self.session_extensions = to_list(session_extensions, []) + \
[_SignallingSessionExtension()]
Expand All @@ -607,7 +632,7 @@ def __init__(self, app=None, use_native_unicode=True,
)

self.session = self.create_scoped_session(session_options)
self.Model = self.make_declarative_base()
self.Model = self.make_declarative_base(modelclass)
self._engine_lock = Lock()

if app is not None:
Expand All @@ -633,9 +658,15 @@ def create_scoped_session(self, options=None):
partial(_SignallingSession, self, **options), scopefunc=scopefunc
)

def make_declarative_base(self):
"""Creates the declarative base."""
base = declarative_base(cls=Model, name='Model',
def make_declarative_base(self, modelclass=Model):
"""Creates the declarative base.

.. versionadded:: 0.16
`modelclass` is used as the base class for the declarative base, to
allow for user-specified base classes.

"""
base = declarative_base(cls=modelclass, name='Model',
mapper=signalling_mapper,
metaclass=_BoundDeclarativeMeta)
base.query = _QueryProperty(self)
Expand Down
39 changes: 39 additions & 0 deletions test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,44 @@ class FOOBar(db.Model):
assert fb not in db.session # because a new scope is generated on each call


class ModelBaseClassTestCase(unittest.TestCase):
"""Tests for providing a different model base class to the
:class:`flaskext.SQLAlchemy` object.

"""

def setUp(self):
# create a new base class for models
class MyBaseModel(sqlalchemy.Model):
myattribute = 'Test'
def return_hello(self):
return 'Hello'

self.app = flask.Flask(__name__)
self.app.config['SQLALCHEMY_ENGINE'] = 'sqlite://'
self.app.config['TESTING'] = True
self.db = sqlalchemy.SQLAlchemy(self.app, modelclass=MyBaseModel)
self.db.create_all()

def test_provided_base_class(self):
# create two different model subclasses which should inherit the func
class User(self.db.Model):
name = self.db.Column(self.db.String(10), primary_key=True)
class House(self.db.Model):
address = self.db.Column(self.db.String(10), primary_key=True)

assert hasattr(User, 'myattribute')
assert hasattr(House, 'myattribute')
assert User.myattribute == 'Test'
assert House.myattribute == 'Test'

# create instances of the model
user = User(name='John')
assert user.return_hello() == 'Hello'
house = House(address='foo')
assert house.return_hello() == 'Hello'


def suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(BasicAppTestCase))
Expand All @@ -426,6 +464,7 @@ def suite():
suite.addTest(unittest.makeSuite(SQLAlchemyIncludesTestCase))
suite.addTest(unittest.makeSuite(RegressionTestCase))
suite.addTest(unittest.makeSuite(SessionScopingTestCase))
suite.addTest(unittest.makeSuite(ModelBaseClassTestCase))
if flask.signals_available:
suite.addTest(unittest.makeSuite(SignallingTestCase))
return suite
Expand Down