Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ Patches and Suggestions
- Steven Harms
- David Lord @davidism
- Alec Nikolas Reiter @justanr
- Barak Alon
57 changes: 38 additions & 19 deletions flask_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from flask.signals import Namespace
from operator import itemgetter
from threading import Lock
from sqlalchemy import orm, event, inspect
from sqlalchemy import orm, event
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.orm.session import Session as SessionBase
from sqlalchemy.engine.url import make_url
Expand Down Expand Up @@ -178,67 +178,86 @@ class _SessionSignalEvents(object):
@classmethod
def register(cls, session):
if not hasattr(session, '_model_changes'):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be _model_changes_stack here?

session._model_changes = {}
session._model_changes_stack = [{}]

event.listen(session, 'after_transaction_create', cls.after_transaction_create)
event.listen(session, 'before_flush', cls.record_ops)
event.listen(session, 'before_commit', cls.record_ops)
event.listen(session, 'before_commit', cls.before_commit)
event.listen(session, 'after_commit', cls.after_commit)
event.listen(session, 'after_rollback', cls.after_rollback)
event.listen(session, 'after_soft_rollback', cls.after_rollback)

@classmethod
def unregister(cls, session):
if hasattr(session, '_model_changes'):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

del session._model_changes
del session._model_changes_stack

event.remove(session, 'after_transaction_create', cls.after_transaction_create)
event.remove(session, 'before_flush', cls.record_ops)
event.remove(session, 'before_commit', cls.record_ops)
event.remove(session, 'before_commit', cls.before_commit)
event.remove(session, 'after_commit', cls.after_commit)
event.remove(session, 'after_rollback', cls.after_rollback)
event.remove(session, 'after_soft_rollback', cls.after_rollback)

@staticmethod
def after_transaction_create(session, transaction):
try:
stack = session._model_changes_stack
except AttributeError:
return

if transaction.nested:
stack.append({})

@staticmethod
def record_ops(session, flush_context=None, instances=None):
try:
d = session._model_changes
stack = session._model_changes_stack
except AttributeError:
return

for targets, operation in ((session.new, 'insert'), (session.dirty, 'update'), (session.deleted, 'delete')):
for target in targets:
state = inspect(target)
key = state.identity_key if state.has_identity else id(target)
d[key] = (target, operation)
key = id(target)
stack[-1][key] = (target, operation)

@staticmethod
def before_commit(session):
try:
d = session._model_changes
stack = session._model_changes_stack
except AttributeError:
return

if d:
before_models_committed.send(session.app, changes=list(d.values()))
if not session.transaction.nested and stack[0]:
before_models_committed.send(session.app, changes=list(stack[0].values()))

@staticmethod
def after_commit(session):
try:
d = session._model_changes
stack = session._model_changes_stack
except AttributeError:
return

if d:
models_committed.send(session.app, changes=list(d.values()))
d.clear()
if not session.transaction.nested and stack[0]:
models_committed.send(session.app, changes=list(stack[0].values()))
stack[0].clear()
elif session.transaction.nested:
nested_changes = stack.pop()
for key, value in nested_changes.items():
# Don't overwrite keys that are lower in the stack
stack[-1].setdefault(key, value)

@staticmethod
def after_rollback(session):
def after_rollback(session, previous_transaction):
try:
d = session._model_changes
stack = session._model_changes_stack
except AttributeError:
return

d.clear()
if previous_transaction.nested:
stack.pop()
else:
stack[0].clear()


class _EngineDebuggingSignalEvents(object):
Expand Down
87 changes: 87 additions & 0 deletions test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,24 @@
from sqlalchemy.orm import sessionmaker


def fix_pysqlite(engine):
"""This ugly mess is a known issue with pysqlite and how it does
Serializable isolation / Savepoints / Transactional DDL:
http://docs.sqlalchemy.org/en/rel_1_0/dialects/sqlite.html#pysqlite-serializable
"""

@event.listens_for(engine, "connect")
def do_connect(dbapi_connection, connection_record):
# disable pysqlite's emitting of the BEGIN statement entirely.
# also stops it from emitting COMMIT before any DDL.
dbapi_connection.isolation_level = None

@event.listens_for(engine, "begin")
def do_begin(conn):
# emit our own BEGIN
conn.execute("BEGIN")


def make_todo_model(db):
class Todo(db.Model):
__tablename__ = 'todos'
Expand Down Expand Up @@ -181,6 +199,7 @@ def setUp(self):
app.config['SQLALCHEMY_ENGINE'] = 'sqlite://'
app.config['TESTING'] = True
self.db = sqlalchemy.SQLAlchemy(app)
fix_pysqlite(self.db.engine)
self.Todo = make_todo_model(self.db)
self.db.create_all()

Expand Down Expand Up @@ -227,6 +246,74 @@ def committed(sender, changes):
self.assertEqual(recorded[0][0], todo)
self.assertEqual(recorded[0][1], 'delete')

def test_model_signals_in_nested_session(self):
"""Nested sessions should not send a models_committed signal when
committing"""

recorded = []
def committed(sender, changes):
recorded.extend(changes)

with sqlalchemy.models_committed.connected_to(committed,
sender=self.app):
todo = self.Todo('Awesome', 'the text')
self.db.session.add(todo)
self.assertEqual(len(recorded), 0)
self.db.session.begin_nested()
todo.done = True
self.db.session.commit()
self.assertEqual(len(recorded), 0)
self.db.session.commit()
self.assertEqual(len(recorded), 1)
self.assertEqual(recorded[0][0], todo)
self.assertEqual(recorded[0][1], 'insert')

def test_model_signals_in_nested_session_that_is_rolled_back(self):
"""Models_committed signal does not send on a change to a model that
is rolled back in a nested session"""

recorded = []
def committed(sender, changes):
recorded.extend(changes)

with sqlalchemy.models_committed.connected_to(committed,
sender=self.app):
todos = [
self.Todo('Awesome', 'I will persist!'),
self.Todo('Radical', 'I will rollback!')
]

self.db.session.add(todos[0])
self.db.session.begin_nested()
self.db.session.add(todos[1])
self.db.session.rollback()
self.db.session.commit()
self.assertEqual(len(recorded), 1)
self.assertEqual(recorded[0][0], todos[0])
self.assertEqual(recorded[0][1], 'insert')

def test_model_signals_merge_changes_in_nested_sessions(self):
"""When committing a nested session, model modifications are merged
to the parent session"""

recorded = []
def committed(sender, changes):
recorded.extend(changes)

with sqlalchemy.models_committed.connected_to(committed,
sender=self.app):
todo = self.Todo('Awesome', 'the text')
self.db.session.add(todo)
self.assertEqual(len(recorded), 0)
self.db.session.begin_nested()
todo.text = 'this is an UPDATE, but as a whole the transaction ' \
'will still be an INSERT'
self.db.session.commit()
self.db.session.commit()
self.assertEqual(len(recorded), 1)
self.assertEqual(recorded[0][0], todo)
self.assertEqual(recorded[0][1], 'insert')


class TablenameTestCase(unittest.TestCase):
def test_name(self):
Expand Down