From 9f3e79a29107f661c4fdc0489e203845ce45ce0c Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Mon, 9 Jun 2025 07:16:08 -0700 Subject: [PATCH 1/6] feat: Multiplexed sessions - Support multiplexed sessions for read/write transactions. Signed-off-by: Taylor Curran --- .../spanner_v1/database_sessions_manager.py | 12 +++------- tests/unit/test_database_session_manager.py | 23 ++++++++++++++++--- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/google/cloud/spanner_v1/database_sessions_manager.py b/google/cloud/spanner_v1/database_sessions_manager.py index 09f93cdcd6..6342c36ba8 100644 --- a/google/cloud/spanner_v1/database_sessions_manager.py +++ b/google/cloud/spanner_v1/database_sessions_manager.py @@ -86,16 +86,10 @@ def get_session(self, transaction_type: TransactionType) -> Session: :returns: a session for the given transaction type. """ - use_multiplexed = self._use_multiplexed(transaction_type) - - # TODO multiplexed: enable for read/write transactions - if use_multiplexed and transaction_type == TransactionType.READ_WRITE: - raise NotImplementedError( - f"Multiplexed sessions are not yet supported for {transaction_type} transactions." - ) - session = ( - self._get_multiplexed_session() if use_multiplexed else self._pool.get() + self._get_multiplexed_session() + if self._use_multiplexed(transaction_type) + else self._pool.get() ) add_span_event( diff --git a/tests/unit/test_database_session_manager.py b/tests/unit/test_database_session_manager.py index 7626bd0d60..9caec7d6b5 100644 --- a/tests/unit/test_database_session_manager.py +++ b/tests/unit/test_database_session_manager.py @@ -156,12 +156,29 @@ def test_read_write_pooled(self): manager.put_session(session) pool.put.assert_called_once_with(session) - # TODO multiplexed: implement support for read/write transactions. def test_read_write_multiplexed(self): + manager = self._manager + pool = manager._pool + self._enable_multiplexed_sessions() - with self.assertRaises(NotImplementedError): - self._manager.get_session(TransactionType.READ_WRITE) + # Session is created. + session_1 = manager.get_session(TransactionType.READ_WRITE) + self.assertTrue(session_1.is_multiplexed) + manager.put_session(session_1) + + # Session is re-used. + session_2 = manager.get_session(TransactionType.READ_WRITE) + self.assertEqual(session_1, session_2) + manager.put_session(session_2) + + # Verify that pool was not used. + pool.get.assert_not_called() + pool.put.assert_not_called() + + # Verify logger calls. + info = manager._database.logger.info + info.assert_called_once_with("Created multiplexed session.") def test_multiplexed_maintenance(self): manager = self._manager From bda6e558853b2ab64467beea4fb8503af34ae029 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Mon, 9 Jun 2025 09:01:20 -0700 Subject: [PATCH 2/6] feat: Multiplexed sessions - Remove `Session._transaction` attribute, since each session may not correspond to multiple transactions. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/session.py | 78 ++++++++++---------------- google/cloud/spanner_v1/transaction.py | 10 ---- tests/unit/test_session.py | 70 ++++++++++------------- tests/unit/test_transaction.py | 7 --- 4 files changed, 60 insertions(+), 105 deletions(-) diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 89f610d988..6aa919e65a 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -74,9 +74,6 @@ def __init__(self, database, labels=None, database_role=None, is_multiplexed=Fal self._database = database self._session_id: Optional[str] = None - # TODO multiplexed - remove - self._transaction: Optional[Transaction] = None - if labels is None: labels = {} @@ -467,23 +464,19 @@ def batch(self): return Batch(self) - def transaction(self): + # TODO multiplexed - deprecate + def transaction(self) -> Transaction: """Create a transaction to perform a set of reads with shared staleness. :rtype: :class:`~google.cloud.spanner_v1.transaction.Transaction` :returns: a transaction bound to this session + :raises ValueError: if the session has not yet been created. """ if self._session_id is None: raise ValueError("Session has not been created.") - # TODO multiplexed - remove - if self._transaction is not None: - self._transaction.rolled_back = True - self._transaction = None - - txn = self._transaction = Transaction(self) - return txn + return Transaction(self) def run_in_transaction(self, func, *args, **kw): """Perform a unit of work in a transaction, retrying on abort. @@ -528,42 +521,36 @@ def run_in_transaction(self, func, *args, **kw): ) isolation_level = kw.pop("isolation_level", None) - attempts = 0 + database = self._database + log_commit_stats = database.log_commit_stats - observability_options = getattr(self._database, "observability_options", None) with trace_call( "CloudSpanner.Session.run_in_transaction", self, - observability_options=observability_options, + observability_options=getattr(database, "observability_options", None), ) as span, MetricsCapture(): + attempts: int = 0 + while True: - # TODO multiplexed - remove - if self._transaction is None: - txn = self.transaction() - txn.transaction_tag = transaction_tag - txn.exclude_txn_from_change_streams = ( - exclude_txn_from_change_streams - ) - txn.isolation_level = isolation_level - else: - txn = self._transaction + # [A] Build transaction + # --------------------- - span_attributes = dict() + txn = self.transaction() + txn.transaction_tag = transaction_tag + txn.exclude_txn_from_change_streams = exclude_txn_from_change_streams + txn.isolation_level = isolation_level - try: - attempts += 1 - span_attributes["attempt"] = attempts - txn_id = getattr(txn, "_transaction_id", "") or "" - if txn_id: - span_attributes["transaction.id"] = txn_id + # [B] Run user operation + # ---------------------- + attempts += 1 + span_attributes = dict(attempt=attempts) + + try: return_value = func(txn, *args, **kw) # TODO multiplexed: store previous transaction ID. except Aborted as exc: - # TODO multiplexed - remove - self._transaction = None - if span: delay_seconds = _get_retry_delay( exc.errors[0], @@ -582,16 +569,15 @@ def run_in_transaction(self, func, *args, **kw): exc, deadline, attempts, default_retry_delay=default_retry_delay ) continue - except GoogleAPICallError: - # TODO multiplexed - remove - self._transaction = None + except GoogleAPICallError: add_span_event( span, "User operation failed due to GoogleAPICallError, not retrying", span_attributes, ) raise + except Exception: add_span_event( span, @@ -601,16 +587,17 @@ def run_in_transaction(self, func, *args, **kw): txn.rollback() raise + # [C] Commit transaction + # ---------------------- + try: txn.commit( - return_commit_stats=self._database.log_commit_stats, + return_commit_stats=log_commit_stats, request_options=commit_request_options, max_commit_delay=max_commit_delay, ) - except Aborted as exc: - # TODO multiplexed - remove - self._transaction = None + except Aborted as exc: if span: delay_seconds = _get_retry_delay( exc.errors[0], @@ -621,7 +608,7 @@ def run_in_transaction(self, func, *args, **kw): attributes.update(span_attributes) add_span_event( span, - "Transaction got aborted during commit, retrying afresh", + "Transaction was aborted during commit, retrying", attributes, ) @@ -629,9 +616,6 @@ def run_in_transaction(self, func, *args, **kw): exc, deadline, attempts, default_retry_delay=default_retry_delay ) except GoogleAPICallError: - # TODO multiplexed - remove - self._transaction = None - add_span_event( span, "Transaction.commit failed due to GoogleAPICallError, not retrying", @@ -639,8 +623,8 @@ def run_in_transaction(self, func, *args, **kw): ) raise else: - if self._database.log_commit_stats and txn.commit_stats: - self._database.logger.info( + if log_commit_stats and txn.commit_stats: + database.logger.info( "CommitStats: {}".format(txn.commit_stats), extra={"commit_stats": txn.commit_stats}, ) diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 8dfb0281e4..6a3959dac8 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -68,10 +68,6 @@ class Transaction(_SnapshotBase, _BatchBase): _read_only: bool = False def __init__(self, session): - # TODO multiplexed - remove - if session._transaction is not None: - raise ValueError("Session has existing transaction.") - super(Transaction, self).__init__(session) self.rolled_back: bool = False @@ -198,9 +194,6 @@ def wrapped_method(*args, **kwargs): self.rolled_back = True - # TODO multiplexed - remove - self._session._transaction = None - def commit( self, return_commit_stats=False, request_options=None, max_commit_delay=None ): @@ -339,9 +332,6 @@ def before_next_retry(nth_retry, delay_in_seconds): if return_commit_stats: self.commit_stats = commit_response_pb.commit_stats - # TODO multiplexed - remove - self._session._transaction = None - return self.committed @staticmethod diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 1052d21dcd..c459ab10df 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -957,18 +957,6 @@ def test_transaction_created(self): self.assertIsInstance(transaction, Transaction) self.assertIs(transaction._session, session) - self.assertIs(session._transaction, transaction) - - def test_transaction_w_existing_txn(self): - database = self._make_database() - session = self._make_one(database) - session._session_id = "DEADBEEF" - - existing = session.transaction() - another = session.transaction() # invalidates existing txn - - self.assertIs(session._transaction, another) - self.assertTrue(existing.rolled_back) def test_run_in_transaction_callback_raises_non_gax_error(self): TABLE_NAME = "citizens" @@ -1000,7 +988,6 @@ def unit_of_work(txn, *args, **kw): with self.assertRaises(Testing): session.run_in_transaction(unit_of_work) - self.assertIsNone(session._transaction) self.assertEqual(len(called_with), 1) txn, args, kw = called_with[0] self.assertIsInstance(txn, Transaction) @@ -1041,7 +1028,6 @@ def unit_of_work(txn, *args, **kw): with self.assertRaises(Cancelled): session.run_in_transaction(unit_of_work) - self.assertIsNone(session._transaction) self.assertEqual(len(called_with), 1) txn, args, kw = called_with[0] self.assertIsInstance(txn, Transaction) @@ -1081,7 +1067,6 @@ def unit_of_work(txn, *args, **kw): return_value = session.run_in_transaction(unit_of_work, "abc", some_arg="def") - self.assertIsNone(session._transaction) self.assertEqual(len(called_with), 1) txn, args, kw = called_with[0] self.assertIsInstance(txn, Transaction) @@ -1128,17 +1113,16 @@ def test_run_in_transaction_w_commit_error(self): ["phred@exammple.com", "Phred", "Phlyntstone", 32], ["bharney@example.com", "Bharney", "Rhubble", 31], ] - TRANSACTION_ID = b"FACEDACE" - gax_api = self._make_spanner_api() - gax_api.commit.side_effect = Unknown("error") database = self._make_database() - database.spanner_api = gax_api + + api = database.spanner_api = build_spanner_api() + begin_transaction = api.begin_transaction + commit = api.commit + + commit.side_effect = Unknown("error") + session = self._make_one(database) session._session_id = self.SESSION_ID - begun_txn = session._transaction = Transaction(session) - begun_txn._transaction_id = TRANSACTION_ID - - assert session._transaction._transaction_id called_with = [] @@ -1149,23 +1133,17 @@ def unit_of_work(txn, *args, **kw): with self.assertRaises(Unknown): session.run_in_transaction(unit_of_work) - self.assertIsNone(session._transaction) self.assertEqual(len(called_with), 1) txn, args, kw = called_with[0] - self.assertIs(txn, begun_txn) self.assertEqual(txn.committed, None) self.assertEqual(args, ()) self.assertEqual(kw, {}) - gax_api.begin_transaction.assert_not_called() - request = CommitRequest( - session=self.SESSION_NAME, - mutations=txn._mutations, - transaction_id=TRANSACTION_ID, - request_options=RequestOptions(), - ) - gax_api.commit.assert_called_once_with( - request=request, + begin_transaction.assert_called_once_with( + request=BeginTransactionRequest( + session=session.name, + options=TransactionOptions(read_write=TransactionOptions.ReadWrite()), + ), metadata=[ ("google-cloud-resource-prefix", database.name), ("x-goog-spanner-route-to-leader", "true"), @@ -1176,6 +1154,23 @@ def unit_of_work(txn, *args, **kw): ], ) + api.commit.assert_called_once_with( + request=CommitRequest( + session=session.name, + mutations=txn._mutations, + transaction_id=begin_transaction.return_value.id, + request_options=RequestOptions(), + ), + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.2.1", + ), + ], + ) + def test_run_in_transaction_w_abort_no_retry_metadata(self): TABLE_NAME = "citizens" COLUMNS = ["email", "first_name", "last_name", "age"] @@ -1733,7 +1728,6 @@ def unit_of_work(txn, *args, **kw): return_value = session.run_in_transaction(unit_of_work, "abc", some_arg="def") - self.assertIsNone(session._transaction) self.assertEqual(len(called_with), 1) txn, args, kw = called_with[0] self.assertIsInstance(txn, Transaction) @@ -1805,7 +1799,6 @@ def unit_of_work(txn, *args, **kw): with self.assertRaises(Unknown): session.run_in_transaction(unit_of_work, "abc", some_arg="def") - self.assertIsNone(session._transaction) self.assertEqual(len(called_with), 1) txn, args, kw = called_with[0] self.assertIsInstance(txn, Transaction) @@ -1879,7 +1872,6 @@ def unit_of_work(txn, *args, **kw): unit_of_work, "abc", some_arg="def", transaction_tag=transaction_tag ) - self.assertIsNone(session._transaction) self.assertEqual(len(called_with), 1) txn, args, kw = called_with[0] self.assertIsInstance(txn, Transaction) @@ -1951,7 +1943,6 @@ def unit_of_work(txn, *args, **kw): unit_of_work, "abc", exclude_txn_from_change_streams=True ) - self.assertIsNone(session._transaction) self.assertEqual(len(called_with), 1) txn, args, kw = called_with[0] self.assertIsInstance(txn, Transaction) @@ -2133,7 +2124,6 @@ def unit_of_work(txn, *args, **kw): unit_of_work, "abc", isolation_level="SERIALIZABLE" ) - self.assertIsNone(session._transaction) self.assertEqual(return_value, 42) expected_options = TransactionOptions( @@ -2170,7 +2160,6 @@ def unit_of_work(txn, *args, **kw): return_value = session.run_in_transaction(unit_of_work, "abc") - self.assertIsNone(session._transaction) self.assertEqual(return_value, 42) expected_options = TransactionOptions( @@ -2211,7 +2200,6 @@ def unit_of_work(txn, *args, **kw): isolation_level=TransactionOptions.IsolationLevel.REPEATABLE_READ, ) - self.assertIsNone(session._transaction) self.assertEqual(return_value, 42) expected_options = TransactionOptions( diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index d9448ef5ba..b668614a62 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -113,12 +113,6 @@ def _make_spanner_api(self): return mock.create_autospec(SpannerClient, instance=True) - def test_ctor_session_w_existing_txn(self): - session = _Session() - session._transaction = object() - with self.assertRaises(ValueError): - self._make_one(session) - def test_ctor_defaults(self): session = _Session() transaction = self._make_one(session) @@ -434,7 +428,6 @@ def _commit_helper( # Verify transaction state. self.assertEqual(transaction.committed, commit_timestamp) - self.assertIsNone(session._transaction) if return_commit_stats: self.assertEqual(transaction.commit_stats.mutation_count, 4) From 1f21bb962aa67af3c9508daa28b9a84cead5bfa6 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Mon, 9 Jun 2025 10:27:38 -0700 Subject: [PATCH 3/6] feat: Multiplexed sessions - Refactor logic for creating transaction selector to base class. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/snapshot.py | 91 ++--- google/cloud/spanner_v1/transaction.py | 41 ++- tests/_builders.py | 13 + tests/unit/test_snapshot.py | 455 +++++++++++-------------- tests/unit/test_transaction.py | 34 +- 5 files changed, 293 insertions(+), 341 deletions(-) diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index fa613bc572..7c35ac3897 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -93,7 +93,7 @@ def _restart_on_unavailable( item_buffer: List[PartialResultSet] = [] if transaction is not None: - transaction_selector = transaction._make_txn_selector() + transaction_selector = transaction._build_transaction_selector_pb() elif transaction_selector is None: raise InvalidArgument( "Either transaction or transaction_selector should be set" @@ -149,7 +149,7 @@ def _restart_on_unavailable( ) as span, MetricsCapture(): request.resume_token = resume_token if transaction is not None: - transaction_selector = transaction._make_txn_selector() + transaction_selector = transaction._build_transaction_selector_pb() request.transaction = transaction_selector attempt += 1 iterator = method( @@ -180,7 +180,7 @@ def _restart_on_unavailable( ) as span, MetricsCapture(): request.resume_token = resume_token if transaction is not None: - transaction_selector = transaction._make_txn_selector() + transaction_selector = transaction._build_transaction_selector_pb() attempt += 1 request.transaction = transaction_selector iterator = method( @@ -238,17 +238,6 @@ def __init__(self, session): # threads, so we need to use a lock when updating the transaction. self._lock: threading.Lock = threading.Lock() - def _make_txn_selector(self): - """Helper for :meth:`read` / :meth:`execute_sql`. - - Subclasses must override, returning an instance of - :class:`transaction_pb2.TransactionSelector` - appropriate for making ``read`` / ``execute_sql`` requests - - :raises: NotImplementedError, always - """ - raise NotImplementedError - def begin(self) -> bytes: """Begins a transaction on the database. @@ -732,7 +721,7 @@ def partition_read( metadata.append( _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) - transaction = self._make_txn_selector() + transaction = self._build_transaction_selector_pb() partition_options = PartitionOptions( partition_size_bytes=partition_size_bytes, max_partitions=max_partitions ) @@ -854,7 +843,7 @@ def partition_query( metadata.append( _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) - transaction = self._make_txn_selector() + transaction = self._build_transaction_selector_pb() partition_options = PartitionOptions( partition_size_bytes=partition_size_bytes, max_partitions=max_partitions ) @@ -944,7 +933,7 @@ def _begin_transaction(self, mutation: Mutation = None) -> bytes: def wrapped_method(): begin_transaction_request = BeginTransactionRequest( session=session.name, - options=self._make_txn_selector().begin, + options=self._build_transaction_selector_pb().begin, mutation_key=mutation, ) begin_transaction_method = functools.partial( @@ -983,6 +972,34 @@ def before_next_retry(nth_retry, delay_in_seconds): self._update_for_transaction_pb(transaction_pb) return self._transaction_id + def _build_transaction_options_pb(self) -> TransactionOptions: + """Builds and returns the transaction options for this snapshot. + + :rtype: :class:`transaction_pb2.TransactionOptions` + :returns: the transaction options for this snapshot. + """ + raise NotImplementedError + + def _build_transaction_selector_pb(self) -> TransactionSelector: + """Builds and returns a transaction selector for this snapshot. + + :rtype: :class:`transaction_pb2.TransactionSelector` + :returns: a transaction selector for this snapshot. + """ + + # Select a previously begun transaction. + if self._transaction_id is not None: + return TransactionSelector(id=self._transaction_id) + + options = self._build_transaction_options_pb() + + # Select a single-use transaction. + if not self._multi_use: + return TransactionSelector(single_use=options) + + # Select a new, multi-use transaction. + return TransactionSelector(begin=options) + def _update_for_result_set_pb( self, result_set_pb: Union[ResultSet, PartialResultSet] ) -> None: @@ -1101,38 +1118,28 @@ def __init__( self._multi_use = multi_use self._transaction_id = transaction_id - # TODO multiplexed - refactor to base class - def _make_txn_selector(self): - """Helper for :meth:`read`.""" - if self._transaction_id is not None: - return TransactionSelector(id=self._transaction_id) + def _build_transaction_options_pb(self) -> TransactionOptions: + """Builds and returns transaction options for this snapshot. + + :rtype: :class:`transaction_pb2.TransactionOptions` + :returns: transaction options for this snapshot. + """ + + read_only_pb_args = dict(return_read_timestamp=True) if self._read_timestamp: - key = "read_timestamp" - value = self._read_timestamp + read_only_pb_args["read_timestamp"] = self._read_timestamp elif self._min_read_timestamp: - key = "min_read_timestamp" - value = self._min_read_timestamp + read_only_pb_args["min_read_timestamp"] = self._min_read_timestamp elif self._max_staleness: - key = "max_staleness" - value = self._max_staleness + read_only_pb_args["max_staleness"] = self._max_staleness elif self._exact_staleness: - key = "exact_staleness" - value = self._exact_staleness + read_only_pb_args["exact_staleness"] = self._exact_staleness else: - key = "strong" - value = True - - options = TransactionOptions( - read_only=TransactionOptions.ReadOnly( - **{key: value, "return_read_timestamp": True} - ) - ) + read_only_pb_args["strong"] = True - if self._multi_use: - return TransactionSelector(begin=options) - else: - return TransactionSelector(single_use=options) + read_only_pb = TransactionOptions.ReadOnly(**read_only_pb_args) + return TransactionOptions(read_only=read_only_pb) def _update_for_transaction_pb(self, transaction_pb: Transaction) -> None: """Updates the snapshot for the given transaction. diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 6a3959dac8..779010fbf0 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -35,7 +35,6 @@ ) from google.cloud.spanner_v1 import ExecuteBatchDmlRequest from google.cloud.spanner_v1 import ExecuteSqlRequest -from google.cloud.spanner_v1 import TransactionSelector from google.cloud.spanner_v1 import TransactionOptions from google.cloud.spanner_v1._helpers import AtomicCounter from google.cloud.spanner_v1.snapshot import _SnapshotBase @@ -71,27 +70,27 @@ def __init__(self, session): super(Transaction, self).__init__(session) self.rolled_back: bool = False - def _make_txn_selector(self): - """Helper for :meth:`read`. + def _build_transaction_options_pb(self) -> TransactionOptions: + """Builds and returns transaction options for this transaction. - :rtype: :class:`~.transaction_pb2.TransactionSelector` - :returns: a selector configured for read-write transaction semantics. + :rtype: :class:`~.transaction_pb2.TransactionOptions` + :returns: transaction options for this transaction. """ - if self._transaction_id is None: - txn_options = TransactionOptions( - read_write=TransactionOptions.ReadWrite(), - exclude_txn_from_change_streams=self.exclude_txn_from_change_streams, - isolation_level=self.isolation_level, - ) + default_transaction_options = ( + self._session._database.default_transaction_options.default_read_write_transaction_options + ) - txn_options = _merge_Transaction_Options( - self._session._database.default_transaction_options.default_read_write_transaction_options, - txn_options, - ) - return TransactionSelector(begin=txn_options) - else: - return TransactionSelector(id=self._transaction_id) + merge_transaction_options = TransactionOptions( + read_write=TransactionOptions.ReadWrite(), + exclude_txn_from_change_streams=self.exclude_txn_from_change_streams, + isolation_level=self.isolation_level, + ) + + return _merge_Transaction_Options( + defaultTransactionOptions=default_transaction_options, + mergeTransactionOptions=merge_transaction_options, + ) def _execute_request( self, @@ -118,7 +117,7 @@ def _execute_request( raise ValueError("Transaction already rolled back.") session = self._session - transaction = self._make_txn_selector() + transaction = self._build_transaction_selector_pb() request.transaction = transaction with trace_call( @@ -469,7 +468,7 @@ def execute_update( execute_sql_request = ExecuteSqlRequest( session=session.name, - transaction=self._make_txn_selector(), + transaction=self._build_transaction_selector_pb(), sql=dml, params=params_pb, param_types=param_types, @@ -617,7 +616,7 @@ def batch_update( execute_batch_dml_request = ExecuteBatchDmlRequest( session=session.name, - transaction=self._make_txn_selector(), + transaction=self._build_transaction_selector_pb(), statements=parsed, seqno=seqno, request_options=request_options, diff --git a/tests/_builders.py b/tests/_builders.py index 1521219dea..c2733be6de 100644 --- a/tests/_builders.py +++ b/tests/_builders.py @@ -172,6 +172,19 @@ def build_session(**kwargs: Mapping) -> Session: return Session(**kwargs) +def build_snapshot(**kwargs): + """Builds and returns a snapshot for testing using the given arguments. + If a required argument is not provided, a default value will be used.""" + + session = kwargs.pop("session", build_session()) + + # Ensure session exists. + if session.session_id is None: + session._session_id = _SESSION_ID + + return session.snapshot(**kwargs) + + def build_transaction(session=None) -> Transaction: """Builds and returns a transaction for testing using the given arguments. If a required argument is not provided, a default value will be used.""" diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 54955f735a..e7cfce3761 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from datetime import timedelta, datetime +from threading import Lock from typing import Mapping from google.api_core import gapic_v1 @@ -22,6 +24,7 @@ RequestOptions, DirectedReadOptions, BeginTransactionRequest, + TransactionOptions, TransactionSelector, ) from google.cloud.spanner_v1.snapshot import _SnapshotBase @@ -30,6 +33,7 @@ build_spanner_api, build_session, build_transaction_pb, + build_snapshot, ) from tests._helpers import ( OpenTelemetryBase, @@ -64,6 +68,9 @@ TXN_ID = b"DEAFBEAD" SECONDS = 3 MICROS = 123456 +DURATION = timedelta(seconds=SECONDS, microseconds=MICROS) +TIMESTAMP = datetime.now() + BASE_ATTRIBUTES = { "db.type": "spanner", "db.url": "spanner.googleapis.com", @@ -105,41 +112,18 @@ ) -def _makeTimestamp(): - import datetime - from google.cloud._helpers import UTC - - return datetime.datetime.utcnow().replace(tzinfo=UTC) - +class _Derived(_SnapshotBase): + """A minimally-implemented _SnapshotBase-derived class for testing""" -class Test_restart_on_unavailable(OpenTelemetryBase): - def _getTargetClass(self): - from google.cloud.spanner_v1.snapshot import _SnapshotBase + # Use a simplified implementation of _build_transaction_options_pb + # that always returns the same transaction options. + TRANSACTION_OPTIONS = TransactionOptions() - return _SnapshotBase + def _build_transaction_options_pb(self) -> TransactionOptions: + return self.TRANSACTION_OPTIONS - def _makeDerived(self, session): - class _Derived(self._getTargetClass()): - _transaction_id = None - _multi_use = False - - def _make_txn_selector(self): - from google.cloud.spanner_v1 import ( - TransactionOptions, - TransactionSelector, - ) - - if self._transaction_id: - return TransactionSelector(id=self._transaction_id) - options = TransactionOptions( - read_only=TransactionOptions.ReadOnly(strong=True) - ) - if self._multi_use: - return TransactionSelector(begin=options) - return TransactionSelector(single_use=options) - - return _Derived(session) +class Test_restart_on_unavailable(OpenTelemetryBase): def build_spanner_api(self): from google.cloud.spanner_v1 import SpannerClient @@ -184,7 +168,7 @@ def test_iteration_w_empty_raw(self): database = _Database() database.spanner_api = build_spanner_api() session = _Session(database) - derived = self._makeDerived(session) + derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), []) restart.assert_called_once_with( @@ -206,7 +190,7 @@ def test_iteration_w_non_empty_raw(self): database = _Database() database.spanner_api = build_spanner_api() session = _Session(database) - derived = self._makeDerived(session) + derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(ITEMS)) restart.assert_called_once_with( @@ -220,7 +204,7 @@ def test_iteration_w_non_empty_raw(self): ) self.assertNoSpans() - def test_iteration_w_raw_w_resume_tken(self): + def test_iteration_w_raw_w_resume_token(self): ITEMS = ( self._make_item(0), self._make_item(1, resume_token=RESUME_TOKEN), @@ -233,7 +217,7 @@ def test_iteration_w_raw_w_resume_tken(self): database = _Database() database.spanner_api = build_spanner_api() session = _Session(database) - derived = self._makeDerived(session) + derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(ITEMS)) restart.assert_called_once_with( @@ -262,7 +246,7 @@ def test_iteration_w_raw_raising_unavailable_no_token(self): database = _Database() database.spanner_api = build_spanner_api() session = _Session(database) - derived = self._makeDerived(session) + derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(ITEMS)) self.assertEqual(len(restart.mock_calls), 2) @@ -285,7 +269,7 @@ def test_iteration_w_raw_raising_retryable_internal_error_no_token(self): database = _Database() database.spanner_api = build_spanner_api() session = _Session(database) - derived = self._makeDerived(session) + derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(ITEMS)) self.assertEqual(len(restart.mock_calls), 2) @@ -307,7 +291,7 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self): database = _Database() database.spanner_api = build_spanner_api() session = _Session(database) - derived = self._makeDerived(session) + derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) with self.assertRaises(InternalServerError): list(resumable) @@ -337,7 +321,7 @@ def test_iteration_w_raw_raising_unavailable(self): database = _Database() database.spanner_api = build_spanner_api() session = _Session(database) - derived = self._makeDerived(session) + derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(FIRST + LAST)) self.assertEqual(len(restart.mock_calls), 2) @@ -359,7 +343,7 @@ def test_iteration_w_raw_raising_retryable_internal_error(self): database = _Database() database.spanner_api = build_spanner_api() session = _Session(database) - derived = self._makeDerived(session) + derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(FIRST + LAST)) self.assertEqual(len(restart.mock_calls), 2) @@ -381,7 +365,7 @@ def test_iteration_w_raw_raising_non_retryable_internal_error(self): database = _Database() database.spanner_api = build_spanner_api() session = _Session(database) - derived = self._makeDerived(session) + derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) with self.assertRaises(InternalServerError): list(resumable) @@ -410,7 +394,7 @@ def test_iteration_w_raw_raising_unavailable_after_token(self): database = _Database() database.spanner_api = build_spanner_api() session = _Session(database) - derived = self._makeDerived(session) + derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(FIRST + SECOND)) self.assertEqual(len(restart.mock_calls), 2) @@ -432,7 +416,7 @@ def test_iteration_w_raw_w_multiuse(self): database = _Database() database.spanner_api = build_spanner_api() session = _Session(database) - derived = self._makeDerived(session) + derived = _build_snapshot_derived(session) derived._multi_use = True resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(FIRST)) @@ -463,7 +447,7 @@ def test_iteration_w_raw_raising_unavailable_w_multiuse(self): database = _Database() database.spanner_api = build_spanner_api() session = _Session(database) - derived = self._makeDerived(session) + derived = _build_snapshot_derived(session) derived._multi_use = True resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(SECOND)) @@ -501,7 +485,7 @@ def test_iteration_w_raw_raising_unavailable_after_token_w_multiuse(self): database = _Database() database.spanner_api = build_spanner_api() session = _Session(database) - derived = self._makeDerived(session) + derived = _build_snapshot_derived(session) derived._multi_use = True resumable = self._call_fut(derived, restart, request, session=session) @@ -535,7 +519,7 @@ def test_iteration_w_raw_raising_retryable_internal_error_after_token(self): database = _Database() database.spanner_api = build_spanner_api() session = _Session(database) - derived = self._makeDerived(session) + derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) self.assertEqual(list(resumable), list(FIRST + SECOND)) self.assertEqual(len(restart.mock_calls), 2) @@ -556,7 +540,7 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self): database = _Database() database.spanner_api = build_spanner_api() session = _Session(database) - derived = self._makeDerived(session) + derived = _build_snapshot_derived(session) resumable = self._call_fut(derived, restart, request, session=session) with self.assertRaises(InternalServerError): list(resumable) @@ -580,7 +564,7 @@ def test_iteration_w_span_creation(self): database = _Database() database.spanner_api = build_spanner_api() session = _Session(database) - derived = self._makeDerived(session) + derived = _build_snapshot_derived(session) resumable = self._call_fut( derived, restart, request, name, _Session(_Database()), extra_atts ) @@ -610,7 +594,7 @@ def test_iteration_w_multiple_span_creation(self): database = _Database() database.spanner_api = build_spanner_api() session = _Session(database) - derived = self._makeDerived(session) + derived = _build_snapshot_derived(session) resumable = self._call_fut( derived, restart, request, name, _Session(_Database()) ) @@ -633,56 +617,60 @@ def test_iteration_w_multiple_span_creation(self): class Test_SnapshotBase(OpenTelemetryBase): - class _Derived(_SnapshotBase): - """A minimally-implemented _SnapshotBase-derived class for testing""" + def test_ctor(self): + session = build_session() + derived = _build_snapshot_derived(session=session) - # Use a simplified implementation of _make_txn_selector - # that always returns the same transaction selector. - TRANSACTION_SELECTOR = TransactionSelector() + # Attributes from _SessionWrapper. + self.assertIs(derived._session, session) - def _make_txn_selector(self) -> TransactionSelector: - return self.TRANSACTION_SELECTOR + # Attributes from _SnapshotBase. + self.assertTrue(derived._read_only) + self.assertFalse(derived._multi_use) + self.assertEqual(derived._execute_sql_request_count, 0) + self.assertEqual(derived._read_request_count, 0) + self.assertIsNone(derived._transaction_id) + self.assertIsNone(derived._precommit_token) + self.assertIsInstance(derived._lock, type(Lock())) - @staticmethod - def _build_derived(session=None, multi_use=False, read_only=True): - """Builds and returns an instance of a minimally-implemented - _SnapshotBase-derived class for testing.""" + self.assertNoSpans() - session = session or build_session() - if session.session_id is None: - session.create() + def test__build_transaction_selector_pb_single_use(self): + derived = _build_snapshot_derived(multi_use=False) - derived = Test_SnapshotBase._Derived(session=session) - derived._multi_use = multi_use - derived._read_only = read_only + actual_selector = derived._build_transaction_selector_pb() - return derived + expected_selector = TransactionSelector(single_use=_Derived.TRANSACTION_OPTIONS) + self.assertEqual(actual_selector, expected_selector) - def test_ctor(self): - session = _Session() - base = _SnapshotBase(session) - self.assertIs(base._session, session) - self.assertEqual(base._execute_sql_request_count, 0) + def test__build_transaction_selector_pb_multi_use(self): + derived = _build_snapshot_derived(multi_use=True) - self.assertNoSpans() + # Select new transaction. + expected_options = _Derived.TRANSACTION_OPTIONS + expected_selector = TransactionSelector(begin=expected_options) + self.assertEqual(expected_selector, derived._build_transaction_selector_pb()) - def test__make_txn_selector_virtual(self): - session = _Session() - base = _SnapshotBase(session) - with self.assertRaises(NotImplementedError): - base._make_txn_selector() + # Select existing transaction. + transaction_id = b"transaction-id" + begin_transaction = derived._session._database.spanner_api.begin_transaction + begin_transaction.return_value = build_transaction_pb(id=transaction_id) + + derived.begin() + + expected_selector = TransactionSelector(id=transaction_id) + self.assertEqual(expected_selector, derived._build_transaction_selector_pb()) def test_begin_error_not_multi_use(self): - derived = self._build_derived(multi_use=False) + derived = _build_snapshot_derived(multi_use=False) - self.reset() with self.assertRaises(ValueError): derived.begin() self.assertNoSpans() def test_begin_error_already_begun(self): - derived = self._build_derived(multi_use=True) + derived = _build_snapshot_derived(multi_use=True) derived.begin() self.reset() @@ -692,13 +680,12 @@ def test_begin_error_already_begun(self): self.assertNoSpans() def test_begin_error_other(self): - derived = self._build_derived(multi_use=True) + derived = _build_snapshot_derived(multi_use=True) database = derived._session._database begin_transaction = database.spanner_api.begin_transaction begin_transaction.side_effect = RuntimeError() - self.reset() with self.assertRaises(RuntimeError): derived.begin() @@ -712,7 +699,7 @@ def test_begin_error_other(self): ) def test_begin_read_write(self): - derived = self._build_derived(multi_use=True, read_only=False) + derived = _build_snapshot_derived(multi_use=True, read_only=False) begin_transaction = derived._session._database.spanner_api.begin_transaction begin_transaction.return_value = build_transaction_pb() @@ -720,7 +707,7 @@ def test_begin_read_write(self): self._execute_begin(derived) def test_begin_read_only(self): - derived = self._build_derived(multi_use=True, read_only=True) + derived = _build_snapshot_derived(multi_use=True, read_only=True) begin_transaction = derived._session._database.spanner_api.begin_transaction begin_transaction.return_value = build_transaction_pb() @@ -728,7 +715,7 @@ def test_begin_read_only(self): self._execute_begin(derived) def test_begin_precommit_token(self): - derived = self._build_derived(multi_use=True) + derived = _build_snapshot_derived(multi_use=True) begin_transaction = derived._session._database.spanner_api.begin_transaction begin_transaction.return_value = build_transaction_pb( @@ -738,7 +725,7 @@ def test_begin_precommit_token(self): self._execute_begin(derived) def test_begin_retry_for_internal_server_error(self): - derived = self._build_derived(multi_use=True) + derived = _build_snapshot_derived(multi_use=True) begin_transaction = derived._session._database.spanner_api.begin_transaction begin_transaction.side_effect = [ @@ -758,7 +745,7 @@ def test_begin_retry_for_internal_server_error(self): self.assertEqual(expected_statuses, actual_statuses) def test_begin_retry_for_aborted(self): - derived = self._build_derived(multi_use=True) + derived = _build_snapshot_derived(multi_use=True) begin_transaction = derived._session._database.spanner_api.begin_transaction begin_transaction.side_effect = [ @@ -785,9 +772,6 @@ def _execute_begin(self, derived: _Derived, attempts: int = 1): session = derived._session database = session._database - # Clear spans. - self.reset() - transaction_id = derived.begin() # Verify transaction state. @@ -813,7 +797,7 @@ def _execute_begin(self, derived: _Derived, attempts: int = 1): database.spanner_api.begin_transaction.assert_called_with( request=BeginTransactionRequest( - session=session.name, options=self._Derived.TRANSACTION_SELECTOR.begin + session=session.name, options=_Derived.TRANSACTION_OPTIONS ), metadata=expected_metadata, ) @@ -836,7 +820,7 @@ def test_read_other_error(self): database.spanner_api = build_spanner_api() database.spanner_api.streaming_read.side_effect = RuntimeError() session = _Session(database) - derived = self._build_derived(session) + derived = _build_snapshot_derived(session) with self.assertRaises(RuntimeError): list(derived.read(TABLE_NAME, COLUMNS, keyset)) @@ -930,9 +914,10 @@ def _execute_read( api = database.spanner_api = build_spanner_api() api.streaming_read.return_value = _MockIterator(*result_sets) session = _Session(database) - derived = self._build_derived(session) + derived = _build_snapshot_derived(session) derived._multi_use = multi_use derived._read_request_count = count + if not first: derived._transaction_id = TXN_ID @@ -941,6 +926,8 @@ def _execute_read( elif type(request_options) is dict: request_options = RequestOptions(request_options) + transaction_selector_pb = derived._build_transaction_selector_pb() + if partition is not None: # 'limit' and 'partition' incompatible result_set = derived.read( TABLE_NAME, @@ -992,7 +979,7 @@ def _execute_read( table=TABLE_NAME, columns=COLUMNS, key_set=keyset._to_pb(), - transaction=self._Derived.TRANSACTION_SELECTOR, + transaction=transaction_selector_pb, index=INDEX, limit=expected_limit, partition_token=partition, @@ -1116,7 +1103,7 @@ def test_execute_sql_other_error(self): database.spanner_api = build_spanner_api() database.spanner_api.execute_streaming_sql.side_effect = RuntimeError() session = _Session(database) - derived = self._build_derived(session) + derived = _build_snapshot_derived(session) with self.assertRaises(RuntimeError): list(derived.execute_sql(SQL_QUERY)) @@ -1213,7 +1200,7 @@ def _execute_sql_helper( api = database.spanner_api = build_spanner_api() api.execute_streaming_sql.return_value = iterator session = _Session(database) - derived = self._build_derived(session, multi_use=multi_use) + derived = _build_snapshot_derived(session, multi_use=multi_use) derived._read_request_count = count derived._execute_sql_request_count = sql_count if not first: @@ -1224,6 +1211,8 @@ def _execute_sql_helper( elif type(request_options) is dict: request_options = RequestOptions(request_options) + transaction_selector_pb = derived._build_transaction_selector_pb() + result_set = derived.execute_sql( SQL_QUERY_WITH_PARAM, PARAMS, @@ -1267,7 +1256,7 @@ def _execute_sql_helper( expected_request = ExecuteSqlRequest( session=session.name, sql=SQL_QUERY_WITH_PARAM, - transaction=self._Derived.TRANSACTION_SELECTOR, + transaction=transaction_selector_pb, params=expected_params, param_types=PARAM_TYPES, query_mode=MODE, @@ -1434,10 +1423,14 @@ def _partition_read_helper( api = database.spanner_api = build_spanner_api() api.partition_read.return_value = response session = _Session(database) - derived = self._build_derived(session) + derived = _build_snapshot_derived(session) derived._multi_use = multi_use + if w_txn: derived._transaction_id = TXN_ID + + transaction_selector_pb = derived._build_transaction_selector_pb() + tokens = list( derived.partition_read( TABLE_NAME, @@ -1462,7 +1455,7 @@ def _partition_read_helper( table=TABLE_NAME, columns=COLUMNS, key_set=keyset._to_pb(), - transaction=self._Derived.TRANSACTION_SELECTOR, + transaction=transaction_selector_pb, index=index, partition_options=expected_partition_options, ) @@ -1511,7 +1504,7 @@ def test_partition_read_other_error(self): database.spanner_api = build_spanner_api() database.spanner_api.partition_read.side_effect = RuntimeError() session = _Session(database) - derived = self._build_derived(session, multi_use=True) + derived = _build_snapshot_derived(session, multi_use=True) derived._transaction_id = TXN_ID with self.assertRaises(RuntimeError): @@ -1554,7 +1547,7 @@ def test_partition_read_w_retry(self): ] session = _Session(database) - derived = self._build_derived(session) + derived = _build_snapshot_derived(session) derived._multi_use = True derived._transaction_id = TXN_ID @@ -1619,10 +1612,12 @@ def _partition_query_helper( api = database.spanner_api = build_spanner_api() api.partition_query.return_value = response session = _Session(database) - derived = self._build_derived(session, multi_use=multi_use) + derived = _build_snapshot_derived(session, multi_use=multi_use) if w_txn: derived._transaction_id = TXN_ID + transaction_selector_pb = derived._build_transaction_selector_pb() + tokens = list( derived.partition_query( SQL_QUERY_WITH_PARAM, @@ -1648,7 +1643,7 @@ def _partition_query_helper( expected_request = PartitionQueryRequest( session=session.name, sql=SQL_QUERY_WITH_PARAM, - transaction=self._Derived.TRANSACTION_SELECTOR, + transaction=transaction_selector_pb, params=expected_params, param_types=PARAM_TYPES, partition_options=expected_partition_options, @@ -1685,7 +1680,7 @@ def test_partition_query_other_error(self): database.spanner_api = build_spanner_api() database.spanner_api.partition_query.side_effect = RuntimeError() session = _Session(database) - derived = self._build_derived(session, multi_use=True) + derived = _build_snapshot_derived(session, multi_use=True) derived._transaction_id = TXN_ID with self.assertRaises(RuntimeError): @@ -1755,218 +1750,133 @@ def _makeDuration(self, seconds=1, microseconds=0): return datetime.timedelta(seconds=seconds, microseconds=microseconds) def test_ctor_defaults(self): - session = _Session() - snapshot = self._make_one(session) + session = build_session() + snapshot = build_snapshot(session=session) + + # Attributes from _SessionWrapper. self.assertIs(snapshot._session, session) + + # Attributes from _SnapshotBase. + self.assertTrue(snapshot._read_only) + self.assertFalse(snapshot._multi_use) + self.assertEqual(snapshot._execute_sql_request_count, 0) + self.assertEqual(snapshot._read_request_count, 0) + self.assertIsNone(snapshot._transaction_id) + self.assertIsNone(snapshot._precommit_token) + self.assertIsInstance(snapshot._lock, type(Lock())) + + # Attributes from Snapshot. self.assertTrue(snapshot._strong) self.assertIsNone(snapshot._read_timestamp) self.assertIsNone(snapshot._min_read_timestamp) self.assertIsNone(snapshot._max_staleness) self.assertIsNone(snapshot._exact_staleness) - self.assertFalse(snapshot._multi_use) def test_ctor_w_multiple_options(self): - timestamp = _makeTimestamp() - duration = self._makeDuration() - session = _Session() - with self.assertRaises(ValueError): - self._make_one(session, read_timestamp=timestamp, max_staleness=duration) + build_snapshot(read_timestamp=datetime.min, max_staleness=timedelta()) def test_ctor_w_read_timestamp(self): - timestamp = _makeTimestamp() - session = _Session() - snapshot = self._make_one(session, read_timestamp=timestamp) - self.assertIs(snapshot._session, session) - self.assertFalse(snapshot._strong) - self.assertEqual(snapshot._read_timestamp, timestamp) - self.assertIsNone(snapshot._min_read_timestamp) - self.assertIsNone(snapshot._max_staleness) - self.assertIsNone(snapshot._exact_staleness) - self.assertFalse(snapshot._multi_use) + snapshot = build_snapshot(read_timestamp=TIMESTAMP) + self.assertEqual(snapshot._read_timestamp, TIMESTAMP) def test_ctor_w_min_read_timestamp(self): - timestamp = _makeTimestamp() - session = _Session() - snapshot = self._make_one(session, min_read_timestamp=timestamp) - self.assertIs(snapshot._session, session) - self.assertFalse(snapshot._strong) - self.assertIsNone(snapshot._read_timestamp) - self.assertEqual(snapshot._min_read_timestamp, timestamp) - self.assertIsNone(snapshot._max_staleness) - self.assertIsNone(snapshot._exact_staleness) - self.assertFalse(snapshot._multi_use) + snapshot = build_snapshot(min_read_timestamp=TIMESTAMP) + self.assertEqual(snapshot._min_read_timestamp, TIMESTAMP) def test_ctor_w_max_staleness(self): - duration = self._makeDuration() - session = _Session() - snapshot = self._make_one(session, max_staleness=duration) - self.assertIs(snapshot._session, session) - self.assertFalse(snapshot._strong) - self.assertIsNone(snapshot._read_timestamp) - self.assertIsNone(snapshot._min_read_timestamp) - self.assertEqual(snapshot._max_staleness, duration) - self.assertIsNone(snapshot._exact_staleness) - self.assertFalse(snapshot._multi_use) + snapshot = build_snapshot(max_staleness=DURATION) + self.assertEqual(snapshot._max_staleness, DURATION) def test_ctor_w_exact_staleness(self): - duration = self._makeDuration() - session = _Session() - snapshot = self._make_one(session, exact_staleness=duration) - self.assertIs(snapshot._session, session) - self.assertFalse(snapshot._strong) - self.assertIsNone(snapshot._read_timestamp) - self.assertIsNone(snapshot._min_read_timestamp) - self.assertIsNone(snapshot._max_staleness) - self.assertEqual(snapshot._exact_staleness, duration) - self.assertFalse(snapshot._multi_use) + snapshot = build_snapshot(exact_staleness=DURATION) + self.assertEqual(snapshot._exact_staleness, DURATION) def test_ctor_w_multi_use(self): - session = _Session() - snapshot = self._make_one(session, multi_use=True) - self.assertTrue(snapshot._session is session) - self.assertTrue(snapshot._strong) - self.assertIsNone(snapshot._read_timestamp) - self.assertIsNone(snapshot._min_read_timestamp) - self.assertIsNone(snapshot._max_staleness) - self.assertIsNone(snapshot._exact_staleness) + snapshot = build_snapshot(multi_use=True) self.assertTrue(snapshot._multi_use) def test_ctor_w_multi_use_and_read_timestamp(self): - timestamp = _makeTimestamp() - session = _Session() - snapshot = self._make_one(session, read_timestamp=timestamp, multi_use=True) - self.assertTrue(snapshot._session is session) - self.assertFalse(snapshot._strong) - self.assertEqual(snapshot._read_timestamp, timestamp) - self.assertIsNone(snapshot._min_read_timestamp) - self.assertIsNone(snapshot._max_staleness) - self.assertIsNone(snapshot._exact_staleness) + snapshot = build_snapshot(multi_use=True, read_timestamp=TIMESTAMP) self.assertTrue(snapshot._multi_use) + self.assertEqual(snapshot._read_timestamp, TIMESTAMP) def test_ctor_w_multi_use_and_min_read_timestamp(self): - timestamp = _makeTimestamp() - session = _Session() - with self.assertRaises(ValueError): - self._make_one(session, min_read_timestamp=timestamp, multi_use=True) + build_snapshot(multi_use=True, min_read_timestamp=TIMESTAMP) def test_ctor_w_multi_use_and_max_staleness(self): - duration = self._makeDuration() - session = _Session() - with self.assertRaises(ValueError): - self._make_one(session, max_staleness=duration, multi_use=True) + build_snapshot(multi_use=True, max_staleness=DURATION) def test_ctor_w_multi_use_and_exact_staleness(self): - duration = self._makeDuration() - session = _Session() - snapshot = self._make_one(session, exact_staleness=duration, multi_use=True) - self.assertTrue(snapshot._session is session) - self.assertFalse(snapshot._strong) - self.assertIsNone(snapshot._read_timestamp) - self.assertIsNone(snapshot._min_read_timestamp) - self.assertIsNone(snapshot._max_staleness) - self.assertEqual(snapshot._exact_staleness, duration) + snapshot = build_snapshot(multi_use=True, exact_staleness=DURATION) self.assertTrue(snapshot._multi_use) + self.assertEqual(snapshot._exact_staleness, DURATION) + + def test__build_transaction_options_strong(self): + snapshot = build_snapshot() + options = snapshot._build_transaction_options_pb() - def test__make_txn_selector_w_transaction_id(self): - session = _Session() - snapshot = self._make_one(session) - snapshot._transaction_id = TXN_ID - selector = snapshot._make_txn_selector() - self.assertEqual(selector.id, TXN_ID) - - def test__make_txn_selector_strong(self): - session = _Session() - snapshot = self._make_one(session) - selector = snapshot._make_txn_selector() - options = selector.single_use - self.assertTrue(options.read_only.strong) - - def test__make_txn_selector_w_read_timestamp(self): - from google.cloud._helpers import _pb_timestamp_to_datetime - - timestamp = _makeTimestamp() - session = _Session() - snapshot = self._make_one(session, read_timestamp=timestamp) - selector = snapshot._make_txn_selector() - options = selector.single_use self.assertEqual( - _pb_timestamp_to_datetime( - type(options).pb(options).read_only.read_timestamp + options, + TransactionOptions( + read_only=TransactionOptions.ReadOnly( + strong=True, return_read_timestamp=True + ) ), - timestamp, ) - def test__make_txn_selector_w_min_read_timestamp(self): - from google.cloud._helpers import _pb_timestamp_to_datetime + def test__build_transaction_options_w_read_timestamp(self): + snapshot = build_snapshot(read_timestamp=TIMESTAMP) + options = snapshot._build_transaction_options_pb() - timestamp = _makeTimestamp() - session = _Session() - snapshot = self._make_one(session, min_read_timestamp=timestamp) - selector = snapshot._make_txn_selector() - options = selector.single_use self.assertEqual( - _pb_timestamp_to_datetime( - type(options).pb(options).read_only.min_read_timestamp + options, + TransactionOptions( + read_only=TransactionOptions.ReadOnly( + read_timestamp=TIMESTAMP, return_read_timestamp=True + ) ), - timestamp, ) - def test__make_txn_selector_w_max_staleness(self): - duration = self._makeDuration(seconds=3, microseconds=123456) - session = _Session() - snapshot = self._make_one(session, max_staleness=duration) - selector = snapshot._make_txn_selector() - options = selector.single_use - self.assertEqual(type(options).pb(options).read_only.max_staleness.seconds, 3) - self.assertEqual( - type(options).pb(options).read_only.max_staleness.nanos, 123456000 - ) + def test__build_transaction_options_w_min_read_timestamp(self): + snapshot = build_snapshot(min_read_timestamp=TIMESTAMP) + options = snapshot._build_transaction_options_pb() - def test__make_txn_selector_w_exact_staleness(self): - duration = self._makeDuration(seconds=3, microseconds=123456) - session = _Session() - snapshot = self._make_one(session, exact_staleness=duration) - selector = snapshot._make_txn_selector() - options = selector.single_use - self.assertEqual(type(options).pb(options).read_only.exact_staleness.seconds, 3) self.assertEqual( - type(options).pb(options).read_only.exact_staleness.nanos, 123456000 + options, + TransactionOptions( + read_only=TransactionOptions.ReadOnly( + min_read_timestamp=TIMESTAMP, return_read_timestamp=True + ) + ), ) - def test__make_txn_selector_strong_w_multi_use(self): - session = _Session() - snapshot = self._make_one(session, multi_use=True) - selector = snapshot._make_txn_selector() - options = selector.begin - self.assertTrue(options.read_only.strong) + def test__build_transaction_options_w_max_staleness(self): + snapshot = build_snapshot(max_staleness=DURATION) + options = snapshot._build_transaction_options_pb() - def test__make_txn_selector_w_read_timestamp_w_multi_use(self): - from google.cloud._helpers import _pb_timestamp_to_datetime - - timestamp = _makeTimestamp() - session = _Session() - snapshot = self._make_one(session, read_timestamp=timestamp, multi_use=True) - selector = snapshot._make_txn_selector() - options = selector.begin self.assertEqual( - _pb_timestamp_to_datetime( - type(options).pb(options).read_only.read_timestamp + options, + TransactionOptions( + read_only=TransactionOptions.ReadOnly( + max_staleness=DURATION, return_read_timestamp=True + ) ), - timestamp, ) - def test__make_txn_selector_w_exact_staleness_w_multi_use(self): - duration = self._makeDuration(seconds=3, microseconds=123456) - session = _Session() - snapshot = self._make_one(session, exact_staleness=duration, multi_use=True) - selector = snapshot._make_txn_selector() - options = selector.begin - self.assertEqual(type(options).pb(options).read_only.exact_staleness.seconds, 3) + def test__build_transaction_options_w_exact_staleness(self): + snapshot = build_snapshot(exact_staleness=DURATION) + options = snapshot._build_transaction_options_pb() + self.assertEqual( - type(options).pb(options).read_only.exact_staleness.nanos, 123456000 + options, + TransactionOptions( + read_only=TransactionOptions.ReadOnly( + exact_staleness=DURATION, return_read_timestamp=True + ) + ), ) @@ -2058,6 +1968,21 @@ def __next__(self): next = __next__ +def _build_snapshot_derived(session=None, multi_use=False, read_only=True) -> _Derived: + """Builds and returns an instance of a minimally- + implemented _Derived class for testing.""" + + session = session or build_session() + if session.session_id is None: + session._session_id = "session-id" + + derived = _Derived(session=session) + derived._multi_use = multi_use + derived._read_only = read_only + + return derived + + def _build_span_attributes(database: Database, attempt: int = 1) -> Mapping[str, str]: """Builds the attributes for spans using the given database and extra attributes.""" diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index b668614a62..307c9f9d8c 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from threading import Lock from typing import Mapping from datetime import timedelta @@ -36,6 +37,7 @@ ) from google.cloud.spanner_v1.batch import _make_write_pb from google.cloud.spanner_v1.database import Database +from google.cloud.spanner_v1.transaction import Transaction from google.cloud.spanner_v1.request_id_header import ( REQ_RAND_PROCESS_ID, build_request_id, @@ -114,21 +116,28 @@ def _make_spanner_api(self): return mock.create_autospec(SpannerClient, instance=True) def test_ctor_defaults(self): - session = _Session() - transaction = self._make_one(session) - self.assertIs(transaction._session, session) - self.assertIsNone(transaction._transaction_id) - self.assertIsNone(transaction.committed) - self.assertFalse(transaction.rolled_back) + session = build_session() + transaction = Transaction(session=session) + + # Attributes from _SessionWrapper + self.assertEqual(transaction._session, session) + + # Attributes from _SnapshotBase + self.assertFalse(transaction._read_only) self.assertTrue(transaction._multi_use) self.assertEqual(transaction._execute_sql_request_count, 0) + self.assertEqual(transaction._read_request_count, 0) + self.assertIsNone(transaction._transaction_id) + self.assertIsNone(transaction._precommit_token) + self.assertIsInstance(transaction._lock, type(Lock())) - def test__make_txn_selector(self): - session = _Session() - transaction = self._make_one(session) - transaction._transaction_id = TRANSACTION_ID - selector = transaction._make_txn_selector() - self.assertEqual(selector.id, TRANSACTION_ID) + # Attributes from _BatchBase + self.assertEqual(transaction._mutations, []) + self.assertIsNone(transaction._precommit_token) + self.assertIsNone(transaction.committed) + self.assertIsNone(transaction.commit_stats) + + self.assertFalse(transaction.rolled_back) def test_begin_already_rolled_back(self): session = _Session() @@ -219,7 +228,6 @@ def test_rollback_ok(self): transaction.rollback() self.assertTrue(transaction.rolled_back) - self.assertIsNone(session._transaction) session_id, txn_id, metadata = api._rolled_back self.assertEqual(session_id, session.name) From 1126b0d805de02bc5b96de63af3f646d7d7f4f06 Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Mon, 9 Jun 2025 11:38:40 -0700 Subject: [PATCH 4/6] feat: Multiplexed sessions - Add retry logic to run_in_transaction with previous transaction ID. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/session.py | 13 +- google/cloud/spanner_v1/transaction.py | 10 +- tests/unit/test_session.py | 195 +++++++++++++------------ 3 files changed, 125 insertions(+), 93 deletions(-) diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 6aa919e65a..92f2488f22 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -464,7 +464,6 @@ def batch(self): return Batch(self) - # TODO multiplexed - deprecate def transaction(self) -> Transaction: """Create a transaction to perform a set of reads with shared staleness. @@ -531,6 +530,12 @@ def run_in_transaction(self, func, *args, **kw): ) as span, MetricsCapture(): attempts: int = 0 + # If the transaction is retried after an aborted user operation, it should include the previous transaction ID + # in the transaction options used to begin the transaction. This allows the backend to recognize the transaction + # and increase the lock order for the new transaction ID that is created. + # See :attr:`~google.cloud.spanner_v1.types.TransactionOptions.ReadWrite.multiplexed_session_previous_transaction_id` + previous_transaction_id: Optional[bytes] = None + while True: # [A] Build transaction # --------------------- @@ -539,6 +544,7 @@ def run_in_transaction(self, func, *args, **kw): txn.transaction_tag = transaction_tag txn.exclude_txn_from_change_streams = exclude_txn_from_change_streams txn.isolation_level = isolation_level + txn._previous_transaction_id = previous_transaction_id # [B] Run user operation # ---------------------- @@ -549,8 +555,8 @@ def run_in_transaction(self, func, *args, **kw): try: return_value = func(txn, *args, **kw) - # TODO multiplexed: store previous transaction ID. except Aborted as exc: + previous_transaction_id = txn._transaction_id if span: delay_seconds = _get_retry_delay( exc.errors[0], @@ -598,6 +604,7 @@ def run_in_transaction(self, func, *args, **kw): ) except Aborted as exc: + previous_transaction_id = txn._transaction_id if span: delay_seconds = _get_retry_delay( exc.errors[0], @@ -615,6 +622,7 @@ def run_in_transaction(self, func, *args, **kw): _delay_until_retry( exc, deadline, attempts, default_retry_delay=default_retry_delay ) + except GoogleAPICallError: add_span_event( span, @@ -622,6 +630,7 @@ def run_in_transaction(self, func, *args, **kw): span_attributes, ) raise + else: if log_commit_stats and txn.commit_stats: database.logger.info( diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 779010fbf0..3b6ae37d86 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -70,6 +70,12 @@ def __init__(self, session): super(Transaction, self).__init__(session) self.rolled_back: bool = False + # If this transaction is used to retry a previous aborted transaction, the + # identifier for that transaction is used to increase the lock order of the new + # transaction (see :meth:`_build_transaction_options_pb`). This attribute should + # only be set by :meth:`~google.cloud.spanner_v1.session.Session.run_in_transaction`. + self._previous_transaction_id: Optional[bytes] = None + def _build_transaction_options_pb(self) -> TransactionOptions: """Builds and returns transaction options for this transaction. @@ -82,7 +88,9 @@ def _build_transaction_options_pb(self) -> TransactionOptions: ) merge_transaction_options = TransactionOptions( - read_write=TransactionOptions.ReadWrite(), + read_write=TransactionOptions.ReadWrite( + multiplexed_session_previous_transaction_id=self._previous_transaction_id + ), exclude_txn_from_change_streams=self.exclude_txn_from_change_streams, isolation_level=self.isolation_level, ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index c459ab10df..0bc1eda54d 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -33,7 +33,7 @@ from google.cloud._helpers import UTC, _datetime_to_pb_timestamp from google.cloud.spanner_v1._helpers import _delay_until_retry from google.cloud.spanner_v1.transaction import Transaction -from tests._builders import build_spanner_api +from tests._builders import build_spanner_api, build_session, build_transaction_pb from tests._helpers import ( OpenTelemetryBase, LIB_VERSION, @@ -57,8 +57,18 @@ _metadata_with_request_id, ) +TABLE_NAME = "citizens" +COLUMNS = ["email", "first_name", "last_name", "age"] +VALUES = [ + ["phred@exammple.com", "Phred", "Phlyntstone", 32], + ["bharney@example.com", "Bharney", "Rhubble", 31], +] +KEYS = ["bharney@example.com", "phred@example.com"] +KEYSET = KeySet(keys=KEYS) +TRANSACTION_ID = b"FACEDACE" -def _make_rpc_error(error_cls, trailing_metadata=None): + +def _make_rpc_error(error_cls, trailing_metadata=[]): grpc_error = mock.create_autospec(grpc.Call, instance=True) grpc_error.trailing_metadata.return_value = trailing_metadata return error_cls("error", errors=(grpc_error,)) @@ -1038,9 +1048,52 @@ def unit_of_work(txn, *args, **kw): gax_api.rollback.assert_not_called() + def test_run_in_transaction_retry_callback_raises_abort(self): + session = build_session() + database = session._database + api = database.spanner_api + + # Build API responses + previous_transaction_id = b"transaction-id" + begin_transaction = api.begin_transaction + begin_transaction.return_value = build_transaction_pb( + id=previous_transaction_id + ) + + streaming_read = api.streaming_read + streaming_read.side_effect = [_make_rpc_error(Aborted), []] + + # Run in transaction. + def unit_of_work(transaction): + transaction.begin() + list(transaction.read(TABLE_NAME, COLUMNS, KEYSET)) + + session.create() + session.run_in_transaction(unit_of_work) + + # Verify retried BeginTransaction API call. + self.assertEqual(begin_transaction.call_count, 2) + + begin_transaction.assert_called_with( + request=BeginTransactionRequest( + session=session.name, + options=TransactionOptions( + read_write=TransactionOptions.ReadWrite( + multiplexed_session_previous_transaction_id=previous_transaction_id + ) + ), + ), + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.4.1", + ), + ], + ) + def test_run_in_transaction_w_args_w_kwargs_wo_abort(self): - TABLE_NAME = "citizens" - COLUMNS = ["email", "first_name", "last_name", "age"] VALUES = [ ["phred@exammple.com", "Phred", "Phlyntstone", 32], ["bharney@example.com", "Bharney", "Rhubble", 31], @@ -1172,13 +1225,6 @@ def unit_of_work(txn, *args, **kw): ) def test_run_in_transaction_w_abort_no_retry_metadata(self): - TABLE_NAME = "citizens" - COLUMNS = ["email", "first_name", "last_name", "age"] - VALUES = [ - ["phred@exammple.com", "Phred", "Phlyntstone", 32], - ["bharney@example.com", "Bharney", "Rhubble", 31], - ] - TRANSACTION_ID = b"FACEDACE" transaction_pb = TransactionPB(id=TRANSACTION_ID) now = datetime.datetime.utcnow().replace(tzinfo=UTC) now_pb = _datetime_to_pb_timestamp(now) @@ -1210,13 +1256,15 @@ def unit_of_work(txn, *args, **kw): self.assertEqual(args, ("abc",)) self.assertEqual(kw, {"some_arg": "def"}) - expected_options = TransactionOptions(read_write=TransactionOptions.ReadWrite()) self.assertEqual( gax_api.begin_transaction.call_args_list, [ mock.call( request=BeginTransactionRequest( - session=self.SESSION_NAME, options=expected_options + session=session.name, + options=TransactionOptions( + read_write=TransactionOptions.ReadWrite() + ), ), metadata=[ ("google-cloud-resource-prefix", database.name), @@ -1229,7 +1277,12 @@ def unit_of_work(txn, *args, **kw): ), mock.call( request=BeginTransactionRequest( - session=self.SESSION_NAME, options=expected_options + session=session.name, + options=TransactionOptions( + read_write=TransactionOptions.ReadWrite( + multiplexed_session_previous_transaction_id=TRANSACTION_ID + ) + ), ), metadata=[ ("google-cloud-resource-prefix", database.name), @@ -1277,13 +1330,6 @@ def unit_of_work(txn, *args, **kw): ) def test_run_in_transaction_w_abort_w_retry_metadata(self): - TABLE_NAME = "citizens" - COLUMNS = ["email", "first_name", "last_name", "age"] - VALUES = [ - ["phred@exammple.com", "Phred", "Phlyntstone", 32], - ["bharney@example.com", "Bharney", "Rhubble", 31], - ] - TRANSACTION_ID = b"FACEDACE" RETRY_SECONDS = 12 RETRY_NANOS = 3456 retry_info = RetryInfo( @@ -1326,13 +1372,15 @@ def unit_of_work(txn, *args, **kw): self.assertEqual(args, ("abc",)) self.assertEqual(kw, {"some_arg": "def"}) - expected_options = TransactionOptions(read_write=TransactionOptions.ReadWrite()) self.assertEqual( gax_api.begin_transaction.call_args_list, [ mock.call( request=BeginTransactionRequest( - session=self.SESSION_NAME, options=expected_options + session=session.name, + options=TransactionOptions( + read_write=TransactionOptions.ReadWrite() + ), ), metadata=[ ("google-cloud-resource-prefix", database.name), @@ -1345,7 +1393,12 @@ def unit_of_work(txn, *args, **kw): ), mock.call( request=BeginTransactionRequest( - session=self.SESSION_NAME, options=expected_options + session=session.name, + options=TransactionOptions( + read_write=TransactionOptions.ReadWrite( + multiplexed_session_previous_transaction_id=TRANSACTION_ID + ) + ), ), metadata=[ ("google-cloud-resource-prefix", database.name), @@ -1393,13 +1446,6 @@ def unit_of_work(txn, *args, **kw): ) def test_run_in_transaction_w_callback_raises_abort_wo_metadata(self): - TABLE_NAME = "citizens" - COLUMNS = ["email", "first_name", "last_name", "age"] - VALUES = [ - ["phred@exammple.com", "Phred", "Phlyntstone", 32], - ["bharney@example.com", "Bharney", "Rhubble", 31], - ] - TRANSACTION_ID = b"FACEDACE" RETRY_SECONDS = 1 RETRY_NANOS = 3456 transaction_pb = TransactionPB(id=TRANSACTION_ID) @@ -1477,13 +1523,6 @@ def unit_of_work(txn, *args, **kw): ) def test_run_in_transaction_w_abort_w_retry_metadata_deadline(self): - TABLE_NAME = "citizens" - COLUMNS = ["email", "first_name", "last_name", "age"] - VALUES = [ - ["phred@exammple.com", "Phred", "Phlyntstone", 32], - ["bharney@example.com", "Bharney", "Rhubble", 31], - ] - TRANSACTION_ID = b"FACEDACE" RETRY_SECONDS = 1 RETRY_NANOS = 3456 transaction_pb = TransactionPB(id=TRANSACTION_ID) @@ -1562,13 +1601,6 @@ def _time(_results=[1, 1.5]): ) def test_run_in_transaction_w_timeout(self): - TABLE_NAME = "citizens" - COLUMNS = ["email", "first_name", "last_name", "age"] - VALUES = [ - ["phred@exammple.com", "Phred", "Phlyntstone", 32], - ["bharney@example.com", "Bharney", "Rhubble", 31], - ] - TRANSACTION_ID = b"FACEDACE" transaction_pb = TransactionPB(id=TRANSACTION_ID) aborted = _make_rpc_error(Aborted, trailing_metadata=[]) gax_api = self._make_spanner_api() @@ -1607,13 +1639,15 @@ def _time(_results=[1, 2, 4, 8]): self.assertEqual(args, ()) self.assertEqual(kw, {}) - expected_options = TransactionOptions(read_write=TransactionOptions.ReadWrite()) self.assertEqual( gax_api.begin_transaction.call_args_list, [ mock.call( request=BeginTransactionRequest( - session=self.SESSION_NAME, options=expected_options + session=session.name, + options=TransactionOptions( + read_write=TransactionOptions.ReadWrite() + ), ), metadata=[ ("google-cloud-resource-prefix", database.name), @@ -1626,7 +1660,12 @@ def _time(_results=[1, 2, 4, 8]): ), mock.call( request=BeginTransactionRequest( - session=self.SESSION_NAME, options=expected_options + session=session.name, + options=TransactionOptions( + read_write=TransactionOptions.ReadWrite( + multiplexed_session_previous_transaction_id=TRANSACTION_ID + ) + ), ), metadata=[ ("google-cloud-resource-prefix", database.name), @@ -1639,7 +1678,12 @@ def _time(_results=[1, 2, 4, 8]): ), mock.call( request=BeginTransactionRequest( - session=self.SESSION_NAME, options=expected_options + session=session.name, + options=TransactionOptions( + read_write=TransactionOptions.ReadWrite( + multiplexed_session_previous_transaction_id=TRANSACTION_ID + ) + ), ), metadata=[ ("google-cloud-resource-prefix", database.name), @@ -1698,13 +1742,6 @@ def _time(_results=[1, 2, 4, 8]): ) def test_run_in_transaction_w_commit_stats_success(self): - TABLE_NAME = "citizens" - COLUMNS = ["email", "first_name", "last_name", "age"] - VALUES = [ - ["phred@exammple.com", "Phred", "Phlyntstone", 32], - ["bharney@example.com", "Bharney", "Rhubble", 31], - ] - TRANSACTION_ID = b"FACEDACE" transaction_pb = TransactionPB(id=TRANSACTION_ID) now = datetime.datetime.utcnow().replace(tzinfo=UTC) now_pb = _datetime_to_pb_timestamp(now) @@ -1772,13 +1809,6 @@ def unit_of_work(txn, *args, **kw): ) def test_run_in_transaction_w_commit_stats_error(self): - TABLE_NAME = "citizens" - COLUMNS = ["email", "first_name", "last_name", "age"] - VALUES = [ - ["phred@exammple.com", "Phred", "Phlyntstone", 32], - ["bharney@example.com", "Bharney", "Rhubble", 31], - ] - TRANSACTION_ID = b"FACEDACE" transaction_pb = TransactionPB(id=TRANSACTION_ID) gax_api = self._make_spanner_api() gax_api.begin_transaction.return_value = transaction_pb @@ -1840,13 +1870,6 @@ def unit_of_work(txn, *args, **kw): database.logger.info.assert_not_called() def test_run_in_transaction_w_transaction_tag(self): - TABLE_NAME = "citizens" - COLUMNS = ["email", "first_name", "last_name", "age"] - VALUES = [ - ["phred@exammple.com", "Phred", "Phlyntstone", 32], - ["bharney@example.com", "Bharney", "Rhubble", 31], - ] - TRANSACTION_ID = b"FACEDACE" transaction_pb = TransactionPB(id=TRANSACTION_ID) now = datetime.datetime.utcnow().replace(tzinfo=UTC) now_pb = _datetime_to_pb_timestamp(now) @@ -1912,13 +1935,6 @@ def unit_of_work(txn, *args, **kw): ) def test_run_in_transaction_w_exclude_txn_from_change_streams(self): - TABLE_NAME = "citizens" - COLUMNS = ["email", "first_name", "last_name", "age"] - VALUES = [ - ["phred@exammple.com", "Phred", "Phlyntstone", 32], - ["bharney@example.com", "Bharney", "Rhubble", 31], - ] - TRANSACTION_ID = b"FACEDACE" transaction_pb = TransactionPB(id=TRANSACTION_ID) now = datetime.datetime.utcnow().replace(tzinfo=UTC) now_pb = _datetime_to_pb_timestamp(now) @@ -1987,13 +2003,6 @@ def unit_of_work(txn, *args, **kw): def test_run_in_transaction_w_abort_w_retry_metadata_w_exclude_txn_from_change_streams( self, ): - TABLE_NAME = "citizens" - COLUMNS = ["email", "first_name", "last_name", "age"] - VALUES = [ - ["phred@exammple.com", "Phred", "Phlyntstone", 32], - ["bharney@example.com", "Bharney", "Rhubble", 31], - ] - TRANSACTION_ID = b"FACEDACE" RETRY_SECONDS = 12 RETRY_NANOS = 3456 retry_info = RetryInfo( @@ -2041,16 +2050,16 @@ def unit_of_work(txn, *args, **kw): self.assertEqual(args, ("abc",)) self.assertEqual(kw, {"some_arg": "def"}) - expected_options = TransactionOptions( - read_write=TransactionOptions.ReadWrite(), - exclude_txn_from_change_streams=True, - ) self.assertEqual( gax_api.begin_transaction.call_args_list, [ mock.call( request=BeginTransactionRequest( - session=self.SESSION_NAME, options=expected_options + session=session.name, + options=TransactionOptions( + read_write=TransactionOptions.ReadWrite(), + exclude_txn_from_change_streams=True, + ), ), metadata=[ ("google-cloud-resource-prefix", database.name), @@ -2063,7 +2072,13 @@ def unit_of_work(txn, *args, **kw): ), mock.call( request=BeginTransactionRequest( - session=self.SESSION_NAME, options=expected_options + session=session.name, + options=TransactionOptions( + read_write=TransactionOptions.ReadWrite( + multiplexed_session_previous_transaction_id=TRANSACTION_ID + ), + exclude_txn_from_change_streams=True, + ), ), metadata=[ ("google-cloud-resource-prefix", database.name), From 18f1a6c8eec2b1289695e774db88586147e2ed7e Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Wed, 11 Jun 2025 09:42:10 -0700 Subject: [PATCH 5/6] feat: Multiplexed sessions - Remove unnecessary divider comments Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/session.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 92f2488f22..d754baaf0f 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -537,18 +537,12 @@ def run_in_transaction(self, func, *args, **kw): previous_transaction_id: Optional[bytes] = None while True: - # [A] Build transaction - # --------------------- - txn = self.transaction() txn.transaction_tag = transaction_tag txn.exclude_txn_from_change_streams = exclude_txn_from_change_streams txn.isolation_level = isolation_level txn._previous_transaction_id = previous_transaction_id - # [B] Run user operation - # ---------------------- - attempts += 1 span_attributes = dict(attempt=attempts) @@ -593,9 +587,6 @@ def run_in_transaction(self, func, *args, **kw): txn.rollback() raise - # [C] Commit transaction - # ---------------------- - try: txn.commit( return_commit_stats=log_commit_stats, From e42f099bc2b517a0e9e4350f4eecfb43c352578f Mon Sep 17 00:00:00 2001 From: Taylor Curran Date: Wed, 11 Jun 2025 19:08:01 -0700 Subject: [PATCH 6/6] feat: Multiplexed sessions - Only populate previous transaction ID for transactions with multiplexed session. Signed-off-by: Taylor Curran --- google/cloud/spanner_v1/session.py | 14 +++- google/cloud/spanner_v1/transaction.py | 12 +-- tests/unit/test_session.py | 107 +++++++++++++++++++++---- 3 files changed, 107 insertions(+), 26 deletions(-) diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index d754baaf0f..1a9313d0d3 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -530,9 +530,11 @@ def run_in_transaction(self, func, *args, **kw): ) as span, MetricsCapture(): attempts: int = 0 - # If the transaction is retried after an aborted user operation, it should include the previous transaction ID - # in the transaction options used to begin the transaction. This allows the backend to recognize the transaction - # and increase the lock order for the new transaction ID that is created. + # If a transaction using a multiplexed session is retried after an aborted + # user operation, it should include the previous transaction ID in the + # transaction options used to begin the transaction. This allows the backend + # to recognize the transaction and increase the lock order for the new + # transaction that is created. # See :attr:`~google.cloud.spanner_v1.types.TransactionOptions.ReadWrite.multiplexed_session_previous_transaction_id` previous_transaction_id: Optional[bytes] = None @@ -541,7 +543,11 @@ def run_in_transaction(self, func, *args, **kw): txn.transaction_tag = transaction_tag txn.exclude_txn_from_change_streams = exclude_txn_from_change_streams txn.isolation_level = isolation_level - txn._previous_transaction_id = previous_transaction_id + + if self.is_multiplexed: + txn._multiplexed_session_previous_transaction_id = ( + previous_transaction_id + ) attempts += 1 span_attributes = dict(attempt=attempts) diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 3b6ae37d86..bfa43a5ea4 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -70,11 +70,11 @@ def __init__(self, session): super(Transaction, self).__init__(session) self.rolled_back: bool = False - # If this transaction is used to retry a previous aborted transaction, the - # identifier for that transaction is used to increase the lock order of the new - # transaction (see :meth:`_build_transaction_options_pb`). This attribute should - # only be set by :meth:`~google.cloud.spanner_v1.session.Session.run_in_transaction`. - self._previous_transaction_id: Optional[bytes] = None + # If this transaction is used to retry a previous aborted transaction with a + # multiplexed session, the identifier for that transaction is used to increase + # the lock order of the new transaction (see :meth:`_build_transaction_options_pb`). + # This attribute should only be set by :meth:`~google.cloud.spanner_v1.session.Session.run_in_transaction`. + self._multiplexed_session_previous_transaction_id: Optional[bytes] = None def _build_transaction_options_pb(self) -> TransactionOptions: """Builds and returns transaction options for this transaction. @@ -89,7 +89,7 @@ def _build_transaction_options_pb(self) -> TransactionOptions: merge_transaction_options = TransactionOptions( read_write=TransactionOptions.ReadWrite( - multiplexed_session_previous_transaction_id=self._previous_transaction_id + multiplexed_session_previous_transaction_id=self._multiplexed_session_previous_transaction_id ), exclude_txn_from_change_streams=self.exclude_txn_from_change_streams, isolation_level=self.isolation_level, diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 0bc1eda54d..d5b9b83478 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -33,7 +33,12 @@ from google.cloud._helpers import UTC, _datetime_to_pb_timestamp from google.cloud.spanner_v1._helpers import _delay_until_retry from google.cloud.spanner_v1.transaction import Transaction -from tests._builders import build_spanner_api, build_session, build_transaction_pb +from tests._builders import ( + build_spanner_api, + build_session, + build_transaction_pb, + build_commit_response_pb, +) from tests._helpers import ( OpenTelemetryBase, LIB_VERSION, @@ -1051,6 +1056,41 @@ def unit_of_work(txn, *args, **kw): def test_run_in_transaction_retry_callback_raises_abort(self): session = build_session() database = session._database + + # Build API responses. + api = database.spanner_api + begin_transaction = api.begin_transaction + streaming_read = api.streaming_read + streaming_read.side_effect = [_make_rpc_error(Aborted), []] + + # Run in transaction. + def unit_of_work(transaction): + transaction.begin() + list(transaction.read(TABLE_NAME, COLUMNS, KEYSET)) + + session.create() + session.run_in_transaction(unit_of_work) + + self.assertEqual(begin_transaction.call_count, 2) + + begin_transaction.assert_called_with( + request=BeginTransactionRequest( + session=session.name, + options=TransactionOptions(read_write=TransactionOptions.ReadWrite()), + ), + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.4.1", + ), + ], + ) + + def test_run_in_transaction_retry_callback_raises_abort_multiplexed(self): + session = build_session(is_multiplexed=True) + database = session._database api = database.spanner_api # Build API responses @@ -1093,6 +1133,51 @@ def unit_of_work(transaction): ], ) + def test_run_in_transaction_retry_commit_raises_abort_multiplexed(self): + session = build_session(is_multiplexed=True) + database = session._database + + # Build API responses + api = database.spanner_api + previous_transaction_id = b"transaction-id" + begin_transaction = api.begin_transaction + begin_transaction.return_value = build_transaction_pb( + id=previous_transaction_id + ) + + commit = api.commit + commit.side_effect = [_make_rpc_error(Aborted), build_commit_response_pb()] + + # Run in transaction. + def unit_of_work(transaction): + transaction.begin() + list(transaction.read(TABLE_NAME, COLUMNS, KEYSET)) + + session.create() + session.run_in_transaction(unit_of_work) + + # Verify retried BeginTransaction API call. + self.assertEqual(begin_transaction.call_count, 2) + + begin_transaction.assert_called_with( + request=BeginTransactionRequest( + session=session.name, + options=TransactionOptions( + read_write=TransactionOptions.ReadWrite( + multiplexed_session_previous_transaction_id=previous_transaction_id + ) + ), + ), + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ( + "x-goog-spanner-request-id", + f"1.{REQ_RAND_PROCESS_ID}.{database._nth_client_id}.{database._channel_id}.5.1", + ), + ], + ) + def test_run_in_transaction_w_args_w_kwargs_wo_abort(self): VALUES = [ ["phred@exammple.com", "Phred", "Phlyntstone", 32], @@ -1279,9 +1364,7 @@ def unit_of_work(txn, *args, **kw): request=BeginTransactionRequest( session=session.name, options=TransactionOptions( - read_write=TransactionOptions.ReadWrite( - multiplexed_session_previous_transaction_id=TRANSACTION_ID - ) + read_write=TransactionOptions.ReadWrite() ), ), metadata=[ @@ -1395,9 +1478,7 @@ def unit_of_work(txn, *args, **kw): request=BeginTransactionRequest( session=session.name, options=TransactionOptions( - read_write=TransactionOptions.ReadWrite( - multiplexed_session_previous_transaction_id=TRANSACTION_ID - ) + read_write=TransactionOptions.ReadWrite() ), ), metadata=[ @@ -1662,9 +1743,7 @@ def _time(_results=[1, 2, 4, 8]): request=BeginTransactionRequest( session=session.name, options=TransactionOptions( - read_write=TransactionOptions.ReadWrite( - multiplexed_session_previous_transaction_id=TRANSACTION_ID - ) + read_write=TransactionOptions.ReadWrite() ), ), metadata=[ @@ -1680,9 +1759,7 @@ def _time(_results=[1, 2, 4, 8]): request=BeginTransactionRequest( session=session.name, options=TransactionOptions( - read_write=TransactionOptions.ReadWrite( - multiplexed_session_previous_transaction_id=TRANSACTION_ID - ) + read_write=TransactionOptions.ReadWrite() ), ), metadata=[ @@ -2074,9 +2151,7 @@ def unit_of_work(txn, *args, **kw): request=BeginTransactionRequest( session=session.name, options=TransactionOptions( - read_write=TransactionOptions.ReadWrite( - multiplexed_session_previous_transaction_id=TRANSACTION_ID - ), + read_write=TransactionOptions.ReadWrite(), exclude_txn_from_change_streams=True, ), ),