Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions flask_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import functools
import sqlalchemy
from math import ceil
from functools import partial
from flask import _request_ctx_stack, abort
from flask.signals import Namespace
from operator import itemgetter
Expand Down Expand Up @@ -138,26 +137,33 @@ class SignallingSession(SessionBase):
uses. It extends the default session system with bind selection and
modification tracking.

If you want to use a different session you can override the
:meth:`SQLAlchemy.create_session` function.
If you want to use a different session class you can override the
:meth:`SQLAlchemy.create_session_class` function.

.. versionadded:: 2.0
"""

def __init__(self, db, autocommit=False, autoflush=True, **options):
_db = None
_options = {}

def __init__(self, db=None, bind=None, **options):
if db is None:
db = self._db
#: The application that this session belongs to.
self.app = db.get_app()
self._model_changes = {}
options.update(self._options)
options.setdefault('autocommit', False)
options.setdefault('autoflush', False)
options.setdefault('binds', db.get_binds(self.app))
#: A flag that controls weather this session should keep track of
#: model modifications. The default value for this attribute
#: is set from the ``SQLALCHEMY_TRACK_MODIFICATIONS`` config
#: key.
self.emit_modification_signals = \
self.app.config['SQLALCHEMY_TRACK_MODIFICATIONS']
bind = options.pop('bind', None) or db.engine
SessionBase.__init__(self, autocommit=autocommit, autoflush=autoflush,
bind=bind,
binds=db.get_binds(self.app), **options)
SessionBase.__init__(self, bind=bind, **options)

def get_bind(self, mapper, clause=None):
# mapper is None if someone tries to just get a connection
Expand Down Expand Up @@ -687,21 +693,32 @@ def metadata(self):

def create_scoped_session(self, options=None):
"""Helper factory method that creates a scoped session. It
internally calls :meth:`create_session`.
internally calls :meth:`create_session_class`.
"""
if options is None:
options = {}

scopefunc = options.pop('scopefunc', None)
return orm.scoped_session(partial(self.create_session, options),
scopefunc=scopefunc)
session_class = self.create_session_class(options)

return orm.scoped_session(
orm.sessionmaker(class_=session_class),
scopefunc=scopefunc
)

def create_session(self, options):
"""Creates the session. The default implementation returns a
def create_session_class(self, options):
"""Creates the session class. The default implementation returns a
:class:`SignallingSession`.

.. versionadded:: 2.0
"""
return SignallingSession(self, **options)

return type(
"SignallingSession", (SignallingSession, ), {
'_db': self,
'_options': options,
}
)

def make_declarative_base(self):
"""Creates the declarative base."""
Expand Down
17 changes: 15 additions & 2 deletions test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from flask.ext import sqlalchemy
from sqlalchemy.orm import sessionmaker


def make_todo_model(db):
class Todo(db.Model):
__tablename__ = 'todos'
Expand Down Expand Up @@ -156,6 +155,21 @@ def committed(sender, changes):
self.assertEqual(recorded[0][0], todo)
self.assertEqual(recorded[0][1], 'delete')

def test_session_events(self):
app = flask.Flask(__name__)
app.config['SQLALCHEMY_ENGINE'] = 'sqlite://'
app.config['TESTING'] = True
db = sqlalchemy.SQLAlchemy(app)

from sqlalchemy.event import listens_for

seen = []
register = listens_for(db.session, 'after_commit')
register(seen.append)

db.session.commit()
self.assertEqual(seen, [db.session()])


class HelperTestCase(unittest.TestCase):

Expand Down Expand Up @@ -420,7 +434,6 @@ class FOOBar(db.Model):
assert fb not in db.session # because a new scope is generated on each call



class CommitOnTeardownTestCase(unittest.TestCase):

def setUp(self):
Expand Down