Skip to content

Commit

Permalink
Don't run ROLLBACK when the connection is closed.
Browse files Browse the repository at this point in the history
This can be caused when a query times out while running, for example,
and the connection is closed as a result (as opposed to cancelling the
query, since PR #570). In this case, we would rather not emit the
ROLLBACK (the connection is already closed, so the transaction is over
anyway), rather than raising an exception when trying to use a
connection which is already closed.

See issue #777.
  • Loading branch information
brianmaissy committed Dec 23, 2020
1 parent abe18d1 commit d576656
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 3 deletions.
8 changes: 8 additions & 0 deletions aiopg/sa/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ async def _commit_impl(self):
self._transaction = None

async def _rollback_impl(self):
if self._connection.closed:
self._transaction = None
return

cur = await self._get_cursor()
try:
await cur.execute('ROLLBACK')
Expand Down Expand Up @@ -253,6 +257,10 @@ async def _savepoint_impl(self, name=None):
cur.close()

async def _rollback_to_savepoint_impl(self, name, parent):
if self._connection.closed:
self._transaction = None
return

cur = await self._get_cursor()
try:
await cur.execute(f'ROLLBACK TO SAVEPOINT {name}')
Expand Down
8 changes: 5 additions & 3 deletions aiopg/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,15 @@ async def commit(self):

async def rollback(self):
self._check_commit_rollback()
await self._cur.execute(self._isolation.rollback())
if not self._cur.closed:
await self._cur.execute(self._isolation.rollback())
self._is_begin = False

async def rollback_savepoint(self):
self._check_release_rollback()
await self._cur.execute(
self._isolation.rollback_savepoint(self._unique_id))
if not self._cur.closed:
await self._cur.execute(
self._isolation.rollback_savepoint(self._unique_id))
self._unique_id = None

async def release_savepoint(self):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_sa_transaction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from unittest import mock

import pytest
Expand Down Expand Up @@ -411,3 +412,26 @@ async def test_transaction_mode(connect):
res1 = await conn.scalar(select([func.count()]).select_from(tbl))
assert 5 == res1
await tr8.commit()


async def test_timeout_in_transaction_context_manager(make_engine):
engine = await make_engine(timeout=1)
with pytest.raises(asyncio.TimeoutError):
async with engine.acquire() as connection:
async with connection.begin():
await connection.execute("SELECT pg_sleep(10)")

engine.terminate()
await engine.wait_closed()


async def test_timeout_in_nested_transaction_context_manager(make_engine):
engine = await make_engine(timeout=1)
with pytest.raises(asyncio.TimeoutError):
async with engine.acquire() as connection:
async with connection.begin():
async with connection.begin_nested():
await connection.execute("SELECT pg_sleep(10)")

engine.terminate()
await engine.wait_closed()
25 changes: 25 additions & 0 deletions tests/test_transaction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

import psycopg2
import pytest

Expand Down Expand Up @@ -181,3 +183,26 @@ async def test_transaction_point_oldstyle(engine):
(3, 'data')]

await tr.commit()


async def test_timeout_in_transaction_context_manager(make_engine):
engine = await make_engine(timeout=1)
with pytest.raises(asyncio.TimeoutError):
async with engine.acquire() as connection:
async with Transaction(connection, IsolationLevel.read_committed):
await connection.execute("SELECT pg_sleep(10)")

engine.terminate()
await engine.wait_closed()


async def test_timeout_in_savepoint_context_manager(make_engine):
engine = await make_engine(timeout=1)
with pytest.raises(asyncio.TimeoutError):
async with engine.acquire() as connection:
async with Transaction(connection, IsolationLevel.read_committed) as transaction:
async with transaction.point():
await connection.execute("SELECT pg_sleep(10)")

engine.terminate()
await engine.wait_closed()

0 comments on commit d576656

Please sign in to comment.