diff --git a/aiopg/sa/connection.py b/aiopg/sa/connection.py index 917ddbf7..cb7f691b 100644 --- a/aiopg/sa/connection.py +++ b/aiopg/sa/connection.py @@ -211,6 +211,10 @@ async def _commit_impl(self): self._transaction = None async def _rollback_impl(self): + if self._connection.closed: + self._transaction = None + return + cur = await self._get_cursor() try: await cur.execute('ROLLBACK') @@ -253,6 +257,10 @@ async def _savepoint_impl(self, name=None): cur.close() async def _rollback_to_savepoint_impl(self, name, parent): + if self._connection.closed: + self._transaction = None + return + cur = await self._get_cursor() try: await cur.execute(f'ROLLBACK TO SAVEPOINT {name}') diff --git a/aiopg/transaction.py b/aiopg/transaction.py index 48bb87dc..830ab92e 100644 --- a/aiopg/transaction.py +++ b/aiopg/transaction.py @@ -127,13 +127,15 @@ async def commit(self): async def rollback(self): self._check_commit_rollback() - await self._cur.execute(self._isolation.rollback()) + if not self._cur.closed: + await self._cur.execute(self._isolation.rollback()) self._is_begin = False async def rollback_savepoint(self): self._check_release_rollback() - await self._cur.execute( - self._isolation.rollback_savepoint(self._unique_id)) + if not self._cur.closed: + await self._cur.execute( + self._isolation.rollback_savepoint(self._unique_id)) self._unique_id = None async def release_savepoint(self): diff --git a/tests/test_sa_transaction.py b/tests/test_sa_transaction.py index d40afd68..f5218a7d 100644 --- a/tests/test_sa_transaction.py +++ b/tests/test_sa_transaction.py @@ -1,3 +1,4 @@ +import asyncio from unittest import mock import pytest @@ -411,3 +412,26 @@ async def test_transaction_mode(connect): res1 = await conn.scalar(select([func.count()]).select_from(tbl)) assert 5 == res1 await tr8.commit() + + +async def test_timeout_in_transaction_context_manager(make_engine): + engine = await make_engine(timeout=1) + with pytest.raises(asyncio.TimeoutError): + async with engine.acquire() as connection: + async with connection.begin(): + await connection.execute("SELECT pg_sleep(10)") + + engine.terminate() + await engine.wait_closed() + + +async def test_timeout_in_nested_transaction_context_manager(make_engine): + engine = await make_engine(timeout=1) + with pytest.raises(asyncio.TimeoutError): + async with engine.acquire() as connection: + async with connection.begin(): + async with connection.begin_nested(): + await connection.execute("SELECT pg_sleep(10)") + + engine.terminate() + await engine.wait_closed() diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 392ead07..495e9cc4 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -1,3 +1,5 @@ +import asyncio + import psycopg2 import pytest @@ -181,3 +183,26 @@ async def test_transaction_point_oldstyle(engine): (3, 'data')] await tr.commit() + + +async def test_timeout_in_transaction_context_manager(make_engine): + engine = await make_engine(timeout=1) + with pytest.raises(asyncio.TimeoutError): + async with engine.acquire() as connection: + async with Transaction(connection, IsolationLevel.read_committed): + await connection.execute("SELECT pg_sleep(10)") + + engine.terminate() + await engine.wait_closed() + + +async def test_timeout_in_savepoint_context_manager(make_engine): + engine = await make_engine(timeout=1) + with pytest.raises(asyncio.TimeoutError): + async with engine.acquire() as connection: + async with Transaction(connection, IsolationLevel.read_committed) as transaction: + async with transaction.point(): + await connection.execute("SELECT pg_sleep(10)") + + engine.terminate() + await engine.wait_closed()