diff --git a/spanner/google/cloud/spanner/session.py b/spanner/google/cloud/spanner/session.py index 45baffa92d43..f25abdd6261a 100644 --- a/spanner/google/cloud/spanner/session.py +++ b/spanner/google/cloud/spanner/session.py @@ -302,7 +302,6 @@ def run_in_transaction(self, func, *args, **kw): continue except Exception: txn.rollback() - del self._transaction raise try: @@ -312,7 +311,6 @@ def run_in_transaction(self, func, *args, **kw): del self._transaction else: committed = txn.committed - del self._transaction return committed diff --git a/spanner/google/cloud/spanner/transaction.py b/spanner/google/cloud/spanner/transaction.py index af2140896830..7c0272d41132 100644 --- a/spanner/google/cloud/spanner/transaction.py +++ b/spanner/google/cloud/spanner/transaction.py @@ -93,6 +93,7 @@ def rollback(self): options = _options_with_prefix(database.name) api.rollback(self._session.name, self._id, options=options) self._rolled_back = True + del self._session._transaction def commit(self): """Commit mutations to the database. @@ -114,6 +115,7 @@ def commit(self): transaction_id=self._id, options=options) self.committed = _pb_timestamp_to_datetime( response.commit_timestamp) + del self._session._transaction return self.committed def __enter__(self): diff --git a/spanner/tests/unit/test_transaction.py b/spanner/tests/unit/test_transaction.py index 997f4d5153c8..973aeedb179d 100644 --- a/spanner/tests/unit/test_transaction.py +++ b/spanner/tests/unit/test_transaction.py @@ -42,8 +42,10 @@ def _getTargetClass(self): return Transaction - def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) + def _make_one(self, session, *args, **kwargs): + transaction = self._getTargetClass()(session, *args, **kwargs) + session._transaction = transaction + return transaction def test_ctor_defaults(self): session = _Session() @@ -208,6 +210,7 @@ def test_rollback_ok(self): transaction.rollback() self.assertTrue(transaction._rolled_back) + self.assertIsNone(session._transaction) session_id, txn_id, options = api._rolled_back self.assertEqual(session_id, session.name) @@ -290,6 +293,7 @@ def test_commit_ok(self): transaction.commit() self.assertEqual(transaction.committed, now) + self.assertIsNone(session._transaction) session_id, mutations, txn_id, options = api._committed self.assertEqual(session_id, session.name) @@ -368,6 +372,8 @@ class _Database(object): class _Session(object): + _transaction = None + def __init__(self, database=None, name=TestTransaction.SESSION_NAME): self._database = database self.name = name