diff --git a/spanner/tests/_fixtures.py b/spanner/tests/_fixtures.py index 1123d03c3f2d..ace9b981b6ec 100644 --- a/spanner/tests/_fixtures.py +++ b/spanner/tests/_fixtures.py @@ -38,6 +38,10 @@ description STRING(16), exactly_hwhen TIMESTAMP) PRIMARY KEY (eye_d); +CREATE TABLE counters ( + name STRING(1024), + value INT64 ) + PRIMARY KEY (name); """ DDL_STATEMENTS = [stmt.strip() for stmt in DDL.split(';') if stmt.strip()] diff --git a/spanner/tests/system/test_system.py b/spanner/tests/system/test_system.py index b4ac62194bb1..e6d73f977e94 100644 --- a/spanner/tests/system/test_system.py +++ b/spanner/tests/system/test_system.py @@ -17,6 +17,7 @@ import operator import os import struct +import threading import unittest from google.cloud.proto.spanner.v1.type_pb2 import ARRAY @@ -358,6 +359,11 @@ class TestSessionAPI(unittest.TestCase, _TestData): 'description', 'exactly_hwhen', ) + COUNTERS_TABLE = 'counters' + COUNTERS_COLUMNS = ( + 'name', + 'value', + ) SOME_DATE = datetime.date(2011, 1, 17) SOME_TIME = datetime.datetime(1989, 1, 17, 17, 59, 12, 345612) NANO_TIME = TimestampWithNanoseconds(1995, 8, 31, nanosecond=987654321) @@ -482,6 +488,31 @@ def test_transaction_read_and_insert_then_rollback(self): rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) self.assertEqual(rows, []) + def _transaction_read_then_raise(self, transaction): + rows = list(transaction.read(self.TABLE, self.COLUMNS, self.ALL)) + self.assertEqual(len(rows), 0) + transaction.insert(self.TABLE, self.COLUMNS, self.ROW_DATA) + raise CustomException() + + @RetryErrors(exception=GrpcRendezvous) + def test_transaction_read_and_insert_then_execption(self): + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + session = self._db.session() + session.create() + self.to_delete.append(session) + + with session.batch() as batch: + batch.delete(self.TABLE, self.ALL) + + with self.assertRaises(CustomException): + session.run_in_transaction(self._transaction_read_then_raise) + + # Transaction was rolled back. + rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) + self.assertEqual(rows, []) + @RetryErrors(exception=GrpcRendezvous) def test_transaction_read_and_insert_or_update_then_commit(self): retry = RetryInstanceState(_has_all_ddl) @@ -508,6 +539,87 @@ def test_transaction_read_and_insert_or_update_then_commit(self): rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL)) self._check_row_data(rows) + def _transaction_concurrency_helper(self, unit_of_work, pkey): + INITIAL_VALUE = 123 + NUM_THREADS = 3 # conforms to equivalent Java systest. + + retry = RetryInstanceState(_has_all_ddl) + retry(self._db.reload)() + + session = self._db.session() + session.create() + self.to_delete.append(session) + + with session.batch() as batch: + batch.insert_or_update( + self.COUNTERS_TABLE, + self.COUNTERS_COLUMNS, + [[pkey, INITIAL_VALUE]]) + + # We don't want to run the threads' transactions in the current + # session, which would fail. + txn_sessions = [] + + for _ in range(NUM_THREADS): + txn_session = self._db.session() + txn_sessions.append(txn_session) + txn_session.create() + self.to_delete.append(txn_session) + + threads = [ + threading.Thread( + target=txn_session.run_in_transaction, + args=(unit_of_work, pkey)) + for txn_session in txn_sessions] + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + keyset = KeySet(keys=[(pkey,)]) + rows = list(session.read( + self.COUNTERS_TABLE, self.COUNTERS_COLUMNS, keyset)) + self.assertEqual(len(rows), 1) + _, value = rows[0] + self.assertEqual(value, INITIAL_VALUE + len(threads)) + + def _read_w_concurrent_update(self, transaction, pkey): + keyset = KeySet(keys=[(pkey,)]) + rows = list(transaction.read( + self.COUNTERS_TABLE, self.COUNTERS_COLUMNS, keyset)) + self.assertEqual(len(rows), 1) + pkey, value = rows[0] + transaction.update( + self.COUNTERS_TABLE, + self.COUNTERS_COLUMNS, + [[pkey, value + 1]]) + + def test_transaction_read_w_concurrent_updates(self): + PKEY = 'read_w_concurrent_updates' + self._transaction_concurrency_helper( + self._read_w_concurrent_update, PKEY) + + def _query_w_concurrent_update(self, transaction, pkey): + SQL = 'SELECT * FROM counters WHERE name = @name' + rows = list(transaction.execute_sql( + SQL, + params={'name': pkey}, + param_types={'name': Type(code=STRING)}, + )) + self.assertEqual(len(rows), 1) + pkey, value = rows[0] + transaction.update( + self.COUNTERS_TABLE, + self.COUNTERS_COLUMNS, + [[pkey, value + 1]]) + + def test_transaction_query_w_concurrent_updates(self): + PKEY = 'query_w_concurrent_updates' + self._transaction_concurrency_helper( + self._query_w_concurrent_update, PKEY) + @staticmethod def _row_data(max_index): for index in range(max_index): @@ -910,6 +1022,10 @@ def test_four_meg(self): self._verify_two_columns(FOUR_MEG) +class CustomException(Exception): + """Placeholder for any user-defined exception.""" + + class _DatabaseDropper(object): """Helper for cleaning up databases created on-the-fly."""