diff --git a/AUTHORS b/AUTHORS index 17e08dc7..ffbf2632 100644 --- a/AUTHORS +++ b/AUTHORS @@ -48,3 +48,4 @@ Patches and Suggestions - Steven Harms - David Lord @davidism - Alec Nikolas Reiter @justanr +- Barak Alon diff --git a/flask_sqlalchemy/__init__.py b/flask_sqlalchemy/__init__.py index f0fb1a10..9aca2650 100644 --- a/flask_sqlalchemy/__init__.py +++ b/flask_sqlalchemy/__init__.py @@ -22,7 +22,7 @@ from flask.signals import Namespace from operator import itemgetter from threading import Lock -from sqlalchemy import orm, event, inspect +from sqlalchemy import orm, event from sqlalchemy.orm.exc import UnmappedClassError from sqlalchemy.orm.session import Session as SessionBase from sqlalchemy.engine.url import make_url @@ -178,67 +178,86 @@ class _SessionSignalEvents(object): @classmethod def register(cls, session): if not hasattr(session, '_model_changes'): - session._model_changes = {} + session._model_changes_stack = [{}] + event.listen(session, 'after_transaction_create', cls.after_transaction_create) event.listen(session, 'before_flush', cls.record_ops) event.listen(session, 'before_commit', cls.record_ops) event.listen(session, 'before_commit', cls.before_commit) event.listen(session, 'after_commit', cls.after_commit) - event.listen(session, 'after_rollback', cls.after_rollback) + event.listen(session, 'after_soft_rollback', cls.after_rollback) @classmethod def unregister(cls, session): if hasattr(session, '_model_changes'): - del session._model_changes + del session._model_changes_stack + event.remove(session, 'after_transaction_create', cls.after_transaction_create) event.remove(session, 'before_flush', cls.record_ops) event.remove(session, 'before_commit', cls.record_ops) event.remove(session, 'before_commit', cls.before_commit) event.remove(session, 'after_commit', cls.after_commit) - event.remove(session, 'after_rollback', cls.after_rollback) + event.remove(session, 'after_soft_rollback', cls.after_rollback) + + @staticmethod + def after_transaction_create(session, transaction): + try: + stack = session._model_changes_stack + except AttributeError: + return + + if transaction.nested: + stack.append({}) @staticmethod def record_ops(session, flush_context=None, instances=None): try: - d = session._model_changes + stack = session._model_changes_stack except AttributeError: return for targets, operation in ((session.new, 'insert'), (session.dirty, 'update'), (session.deleted, 'delete')): for target in targets: - state = inspect(target) - key = state.identity_key if state.has_identity else id(target) - d[key] = (target, operation) + key = id(target) + stack[-1][key] = (target, operation) @staticmethod def before_commit(session): try: - d = session._model_changes + stack = session._model_changes_stack except AttributeError: return - if d: - before_models_committed.send(session.app, changes=list(d.values())) + if not session.transaction.nested and stack[0]: + before_models_committed.send(session.app, changes=list(stack[0].values())) @staticmethod def after_commit(session): try: - d = session._model_changes + stack = session._model_changes_stack except AttributeError: return - if d: - models_committed.send(session.app, changes=list(d.values())) - d.clear() + if not session.transaction.nested and stack[0]: + models_committed.send(session.app, changes=list(stack[0].values())) + stack[0].clear() + elif session.transaction.nested: + nested_changes = stack.pop() + for key, value in nested_changes.items(): + # Don't overwrite keys that are lower in the stack + stack[-1].setdefault(key, value) @staticmethod - def after_rollback(session): + def after_rollback(session, previous_transaction): try: - d = session._model_changes + stack = session._model_changes_stack except AttributeError: return - d.clear() + if previous_transaction.nested: + stack.pop() + else: + stack[0].clear() class _EngineDebuggingSignalEvents(object): diff --git a/test_sqlalchemy.py b/test_sqlalchemy.py index c2f9b6ad..a72551f6 100644 --- a/test_sqlalchemy.py +++ b/test_sqlalchemy.py @@ -10,6 +10,24 @@ from sqlalchemy.orm import sessionmaker +def fix_pysqlite(engine): + """This ugly mess is a known issue with pysqlite and how it does + Serializable isolation / Savepoints / Transactional DDL: + http://docs.sqlalchemy.org/en/rel_1_0/dialects/sqlite.html#pysqlite-serializable + """ + + @event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): + # disable pysqlite's emitting of the BEGIN statement entirely. + # also stops it from emitting COMMIT before any DDL. + dbapi_connection.isolation_level = None + + @event.listens_for(engine, "begin") + def do_begin(conn): + # emit our own BEGIN + conn.execute("BEGIN") + + def make_todo_model(db): class Todo(db.Model): __tablename__ = 'todos' @@ -181,6 +199,7 @@ def setUp(self): app.config['SQLALCHEMY_ENGINE'] = 'sqlite://' app.config['TESTING'] = True self.db = sqlalchemy.SQLAlchemy(app) + fix_pysqlite(self.db.engine) self.Todo = make_todo_model(self.db) self.db.create_all() @@ -227,6 +246,74 @@ def committed(sender, changes): self.assertEqual(recorded[0][0], todo) self.assertEqual(recorded[0][1], 'delete') + def test_model_signals_in_nested_session(self): + """Nested sessions should not send a models_committed signal when + committing""" + + recorded = [] + def committed(sender, changes): + recorded.extend(changes) + + with sqlalchemy.models_committed.connected_to(committed, + sender=self.app): + todo = self.Todo('Awesome', 'the text') + self.db.session.add(todo) + self.assertEqual(len(recorded), 0) + self.db.session.begin_nested() + todo.done = True + self.db.session.commit() + self.assertEqual(len(recorded), 0) + self.db.session.commit() + self.assertEqual(len(recorded), 1) + self.assertEqual(recorded[0][0], todo) + self.assertEqual(recorded[0][1], 'insert') + + def test_model_signals_in_nested_session_that_is_rolled_back(self): + """Models_committed signal does not send on a change to a model that + is rolled back in a nested session""" + + recorded = [] + def committed(sender, changes): + recorded.extend(changes) + + with sqlalchemy.models_committed.connected_to(committed, + sender=self.app): + todos = [ + self.Todo('Awesome', 'I will persist!'), + self.Todo('Radical', 'I will rollback!') + ] + + self.db.session.add(todos[0]) + self.db.session.begin_nested() + self.db.session.add(todos[1]) + self.db.session.rollback() + self.db.session.commit() + self.assertEqual(len(recorded), 1) + self.assertEqual(recorded[0][0], todos[0]) + self.assertEqual(recorded[0][1], 'insert') + + def test_model_signals_merge_changes_in_nested_sessions(self): + """When committing a nested session, model modifications are merged + to the parent session""" + + recorded = [] + def committed(sender, changes): + recorded.extend(changes) + + with sqlalchemy.models_committed.connected_to(committed, + sender=self.app): + todo = self.Todo('Awesome', 'the text') + self.db.session.add(todo) + self.assertEqual(len(recorded), 0) + self.db.session.begin_nested() + todo.text = 'this is an UPDATE, but as a whole the transaction ' \ + 'will still be an INSERT' + self.db.session.commit() + self.db.session.commit() + self.assertEqual(len(recorded), 1) + self.assertEqual(recorded[0][0], todo) + self.assertEqual(recorded[0][1], 'insert') + class TablenameTestCase(unittest.TestCase): def test_name(self):