From 25824299f0cf174261bcd05de4d3e58e175e31d1 Mon Sep 17 00:00:00 2001 From: Malthe Borch Date: Mon, 12 Aug 2013 14:55:46 +0200 Subject: [PATCH] Use the 'sessionmaker' function as argument to 'scoped_session'. As a result, it's no longer possible to provide a session object, but rather a session class, as a means of using a custom implementation. This fixes issue #147. --- flask_sqlalchemy/__init__.py | 43 +++++++++++++++++++++++++----------- test_sqlalchemy.py | 17 ++++++++++++-- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/flask_sqlalchemy/__init__.py b/flask_sqlalchemy/__init__.py index 22941de6..0eaa294d 100644 --- a/flask_sqlalchemy/__init__.py +++ b/flask_sqlalchemy/__init__.py @@ -16,7 +16,6 @@ import functools import sqlalchemy from math import ceil -from functools import partial from flask import _request_ctx_stack, abort from flask.signals import Namespace from operator import itemgetter @@ -138,16 +137,25 @@ class SignallingSession(SessionBase): uses. It extends the default session system with bind selection and modification tracking. - If you want to use a different session you can override the - :meth:`SQLAlchemy.create_session` function. + If you want to use a different session class you can override the + :meth:`SQLAlchemy.create_session_class` function. .. versionadded:: 2.0 """ - def __init__(self, db, autocommit=False, autoflush=True, **options): + _db = None + _options = {} + + def __init__(self, db=None, bind=None, **options): + if db is None: + db = self._db #: The application that this session belongs to. self.app = db.get_app() self._model_changes = {} + options.update(self._options) + options.setdefault('autocommit', False) + options.setdefault('autoflush', False) + options.setdefault('binds', db.get_binds(self.app)) #: A flag that controls weather this session should keep track of #: model modifications. The default value for this attribute #: is set from the ``SQLALCHEMY_TRACK_MODIFICATIONS`` config @@ -155,9 +163,7 @@ def __init__(self, db, autocommit=False, autoflush=True, **options): self.emit_modification_signals = \ self.app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] bind = options.pop('bind', None) or db.engine - SessionBase.__init__(self, autocommit=autocommit, autoflush=autoflush, - bind=bind, - binds=db.get_binds(self.app), **options) + SessionBase.__init__(self, bind=bind, **options) def get_bind(self, mapper, clause=None): # mapper is None if someone tries to just get a connection @@ -687,21 +693,32 @@ def metadata(self): def create_scoped_session(self, options=None): """Helper factory method that creates a scoped session. It - internally calls :meth:`create_session`. + internally calls :meth:`create_session_class`. """ if options is None: options = {} + scopefunc = options.pop('scopefunc', None) - return orm.scoped_session(partial(self.create_session, options), - scopefunc=scopefunc) + session_class = self.create_session_class(options) + + return orm.scoped_session( + orm.sessionmaker(class_=session_class), + scopefunc=scopefunc + ) - def create_session(self, options): - """Creates the session. The default implementation returns a + def create_session_class(self, options): + """Creates the session class. The default implementation returns a :class:`SignallingSession`. .. versionadded:: 2.0 """ - return SignallingSession(self, **options) + + return type( + "SignallingSession", (SignallingSession, ), { + '_db': self, + '_options': options, + } + ) def make_declarative_base(self): """Creates the declarative base.""" diff --git a/test_sqlalchemy.py b/test_sqlalchemy.py index 84807516..44eead57 100644 --- a/test_sqlalchemy.py +++ b/test_sqlalchemy.py @@ -8,7 +8,6 @@ from flask.ext import sqlalchemy from sqlalchemy.orm import sessionmaker - def make_todo_model(db): class Todo(db.Model): __tablename__ = 'todos' @@ -156,6 +155,21 @@ def committed(sender, changes): self.assertEqual(recorded[0][0], todo) self.assertEqual(recorded[0][1], 'delete') + def test_session_events(self): + app = flask.Flask(__name__) + app.config['SQLALCHEMY_ENGINE'] = 'sqlite://' + app.config['TESTING'] = True + db = sqlalchemy.SQLAlchemy(app) + + from sqlalchemy.event import listens_for + + seen = [] + register = listens_for(db.session, 'after_commit') + register(seen.append) + + db.session.commit() + self.assertEqual(seen, [db.session()]) + class HelperTestCase(unittest.TestCase): @@ -420,7 +434,6 @@ class FOOBar(db.Model): assert fb not in db.session # because a new scope is generated on each call - class CommitOnTeardownTestCase(unittest.TestCase): def setUp(self):