From 454e7b9655d8ca2767c5569b32dfac3cce8eabe0 Mon Sep 17 00:00:00 2001 From: ankiaga Date: Wed, 27 Dec 2023 15:41:04 +0530 Subject: [PATCH 1/9] feat: Fixing and refactoring transaction retry logic in dbapi. Also adding interceptors support for testing --- .../cloud/spanner_dbapi/batch_dml_executor.py | 23 +- google/cloud/spanner_dbapi/checksum.py | 6 +- google/cloud/spanner_dbapi/connection.py | 109 +--- google/cloud/spanner_dbapi/cursor.py | 232 ++++---- google/cloud/spanner_dbapi/parse_utils.py | 2 - .../cloud/spanner_dbapi/parsed_statement.py | 3 - .../cloud/spanner_dbapi/transaction_helper.py | 230 ++++++++ google/cloud/spanner_v1/instance.py | 43 +- .../cloud/spanner_v1/testing/database_test.py | 112 ++++ .../cloud/spanner_v1/testing/interceptors.py | 59 ++ setup.py | 1 + testing/constraints-3.7.txt | 2 + tests/system/test_dbapi.py | 296 ++++++---- tests/unit/spanner_dbapi/test_connection.py | 348 +---------- tests/unit/spanner_dbapi/test_cursor.py | 460 +++++++-------- .../spanner_dbapi/test_transaction_helper.py | 541 ++++++++++++++++++ 16 files changed, 1552 insertions(+), 915 deletions(-) create mode 100644 google/cloud/spanner_dbapi/transaction_helper.py create mode 100644 google/cloud/spanner_v1/testing/database_test.py create mode 100644 google/cloud/spanner_v1/testing/interceptors.py create mode 100644 tests/unit/spanner_dbapi/test_transaction_helper.py diff --git a/google/cloud/spanner_dbapi/batch_dml_executor.py b/google/cloud/spanner_dbapi/batch_dml_executor.py index f91cf37b59..57e36d7991 100644 --- a/google/cloud/spanner_dbapi/batch_dml_executor.py +++ b/google/cloud/spanner_dbapi/batch_dml_executor.py @@ -16,7 +16,6 @@ from enum import Enum from typing import TYPE_CHECKING, List -from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.parsed_statement import ( ParsedStatement, StatementType, @@ -80,8 +79,10 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]): """ from google.cloud.spanner_dbapi import OperationalError - connection = cursor.connection many_result_set = StreamedManyResultSets() + if not statements: + return many_result_set + connection = cursor.connection statements_tuple = [] for statement in statements: statements_tuple.append(statement.get_tuple()) @@ -90,28 +91,24 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]): many_result_set.add_iter(res) cursor._row_count = sum([max(val, 0) for val in res]) else: - retried = False while True: try: transaction = connection.transaction_checkout() status, res = transaction.batch_update(statements_tuple) - many_result_set.add_iter(res) - res_checksum = ResultsChecksum() - res_checksum.consume_result(res) - res_checksum.consume_result(status.code) - if not retried: - connection._statements.append((statements, res_checksum)) - cursor._row_count = sum([max(val, 0) for val in res]) - if status.code == ABORTED: connection._transaction = None raise Aborted(status.message) elif status.code != OK: raise OperationalError(status.message) + + many_result_set.add_iter(res) + cursor._row_count = sum([max(val, 0) for val in res]) return many_result_set except Aborted: - connection.retry_transaction() - retried = True + if cursor._in_retry_mode: + raise + else: + connection._transaction_helper.retry_transaction() def _do_batch_update(transaction, statements): diff --git a/google/cloud/spanner_dbapi/checksum.py b/google/cloud/spanner_dbapi/checksum.py index 7a2a1d75b9..b2b3297db2 100644 --- a/google/cloud/spanner_dbapi/checksum.py +++ b/google/cloud/spanner_dbapi/checksum.py @@ -62,6 +62,8 @@ def consume_result(self, result): def _compare_checksums(original, retried): + from google.cloud.spanner_dbapi.transaction_helper import RETRY_ABORTED_ERROR + """Compare the given checksums. Raise an error if the given checksums are not equal. @@ -75,6 +77,4 @@ def _compare_checksums(original, retried): :raises: :exc:`google.cloud.spanner_dbapi.exceptions.RetryAborted` in case if checksums are not equal. """ if retried != original: - raise RetryAborted( - "The transaction was aborted and could not be retried due to a concurrent modification." - ) + raise RetryAborted(RETRY_ABORTED_ERROR) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 47680fd550..fb635c003b 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -13,7 +13,6 @@ # limitations under the License. """DB-API Connection for the Google Cloud Spanner.""" -import time import warnings from google.api_core.exceptions import Aborted @@ -23,18 +22,15 @@ from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor from google.cloud.spanner_dbapi.parse_utils import _get_statement_type from google.cloud.spanner_dbapi.parsed_statement import ( - ParsedStatement, - Statement, StatementType, ) from google.cloud.spanner_dbapi.partition_helper import PartitionId +from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement +from google.cloud.spanner_dbapi.transaction_helper import TransactionRetryHelper from google.cloud.spanner_v1 import RequestOptions -from google.cloud.spanner_v1.session import _get_retry_delay from google.cloud.spanner_v1.snapshot import Snapshot from deprecated import deprecated -from google.cloud.spanner_dbapi.checksum import _compare_checksums -from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Cursor from google.cloud.spanner_dbapi.exceptions import ( InterfaceError, @@ -44,13 +40,10 @@ from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT from google.cloud.spanner_dbapi.version import PY_VERSION -from google.rpc.code_pb2 import ABORTED - CLIENT_TRANSACTION_NOT_STARTED_WARNING = ( "This method is non-operational as a transaction has not been started." ) -MAX_INTERNAL_RETRIES = 50 def check_not_closed(function): @@ -106,9 +99,6 @@ def __init__(self, instance, database=None, read_only=False): self._transaction = None self._session = None self._snapshot = None - # SQL statements, which were executed - # within the current transaction - self._statements = [] self.is_closed = False self._autocommit = False @@ -125,6 +115,7 @@ def __init__(self, instance, database=None, read_only=False): self._spanner_transaction_started = False self._batch_mode = BatchMode.NONE self._batch_dml_executor: BatchDmlExecutor = None + self._transaction_helper = TransactionRetryHelper(self) @property def autocommit(self): @@ -288,76 +279,6 @@ def _release_session(self): self.database._pool.put(self._session) self._session = None - def retry_transaction(self): - """Retry the aborted transaction. - - All the statements executed in the original transaction - will be re-executed in new one. Results checksums of the - original statements and the retried ones will be compared. - - :raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted` - If results checksum of the retried statement is - not equal to the checksum of the original one. - """ - attempt = 0 - while True: - self._spanner_transaction_started = False - attempt += 1 - if attempt > MAX_INTERNAL_RETRIES: - raise - - try: - self._rerun_previous_statements() - break - except Aborted as exc: - delay = _get_retry_delay(exc.errors[0], attempt) - if delay: - time.sleep(delay) - - def _rerun_previous_statements(self): - """ - Helper to run all the remembered statements - from the last transaction. - """ - for statement in self._statements: - if isinstance(statement, list): - statements, checksum = statement - - transaction = self.transaction_checkout() - statements_tuple = [] - for single_statement in statements: - statements_tuple.append(single_statement.get_tuple()) - status, res = transaction.batch_update(statements_tuple) - - if status.code == ABORTED: - raise Aborted(status.details) - - retried_checksum = ResultsChecksum() - retried_checksum.consume_result(res) - retried_checksum.consume_result(status.code) - - _compare_checksums(checksum, retried_checksum) - else: - res_iter, retried_checksum = self.run_statement(statement, retried=True) - # executing all the completed statements - if statement != self._statements[-1]: - for res in res_iter: - retried_checksum.consume_result(res) - - _compare_checksums(statement.checksum, retried_checksum) - # executing the failed statement - else: - # streaming up to the failed result or - # to the end of the streaming iterator - while len(retried_checksum) < len(statement.checksum): - try: - res = next(iter(res_iter)) - retried_checksum.consume_result(res) - except StopIteration: - break - - _compare_checksums(statement.checksum, retried_checksum) - def transaction_checkout(self): """Get a Cloud Spanner transaction. @@ -450,11 +371,11 @@ def commit(self): if self._spanner_transaction_started and not self._read_only: self._transaction.commit() except Aborted: - self.retry_transaction() + self._transaction_helper.retry_transaction() self.commit() finally: self._release_session() - self._statements = [] + self._transaction_helper.reset() self._transaction_begin_marked = False self._spanner_transaction_started = False @@ -474,7 +395,7 @@ def rollback(self): self._transaction.rollback() finally: self._release_session() - self._statements = [] + self._transaction_helper.reset() self._transaction_begin_marked = False self._spanner_transaction_started = False @@ -493,7 +414,7 @@ def run_prior_DDL_statements(self): return self.database.update_ddl(ddl_statements).result() - def run_statement(self, statement: Statement, retried=False): + def run_statement(self, statement: Statement): """Run single SQL statement in begun transaction. This method is never used in autocommit mode. In @@ -513,17 +434,11 @@ def run_statement(self, statement: Statement, retried=False): checksum of this statement results. """ transaction = self.transaction_checkout() - if not retried: - self._statements.append(statement) - - return ( - transaction.execute_sql( - statement.sql, - statement.params, - param_types=statement.param_types, - request_options=self.request_options, - ), - ResultsChecksum() if retried else statement.checksum, + return transaction.execute_sql( + statement.sql, + statement.params, + param_types=statement.param_types, + request_options=self.request_options, ) @check_not_closed diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index ff91e9e666..1a6dd12e95 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -13,7 +13,6 @@ # limitations under the License. """Database cursor for Google Cloud Spanner DB API.""" - from collections import namedtuple import sqlparse @@ -47,6 +46,7 @@ Statement, ParsedStatement, ) +from google.cloud.spanner_dbapi.transaction_helper import CursorStatementType from google.cloud.spanner_dbapi.utils import PeekIterator from google.cloud.spanner_dbapi.utils import StreamedManyResultSets @@ -90,11 +90,12 @@ def __init__(self, connection): self._row_count = _UNSET_COUNT self.lastrowid = None self.connection = connection + self.transaction_helper = self.connection._transaction_helper self._is_closed = False - # the currently running SQL statement results checksum - self._checksum = None # the number of rows to fetch at a time with fetchmany() self.arraysize = 1 + self._parsed_statement: ParsedStatement = None + self._in_retry_mode = False @property def is_closed(self): @@ -219,7 +220,7 @@ def _batch_DDLs(self, sql): self.connection._ddl_statements.extend(statements) @check_not_closed - def execute(self, sql, args=None): + def execute(self, sql, args=None, call_from_execute_many=False): """Prepares and executes a Spanner database operation. :type sql: str @@ -233,14 +234,13 @@ def execute(self, sql, args=None): self._itr = None self._result_set = None self._row_count = _UNSET_COUNT + exception = None try: - parsed_statement: ParsedStatement = parse_utils.classify_statement( - sql, args - ) - if parsed_statement.statement_type == StatementType.CLIENT_SIDE: + self._parsed_statement = parse_utils.classify_statement(sql, args) + if self._parsed_statement.statement_type == StatementType.CLIENT_SIDE: self._result_set = client_side_statement_executor.execute( - self, parsed_statement + self, self._parsed_statement ) if self._result_set is not None: if isinstance(self._result_set, StreamedManyResultSets): @@ -248,53 +248,60 @@ def execute(self, sql, args=None): else: self._itr = PeekIterator(self._result_set) elif self.connection._batch_mode == BatchMode.DML: - self.connection.execute_batch_dml_statement(parsed_statement) + self.connection.execute_batch_dml_statement(self._parsed_statement) elif self.connection.read_only or ( not self.connection._client_transaction_started - and parsed_statement.statement_type == StatementType.QUERY + and self._parsed_statement.statement_type == StatementType.QUERY ): self._handle_DQL(sql, args or None) - elif parsed_statement.statement_type == StatementType.DDL: + elif self._parsed_statement.statement_type == StatementType.DDL: self._batch_DDLs(sql) if not self.connection._client_transaction_started: self.connection.run_prior_DDL_statements() else: - self._execute_in_rw_transaction(parsed_statement) + self._execute_in_rw_transaction() except (AlreadyExists, FailedPrecondition, OutOfRange) as e: + exception = e raise IntegrityError(getattr(e, "details", e)) from e except InvalidArgument as e: + exception = e raise ProgrammingError(getattr(e, "details", e)) from e except InternalServerError as e: + exception = e raise OperationalError(getattr(e, "details", e)) from e + except Exception as e: + exception = e + raise finally: + if not self._in_retry_mode and not call_from_execute_many: + self.transaction_helper.add_execute_statement_for_retry( + self._parsed_statement, sql, args, self.rowcount, exception, False + ) if self.connection._client_transaction_started is False: self.connection._spanner_transaction_started = False - def _execute_in_rw_transaction(self, parsed_statement: ParsedStatement): + def _execute_in_rw_transaction(self): # For every other operation, we've got to ensure that # any prior DDL statements were run. self.connection.run_prior_DDL_statements() + statement = self._parsed_statement.statement if self.connection._client_transaction_started: - ( - self._result_set, - self._checksum, - ) = self.connection.run_statement(parsed_statement.statement) - while True: try: + self._result_set = self.connection.run_statement(statement) self._itr = PeekIterator(self._result_set) - break + return except Aborted: - self.connection.retry_transaction() - except Exception as ex: - self.connection._statements.remove(parsed_statement.statement) - raise ex + if self._in_retry_mode: + raise + else: + self.transaction_helper.retry_transaction() else: self.connection.database.run_in_transaction( self._do_execute_update_in_autocommit, - parsed_statement.statement.sql, - parsed_statement.statement.params or None, + statement.sql, + statement.params or None, ) @check_not_closed @@ -314,82 +321,75 @@ def executemany(self, operation, seq_of_params): self._itr = None self._result_set = None self._row_count = _UNSET_COUNT + exception = None - parsed_statement = parse_utils.classify_statement(operation) - if parsed_statement.statement_type == StatementType.DDL: - raise ProgrammingError( - "Executing DDL statements with executemany() method is not allowed." - ) - - if parsed_statement.statement_type == StatementType.CLIENT_SIDE: - raise ProgrammingError( - "Executing the following operation: " - + operation - + ", with executemany() method is not allowed." - ) + try: + self._parsed_statement = parse_utils.classify_statement(operation) + if self._parsed_statement.statement_type == StatementType.DDL: + raise ProgrammingError( + "Executing DDL statements with executemany() method is not allowed." + ) - # For every operation, we've got to ensure that any prior DDL - # statements were run. - self.connection.run_prior_DDL_statements() - if parsed_statement.statement_type in ( - StatementType.INSERT, - StatementType.UPDATE, - ): - statements = [] - for params in seq_of_params: - sql, params = parse_utils.sql_pyformat_args_to_spanner( - operation, params + if self._parsed_statement.statement_type == StatementType.CLIENT_SIDE: + raise ProgrammingError( + "Executing the following operation: " + + operation + + ", with executemany() method is not allowed." ) - statements.append(Statement(sql, params, get_param_types(params))) - many_result_set = batch_dml_executor.run_batch_dml(self, statements) - else: - many_result_set = StreamedManyResultSets() - for params in seq_of_params: - self.execute(operation, params) - many_result_set.add_iter(self._itr) - self._result_set = many_result_set - self._itr = many_result_set + # For every operation, we've got to ensure that any prior DDL + # statements were run. + self.connection.run_prior_DDL_statements() + if self._parsed_statement.statement_type in ( + StatementType.INSERT, + StatementType.UPDATE, + ): + statements = [] + for params in seq_of_params: + sql, params = parse_utils.sql_pyformat_args_to_spanner( + operation, params + ) + statements.append(Statement(sql, params, get_param_types(params))) + many_result_set = batch_dml_executor.run_batch_dml(self, statements) + else: + many_result_set = StreamedManyResultSets() + for params in seq_of_params: + self.execute(operation, params, True) + many_result_set.add_iter(self._itr) + + self._result_set = many_result_set + self._itr = many_result_set + except Exception as e: + exception = e + raise + finally: + if not self._in_retry_mode: + self.transaction_helper.add_execute_statement_for_retry( + self._parsed_statement, + operation, + seq_of_params, + self.rowcount, + exception, + True, + ) + if self.connection._client_transaction_started is False: + self.connection._spanner_transaction_started = False @check_not_closed def fetchone(self): """Fetch the next row of a query result set, returning a single sequence, or None when no more data is available.""" - try: - res = next(self) - if ( - self.connection._client_transaction_started - and not self.connection.read_only - ): - self._checksum.consume_result(res) - return res - except StopIteration: + rows = self._fetch(CursorStatementType.FETCH_ONE) + if not rows: return - except Aborted: - if not self.connection.read_only: - self.connection.retry_transaction() - return self.fetchone() + return rows[0] @check_not_closed def fetchall(self): """Fetch all (remaining) rows of a query result, returning them as a sequence of sequences. """ - res = [] - try: - for row in self: - if ( - self.connection._client_transaction_started - and not self.connection.read_only - ): - self._checksum.consume_result(row) - res.append(row) - except Aborted: - if not self.connection.read_only: - self.connection.retry_transaction() - return self.fetchall() - - return res + return self._fetch(CursorStatementType.FETCH_ALL) @check_not_closed def fetchmany(self, size=None): @@ -405,25 +405,49 @@ def fetchmany(self, size=None): """ if size is None: size = self.arraysize + return self._fetch(CursorStatementType.FETCH_MANY, size) - items = [] - for _ in range(size): - try: - res = next(self) - if ( - self.connection._client_transaction_started - and not self.connection.read_only - ): - self._checksum.consume_result(res) - items.append(res) - except StopIteration: - break - except Aborted: - if not self.connection.read_only: - self.connection.retry_transaction() - return self.fetchmany(size) - - return items + def _fetch(self, cursor_statement_type: CursorStatementType, size=None): + exception = None + rows = [] + is_fetch_all = False + try: + while True: + rows = [] + try: + if cursor_statement_type == CursorStatementType.FETCH_ALL: + is_fetch_all = True + for row in self: + rows.append(row) + elif cursor_statement_type == CursorStatementType.FETCH_MANY: + for _ in range(size): + try: + row = next(self) + rows.append(row) + except StopIteration: + break + elif cursor_statement_type == CursorStatementType.FETCH_ONE: + try: + row = next(self) + rows.append(row) + except StopIteration: + return + break + except Aborted: + if not self.connection.read_only: + if self._in_retry_mode: + raise + else: + self.transaction_helper.retry_transaction() + except Exception as e: + exception = e + raise + finally: + if not self._in_retry_mode: + self.transaction_helper.add_fetch_statement_for_retry( + rows, exception, is_fetch_all + ) + return rows def _handle_DQL_with_snapshot(self, snapshot, sql, params): self._result_set = snapshot.execute_sql( diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index 008f21bf93..b642daf084 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -24,7 +24,6 @@ from . import client_side_statement_parser from deprecated import deprecated -from .checksum import ResultsChecksum from .exceptions import Error from .parsed_statement import ParsedStatement, StatementType, Statement from .types import DateStr, TimestampStr @@ -230,7 +229,6 @@ def classify_statement(query, args=None): query, args, get_param_types(args or None), - ResultsChecksum(), ) statement_type = _get_statement_type(statement) return ParsedStatement(statement_type, statement) diff --git a/google/cloud/spanner_dbapi/parsed_statement.py b/google/cloud/spanner_dbapi/parsed_statement.py index 798f5126c3..b489da14cc 100644 --- a/google/cloud/spanner_dbapi/parsed_statement.py +++ b/google/cloud/spanner_dbapi/parsed_statement.py @@ -15,8 +15,6 @@ from enum import Enum from typing import Any, List -from google.cloud.spanner_dbapi.checksum import ResultsChecksum - class StatementType(Enum): CLIENT_SIDE = 1 @@ -44,7 +42,6 @@ class Statement: sql: str params: Any = None param_types: Any = None - checksum: ResultsChecksum = None def get_tuple(self): return self.sql, self.params, self.param_types diff --git a/google/cloud/spanner_dbapi/transaction_helper.py b/google/cloud/spanner_dbapi/transaction_helper.py new file mode 100644 index 0000000000..e7d8ea22d0 --- /dev/null +++ b/google/cloud/spanner_dbapi/transaction_helper.py @@ -0,0 +1,230 @@ +# Copyright 2023 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +from dataclasses import dataclass +from enum import Enum +from typing import TYPE_CHECKING, List, Union, Any +from google.api_core.exceptions import Aborted + +import time + +from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode +from google.cloud.spanner_dbapi.exceptions import RetryAborted +from google.cloud.spanner_dbapi.parsed_statement import ( + StatementType, + ClientSideStatementType, +) + +if TYPE_CHECKING: + from google.cloud.spanner_dbapi import Connection, Cursor +from google.cloud.spanner_dbapi.checksum import ResultsChecksum, _compare_checksums + +MAX_INTERNAL_RETRIES = 50 +RETRY_ABORTED_ERROR = "The transaction was aborted and could not be retried due to a concurrent modification." + + +class TransactionRetryHelper: + def __init__(self, connection: "Connection"): + """Helper class used in retrying the transaction when aborted This will + maintain all the statements executed on original transaction and replay + them again in the retried transaction. + + :type connection: :class:`~google.cloud.spanner_dbapi.connection.Connection` + :param connection: A DB-API connection to Google Cloud Spanner. + """ + + self._connection = connection + # list of all single and batch statements in the same order as executed + # in original transaction along with their checksum value + self._statement_result_details_list: List[StatementResultDetails] = [] + self._last_statement_result_details: StatementResultDetails = None + + def _set_connection_for_retry(self): + self._connection._spanner_transaction_started = False + self._connection._transaction_begin_marked = False + self._connection._batch_mode = BatchMode.NONE + + def reset(self): + """ + Resets the state of the class when the ongoing transaction is committed + or aborted + """ + self._statement_result_details_list = [] + self._last_statement_result_details = None + + def add_fetch_statement_for_retry(self, rows, exception, is_fetch_all): + if not self._connection._client_transaction_started: + return + if ( + self._last_statement_result_details is not None + and self._last_statement_result_details.statement_type + == CursorStatementType.FETCH_MANY + ): + if exception is not None: + self._last_statement_result_details.result_details = exception + else: + for row in rows: + self._last_statement_result_details.result_details.consume_result( + row + ) + self._last_statement_result_details.size += len(rows) + else: + result_details = _get_statement_result_checksum(rows) + if is_fetch_all: + self._last_statement_result_details = FetchStatement( + statement_type=CursorStatementType.FETCH_ALL, + result_details=result_details, + ) + else: + self._last_statement_result_details = FetchStatement( + statement_type=CursorStatementType.FETCH_MANY, + result_details=result_details, + size=len(rows), + ) + self._statement_result_details_list.append( + self._last_statement_result_details + ) + + def add_execute_statement_for_retry( + self, parsed_statement, sql, args, rowcount, exception, is_execute_many + ): + if not self._connection._client_transaction_started: + return + statement_type = CursorStatementType.EXECUTE + if is_execute_many: + statement_type = CursorStatementType.EXECUTE_MANY + result_details = None + if exception is not None: + result_details = exception + elif ( + parsed_statement.statement_type == StatementType.INSERT + or parsed_statement.statement_type == StatementType.UPDATE + or parsed_statement.client_side_statement_type + == ClientSideStatementType.RUN_BATCH + ): + result_details = rowcount + + self._last_statement_result_details = ExecuteStatement( + statement_type=statement_type, + sql=sql, + args=args, + result_details=result_details, + ) + self._statement_result_details_list.append(self._last_statement_result_details) + + def retry_transaction(self): + """Retry the aborted transaction. + + All the statements executed in the original transaction + will be re-executed in new one. Results checksums of the + original statements and the retried ones will be compared. + + :raises: :class:`google.cloud.spanner_dbapi.exceptions.RetryAborted` + If results checksum of the retried statement is + not equal to the checksum of the original one. + """ + attempt = 0 + while True: + attempt += 1 + if attempt > MAX_INTERNAL_RETRIES: + raise + + self._set_connection_for_retry() + cursor: Cursor = self._connection.cursor() + cursor._in_retry_mode = True + try: + for statement_result_details in self._statement_result_details_list: + try: + self._handle_statement(statement_result_details, cursor) + except Aborted: + raise + except RetryAborted: + raise + except Exception as ex: + if ( + type(statement_result_details.result_details) + is not type(ex) + or ex.args != statement_result_details.result_details.args + ): + raise RetryAborted(RETRY_ABORTED_ERROR, ex) + return + except Aborted: + delay = 2**attempt + random.random() + if delay: + time.sleep(delay) + + def _handle_statement(self, statement_result_details, cursor): + statement_type = statement_result_details.statement_type + if statement_type == CursorStatementType.EXECUTE: + cursor.execute(statement_result_details.sql, statement_result_details.args) + if ( + type(statement_result_details.result_details) is int + and statement_result_details.result_details != cursor.rowcount + ): + raise RetryAborted(RETRY_ABORTED_ERROR) + elif statement_type == CursorStatementType.EXECUTE_MANY: + cursor.executemany( + statement_result_details.sql, + statement_result_details.args, + ) + if ( + type(statement_result_details.result_details) is int + and statement_result_details.result_details != cursor.rowcount + ): + raise RetryAborted(RETRY_ABORTED_ERROR) + elif statement_type == CursorStatementType.FETCH_ALL: + res = cursor.fetchall() + checksum = _get_statement_result_checksum(res) + _compare_checksums(checksum, statement_result_details.result_details) + elif statement_type == CursorStatementType.FETCH_MANY: + res = cursor.fetchmany(statement_result_details.size) + checksum = _get_statement_result_checksum(res) + _compare_checksums(checksum, statement_result_details.result_details) + + +def _get_statement_result_checksum(res_iter): + retried_checksum = ResultsChecksum() + for res in res_iter: + retried_checksum.consume_result(res) + return retried_checksum + + +class CursorStatementType(Enum): + EXECUTE = 1 + EXECUTE_MANY = 2 + FETCH_ONE = 3 + FETCH_ALL = 4 + FETCH_MANY = 5 + + +@dataclass +class StatementResultDetails: + statement_type: CursorStatementType + # This would be one of + # 1. checksum of ResultSet in case of fetch call on query statement + # 2. Total rows updated in case of DML + # 3. Exception details in case of statement execution throws exception + # 4. None in case of execute calls + result_details: Union[ResultsChecksum, int, Exception, None] + + +@dataclass +class ExecuteStatement(StatementResultDetails): + sql: str + args: Any = None + + +@dataclass +class FetchStatement(StatementResultDetails): + size: int = None diff --git a/google/cloud/spanner_v1/instance.py b/google/cloud/spanner_v1/instance.py index 1b426f8cc2..26627fb9b1 100644 --- a/google/cloud/spanner_v1/instance.py +++ b/google/cloud/spanner_v1/instance.py @@ -34,7 +34,7 @@ from google.cloud.spanner_v1._helpers import _metadata_with_prefix from google.cloud.spanner_v1.backup import Backup from google.cloud.spanner_v1.database import Database - +from google.cloud.spanner_v1.testing.database_test import TestDatabase _INSTANCE_NAME_RE = re.compile( r"^projects/(?P[^/]+)/" r"instances/(?P[a-z][-a-z0-9]*)$" @@ -433,6 +433,8 @@ def database( database_dialect=DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED, database_role=None, enable_drop_protection=False, + # should be only set for tests if tests want to use interceptors + enable_interceptors_in_tests=False, ): """Factory to create a database within this instance. @@ -472,20 +474,37 @@ def database( :param enable_drop_protection: (Optional) Represents whether the database has drop protection enabled or not. + :type enable_interceptors_in_tests: boolean + :param enable_interceptors_in_tests: (Optional) should only be set to True + for tests if the tests want to use interceptors. + :rtype: :class:`~google.cloud.spanner_v1.database.Database` :returns: a database owned by this instance. """ - return Database( - database_id, - self, - ddl_statements=ddl_statements, - pool=pool, - logger=logger, - encryption_config=encryption_config, - database_dialect=database_dialect, - database_role=database_role, - enable_drop_protection=enable_drop_protection, - ) + if not enable_interceptors_in_tests: + return Database( + database_id, + self, + ddl_statements=ddl_statements, + pool=pool, + logger=logger, + encryption_config=encryption_config, + database_dialect=database_dialect, + database_role=database_role, + enable_drop_protection=enable_drop_protection, + ) + else: + return TestDatabase( + database_id, + self, + ddl_statements=ddl_statements, + pool=pool, + logger=logger, + encryption_config=encryption_config, + database_dialect=database_dialect, + database_role=database_role, + enable_drop_protection=enable_drop_protection, + ) def list_databases(self, page_size=None): """List databases for the instance. diff --git a/google/cloud/spanner_v1/testing/database_test.py b/google/cloud/spanner_v1/testing/database_test.py new file mode 100644 index 0000000000..54afda11e0 --- /dev/null +++ b/google/cloud/spanner_v1/testing/database_test.py @@ -0,0 +1,112 @@ +# Copyright 2023 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import grpc + +from google.api_core import grpc_helpers +import google.auth.credentials +from google.cloud.spanner_admin_database_v1 import DatabaseDialect +from google.cloud.spanner_v1 import SpannerClient +from google.cloud.spanner_v1.database import Database, SPANNER_DATA_SCOPE +from google.cloud.spanner_v1.services.spanner.transports import ( + SpannerGrpcTransport, + SpannerTransport, +) +from google.cloud.spanner_v1.testing.interceptors import ( + MethodCountInterceptor, + MethodAbortInterceptor, +) + + +class TestDatabase(Database): + """Representation of a Cloud Spanner Database. This class is only used for + system testing as there is no support for interceptors in grpc client + currently, and we don't want to make changes in the Database class for + testing purpose as this is a hack to use interceptors in tests.""" + + def __init__( + self, + database_id, + instance, + ddl_statements=(), + pool=None, + logger=None, + encryption_config=None, + database_dialect=DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED, + database_role=None, + enable_drop_protection=False, + ): + super().__init__( + database_id, + instance, + ddl_statements, + pool, + logger, + encryption_config, + database_dialect, + database_role, + enable_drop_protection, + ) + + self._method_count_interceptor = MethodCountInterceptor() + self._method_abort_interceptor = MethodAbortInterceptor() + self._interceptors = [ + self._method_count_interceptor, + self._method_abort_interceptor, + ] + + @property + def spanner_api(self): + """Helper for session-related API calls.""" + if self._spanner_api is None: + client = self._instance._client + client_info = client._client_info + client_options = client._client_options + if self._instance.emulator_host is not None: + channel = grpc.insecure_channel(self._instance.emulator_host) + channel = grpc.intercept_channel(channel, *self._interceptors) + transport = SpannerGrpcTransport(channel=channel) + self._spanner_api = SpannerClient( + client_info=client_info, + transport=transport, + ) + return self._spanner_api + credentials = client.credentials + if isinstance(credentials, google.auth.credentials.Scoped): + credentials = credentials.with_scopes((SPANNER_DATA_SCOPE,)) + self._spanner_api = self._create_spanner_client_for_tests( + client_options, + credentials, + ) + return self._spanner_api + + def _create_spanner_client_for_tests(self, client_options, credentials): + ( + api_endpoint, + client_cert_source_func, + ) = SpannerClient.get_mtls_endpoint_and_cert_source(client_options) + channel = grpc_helpers.create_channel( + api_endpoint, + credentials=credentials, + credentials_file=client_options.credentials_file, + quota_project_id=client_options.quota_project_id, + default_scopes=SpannerTransport.AUTH_SCOPES, + scopes=client_options.scopes, + default_host=SpannerTransport.DEFAULT_HOST, + ) + channel = grpc.intercept_channel(channel, *self._interceptors) + transport = SpannerGrpcTransport(channel=channel) + return SpannerClient( + client_options=client_options, + transport=transport, + ) diff --git a/google/cloud/spanner_v1/testing/interceptors.py b/google/cloud/spanner_v1/testing/interceptors.py new file mode 100644 index 0000000000..a439ba4f86 --- /dev/null +++ b/google/cloud/spanner_v1/testing/interceptors.py @@ -0,0 +1,59 @@ +# Copyright 2023 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from grpc_interceptor import ClientInterceptor +from google.api_core.exceptions import Aborted + + +class MethodCountInterceptor(ClientInterceptor): + """Test interceptor that counts number of times a method is being called.""" + + def __init__(self): + self._counts = defaultdict(int) + + def intercept(self, method, request_or_iterator, call_details): + """Count number of times a method is being called.""" + self._counts[call_details.method] += 1 + return method(request_or_iterator, call_details) + + def reset(self): + self._counts = defaultdict(int) + + +class MethodAbortInterceptor(ClientInterceptor): + """Test interceptor that throws Aborted exception for a specific method.""" + + def __init__(self): + self._method_to_abort = None + self._count = 0 + self._max_raise_count = 1 + + def intercept(self, method, request_or_iterator, call_details): + if ( + self._count < self._max_raise_count + and call_details.method == self._method_to_abort + ): + self._count += 1 + raise Aborted("Thrown from ClientInterceptor for testing") + return method(request_or_iterator, call_details) + + def set_method_to_abort(self, method_to_abort, max_raise_count=1): + self._method_to_abort = method_to_abort + self._max_raise_count = max_raise_count + + def reset(self): + """Reset the interceptor to the original state.""" + self._method_to_abort = None + self._count = 0 diff --git a/setup.py b/setup.py index ec4d94c05e..4518234679 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ "proto-plus >= 1.22.2, <2.0.0dev; python_version>='3.11'", "protobuf>=3.19.5,<5.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", "deprecated >= 1.2.14", + "grpc-interceptor >= 0.15.4", ] extras = { "tracing": [ diff --git a/testing/constraints-3.7.txt b/testing/constraints-3.7.txt index 165814fd90..b0162a8987 100644 --- a/testing/constraints-3.7.txt +++ b/testing/constraints-3.7.txt @@ -14,3 +14,5 @@ opentelemetry-api==1.1.0 opentelemetry-sdk==1.1.0 opentelemetry-instrumentation==0.20b0 protobuf==3.19.5 +deprecated==1.2.14 +grpc-interceptor==0.15.4 diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index aa3fd610e1..e012299925 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -13,8 +13,7 @@ # limitations under the License. import datetime -import hashlib -import pickle +from collections import defaultdict import pytest import time @@ -29,6 +28,11 @@ from . import _helpers DATABASE_NAME = "dbapi-txn" +SPANNER_RPC_PREFIX = "/google.spanner.v1.Spanner/" +EXECUTE_BATCH_DML_METHOD = SPANNER_RPC_PREFIX + "ExecuteBatchDml" +COMMIT_METHOD = SPANNER_RPC_PREFIX + "Commit" +EXECUTE_SQL_METHOD = SPANNER_RPC_PREFIX + "ExecuteSql" +EXECUTE_STREAMING_SQL_METHOD = SPANNER_RPC_PREFIX + "ExecuteStreamingSql" DDL_STATEMENTS = ( """CREATE TABLE contacts ( @@ -49,6 +53,7 @@ def raw_database(shared_instance, database_operation_timeout, not_postgres): database_id, ddl_statements=DDL_STATEMENTS, pool=pool, + enable_interceptors_in_tests=True, ) op = database.create() op.result(database_operation_timeout) # raises on failure / timeout. @@ -65,6 +70,9 @@ def clear_table(transaction): @pytest.fixture(scope="function") def dbapi_database(self, raw_database): + # Resetting the count so that each test gives correct count of the api + # methods called during that test + raw_database._method_count_interceptor._counts = defaultdict(int) raw_database.run_in_transaction(self.clear_table) yield raw_database @@ -126,7 +134,10 @@ def test_commit(self, client_side): assert got_rows == [updated_row] - @pytest.mark.skip(reason="b/315807641") + @pytest.mark.skipif( + _helpers.USE_EMULATOR, + reason="Emulator does not support multiple parallel transactions.", + ) def test_commit_exception(self): """Test that if exception during commit method is caught, then subsequent operations on same Cursor and Connection object works @@ -148,7 +159,10 @@ def test_commit_exception(self): assert got_rows == [updated_row] - @pytest.mark.skip(reason="b/315807641") + @pytest.mark.skipif( + _helpers.USE_EMULATOR, + reason="Emulator does not support multiple parallel transactions.", + ) def test_rollback_exception(self): """Test that if exception during rollback method is caught, then subsequent operations on same Cursor and Connection object works @@ -170,7 +184,6 @@ def test_rollback_exception(self): assert got_rows == [updated_row] - @pytest.mark.skip(reason="b/315807641") def test_cursor_execute_exception(self): """Test that if exception in Cursor's execute method is caught when Connection is not in autocommit mode, then subsequent operations on @@ -250,27 +263,35 @@ def test_begin_client_side(self, shared_instance, dbapi_database): conn3 = Connection(shared_instance, dbapi_database) cursor3 = conn3.cursor() cursor3.execute("SELECT * FROM contacts") - conn3.commit() got_rows = cursor3.fetchall() + conn3.commit() cursor3.close() conn3.close() assert got_rows == [updated_row] - def test_begin_and_commit(self): + def test_noop_sql_statements(self, dbapi_database): """Test beginning and then committing a transaction is a Noop""" + dbapi_database._method_count_interceptor.reset() self._cursor.execute("begin transaction") self._cursor.execute("commit transaction") + assert dbapi_database._method_count_interceptor._counts == {} self._cursor.execute("SELECT * FROM contacts") self._conn.commit() assert self._cursor.fetchall() == [] - def test_begin_and_rollback(self): """Test beginning and then rolling back a transaction is a Noop""" + dbapi_database._method_count_interceptor.reset() self._cursor.execute("begin transaction") self._cursor.execute("rollback transaction") + assert dbapi_database._method_count_interceptor._counts == {} self._cursor.execute("SELECT * FROM contacts") - self._conn.commit() assert self._cursor.fetchall() == [] + self._conn.commit() + + dbapi_database._method_count_interceptor.reset() + self._cursor.execute("start batch dml") + self._cursor.execute("run batch") + assert dbapi_database._method_count_interceptor._counts == {} def test_read_and_commit_timestamps(self): """Test COMMIT_TIMESTAMP is not available after read statement and @@ -420,19 +441,18 @@ def test_read_timestamp_client_side_autocommit(self): assert self._cursor.description[0].name == "SHOW_READ_TIMESTAMP" assert isinstance(read_timestamp_query_result_1[0][0], DatetimeWithNanoseconds) - self._conn.read_only = False - self._insert_row(3) - - self._conn.read_only = True + time.sleep(0.25) self._cursor.execute("SELECT * FROM contacts") self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP") read_timestamp_query_result_2 = self._cursor.fetchall() assert read_timestamp_query_result_1 != read_timestamp_query_result_2 @pytest.mark.parametrize("auto_commit", [False, True]) - def test_batch_dml(self, auto_commit): + def test_batch_dml(self, auto_commit, dbapi_database): """Test batch dml.""" + method_count_interceptor = dbapi_database._method_count_interceptor + method_count_interceptor.reset() if auto_commit: self._conn.autocommit = True self._insert_row(1) @@ -481,6 +501,8 @@ def test_batch_dml(self, auto_commit): self._cursor.execute("SELECT * FROM contacts") assert len(self._cursor.fetchall()) == 9 + # Test that ExecuteBatchDml rpc is called + assert method_count_interceptor._counts[EXECUTE_BATCH_DML_METHOD] == 3 def test_abort_batch_dml(self): """Test abort batch dml.""" @@ -540,81 +562,185 @@ def test_batch_dml_invalid_statements(self): with pytest.raises(OperationalError): self._cursor.execute("run batch") - def test_partitioned_query(self): - """Test partition query works in read-only mode.""" + def _insert_row(self, i): + self._cursor.execute( + f""" + INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES ({i}, 'first-name-{i}', 'last-name-{i}', 'test.email@domen.ru') + """ + ) + + def test_commit_abort_retry(self, dbapi_database): + """Test that when commit failed with Abort exception, then the retry + succeeds with transaction having insert as well as query type of + statements along with batch dml statements. + We are trying to test all types of statements like execute, executemany, + fetchone, fetchmany, fetchall""" + + method_count_interceptor = dbapi_database._method_count_interceptor + method_count_interceptor.reset() + # called 2 times + self._insert_row(1) + # called 2 times + self._cursor.execute("SELECT * FROM contacts") + self._cursor.fetchall() self._cursor.execute("start batch dml") - for i in range(1, 11): - self._insert_row(i) + self._insert_row(2) + self._insert_row(3) + # called 2 times for batch dml rpc self._cursor.execute("run batch") + row_data = [ + (4, "first-name4", "last-name4", "test.email4@example.com"), + (5, "first-name5", "last-name5", "test.email5@example.com"), + ] + # called 2 times for batch dml rpc + self._cursor.executemany( + """ + INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (%s, %s, %s, %s) + """, + row_data, + ) + # called 2 times and as this would make 3 execute streaming sql calls + # so total 6 calls + self._cursor.executemany( + """SELECT * FROM contacts WHERE contact_id = %s""", + ((1,), (2,), (3,)), + ) + self._cursor.fetchone() + self._cursor.fetchmany(2) + dbapi_database._method_abort_interceptor.set_method_to_abort(COMMIT_METHOD) + # called 2 times self._conn.commit() + dbapi_database._method_abort_interceptor.reset() + assert method_count_interceptor._counts[COMMIT_METHOD] == 2 + assert method_count_interceptor._counts[EXECUTE_BATCH_DML_METHOD] == 4 + assert method_count_interceptor._counts[EXECUTE_STREAMING_SQL_METHOD] == 10 - self._conn.read_only = True - self._cursor.execute("PARTITION SELECT * FROM contacts") - partition_id_rows = self._cursor.fetchall() - assert len(partition_id_rows) > 0 - - rows = [] - for partition_id_row in partition_id_rows: - self._cursor.execute("RUN PARTITION " + partition_id_row[0]) - rows = rows + self._cursor.fetchall() - assert len(rows) == 10 - self._conn.commit() + self._cursor.execute("SELECT * FROM contacts") + got_rows = self._cursor.fetchall() + assert len(got_rows) == 5 - def test_partitioned_query_in_rw_transaction(self): - """Test partition query throws exception when connection is not in - read-only mode and neither in auto-commit mode.""" + def test_execute_sql_abort_retry_multiple_times(self, dbapi_database): + """Test that when execute sql failed 2 times with Abort exception, then + the retry succeeds 3rd time.""" - with pytest.raises(ProgrammingError): - self._cursor.execute("PARTITION SELECT * FROM contacts") + method_count_interceptor = dbapi_database._method_count_interceptor + method_count_interceptor.reset() + self._cursor.execute("start batch dml") + self._insert_row(1) + self._insert_row(2) + self._cursor.execute("run batch") + # aborting method 2 times before succeeding + dbapi_database._method_abort_interceptor.set_method_to_abort( + EXECUTE_STREAMING_SQL_METHOD, 2 + ) + self._cursor.execute("SELECT * FROM contacts") + self._cursor.fetchmany(2) + dbapi_database._method_abort_interceptor.reset() + self._conn.commit() + # Check that all rpcs except commit should be called 3 times the original + print(method_count_interceptor._counts) + assert method_count_interceptor._counts[COMMIT_METHOD] == 1 + assert method_count_interceptor._counts[EXECUTE_BATCH_DML_METHOD] == 3 + assert method_count_interceptor._counts[EXECUTE_STREAMING_SQL_METHOD] == 3 - def test_partitioned_query_with_dml_query(self): - """Test partition query throws exception when sql query is a DML query.""" + self._cursor.execute("SELECT * FROM contacts") + got_rows = self._cursor.fetchall() + assert len(got_rows) == 2 - self._conn.read_only = True - with pytest.raises(ProgrammingError): - self._cursor.execute( - """ - PARTITION INSERT INTO contacts (contact_id, first_name, last_name, email) - VALUES (1111, 'first-name', 'last-name', 'test.email@domen.ru') - """ - ) + def test_execute_batch_dml_abort_retry(self, dbapi_database): + """Test that when any execute batch dml failed with Abort exception, + then the retry succeeds with transaction having insert as well as query + type of statements along with batch dml statements.""" - def test_partitioned_query_in_autocommit_mode(self): - """Test partition query works when connection is not in read-only mode - but is in auto-commit mode.""" + method_count_interceptor = dbapi_database._method_count_interceptor + method_count_interceptor.reset() + # called 3 times + self._insert_row(1) + # called 3 times + self._cursor.execute("SELECT * FROM contacts") + self._cursor.fetchall() self._cursor.execute("start batch dml") - for i in range(1, 11): - self._insert_row(i) + self._insert_row(2) + self._insert_row(3) + dbapi_database._method_abort_interceptor.set_method_to_abort( + EXECUTE_BATCH_DML_METHOD, 2 + ) + # called 3 times self._cursor.execute("run batch") + dbapi_database._method_abort_interceptor.reset() self._conn.commit() + assert method_count_interceptor._counts[COMMIT_METHOD] == 1 + assert method_count_interceptor._counts[EXECUTE_BATCH_DML_METHOD] == 3 + assert method_count_interceptor._counts[EXECUTE_STREAMING_SQL_METHOD] == 6 - self._conn.autocommit = True - self._cursor.execute("PARTITION SELECT * FROM contacts") - partition_id_rows = self._cursor.fetchall() - assert len(partition_id_rows) > 0 + self._cursor.execute("SELECT * FROM contacts") + got_rows = self._cursor.fetchall() + assert len(got_rows) == 3 - rows = [] - for partition_id_row in partition_id_rows: - self._cursor.execute("RUN PARTITION " + partition_id_row[0]) - rows = rows + self._cursor.fetchall() - assert len(rows) == 10 + def test_multiple_aborts_in_transaction(self, dbapi_database): + """Test that when there are multiple Abort exceptions in a transaction + on different statements, then the retry succeeds.""" - def test_partitioned_query_with_client_transaction_started(self): - """Test partition query throws exception when connection is not in - read-only mode and transaction started using client side statement.""" + method_count_interceptor = dbapi_database._method_count_interceptor + method_count_interceptor.reset() + # called 3 times + self._insert_row(1) + dbapi_database._method_abort_interceptor.set_method_to_abort( + EXECUTE_STREAMING_SQL_METHOD + ) + # called 3 times + self._cursor.execute("SELECT * FROM contacts") + dbapi_database._method_abort_interceptor.reset() + self._cursor.fetchall() + # called 2 times + self._insert_row(2) + # called 2 times + self._cursor.execute("SELECT * FROM contacts") + self._cursor.fetchone() + dbapi_database._method_abort_interceptor.set_method_to_abort(COMMIT_METHOD) + # called 2 times + self._conn.commit() + dbapi_database._method_abort_interceptor.reset() + assert method_count_interceptor._counts[COMMIT_METHOD] == 2 + assert method_count_interceptor._counts[EXECUTE_STREAMING_SQL_METHOD] == 10 - self._conn.autocommit = True - self._cursor.execute("begin transaction") - with pytest.raises(ProgrammingError): - self._cursor.execute("PARTITION SELECT * FROM contacts") + self._cursor.execute("SELECT * FROM contacts") + got_rows = self._cursor.fetchall() + assert len(got_rows) == 2 - def _insert_row(self, i): - self._cursor.execute( - f""" - INSERT INTO contacts (contact_id, first_name, last_name, email) - VALUES ({i}, 'first-name-{i}', 'last-name-{i}', 'test.email@domen.ru') - """ - ) + def test_consecutive_aborted_transactions(self, dbapi_database): + """Test 2 consecutive transactions with Abort exceptions on the same + connection works.""" + + method_count_interceptor = dbapi_database._method_count_interceptor + method_count_interceptor.reset() + self._insert_row(1) + self._insert_row(2) + self._cursor.execute("SELECT * FROM contacts") + self._cursor.fetchall() + dbapi_database._method_abort_interceptor.set_method_to_abort(COMMIT_METHOD) + self._conn.commit() + dbapi_database._method_abort_interceptor.reset() + assert method_count_interceptor._counts[COMMIT_METHOD] == 2 + assert method_count_interceptor._counts[EXECUTE_STREAMING_SQL_METHOD] == 6 + + method_count_interceptor = dbapi_database._method_count_interceptor + method_count_interceptor.reset() + self._insert_row(3) + self._insert_row(4) + self._cursor.execute("SELECT * FROM contacts") + self._cursor.fetchall() + dbapi_database._method_abort_interceptor.set_method_to_abort(COMMIT_METHOD) + self._conn.commit() + dbapi_database._method_abort_interceptor.reset() + assert method_count_interceptor._counts[COMMIT_METHOD] == 2 + assert method_count_interceptor._counts[EXECUTE_STREAMING_SQL_METHOD] == 6 + + self._cursor.execute("SELECT * FROM contacts") + got_rows = self._cursor.fetchall() + assert len(got_rows) == 4 def test_begin_success_post_commit(self): """Test beginning a new transaction post commiting an existing transaction @@ -763,32 +889,6 @@ def test_rollback_on_connection_closing(self, shared_instance, dbapi_database): cursor.close() conn.close() - def test_results_checksum(self): - """Test that results checksum is calculated properly.""" - - self._cursor.execute( - """ - INSERT INTO contacts (contact_id, first_name, last_name, email) - VALUES - (1, 'first-name', 'last-name', 'test.email@domen.ru'), - (2, 'first-name2', 'last-name2', 'test.email2@domen.ru') - """ - ) - assert len(self._conn._statements) == 1 - self._conn.commit() - - self._cursor.execute("SELECT * FROM contacts") - got_rows = self._cursor.fetchall() - - assert len(self._conn._statements) == 1 - self._conn.commit() - - checksum = hashlib.sha256() - checksum.update(pickle.dumps(got_rows[0])) - checksum.update(pickle.dumps(got_rows[1])) - - assert self._cursor._checksum.checksum.digest() == checksum.digest() - def test_execute_many(self): row_data = [ (1, "first-name", "last-name", "test.email@example.com"), diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 8996a06ce6..e2706e1966 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -139,10 +139,6 @@ def test_read_only_not_retried(self): ) ) - cursor.fetchone() - cursor.fetchall() - cursor.fetchmany(5) - connection.retry_transaction.assert_not_called() @staticmethod @@ -280,6 +276,8 @@ def test_commit(self): self._under_test._transaction = mock_transaction = mock.MagicMock() self._under_test._spanner_transaction_started = True mock_transaction.commit = mock_commit = mock.MagicMock() + transaction_helper = self._under_test._transaction_helper + transaction_helper._statement_result_details_list = [{}, {}] with mock.patch( "google.cloud.spanner_dbapi.connection.Connection._release_session" @@ -288,6 +286,7 @@ def test_commit(self): mock_commit.assert_called_once_with() mock_release.assert_called_once_with() + self.assertEqual(len(transaction_helper._statement_result_details_list), 0) @mock.patch.object(warnings, "warn") def test_commit_in_autocommit_mode(self, mock_warn): @@ -325,12 +324,14 @@ def test_rollback(self, mock_warn): self._under_test._transaction = mock_transaction mock_rollback = mock.MagicMock() mock_transaction.rollback = mock_rollback - + transaction_helper = self._under_test._transaction_helper + transaction_helper._statement_result_details_list = [{}, {}] with mock.patch( "google.cloud.spanner_dbapi.connection.Connection._release_session" ) as mock_release: self._under_test.rollback() + self.assertEqual(len(transaction_helper._statement_result_details_list), 0) mock_rollback.assert_called_once_with() mock_release.assert_called_once_with() @@ -493,347 +494,22 @@ def test_begin(self): self.assertEqual(self._under_test._transaction_begin_marked, True) - def test_run_statement_wo_retried(self): - """Check that Connection remembers executed statements.""" - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - - sql = """SELECT 23 FROM table WHERE id = @a1""" - params = {"a1": "value"} - param_types = {"a1": str} - - connection = self._make_connection() - connection.transaction_checkout = mock.Mock() - statement = Statement(sql, params, param_types, ResultsChecksum()) - connection.run_statement(statement) - - self.assertEqual(connection._statements[0].sql, sql) - self.assertEqual(connection._statements[0].params, params) - self.assertEqual(connection._statements[0].param_types, param_types) - self.assertIsInstance(connection._statements[0].checksum, ResultsChecksum) - - def test_run_statement_w_retried(self): - """Check that Connection doesn't remember re-executed statements.""" - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - - sql = """SELECT 23 FROM table WHERE id = @a1""" - params = {"a1": "value"} - param_types = {"a1": str} - - connection = self._make_connection() - connection.transaction_checkout = mock.Mock() - statement = Statement(sql, params, param_types, ResultsChecksum()) - connection.run_statement(statement, retried=True) - - self.assertEqual(len(connection._statements), 0) - - def test_run_statement_w_heterogenous_insert_statements(self): - """Check that Connection executed heterogenous insert statements.""" - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - from google.rpc.status_pb2 import Status - from google.rpc.code_pb2 import OK - - sql = "INSERT INTO T (f1, f2) VALUES (1, 2)" - params = None - param_types = None - - connection = self._make_connection() - transaction = mock.MagicMock() - connection.transaction_checkout = mock.Mock(return_value=transaction) - transaction.batch_update = mock.Mock(return_value=(Status(code=OK), 1)) - statement = Statement(sql, params, param_types, ResultsChecksum()) - - connection.run_statement(statement, retried=True) - - self.assertEqual(len(connection._statements), 0) - - def test_run_statement_w_homogeneous_insert_statements(self): - """Check that Connection executed homogeneous insert statements.""" - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - from google.rpc.status_pb2 import Status - from google.rpc.code_pb2 import OK - - sql = "INSERT INTO T (f1, f2) VALUES (%s, %s), (%s, %s)" - params = ["a", "b", "c", "d"] - param_types = {"f1": str, "f2": str} - - connection = self._make_connection() - transaction = mock.MagicMock() - connection.transaction_checkout = mock.Mock(return_value=transaction) - transaction.batch_update = mock.Mock(return_value=(Status(code=OK), 1)) - statement = Statement(sql, params, param_types, ResultsChecksum()) - - connection.run_statement(statement, retried=True) - - self.assertEqual(len(connection._statements), 0) - - @mock.patch("google.cloud.spanner_v1.transaction.Transaction") - def test_commit_clears_statements(self, mock_transaction): - """ - Check that all the saved statements are - cleared, when the transaction is commited. - """ - connection = self._make_connection() - connection._spanner_transaction_started = True - connection._transaction = mock.Mock() - connection._statements = [{}, {}] - - self.assertEqual(len(connection._statements), 2) - - connection.commit() - - self.assertEqual(len(connection._statements), 0) - - @mock.patch("google.cloud.spanner_v1.transaction.Transaction") - def test_rollback_clears_statements(self, mock_transaction): - """ - Check that all the saved statements are - cleared, when the transaction is roll backed. - """ - connection = self._make_connection() - connection._spanner_transaction_started = True - connection._transaction = mock_transaction - connection._statements = [{}, {}] - - self.assertEqual(len(connection._statements), 2) - - connection.rollback() - - self.assertEqual(len(connection._statements), 0) - - def test_retry_transaction_w_checksum_match(self): - """Check retrying an aborted transaction.""" - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - - row = ["field1", "field2"] - connection = self._make_connection() - checksum = ResultsChecksum() - checksum.consume_result(row) - - retried_checkum = ResultsChecksum() - run_mock = connection.run_statement = mock.Mock() - run_mock.return_value = ([row], retried_checkum) - - statement = Statement("SELECT 1", [], {}, checksum) - connection._statements.append(statement) - - with mock.patch( - "google.cloud.spanner_dbapi.connection._compare_checksums" - ) as compare_mock: - connection.retry_transaction() - - compare_mock.assert_called_with(checksum, retried_checkum) - run_mock.assert_called_with(statement, retried=True) - - def test_retry_transaction_w_checksum_mismatch(self): - """ - Check retrying an aborted transaction - with results checksums mismatch. - """ - from google.cloud.spanner_dbapi.exceptions import RetryAborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - - row = ["field1", "field2"] - retried_row = ["field3", "field4"] - connection = self._make_connection() - - checksum = ResultsChecksum() - checksum.consume_result(row) - retried_checkum = ResultsChecksum() - run_mock = connection.run_statement = mock.Mock() - run_mock.return_value = ([retried_row], retried_checkum) - - statement = Statement("SELECT 1", [], {}, checksum) - connection._statements.append(statement) - - with self.assertRaises(RetryAborted): - connection.retry_transaction() - @mock.patch("google.cloud.spanner_v1.Client") def test_commit_retry_aborted_statements(self, mock_client): """Check that retried transaction executing the same statements.""" from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.connection import connect - from google.cloud.spanner_dbapi.parsed_statement import Statement - - row = ["field1", "field2"] connection = connect("test-instance", "test-database") - - cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) - - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) mock_transaction = mock.Mock() connection._spanner_transaction_started = True connection._transaction = mock_transaction mock_transaction.commit.side_effect = [Aborted("Aborted"), None] - run_mock = connection.run_statement = mock.Mock() - run_mock.return_value = ([row], ResultsChecksum()) + run_mock = connection._transaction_helper = mock.Mock() connection.commit() - run_mock.assert_called_with(statement, retried=True) - - @mock.patch("google.cloud.spanner_v1.Client") - def test_retry_aborted_retry(self, mock_client): - """ - Check that in case of a retried transaction failed, - the connection will retry it once again. - """ - from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.connection import connect - from google.cloud.spanner_dbapi.parsed_statement import Statement - - row = ["field1", "field2"] - - connection = connect("test-instance", "test-database") - - cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) - - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) - metadata_mock = mock.Mock() - metadata_mock.trailing_metadata.return_value = {} - run_mock = connection.run_statement = mock.Mock() - run_mock.side_effect = [ - Aborted("Aborted", errors=[metadata_mock]), - ([row], ResultsChecksum()), - ] - - connection.retry_transaction() - - run_mock.assert_has_calls( - ( - mock.call(statement, retried=True), - mock.call(statement, retried=True), - ) - ) - - def test_retry_transaction_raise_max_internal_retries(self): - """Check retrying raise an error of max internal retries.""" - from google.cloud.spanner_dbapi import connection as conn - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - - conn.MAX_INTERNAL_RETRIES = 0 - row = ["field1", "field2"] - connection = self._make_connection() - - checksum = ResultsChecksum() - checksum.consume_result(row) - - statement = Statement("SELECT 1", [], {}, checksum) - connection._statements.append(statement) - - with self.assertRaises(Exception): - connection.retry_transaction() - - conn.MAX_INTERNAL_RETRIES = 50 - - @mock.patch("google.cloud.spanner_v1.Client") - def test_retry_aborted_retry_without_delay(self, mock_client): - """ - Check that in case of a retried transaction failed, - the connection will retry it once again. - """ - from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.connection import connect - from google.cloud.spanner_dbapi.parsed_statement import Statement - - row = ["field1", "field2"] - - connection = connect("test-instance", "test-database") - - cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) - - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) - metadata_mock = mock.Mock() - metadata_mock.trailing_metadata.return_value = {} - run_mock = connection.run_statement = mock.Mock() - run_mock.side_effect = [ - Aborted("Aborted", errors=[metadata_mock]), - ([row], ResultsChecksum()), - ] - connection._get_retry_delay = mock.Mock(return_value=False) - - connection.retry_transaction() - - run_mock.assert_has_calls( - ( - mock.call(statement, retried=True), - mock.call(statement, retried=True), - ) - ) - - def test_retry_transaction_w_multiple_statement(self): - """Check retrying an aborted transaction.""" - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - - row = ["field1", "field2"] - connection = self._make_connection() - - checksum = ResultsChecksum() - checksum.consume_result(row) - retried_checkum = ResultsChecksum() - - statement = Statement("SELECT 1", [], {}, checksum) - statement1 = Statement("SELECT 2", [], {}, checksum) - connection._statements.append(statement) - connection._statements.append(statement1) - run_mock = connection.run_statement = mock.Mock() - run_mock.return_value = ([row], retried_checkum) - - with mock.patch( - "google.cloud.spanner_dbapi.connection._compare_checksums" - ) as compare_mock: - connection.retry_transaction() - - compare_mock.assert_called_with(checksum, retried_checkum) - - run_mock.assert_called_with(statement1, retried=True) - - def test_retry_transaction_w_empty_response(self): - """Check retrying an aborted transaction.""" - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.parsed_statement import Statement - - row = [] - connection = self._make_connection() - - checksum = ResultsChecksum() - checksum.count = 1 - retried_checkum = ResultsChecksum() - - statement = Statement("SELECT 1", [], {}, checksum) - connection._statements.append(statement) - run_mock = connection.run_statement = mock.Mock() - run_mock.return_value = ([row], retried_checkum) - - with mock.patch( - "google.cloud.spanner_dbapi.connection._compare_checksums" - ) as compare_mock: - connection.retry_transaction() - - compare_mock.assert_called_with(checksum, retried_checkum) - - run_mock.assert_called_with(statement, retried=True) + assert run_mock.retry_transaction.called def test_validate_ok(self): connection = self._make_connection() @@ -978,6 +654,7 @@ def test_staleness_single_use_autocommit(self, MockedPeekIterator): snapshot_obj = mock.Mock() _result_set = mock.Mock() snapshot_obj.execute_sql.return_value = _result_set + _result_set.stats = None snapshot_ctx = mock.Mock() snapshot_ctx.__enter__ = mock.Mock(return_value=snapshot_obj) @@ -1011,6 +688,8 @@ def test_staleness_single_use_readonly_autocommit(self, MockedPeekIterator): # mock snapshot context manager snapshot_obj = mock.Mock() _result_set = mock.Mock() + _result_set.stats = None + snapshot_obj.execute_sql.return_value = _result_set snapshot_ctx = mock.Mock() @@ -1026,7 +705,6 @@ def test_staleness_single_use_readonly_autocommit(self, MockedPeekIterator): connection.database.snapshot.assert_called_with(read_timestamp=timestamp) def test_request_priority(self): - from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.parsed_statement import Statement from google.cloud.spanner_v1 import RequestOptions @@ -1044,7 +722,7 @@ def test_request_priority(self): req_opts = RequestOptions(priority=priority) - connection.run_statement(Statement(sql, params, param_types, ResultsChecksum())) + connection.run_statement(Statement(sql, params, param_types)) connection._transaction.execute_sql.assert_called_with( sql, params, param_types=param_types, request_options=req_opts @@ -1052,7 +730,7 @@ def test_request_priority(self): assert connection.request_priority is None # check that priority is applied for only one request - connection.run_statement(Statement(sql, params, param_types, ResultsChecksum())) + connection.run_statement(Statement(sql, params, param_types)) connection._transaction.execute_sql.assert_called_with( sql, params, param_types=param_types, request_options=None diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 3328b0e17f..b91955bfb7 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -16,12 +16,15 @@ from unittest import mock import sys import unittest +from google.rpc.code_pb2 import ABORTED from google.cloud.spanner_dbapi.parsed_statement import ( ParsedStatement, StatementType, Statement, ) +from google.api_core.exceptions import Aborted +from google.cloud.spanner_dbapi.connection import connect class TestCursor(unittest.TestCase): @@ -44,7 +47,7 @@ def _make_connection(self, *args, **kwargs): def _transaction_mock(self, mock_response=[]): from google.rpc.code_pb2 import OK - transaction = mock.Mock(committed=False, rolled_back=False) + transaction = mock.Mock() transaction.batch_update = mock.Mock( return_value=[mock.Mock(code=OK), mock_response] ) @@ -175,8 +178,6 @@ def test_execute_database_error(self): cursor.execute(sql="SELECT 1") def test_execute_autocommit_off(self): - from google.cloud.spanner_dbapi.utils import PeekIterator - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) cursor.connection._autocommit = False @@ -184,30 +185,24 @@ def test_execute_autocommit_off(self): cursor.execute("sql") self.assertIsInstance(cursor._result_set, mock.MagicMock) - self.assertIsInstance(cursor._itr, PeekIterator) def test_execute_insert_statement_autocommit_off(self): - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.utils import PeekIterator - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) cursor.connection._autocommit = False cursor.connection.transaction_checkout = mock.MagicMock(autospec=True) - cursor._checksum = ResultsChecksum() sql = "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)" with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_statement", - return_value=ParsedStatement(StatementType.UPDATE, sql), + return_value=ParsedStatement(StatementType.UPDATE, Statement(sql)), ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=(mock.MagicMock(), ResultsChecksum()), + return_value=(mock.MagicMock()), ): cursor.execute(sql) self.assertIsInstance(cursor._result_set, mock.MagicMock) - self.assertIsInstance(cursor._itr, PeekIterator) def test_execute_statement(self): connection = self._make_connection(self.INSTANCE, mock.MagicMock()) @@ -261,6 +256,143 @@ def test_execute_statement(self): cursor._do_execute_update_in_autocommit, "sql", None ) + def test_execute_statement_with_cursor_not_in_retry_mode(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + sql = "sql" + transaction_helper_mock = cursor.transaction_helper = mock.Mock() + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_statement", + return_value=ParsedStatement(StatementType.QUERY, Statement(sql)), + ): + cursor.execute(sql=sql) + + transaction_helper_mock.add_execute_statement_for_retry.assert_called_once() + transaction_helper_mock.retry_transaction.assert_not_called() + + def test_executemany_query_statement_with_cursor_not_in_retry_mode(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + sql = "sql" + transaction_helper_mock = cursor.transaction_helper = mock.Mock() + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_statement", + return_value=ParsedStatement(StatementType.QUERY, Statement(sql)), + ): + cursor.executemany(operation=sql, seq_of_params=[]) + + transaction_helper_mock.add_execute_statement_for_retry.assert_called_once() + transaction_helper_mock.retry_transaction.assert_not_called() + + def test_executemany_dml_statement_with_cursor_not_in_retry_mode(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + sql = "sql" + transaction_helper_mock = cursor.transaction_helper = mock.Mock() + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_statement", + return_value=ParsedStatement(StatementType.INSERT, Statement(sql)), + ): + cursor.executemany(operation=sql, seq_of_params=[]) + + transaction_helper_mock.add_execute_statement_for_retry.assert_called_once() + transaction_helper_mock.retry_transaction.assert_not_called() + + def test_execute_statement_with_cursor_in_retry_mode(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor._in_retry_mode = True + sql = "sql" + transaction_helper_mock = cursor.transaction_helper = mock.Mock() + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_statement", + return_value=ParsedStatement(StatementType.QUERY, Statement(sql)), + ): + cursor.execute(sql=sql) + + transaction_helper_mock.add_execute_statement_for_retry.assert_not_called() + transaction_helper_mock.retry_transaction.assert_not_called() + + def test_executemany_statement_with_cursor_in_retry_mode(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor._in_retry_mode = True + sql = "sql" + transaction_helper_mock = cursor.transaction_helper = mock.Mock() + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_statement", + return_value=ParsedStatement(StatementType.QUERY, Statement(sql)), + ): + cursor.executemany(operation=sql, seq_of_params=[]) + + transaction_helper_mock.add_execute_statement_for_retry.assert_not_called() + transaction_helper_mock.retry_transaction.assert_not_called() + + @mock.patch("google.cloud.spanner_dbapi.cursor.PeekIterator") + def test_execute_statement_aborted_with_cursor_not_in_retry_mode( + self, mock_peek_iterator + ): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + sql = "sql" + transaction_helper_mock = cursor.transaction_helper = mock.Mock() + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_statement", + return_value=ParsedStatement(StatementType.QUERY, Statement(sql)), + ): + connection.run_statement = mock.Mock( + side_effect=(Aborted("Aborted"), None), + ) + cursor.execute(sql=sql) + + transaction_helper_mock.add_execute_statement_for_retry.assert_called_once() + transaction_helper_mock.retry_transaction.assert_called_once() + + def test_execute_statement_aborted_with_cursor_in_retry_mode(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor._in_retry_mode = True + sql = "sql" + transaction_helper_mock = cursor.transaction_helper = mock.Mock() + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_statement", + return_value=ParsedStatement(StatementType.QUERY, Statement(sql)), + ): + connection.run_statement = mock.Mock( + side_effect=Aborted("Aborted"), + ) + with self.assertRaises(Aborted): + cursor.execute(sql=sql) + + transaction_helper_mock.add_execute_statement_for_retry.assert_not_called() + transaction_helper_mock.retry_transaction.assert_not_called() + + def test_execute_statement_exception_with_cursor_not_in_retry_mode(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + sql = "sql" + transaction_helper_mock = cursor.transaction_helper = mock.Mock() + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_statement", + return_value=ParsedStatement(StatementType.QUERY, Statement(sql)), + ): + connection.run_statement = mock.Mock( + side_effect=(Exception("Exception"), None), + ) + with self.assertRaises(Exception): + cursor.execute(sql=sql) + + transaction_helper_mock.add_execute_statement_for_retry.assert_called_once() + transaction_helper_mock.retry_transaction.assert_not_called() + def test_execute_integrity_error(self): from google.api_core import exceptions from google.cloud.spanner_dbapi.exceptions import IntegrityError @@ -378,7 +510,7 @@ def test_executemany(self, mock_client): cursor.executemany(operation, params_seq) execute_mock.assert_has_calls( - (mock.call(operation, (1,)), mock.call(operation, (2,))) + (mock.call(operation, (1,), True), mock.call(operation, (2,), True)) ) def test_executemany_delete_batch_autocommit(self): @@ -547,7 +679,7 @@ def test_executemany_insert_batch_failed(self): connection.autocommit = True cursor = connection.cursor() - transaction = mock.Mock(committed=False, rolled_back=False) + transaction = mock.Mock() transaction.batch_update = mock.Mock( return_value=(mock.Mock(code=UNKNOWN, message=err_details), []) ) @@ -565,16 +697,15 @@ def test_executemany_insert_batch_failed(self): def test_executemany_insert_batch_aborted(self): from google.cloud.spanner_dbapi import connect - from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_v1.param_types import INT64 - from google.rpc.code_pb2 import ABORTED sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)""" + args = [(1, 2, 3, 4), (5, 6, 7, 8)] err_details = "Aborted details here" connection = connect("test-instance", "test-database") - transaction1 = mock.Mock(committed=False, rolled_back=False) + transaction1 = mock.Mock() transaction1.batch_update = mock.Mock( side_effect=[(mock.Mock(code=ABORTED, message=err_details), [])] ) @@ -584,10 +715,9 @@ def test_executemany_insert_batch_aborted(self): connection.transaction_checkout = mock.Mock( side_effect=[transaction1, transaction2] ) - connection.retry_transaction = mock.Mock() cursor = connection.cursor() - cursor.executemany(sql, [(1, 2, 3, 4), (5, 6, 7, 8)]) + cursor.executemany(sql, args) transaction1.batch_update.assert_called_with( [ @@ -617,24 +747,6 @@ def test_executemany_insert_batch_aborted(self): ), ] ) - connection.retry_transaction.assert_called_once() - - self.assertEqual( - connection._statements[0][0], - [ - Statement( - """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (@a0, @a1, @a2, @a3)""", - {"a0": 1, "a1": 2, "a2": 3, "a3": 4}, - {"a0": INT64, "a1": INT64, "a2": INT64, "a3": INT64}, - ), - Statement( - """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (@a0, @a1, @a2, @a3)""", - {"a0": 5, "a1": 6, "a2": 7, "a3": 8}, - {"a0": INT64, "a1": INT64, "a2": INT64, "a3": INT64}, - ), - ], - ) - self.assertIsInstance(connection._statements[0][1], ResultsChecksum) @mock.patch("google.cloud.spanner_v1.Client") def test_executemany_database_error(self, mock_client): @@ -650,11 +762,9 @@ def test_executemany_database_error(self, mock_client): sys.version_info[0] < 3, "Python 2 has an outdated iterator definition" ) def test_fetchone(self): - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) - cursor._checksum = ResultsChecksum() + cursor._parsed_statement = mock.Mock() lst = [1, 2, 3] cursor._itr = iter(lst) for i in range(len(lst)): @@ -665,12 +775,9 @@ def test_fetchone(self): sys.version_info[0] < 3, "Python 2 has an outdated iterator definition" ) def test_fetchone_w_autocommit(self): - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) connection.autocommit = True cursor = self._make_one(connection) - cursor._checksum = ResultsChecksum() lst = [1, 2, 3] cursor._itr = iter(lst) for i in range(len(lst)): @@ -678,11 +785,9 @@ def test_fetchone_w_autocommit(self): self.assertIsNone(cursor.fetchone()) def test_fetchmany(self): - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) - cursor._checksum = ResultsChecksum() + cursor._parsed_statement = mock.Mock() lst = [(1,), (2,), (3,)] cursor._itr = iter(lst) @@ -692,12 +797,9 @@ def test_fetchmany(self): self.assertEqual(result, lst[1:]) def test_fetchmany_w_autocommit(self): - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) connection.autocommit = True cursor = self._make_one(connection) - cursor._checksum = ResultsChecksum() lst = [(1,), (2,), (3,)] cursor._itr = iter(lst) @@ -707,22 +809,22 @@ def test_fetchmany_w_autocommit(self): self.assertEqual(result, lst[1:]) def test_fetchall(self): - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) cursor = self._make_one(connection) - cursor._checksum = ResultsChecksum() + cursor._parsed_statement = mock.Mock() + transaction_helper_mock = cursor.transaction_helper = mock.Mock() + lst = [(1,), (2,), (3,)] cursor._itr = iter(lst) self.assertEqual(cursor.fetchall(), lst) - def test_fetchall_w_autocommit(self): - from google.cloud.spanner_dbapi.checksum import ResultsChecksum + transaction_helper_mock.add_fetch_statement_for_retry.assert_called_once() + transaction_helper_mock.retry_transaction.assert_not_called() + def test_fetchall_w_autocommit(self): connection = self._make_connection(self.INSTANCE, mock.MagicMock()) connection.autocommit = True cursor = self._make_one(connection) - cursor._checksum = ResultsChecksum() lst = [(1,), (2,), (3,)] cursor._itr = iter(lst) self.assertEqual(cursor.fetchall(), lst) @@ -905,283 +1007,145 @@ def test_peek_iterator_aborted(self, mock_client): from google.cloud.spanner_dbapi.connection import connect connection = connect("test-instance", "test-database") - cursor = connection.cursor() with mock.patch( "google.cloud.spanner_dbapi.utils.PeekIterator.__init__", side_effect=(Aborted("Aborted"), None), ): with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" + "google.cloud.spanner_dbapi.transaction_helper.TransactionRetryHelper.retry_transaction" ) as retry_mock: with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=((1, 2, 3), None), + return_value=(1, 2, 3), ): cursor.execute("SELECT * FROM table_name") - retry_mock.assert_called_with() - - @mock.patch("google.cloud.spanner_v1.Client") - def test_fetchone_retry_aborted(self, mock_client): - """Check that aborted fetch re-executing transaction.""" - from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.connection import connect - - connection = connect("test-instance", "test-database") - - cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - - with mock.patch( - "google.cloud.spanner_dbapi.cursor.Cursor.__next__", - side_effect=(Aborted("Aborted"), None), - ): - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" - ) as retry_mock: - cursor.fetchone() - - retry_mock.assert_called_with() + retry_mock.assert_called_with() @mock.patch("google.cloud.spanner_v1.Client") - def test_fetchone_retry_aborted_statements(self, mock_client): - """Check that retried transaction executing the same statements.""" - from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.connection import connect - from google.cloud.spanner_dbapi.cursor import Statement - - row = ["field1", "field2"] + def test_fetchone_aborted_with_cursor_not_in_retry_mode(self, mock_client): connection = connect("test-instance", "test-database") - cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) - - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) + transaction_helper_mock = cursor.transaction_helper = mock.Mock() with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__next__", - side_effect=(Aborted("Aborted"), None), + side_effect=(Aborted("Aborted"), iter([])), ): - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=([row], ResultsChecksum()), - ) as run_mock: - cursor.fetchone() + cursor.fetchone() - run_mock.assert_called_with(statement, retried=True) + transaction_helper_mock.add_fetch_statement_for_retry.assert_called_once() + transaction_helper_mock.retry_transaction.assert_called_once() @mock.patch("google.cloud.spanner_v1.Client") - def test_fetchone_retry_aborted_statements_checksums_mismatch(self, mock_client): - """Check transaction retrying with underlying data being changed.""" - from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.exceptions import RetryAborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.connection import connect - from google.cloud.spanner_dbapi.cursor import Statement - - row = ["field1", "field2"] - row2 = ["updated_field1", "field2"] - + def test_fetchone_aborted_with_cursor_in_retry_mode(self, mock_client): connection = connect("test-instance", "test-database") - cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) - - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) + cursor._in_retry_mode = True + transaction_helper_mock = cursor.transaction_helper = mock.Mock() with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__next__", - side_effect=(Aborted("Aborted"), None), + side_effect=(Aborted("Aborted"), iter([])), ): - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=([row2], ResultsChecksum()), - ) as run_mock: - with self.assertRaises(RetryAborted): - cursor.fetchone() + cursor.fetchone() - run_mock.assert_called_with(statement, retried=True) + transaction_helper_mock.add_fetch_statement_for_retry.assert_not_called() + transaction_helper_mock.retry_transaction.assert_not_called() @mock.patch("google.cloud.spanner_v1.Client") - def test_fetchall_retry_aborted(self, mock_client): - """Check that aborted fetch re-executing transaction.""" - from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.connection import connect - + def test_fetchall_aborted_with_cursor_not_in_retry_mode(self, mock_client): connection = connect("test-instance", "test-database") - cursor = connection.cursor() - cursor._checksum = ResultsChecksum() + transaction_helper_mock = cursor.transaction_helper = mock.Mock() with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__iter__", side_effect=(Aborted("Aborted"), iter([])), ): - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" - ) as retry_mock: - cursor.fetchall() + cursor.fetchall() - retry_mock.assert_called_with() + transaction_helper_mock.add_fetch_statement_for_retry.assert_called_once() + transaction_helper_mock.retry_transaction.assert_called_once() @mock.patch("google.cloud.spanner_v1.Client") - def test_fetchall_retry_aborted_statements(self, mock_client): - """Check that retried transaction executing the same statements.""" - from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.connection import connect - from google.cloud.spanner_dbapi.cursor import Statement - - row = ["field1", "field2"] + def test_fetchall_aborted_with_cursor_in_retry_mode(self, mock_client): connection = connect("test-instance", "test-database") - cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) - - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) + cursor._in_retry_mode = True + transaction_helper_mock = cursor.transaction_helper = mock.Mock() with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__iter__", - side_effect=(Aborted("Aborted"), iter(row)), + side_effect=(Aborted("Aborted"), iter([])), ): - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=([row], ResultsChecksum()), - ) as run_mock: - cursor.fetchall() + cursor.fetchall() - run_mock.assert_called_with(statement, retried=True) + transaction_helper_mock.add_fetch_statement_for_retry.assert_not_called() + transaction_helper_mock.retry_transaction.assert_not_called() @mock.patch("google.cloud.spanner_v1.Client") - def test_fetchall_retry_aborted_statements_checksums_mismatch(self, mock_client): - """Check transaction retrying with underlying data being changed.""" - from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.exceptions import RetryAborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.connection import connect - from google.cloud.spanner_dbapi.cursor import Statement - - row = ["field1", "field2"] - row2 = ["updated_field1", "field2"] - + def test_fetchmany_aborted_with_cursor_not_in_retry_mode(self, mock_client): connection = connect("test-instance", "test-database") - cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) - - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) + transaction_helper_mock = cursor.transaction_helper = mock.Mock() with mock.patch( - "google.cloud.spanner_dbapi.cursor.Cursor.__iter__", - side_effect=(Aborted("Aborted"), iter(row)), + "google.cloud.spanner_dbapi.cursor.Cursor.__next__", + side_effect=(Aborted("Aborted"), iter([])), ): - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=([row2], ResultsChecksum()), - ) as run_mock: - with self.assertRaises(RetryAborted): - cursor.fetchall() + cursor.fetchmany() - run_mock.assert_called_with(statement, retried=True) + transaction_helper_mock.add_fetch_statement_for_retry.assert_called_once() + transaction_helper_mock.retry_transaction.assert_called_once() @mock.patch("google.cloud.spanner_v1.Client") - def test_fetchmany_retry_aborted(self, mock_client): - """Check that aborted fetch re-executing transaction.""" - from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.connection import connect - + def test_fetchmany_aborted_with_cursor_in_retry_mode(self, mock_client): connection = connect("test-instance", "test-database") - cursor = connection.cursor() - cursor._checksum = ResultsChecksum() + cursor._in_retry_mode = True + transaction_helper_mock = cursor.transaction_helper = mock.Mock() with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__next__", - side_effect=(Aborted("Aborted"), None), + side_effect=(Aborted("Aborted"), iter([])), ): - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.retry_transaction" - ) as retry_mock: - cursor.fetchmany() + cursor.fetchmany() - retry_mock.assert_called_with() + transaction_helper_mock.add_fetch_statement_for_retry.assert_not_called() + transaction_helper_mock.retry_transaction.assert_not_called() @mock.patch("google.cloud.spanner_v1.Client") - def test_fetchmany_retry_aborted_statements(self, mock_client): - """Check that retried transaction executing the same statements.""" - from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.connection import connect - from google.cloud.spanner_dbapi.cursor import Statement - - row = ["field1", "field2"] + def test_fetch_exception_with_cursor_not_in_retry_mode(self, mock_client): connection = connect("test-instance", "test-database") - cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) - - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) + transaction_helper_mock = cursor.transaction_helper = mock.Mock() with mock.patch( - "google.cloud.spanner_dbapi.cursor.Cursor.__next__", - side_effect=(Aborted("Aborted"), None), + "google.cloud.spanner_dbapi.cursor.Cursor.__iter__", + side_effect=Exception("Exception"), ): - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=([row], ResultsChecksum()), - ) as run_mock: - cursor.fetchmany(len(row)) + cursor.fetchall() - run_mock.assert_called_with(statement, retried=True) + transaction_helper_mock.add_fetch_statement_for_retry.assert_called_once() + transaction_helper_mock.retry_transaction.assert_not_called() @mock.patch("google.cloud.spanner_v1.Client") - def test_fetchmany_retry_aborted_statements_checksums_mismatch(self, mock_client): - """Check transaction retrying with underlying data being changed.""" - from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.exceptions import RetryAborted - from google.cloud.spanner_dbapi.checksum import ResultsChecksum - from google.cloud.spanner_dbapi.connection import connect - from google.cloud.spanner_dbapi.cursor import Statement - - row = ["field1", "field2"] - row2 = ["updated_field1", "field2"] - + def test_fetch_exception_with_cursor_in_retry_mode(self, mock_client): connection = connect("test-instance", "test-database") - cursor = connection.cursor() - cursor._checksum = ResultsChecksum() - cursor._checksum.consume_result(row) - - statement = Statement("SELECT 1", [], {}, cursor._checksum) - connection._statements.append(statement) + cursor._in_retry_mode = True + transaction_helper_mock = cursor.transaction_helper = mock.Mock() with mock.patch( "google.cloud.spanner_dbapi.cursor.Cursor.__next__", - side_effect=(Aborted("Aborted"), None), + side_effect=Exception("Exception"), ): - with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.run_statement", - return_value=([row2], ResultsChecksum()), - ) as run_mock: - with self.assertRaises(RetryAborted): - cursor.fetchmany(len(row)) + cursor.fetchmany() - run_mock.assert_called_with(statement, retried=True) + transaction_helper_mock.add_fetch_statement_for_retry.assert_not_called() + transaction_helper_mock.retry_transaction.assert_not_called() @mock.patch("google.cloud.spanner_v1.Client") def test_ddls_with_semicolon(self, mock_client): diff --git a/tests/unit/spanner_dbapi/test_transaction_helper.py b/tests/unit/spanner_dbapi/test_transaction_helper.py new file mode 100644 index 0000000000..dfe0d96538 --- /dev/null +++ b/tests/unit/spanner_dbapi/test_transaction_helper.py @@ -0,0 +1,541 @@ +# Copyright 2023 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from unittest import mock + +from google.cloud.spanner_dbapi.exceptions import ( + RetryAborted, +) +from google.cloud.spanner_dbapi.checksum import ResultsChecksum +from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, StatementType +from google.api_core.exceptions import Aborted + +from google.cloud.spanner_dbapi.transaction_helper import ( + TransactionRetryHelper, + ExecuteStatement, + CursorStatementType, + FetchStatement, +) + + +def _get_checksum(row): + checksum = ResultsChecksum() + checksum.consume_result(row) + return checksum + + +SQL = "SELECT 1" +ARGS = [] + + +class TestTransactionHelper(unittest.TestCase): + @mock.patch("google.cloud.spanner_dbapi.connection.Connection") + def setUp(self, mock_connection): + self._under_test = TransactionRetryHelper(mock_connection) + self._mock_connection = mock_connection + + def test_retry_transaction_execute(self): + """ + Test retrying a transaction with an execute statement works. + """ + execute_statement = ExecuteStatement( + statement_type=CursorStatementType.EXECUTE, + sql=SQL, + args=ARGS, + result_details=None, + ) + self._under_test._statement_result_details_list.append(execute_statement) + run_mock = self._under_test._connection.cursor().execute = mock.Mock() + + self._under_test.retry_transaction() + + run_mock.assert_called_with(SQL, ARGS) + + def test_retry_transaction_dml_execute(self): + """ + Test retrying a transaction with an execute DML statement works. + """ + update_count = 3 + execute_statement = ExecuteStatement( + statement_type=CursorStatementType.EXECUTE, + sql=SQL, + args=ARGS, + result_details=update_count, + ) + self._under_test._statement_result_details_list.append(execute_statement) + run_mock = self._under_test._connection.cursor = mock.Mock() + run_mock().rowcount = update_count + + self._under_test.retry_transaction() + + run_mock().execute.assert_called_with(SQL, ARGS) + + def test_retry_transaction_dml_execute_exception(self): + """ + Test retrying a transaction with an execute DML statement with different + row update count than original throws RetryAborted exception. + """ + execute_statement = ExecuteStatement( + statement_type=CursorStatementType.EXECUTE, + sql=SQL, + args=ARGS, + result_details=2, + ) + self._under_test._statement_result_details_list.append(execute_statement) + run_mock = self._under_test._connection.cursor = mock.Mock() + run_mock().rowcount = 3 + + with self.assertRaises(RetryAborted): + self._under_test.retry_transaction() + + run_mock().execute.assert_called_with(SQL, ARGS) + + def test_retry_transaction_execute_many(self): + """ + Test retrying a transaction with an executemany on Query statement works. + """ + execute_statement = ExecuteStatement( + statement_type=CursorStatementType.EXECUTE_MANY, + sql=SQL, + args=ARGS, + result_details=None, + ) + self._under_test._statement_result_details_list.append(execute_statement) + run_mock = self._under_test._connection.cursor().executemany = mock.Mock() + + self._under_test.retry_transaction() + + run_mock.assert_called_with(SQL, ARGS) + + def test_retry_transaction_dml_execute_many(self): + """ + Test retrying a transaction with an executemany on DML statement works. + """ + update_count = 3 + execute_statement = ExecuteStatement( + statement_type=CursorStatementType.EXECUTE_MANY, + sql=SQL, + args=ARGS, + result_details=update_count, + ) + self._under_test._statement_result_details_list.append(execute_statement) + run_mock = self._under_test._connection.cursor = mock.Mock() + run_mock().rowcount = update_count + + self._under_test.retry_transaction() + + run_mock().executemany.assert_called_with(SQL, ARGS) + + def test_retry_transaction_dml_executemany_exception(self): + """ + Test retrying a transaction with an executemany DML statement with different + row update count than original throws RetryAborted exception. + """ + execute_statement = ExecuteStatement( + statement_type=CursorStatementType.EXECUTE_MANY, + sql=SQL, + args=ARGS, + result_details=2, + ) + self._under_test._statement_result_details_list.append(execute_statement) + run_mock = self._under_test._connection.cursor = mock.Mock() + run_mock().rowcount = 3 + + with self.assertRaises(RetryAborted): + self._under_test.retry_transaction() + + run_mock().executemany.assert_called_with(SQL, ARGS) + + def test_retry_transaction_fetchall(self): + """ + Test retrying a transaction on a fetchall statement works. + """ + result_row = ("field1", "field2") + fetch_statement = FetchStatement( + statement_type=CursorStatementType.FETCH_ALL, + result_details=_get_checksum(result_row), + ) + self._under_test._statement_result_details_list.append(fetch_statement) + run_mock = self._under_test._connection.cursor().fetchall = mock.Mock() + run_mock.return_value = [result_row] + + self._under_test.retry_transaction() + + run_mock.assert_called_with() + + def test_retry_transaction_fetchall_exception(self): + """ + Test retrying a transaction on a fetchall statement throws exception + when results is different from original in retry. + """ + result_row = ("field1", "field2") + fetch_statement = FetchStatement( + statement_type=CursorStatementType.FETCH_ALL, + result_details=_get_checksum(result_row), + ) + self._under_test._statement_result_details_list.append(fetch_statement) + run_mock = self._under_test._connection.cursor().fetchall = mock.Mock() + retried_result_row = "field3" + run_mock.return_value = [retried_result_row] + + with self.assertRaises(RetryAborted): + self._under_test.retry_transaction() + + run_mock.assert_called_with() + + def test_retry_transaction_fetchmany(self): + """ + Test retrying a transaction on a fetchmany statement works. + """ + result_row = ("field1", "field2") + fetch_statement = FetchStatement( + statement_type=CursorStatementType.FETCH_MANY, + result_details=_get_checksum(result_row), + size=1, + ) + self._under_test._statement_result_details_list.append(fetch_statement) + run_mock = self._under_test._connection.cursor().fetchmany = mock.Mock() + run_mock.return_value = [result_row] + + self._under_test.retry_transaction() + + run_mock.assert_called_with(1) + + def test_retry_transaction_fetchmany_exception(self): + """ + Test retrying a transaction on a fetchmany statement throws exception + when results is different from original in retry. + """ + result_row = ("field1", "field2") + fetch_statement = FetchStatement( + statement_type=CursorStatementType.FETCH_MANY, + result_details=_get_checksum(result_row), + size=1, + ) + self._under_test._statement_result_details_list.append(fetch_statement) + run_mock = self._under_test._connection.cursor().fetchmany = mock.Mock() + retried_result_row = "field3" + run_mock.return_value = [retried_result_row] + + with self.assertRaises(RetryAborted): + self._under_test.retry_transaction() + + run_mock.assert_called_with(1) + + def test_retry_transaction_same_exception(self): + """ + Test retrying a transaction with statement throwing same exception in + retry works. + """ + exception = Exception("Test") + execute_statement = ExecuteStatement( + statement_type=CursorStatementType.EXECUTE, + sql=SQL, + args=ARGS, + result_details=exception, + ) + self._under_test._statement_result_details_list.append(execute_statement) + run_mock = self._under_test._connection.cursor().execute = mock.Mock() + run_mock.side_effect = exception + + self._under_test.retry_transaction() + + run_mock.assert_called_with(SQL, ARGS) + + def test_retry_transaction_different_exception(self): + """ + Test retrying a transaction with statement throwing different exception + in retry results in RetryAborted exception. + """ + execute_statement = ExecuteStatement( + statement_type=CursorStatementType.EXECUTE, + sql=SQL, + args=ARGS, + result_details=Exception("Test"), + ) + self._under_test._statement_result_details_list.append(execute_statement) + run_mock = self._under_test._connection.cursor().execute = mock.Mock() + run_mock.side_effect = Exception("Test2") + + with self.assertRaises(RetryAborted): + self._under_test.retry_transaction() + + run_mock.assert_called_with(SQL, ARGS) + + def test_retry_transaction_aborted_retry(self): + """ + Check that in case of a retried transaction aborted, + it will be retried once again. + """ + execute_statement = ExecuteStatement( + statement_type=CursorStatementType.EXECUTE, + sql=SQL, + args=ARGS, + result_details=None, + ) + self._under_test._statement_result_details_list.append(execute_statement) + run_mock = self._under_test._connection.cursor().execute = mock.Mock() + metadata_mock = mock.Mock() + metadata_mock.trailing_metadata.return_value = {} + run_mock.side_effect = [ + Aborted("Aborted", errors=[metadata_mock]), + None, + ] + + self._under_test.retry_transaction() + + run_mock.assert_has_calls( + ( + mock.call(SQL, ARGS), + mock.call(SQL, ARGS), + ) + ) + + def test_add_execute_statement_for_retry(self): + """ + Test add_execute_statement_for_retry method works + """ + parsed_statement = ParsedStatement( + statement_type=StatementType.INSERT, statement=None + ) + + sql = "INSERT INTO Table" + rows_inserted = 3 + self._under_test.add_execute_statement_for_retry( + parsed_statement, sql, [], rows_inserted, None, False + ) + + expected_statement_result_details = ExecuteStatement( + statement_type=CursorStatementType.EXECUTE, + sql=sql, + args=[], + result_details=rows_inserted, + ) + self.assertEqual( + self._under_test._last_statement_result_details, + expected_statement_result_details, + ) + self.assertEqual( + self._under_test._statement_result_details_list, + [expected_statement_result_details], + ) + + def test_add_execute_statement_for_retry_with_exception(self): + """ + Test add_execute_statement_for_retry method with exception + """ + parsed_statement = ParsedStatement( + statement_type=StatementType.INSERT, statement=None + ) + + sql = "INSERT INTO Table" + exception = Exception("Test") + self._under_test.add_execute_statement_for_retry( + parsed_statement, sql, [], -1, exception, False + ) + + expected_statement_result_details = ExecuteStatement( + statement_type=CursorStatementType.EXECUTE, + sql=sql, + args=[], + result_details=exception, + ) + self.assertEqual( + self._under_test._last_statement_result_details, + expected_statement_result_details, + ) + self.assertEqual( + self._under_test._statement_result_details_list, + [expected_statement_result_details], + ) + + def test_add_execute_statement_for_retry_query_statement(self): + """ + Test add_execute_statement_for_retry method works for non DML statement + """ + parsed_statement = ParsedStatement( + statement_type=StatementType.QUERY, statement=None + ) + + sql = "SELECT 1" + self._under_test.add_execute_statement_for_retry( + parsed_statement, sql, [], -1, None, False + ) + + expected_statement_result_details = ExecuteStatement( + statement_type=CursorStatementType.EXECUTE, + sql=sql, + args=[], + result_details=None, + ) + self.assertEqual( + self._under_test._last_statement_result_details, + expected_statement_result_details, + ) + self.assertEqual( + self._under_test._statement_result_details_list, + [expected_statement_result_details], + ) + + def test_add_execute_many_statement_for_retry(self): + """ + Test add_execute_statement_for_retry method works for executemany + """ + parsed_statement = ParsedStatement( + statement_type=StatementType.INSERT, statement=None + ) + + sql = "INSERT INTO Table" + rows_inserted = 3 + self._under_test.add_execute_statement_for_retry( + parsed_statement, sql, [], rows_inserted, None, True + ) + + expected_statement_result_details = ExecuteStatement( + statement_type=CursorStatementType.EXECUTE_MANY, + sql=sql, + args=[], + result_details=rows_inserted, + ) + self.assertEqual( + self._under_test._last_statement_result_details, + expected_statement_result_details, + ) + self.assertEqual( + self._under_test._statement_result_details_list, + [expected_statement_result_details], + ) + + def test_add_fetch_statement_for_retry(self): + """ + Test add_execute_statement_for_retry method when + last_statement_result_details is a Fetch statement + """ + result_row = ("field1", "field2") + result_checksum = _get_checksum(result_row) + original_checksum_digest = result_checksum.checksum.digest() + self._under_test._last_statement_result_details = FetchStatement( + statement_type=CursorStatementType.FETCH_MANY, + result_details=result_checksum, + size=1, + ) + new_rows = [("field3", "field4"), ("field5", "field6")] + + self._under_test.add_fetch_statement_for_retry(new_rows, None, False) + + self.assertEqual( + self._under_test._last_statement_result_details.size, + 3, + ) + self.assertNotEqual( + self._under_test._last_statement_result_details.result_details.checksum.digest(), + original_checksum_digest, + ) + + def test_add_fetch_statement_for_retry_with_exception(self): + """ + Test add_execute_statement_for_retry method with exception + """ + result_row = ("field1", "field2") + self._under_test._last_statement_result_details = FetchStatement( + statement_type=CursorStatementType.FETCH_MANY, + result_details=_get_checksum(result_row), + size=1, + ) + exception = Exception("Test") + + self._under_test.add_fetch_statement_for_retry([], exception, False) + + self.assertEqual( + self._under_test._last_statement_result_details, + FetchStatement( + statement_type=CursorStatementType.FETCH_MANY, + result_details=exception, + size=1, + ), + ) + + def test_add_fetch_statement_for_retry_last_statement_not_exists(self): + """ + Test add_execute_statement_for_retry method when + last_statement_result_details doesn't exists + """ + row = ("field3", "field4") + + self._under_test.add_fetch_statement_for_retry([row], None, False) + + expected_statement = FetchStatement( + statement_type=CursorStatementType.FETCH_MANY, + result_details=_get_checksum(row), + size=1, + ) + self.assertEqual( + self._under_test._last_statement_result_details, + expected_statement, + ) + self.assertEqual( + self._under_test._statement_result_details_list, + [expected_statement], + ) + + def test_add_fetch_statement_for_retry_fetch_all_statement(self): + """ + Test add_execute_statement_for_retry method for fetchall statement + """ + row = ("field3", "field4") + + self._under_test.add_fetch_statement_for_retry([row], None, True) + + expected_statement = FetchStatement( + statement_type=CursorStatementType.FETCH_ALL, + result_details=_get_checksum(row), + ) + self.assertEqual( + self._under_test._last_statement_result_details, + expected_statement, + ) + self.assertEqual( + self._under_test._statement_result_details_list, + [expected_statement], + ) + + def test_add_fetch_statement_for_retry_when_last_statement_is_not_fetch(self): + """ + Test add_execute_statement_for_retry method when last statement is not + a fetch type of statement + """ + execute_statement = ExecuteStatement( + statement_type=CursorStatementType.EXECUTE, + sql=SQL, + args=ARGS, + result_details=2, + ) + self._under_test._last_statement_result_details = execute_statement + self._under_test._statement_result_details_list.append(execute_statement) + row = ("field3", "field4") + + self._under_test.add_fetch_statement_for_retry([row], None, False) + + expected_fetch_statement = FetchStatement( + statement_type=CursorStatementType.FETCH_MANY, + result_details=_get_checksum(row), + size=1, + ) + self.assertEqual( + self._under_test._last_statement_result_details, expected_fetch_statement + ) + self.assertEqual( + self._under_test._statement_result_details_list, + [execute_statement, expected_fetch_statement], + ) From 0f864339c53052211f63611f47eb17e77fc0c003 Mon Sep 17 00:00:00 2001 From: ankiaga Date: Wed, 3 Jan 2024 20:14:18 +0530 Subject: [PATCH 2/9] Comments incorporated and changes for also storing Cursor object with the statements details added for retry --- .../cloud/spanner_dbapi/batch_dml_executor.py | 1 + google/cloud/spanner_dbapi/connection.py | 22 ++--- google/cloud/spanner_dbapi/cursor.py | 17 ++-- .../cloud/spanner_dbapi/transaction_helper.py | 79 ++++++++++++++---- tests/system/test_dbapi.py | 63 ++++++++++++++ .../spanner_dbapi/test_transaction_helper.py | 83 ++++++++++++++----- 6 files changed, 205 insertions(+), 60 deletions(-) diff --git a/google/cloud/spanner_dbapi/batch_dml_executor.py b/google/cloud/spanner_dbapi/batch_dml_executor.py index 57e36d7991..7aaa69b2bd 100644 --- a/google/cloud/spanner_dbapi/batch_dml_executor.py +++ b/google/cloud/spanner_dbapi/batch_dml_executor.py @@ -105,6 +105,7 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]): cursor._row_count = sum([max(val, 0) for val in res]) return many_result_set except Aborted: + # We are raising it so it could be handled in transaction_helper.py and is retried if cursor._in_retry_mode: raise else: diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index fb635c003b..1c18dbbf9c 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -27,11 +27,11 @@ from google.cloud.spanner_dbapi.partition_helper import PartitionId from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement from google.cloud.spanner_dbapi.transaction_helper import TransactionRetryHelper +from google.cloud.spanner_dbapi.cursor import Cursor from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1.snapshot import Snapshot from deprecated import deprecated -from google.cloud.spanner_dbapi.cursor import Cursor from google.cloud.spanner_dbapi.exceptions import ( InterfaceError, OperationalError, @@ -354,12 +354,10 @@ def begin(self): def commit(self): """Commits any pending transaction to the database. - This is a no-op if there is no active client transaction. """ if self.database is None: raise ValueError("Database needs to be passed for this operation") - if not self._client_transaction_started: warnings.warn( CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2 @@ -374,14 +372,10 @@ def commit(self): self._transaction_helper.retry_transaction() self.commit() finally: - self._release_session() - self._transaction_helper.reset() - self._transaction_begin_marked = False - self._spanner_transaction_started = False + self._reset_post_commit_or_rollback() def rollback(self): """Rolls back any pending transaction. - This is a no-op if there is no active client transaction. """ if not self._client_transaction_started: @@ -389,15 +383,17 @@ def rollback(self): CLIENT_TRANSACTION_NOT_STARTED_WARNING, UserWarning, stacklevel=2 ) return - try: if self._spanner_transaction_started and not self._read_only: self._transaction.rollback() finally: - self._release_session() - self._transaction_helper.reset() - self._transaction_begin_marked = False - self._spanner_transaction_started = False + self._reset_post_commit_or_rollback() + + def _reset_post_commit_or_rollback(self): + self._release_session() + self._transaction_helper.reset() + self._transaction_begin_marked = False + self._spanner_transaction_started = False @check_not_closed def cursor(self): diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 1a6dd12e95..9bff1907d6 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -220,7 +220,10 @@ def _batch_DDLs(self, sql): self.connection._ddl_statements.extend(statements) @check_not_closed - def execute(self, sql, args=None, call_from_execute_many=False): + def execute(self, sql, args=None): + self._execute(sql, args, False) + + def _execute(self, sql, args=None, call_from_execute_many=False): """Prepares and executes a Spanner database operation. :type sql: str @@ -276,7 +279,7 @@ def execute(self, sql, args=None, call_from_execute_many=False): finally: if not self._in_retry_mode and not call_from_execute_many: self.transaction_helper.add_execute_statement_for_retry( - self._parsed_statement, sql, args, self.rowcount, exception, False + self, sql, args, exception, False ) if self.connection._client_transaction_started is False: self.connection._spanner_transaction_started = False @@ -293,6 +296,7 @@ def _execute_in_rw_transaction(self): self._itr = PeekIterator(self._result_set) return except Aborted: + # We are raising it so it could be handled in transaction_helper.py and is retried if self._in_retry_mode: raise else: @@ -354,7 +358,7 @@ def executemany(self, operation, seq_of_params): else: many_result_set = StreamedManyResultSets() for params in seq_of_params: - self.execute(operation, params, True) + self._execute(operation, params, True) many_result_set.add_iter(self._itr) self._result_set = many_result_set @@ -365,10 +369,9 @@ def executemany(self, operation, seq_of_params): finally: if not self._in_retry_mode: self.transaction_helper.add_execute_statement_for_retry( - self._parsed_statement, + self, operation, seq_of_params, - self.rowcount, exception, True, ) @@ -407,7 +410,7 @@ def fetchmany(self, size=None): size = self.arraysize return self._fetch(CursorStatementType.FETCH_MANY, size) - def _fetch(self, cursor_statement_type: CursorStatementType, size=None): + def _fetch(self, cursor_statement_type, size=None): exception = None rows = [] is_fetch_all = False @@ -445,7 +448,7 @@ def _fetch(self, cursor_statement_type: CursorStatementType, size=None): finally: if not self._in_retry_mode: self.transaction_helper.add_fetch_statement_for_retry( - rows, exception, is_fetch_all + self, rows, exception, is_fetch_all ) return rows diff --git a/google/cloud/spanner_dbapi/transaction_helper.py b/google/cloud/spanner_dbapi/transaction_helper.py index e7d8ea22d0..8abe8a41ae 100644 --- a/google/cloud/spanner_dbapi/transaction_helper.py +++ b/google/cloud/spanner_dbapi/transaction_helper.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import random from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, List, Union, Any +from typing import TYPE_CHECKING, List, Union, Any, Dict from google.api_core.exceptions import Aborted import time @@ -25,6 +24,7 @@ StatementType, ClientSideStatementType, ) +from google.cloud.spanner_v1.session import _get_retry_delay if TYPE_CHECKING: from google.cloud.spanner_dbapi import Connection, Cursor @@ -47,8 +47,12 @@ def __init__(self, connection: "Connection"): self._connection = connection # list of all single and batch statements in the same order as executed # in original transaction along with their checksum value - self._statement_result_details_list: List[StatementResultDetails] = [] - self._last_statement_result_details: StatementResultDetails = None + self._statement_result_details_list: List[StatementDetails] = [] + # last StatementDetails that was added in the _statement_result_details_list + self._last_statement_result_details: StatementDetails = None + # 1-1 map from original cursor object on which transaction ran to the + # new cursor object used in the retry + self._cursor_map: Dict[Cursor, Cursor] = {} def _set_connection_for_retry(self): self._connection._spanner_transaction_started = False @@ -62,49 +66,80 @@ def reset(self): """ self._statement_result_details_list = [] self._last_statement_result_details = None + self._cursor_map = {} - def add_fetch_statement_for_retry(self, rows, exception, is_fetch_all): + def add_fetch_statement_for_retry( + self, cursor, result_rows, exception, is_fetch_all + ): + """ + StatementDetails to be added to _statement_result_details_list whenever fetchone, fetchmany or + fetchall method is called on the cursor. + If fetchone is consecutively called n times then it is stored as fetchmany with size as n. + Same for fetchmany, so consecutive fetchone and fetchmany statements are stored as one + fetchmany statement in _statement_result_details_list with size param appropriately set + + :param cursor: original Cursor object on which statement executed in the transaction + :param result_rows: All the rows from the resultSet from fetch statement execution + :param exception: Not none in case non-aborted exception is thrown on the original + statement execution + :param is_fetch_all: True in case of fetchall statement execution + """ if not self._connection._client_transaction_started: return if ( self._last_statement_result_details is not None and self._last_statement_result_details.statement_type == CursorStatementType.FETCH_MANY + and self._last_statement_result_details.cursor == cursor ): if exception is not None: self._last_statement_result_details.result_details = exception else: - for row in rows: + for row in result_rows: self._last_statement_result_details.result_details.consume_result( row ) - self._last_statement_result_details.size += len(rows) + self._last_statement_result_details.size += len(result_rows) else: - result_details = _get_statement_result_checksum(rows) + result_details = _get_statement_result_checksum(result_rows) if is_fetch_all: self._last_statement_result_details = FetchStatement( + cursor=cursor, statement_type=CursorStatementType.FETCH_ALL, result_details=result_details, ) else: self._last_statement_result_details = FetchStatement( + cursor=cursor, statement_type=CursorStatementType.FETCH_MANY, result_details=result_details, - size=len(rows), + size=len(result_rows), ) self._statement_result_details_list.append( self._last_statement_result_details ) def add_execute_statement_for_retry( - self, parsed_statement, sql, args, rowcount, exception, is_execute_many + self, cursor, sql, args, exception, is_execute_many ): + """ + StatementDetails to be added to _statement_result_details_list whenever execute or + executemany method is called on the cursor. + + :param cursor: original Cursor object on which statement executed in the transaction + :param sql: Input param of the execute/executemany method + :param args: Input param of the execute/executemany method + :param exception: Not none in case non-aborted exception is thrown on the original + statement execution + :param is_execute_many: True in case of executemany statement execution + """ if not self._connection._client_transaction_started: return statement_type = CursorStatementType.EXECUTE if is_execute_many: statement_type = CursorStatementType.EXECUTE_MANY result_details = None + parsed_statement = cursor._parsed_statement if exception is not None: result_details = exception elif ( @@ -113,9 +148,10 @@ def add_execute_statement_for_retry( or parsed_statement.client_side_statement_type == ClientSideStatementType.RUN_BATCH ): - result_details = rowcount + result_details = cursor.rowcount self._last_statement_result_details = ExecuteStatement( + cursor=cursor, statement_type=statement_type, sql=sql, args=args, @@ -141,10 +177,15 @@ def retry_transaction(self): raise self._set_connection_for_retry() - cursor: Cursor = self._connection.cursor() - cursor._in_retry_mode = True + try: for statement_result_details in self._statement_result_details_list: + if statement_result_details.cursor in self._cursor_map: + cursor = self._cursor_map.get(statement_result_details.cursor) + else: + cursor: Cursor = self._connection.cursor() + cursor._in_retry_mode = True + self._cursor_map[statement_result_details.cursor] = cursor try: self._handle_statement(statement_result_details, cursor) except Aborted: @@ -159,8 +200,8 @@ def retry_transaction(self): ): raise RetryAborted(RETRY_ABORTED_ERROR, ex) return - except Aborted: - delay = 2**attempt + random.random() + except Aborted as ex: + delay = _get_retry_delay(ex.errors[0], attempt) if delay: time.sleep(delay) @@ -209,8 +250,10 @@ class CursorStatementType(Enum): @dataclass -class StatementResultDetails: +class StatementDetails: statement_type: CursorStatementType + # The cursor object on which this statement was executed + cursor: "Cursor" # This would be one of # 1. checksum of ResultSet in case of fetch call on query statement # 2. Total rows updated in case of DML @@ -220,11 +263,11 @@ class StatementResultDetails: @dataclass -class ExecuteStatement(StatementResultDetails): +class ExecuteStatement(StatementDetails): sql: str args: Any = None @dataclass -class FetchStatement(StatementResultDetails): +class FetchStatement(StatementDetails): size: int = None diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index e012299925..a433f52d67 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -649,6 +649,69 @@ def test_execute_sql_abort_retry_multiple_times(self, dbapi_database): got_rows = self._cursor.fetchall() assert len(got_rows) == 2 + @pytest.mark.noautofixt + def test_abort_retry_multiple_cursors(self, shared_instance, dbapi_database): + """Test that retry works when multiple cursors are involved in the transaction.""" + + try: + conn = Connection(shared_instance, dbapi_database) + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + Name STRING(1024), + ) PRIMARY KEY (SingerId) + """ + ) + cur.execute( + """ + INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + cur.execute( + """ + INSERT INTO Singers (SingerId, Name) + VALUES (1, 'first-name') + """ + ) + cur.execute( + """ + INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + cur.execute( + """ + INSERT INTO Singers (SingerId, Name) + VALUES (2, 'first-name') + """ + ) + conn.commit() + + cur1 = conn.cursor() + cur1.execute("SELECT * FROM contacts") + cur2 = conn.cursor() + cur2.execute("SELECT * FROM Singers") + row1 = cur1.fetchone() + row2 = cur2.fetchone() + row3 = cur1.fetchone() + row4 = cur2.fetchone() + dbapi_database._method_abort_interceptor.set_method_to_abort(COMMIT_METHOD) + conn.commit() + dbapi_database._method_abort_interceptor.reset() + assert row1 == (1, "first-name", "last-name", "test.email@domen.ru") + assert row2 == (1, "first-name") + assert row3 == (2, "first-name", "last-name", "test.email@domen.ru") + assert row4 == (2, "first-name") + finally: + # Delete table + table = dbapi_database.table("Singers") + if table.exists(): + op = dbapi_database.update_ddl(["DROP TABLE Singers"]) + op.result() + def test_execute_batch_dml_abort_retry(self, dbapi_database): """Test that when any execute batch dml failed with Abort exception, then the retry succeeds with transaction having insert as well as query diff --git a/tests/unit/spanner_dbapi/test_transaction_helper.py b/tests/unit/spanner_dbapi/test_transaction_helper.py index dfe0d96538..65c274f5d7 100644 --- a/tests/unit/spanner_dbapi/test_transaction_helper.py +++ b/tests/unit/spanner_dbapi/test_transaction_helper.py @@ -40,10 +40,11 @@ def _get_checksum(row): class TestTransactionHelper(unittest.TestCase): + @mock.patch("google.cloud.spanner_dbapi.cursor.Cursor") @mock.patch("google.cloud.spanner_dbapi.connection.Connection") - def setUp(self, mock_connection): + def setUp(self, mock_connection, mock_cursor): self._under_test = TransactionRetryHelper(mock_connection) - self._mock_connection = mock_connection + self._mock_cursor = mock_cursor def test_retry_transaction_execute(self): """ @@ -51,6 +52,7 @@ def test_retry_transaction_execute(self): """ execute_statement = ExecuteStatement( statement_type=CursorStatementType.EXECUTE, + cursor=self._mock_cursor, sql=SQL, args=ARGS, result_details=None, @@ -69,6 +71,7 @@ def test_retry_transaction_dml_execute(self): update_count = 3 execute_statement = ExecuteStatement( statement_type=CursorStatementType.EXECUTE, + cursor=self._mock_cursor, sql=SQL, args=ARGS, result_details=update_count, @@ -88,6 +91,7 @@ def test_retry_transaction_dml_execute_exception(self): """ execute_statement = ExecuteStatement( statement_type=CursorStatementType.EXECUTE, + cursor=self._mock_cursor, sql=SQL, args=ARGS, result_details=2, @@ -107,6 +111,7 @@ def test_retry_transaction_execute_many(self): """ execute_statement = ExecuteStatement( statement_type=CursorStatementType.EXECUTE_MANY, + cursor=self._mock_cursor, sql=SQL, args=ARGS, result_details=None, @@ -125,6 +130,7 @@ def test_retry_transaction_dml_execute_many(self): update_count = 3 execute_statement = ExecuteStatement( statement_type=CursorStatementType.EXECUTE_MANY, + cursor=self._mock_cursor, sql=SQL, args=ARGS, result_details=update_count, @@ -144,6 +150,7 @@ def test_retry_transaction_dml_executemany_exception(self): """ execute_statement = ExecuteStatement( statement_type=CursorStatementType.EXECUTE_MANY, + cursor=self._mock_cursor, sql=SQL, args=ARGS, result_details=2, @@ -163,6 +170,7 @@ def test_retry_transaction_fetchall(self): """ result_row = ("field1", "field2") fetch_statement = FetchStatement( + cursor=self._mock_cursor, statement_type=CursorStatementType.FETCH_ALL, result_details=_get_checksum(result_row), ) @@ -181,6 +189,7 @@ def test_retry_transaction_fetchall_exception(self): """ result_row = ("field1", "field2") fetch_statement = FetchStatement( + cursor=self._mock_cursor, statement_type=CursorStatementType.FETCH_ALL, result_details=_get_checksum(result_row), ) @@ -200,6 +209,7 @@ def test_retry_transaction_fetchmany(self): """ result_row = ("field1", "field2") fetch_statement = FetchStatement( + cursor=self._mock_cursor, statement_type=CursorStatementType.FETCH_MANY, result_details=_get_checksum(result_row), size=1, @@ -219,6 +229,7 @@ def test_retry_transaction_fetchmany_exception(self): """ result_row = ("field1", "field2") fetch_statement = FetchStatement( + cursor=self._mock_cursor, statement_type=CursorStatementType.FETCH_MANY, result_details=_get_checksum(result_row), size=1, @@ -241,6 +252,7 @@ def test_retry_transaction_same_exception(self): exception = Exception("Test") execute_statement = ExecuteStatement( statement_type=CursorStatementType.EXECUTE, + cursor=self._mock_cursor, sql=SQL, args=ARGS, result_details=exception, @@ -260,6 +272,7 @@ def test_retry_transaction_different_exception(self): """ execute_statement = ExecuteStatement( statement_type=CursorStatementType.EXECUTE, + cursor=self._mock_cursor, sql=SQL, args=ARGS, result_details=Exception("Test"), @@ -280,6 +293,7 @@ def test_retry_transaction_aborted_retry(self): """ execute_statement = ExecuteStatement( statement_type=CursorStatementType.EXECUTE, + cursor=self._mock_cursor, sql=SQL, args=ARGS, result_details=None, @@ -306,18 +320,20 @@ def test_add_execute_statement_for_retry(self): """ Test add_execute_statement_for_retry method works """ - parsed_statement = ParsedStatement( + self._mock_cursor._parsed_statement = ParsedStatement( statement_type=StatementType.INSERT, statement=None ) sql = "INSERT INTO Table" rows_inserted = 3 + self._mock_cursor.rowcount = rows_inserted self._under_test.add_execute_statement_for_retry( - parsed_statement, sql, [], rows_inserted, None, False + self._mock_cursor, sql, [], None, False ) expected_statement_result_details = ExecuteStatement( statement_type=CursorStatementType.EXECUTE, + cursor=self._mock_cursor, sql=sql, args=[], result_details=rows_inserted, @@ -335,18 +351,20 @@ def test_add_execute_statement_for_retry_with_exception(self): """ Test add_execute_statement_for_retry method with exception """ - parsed_statement = ParsedStatement( + self._mock_cursor._parsed_statement = ParsedStatement( statement_type=StatementType.INSERT, statement=None ) + self._mock_cursor.rowcount = -1 sql = "INSERT INTO Table" exception = Exception("Test") self._under_test.add_execute_statement_for_retry( - parsed_statement, sql, [], -1, exception, False + self._mock_cursor, sql, [], exception, False ) expected_statement_result_details = ExecuteStatement( statement_type=CursorStatementType.EXECUTE, + cursor=self._mock_cursor, sql=sql, args=[], result_details=exception, @@ -364,17 +382,19 @@ def test_add_execute_statement_for_retry_query_statement(self): """ Test add_execute_statement_for_retry method works for non DML statement """ - parsed_statement = ParsedStatement( + self._mock_cursor._parsed_statement = ParsedStatement( statement_type=StatementType.QUERY, statement=None ) + self._mock_cursor.rowcount = -1 sql = "SELECT 1" self._under_test.add_execute_statement_for_retry( - parsed_statement, sql, [], -1, None, False + self._mock_cursor, sql, [], None, False ) expected_statement_result_details = ExecuteStatement( statement_type=CursorStatementType.EXECUTE, + cursor=self._mock_cursor, sql=sql, args=[], result_details=None, @@ -392,18 +412,20 @@ def test_add_execute_many_statement_for_retry(self): """ Test add_execute_statement_for_retry method works for executemany """ - parsed_statement = ParsedStatement( + self._mock_cursor._parsed_statement = ParsedStatement( statement_type=StatementType.INSERT, statement=None ) sql = "INSERT INTO Table" rows_inserted = 3 + self._mock_cursor.rowcount = rows_inserted self._under_test.add_execute_statement_for_retry( - parsed_statement, sql, [], rows_inserted, None, True + self._mock_cursor, sql, [], None, True ) expected_statement_result_details = ExecuteStatement( statement_type=CursorStatementType.EXECUTE_MANY, + cursor=self._mock_cursor, sql=sql, args=[], result_details=rows_inserted, @@ -419,20 +441,23 @@ def test_add_execute_many_statement_for_retry(self): def test_add_fetch_statement_for_retry(self): """ - Test add_execute_statement_for_retry method when - last_statement_result_details is a Fetch statement + Test add_fetch_statement_for_retry method when last_statement_result_details is a + Fetch statement """ result_row = ("field1", "field2") result_checksum = _get_checksum(result_row) original_checksum_digest = result_checksum.checksum.digest() self._under_test._last_statement_result_details = FetchStatement( statement_type=CursorStatementType.FETCH_MANY, + cursor=self._mock_cursor, result_details=result_checksum, size=1, ) new_rows = [("field3", "field4"), ("field5", "field6")] - self._under_test.add_fetch_statement_for_retry(new_rows, None, False) + self._under_test.add_fetch_statement_for_retry( + self._mock_cursor, new_rows, None, False + ) self.assertEqual( self._under_test._last_statement_result_details.size, @@ -445,22 +470,26 @@ def test_add_fetch_statement_for_retry(self): def test_add_fetch_statement_for_retry_with_exception(self): """ - Test add_execute_statement_for_retry method with exception + Test add_fetch_statement_for_retry method with exception """ result_row = ("field1", "field2") self._under_test._last_statement_result_details = FetchStatement( statement_type=CursorStatementType.FETCH_MANY, + cursor=self._mock_cursor, result_details=_get_checksum(result_row), size=1, ) exception = Exception("Test") - self._under_test.add_fetch_statement_for_retry([], exception, False) + self._under_test.add_fetch_statement_for_retry( + self._mock_cursor, [], exception, False + ) self.assertEqual( self._under_test._last_statement_result_details, FetchStatement( statement_type=CursorStatementType.FETCH_MANY, + cursor=self._mock_cursor, result_details=exception, size=1, ), @@ -468,15 +497,18 @@ def test_add_fetch_statement_for_retry_with_exception(self): def test_add_fetch_statement_for_retry_last_statement_not_exists(self): """ - Test add_execute_statement_for_retry method when - last_statement_result_details doesn't exists + Test add_fetch_statement_for_retry method when last_statement_result_details + doesn't exists """ row = ("field3", "field4") - self._under_test.add_fetch_statement_for_retry([row], None, False) + self._under_test.add_fetch_statement_for_retry( + self._mock_cursor, [row], None, False + ) expected_statement = FetchStatement( statement_type=CursorStatementType.FETCH_MANY, + cursor=self._mock_cursor, result_details=_get_checksum(row), size=1, ) @@ -491,14 +523,17 @@ def test_add_fetch_statement_for_retry_last_statement_not_exists(self): def test_add_fetch_statement_for_retry_fetch_all_statement(self): """ - Test add_execute_statement_for_retry method for fetchall statement + Test add_fetch_statement_for_retry method for fetchall statement """ row = ("field3", "field4") - self._under_test.add_fetch_statement_for_retry([row], None, True) + self._under_test.add_fetch_statement_for_retry( + self._mock_cursor, [row], None, True + ) expected_statement = FetchStatement( statement_type=CursorStatementType.FETCH_ALL, + cursor=self._mock_cursor, result_details=_get_checksum(row), ) self.assertEqual( @@ -512,11 +547,12 @@ def test_add_fetch_statement_for_retry_fetch_all_statement(self): def test_add_fetch_statement_for_retry_when_last_statement_is_not_fetch(self): """ - Test add_execute_statement_for_retry method when last statement is not + Test add_fetch_statement_for_retry method when last statement is not a fetch type of statement """ execute_statement = ExecuteStatement( statement_type=CursorStatementType.EXECUTE, + cursor=self._mock_cursor, sql=SQL, args=ARGS, result_details=2, @@ -525,10 +561,13 @@ def test_add_fetch_statement_for_retry_when_last_statement_is_not_fetch(self): self._under_test._statement_result_details_list.append(execute_statement) row = ("field3", "field4") - self._under_test.add_fetch_statement_for_retry([row], None, False) + self._under_test.add_fetch_statement_for_retry( + self._mock_cursor, [row], None, False + ) expected_fetch_statement = FetchStatement( statement_type=CursorStatementType.FETCH_MANY, + cursor=self._mock_cursor, result_details=_get_checksum(row), size=1, ) From 42e6e5761e623adc631666a65f102e9b1ab491d4 Mon Sep 17 00:00:00 2001 From: ankiaga Date: Thu, 4 Jan 2024 13:20:29 +0530 Subject: [PATCH 3/9] Some refactoring of transaction_helper.py and maintaining state of rows update count for batch dml in cursor --- .../cloud/spanner_dbapi/batch_dml_executor.py | 1 + google/cloud/spanner_dbapi/cursor.py | 35 +++-- .../cloud/spanner_dbapi/transaction_helper.py | 98 +++++++------ tests/system/test_dbapi.py | 130 +++++++++--------- tests/unit/spanner_dbapi/test_cursor.py | 13 +- .../spanner_dbapi/test_transaction_helper.py | 25 ++++ 6 files changed, 166 insertions(+), 136 deletions(-) diff --git a/google/cloud/spanner_dbapi/batch_dml_executor.py b/google/cloud/spanner_dbapi/batch_dml_executor.py index 7aaa69b2bd..7c4272a0ca 100644 --- a/google/cloud/spanner_dbapi/batch_dml_executor.py +++ b/google/cloud/spanner_dbapi/batch_dml_executor.py @@ -101,6 +101,7 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]): elif status.code != OK: raise OperationalError(status.message) + cursor._batch_dml_rows_count = res many_result_set.add_iter(res) cursor._row_count = sum([max(val, 0) for val in res]) return many_result_set diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 9bff1907d6..ed6178e054 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -50,8 +50,6 @@ from google.cloud.spanner_dbapi.utils import PeekIterator from google.cloud.spanner_dbapi.utils import StreamedManyResultSets -_UNSET_COUNT = -1 - ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) @@ -87,7 +85,7 @@ class Cursor(object): def __init__(self, connection): self._itr = None self._result_set = None - self._row_count = _UNSET_COUNT + self._row_count = None self.lastrowid = None self.connection = connection self.transaction_helper = self.connection._transaction_helper @@ -96,6 +94,7 @@ def __init__(self, connection): self.arraysize = 1 self._parsed_statement: ParsedStatement = None self._in_retry_mode = False + self._batch_dml_rows_count = None @property def is_closed(self): @@ -150,14 +149,14 @@ def rowcount(self): :returns: The number of rows updated by the last INSERT, UPDATE, DELETE request's .execute*() call. """ - if self._row_count != _UNSET_COUNT or self._result_set is None: + if self._row_count is not None or self._result_set is None: return self._row_count stats = getattr(self._result_set, "stats", None) if stats is not None and "row_count_exact" in stats: return stats.row_count_exact - return _UNSET_COUNT + return -1 @check_not_closed def callproc(self, procname, args=None): @@ -191,7 +190,7 @@ def _do_execute_update_in_autocommit(self, transaction, sql, params): sql, params=params, param_types=get_param_types(params) ) self._itr = PeekIterator(self._result_set) - self._row_count = _UNSET_COUNT + self._row_count = None def _batch_DDLs(self, sql): """ @@ -219,6 +218,14 @@ def _batch_DDLs(self, sql): # Only queue DDL statements if they are all correctly classified. self.connection._ddl_statements.extend(statements) + def _reset(self): + if self.connection.database is None: + raise ValueError("Database needs to be passed for this operation") + self._itr = None + self._result_set = None + self._row_count = None + self._batch_dml_rows_count = None + @check_not_closed def execute(self, sql, args=None): self._execute(sql, args, False) @@ -232,13 +239,8 @@ def _execute(self, sql, args=None, call_from_execute_many=False): :type args: list :param args: Additional parameters to supplement the SQL query. """ - if self.connection.database is None: - raise ValueError("Database needs to be passed for this operation") - self._itr = None - self._result_set = None - self._row_count = _UNSET_COUNT + self._reset() exception = None - try: self._parsed_statement = parse_utils.classify_statement(sql, args) if self._parsed_statement.statement_type == StatementType.CLIENT_SIDE: @@ -320,13 +322,8 @@ def executemany(self, operation, seq_of_params): :param seq_of_params: Sequence of additional parameters to run the query with. """ - if self.connection.database is None: - raise ValueError("Database needs to be passed for this operation") - self._itr = None - self._result_set = None - self._row_count = _UNSET_COUNT + self._reset() exception = None - try: self._parsed_statement = parse_utils.classify_statement(operation) if self._parsed_statement.statement_type == StatementType.DDL: @@ -464,7 +461,7 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params): self._itr = PeekIterator(self._result_set) # Unfortunately, Spanner doesn't seem to send back # information about the number of rows available. - self._row_count = _UNSET_COUNT + self._row_count = None if self._result_set.metadata.transaction.read_timestamp is not None: snapshot._transaction_read_timestamp = ( self._result_set.metadata.transaction.read_timestamp diff --git a/google/cloud/spanner_dbapi/transaction_helper.py b/google/cloud/spanner_dbapi/transaction_helper.py index 8abe8a41ae..b9418db947 100644 --- a/google/cloud/spanner_dbapi/transaction_helper.py +++ b/google/cloud/spanner_dbapi/transaction_helper.py @@ -13,17 +13,13 @@ # limitations under the License. from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, List, Union, Any, Dict +from typing import TYPE_CHECKING, List, Any, Dict from google.api_core.exceptions import Aborted import time from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode from google.cloud.spanner_dbapi.exceptions import RetryAborted -from google.cloud.spanner_dbapi.parsed_statement import ( - StatementType, - ClientSideStatementType, -) from google.cloud.spanner_v1.session import _get_retry_delay if TYPE_CHECKING: @@ -45,8 +41,8 @@ def __init__(self, connection: "Connection"): """ self._connection = connection - # list of all single and batch statements in the same order as executed - # in original transaction along with their checksum value + # list of all statements in the same order as executed in original + # transaction along with their results self._statement_result_details_list: List[StatementDetails] = [] # last StatementDetails that was added in the _statement_result_details_list self._last_statement_result_details: StatementDetails = None @@ -93,6 +89,7 @@ def add_fetch_statement_for_retry( and self._last_statement_result_details.cursor == cursor ): if exception is not None: + self._last_statement_result_details.result_type = ResultType.EXCEPTION self._last_statement_result_details.result_details = exception else: for row in result_rows: @@ -106,12 +103,14 @@ def add_fetch_statement_for_retry( self._last_statement_result_details = FetchStatement( cursor=cursor, statement_type=CursorStatementType.FETCH_ALL, + result_type=ResultType.CHECKSUM, result_details=result_details, ) else: self._last_statement_result_details = FetchStatement( cursor=cursor, statement_type=CursorStatementType.FETCH_MANY, + result_type=ResultType.CHECKSUM, result_details=result_details, size=len(result_rows), ) @@ -138,16 +137,15 @@ def add_execute_statement_for_retry( statement_type = CursorStatementType.EXECUTE if is_execute_many: statement_type = CursorStatementType.EXECUTE_MANY + + result_type = ResultType.NONE result_details = None - parsed_statement = cursor._parsed_statement if exception is not None: + result_type = ResultType.EXCEPTION result_details = exception - elif ( - parsed_statement.statement_type == StatementType.INSERT - or parsed_statement.statement_type == StatementType.UPDATE - or parsed_statement.client_side_statement_type - == ClientSideStatementType.RUN_BATCH - ): + # True in case of DML statement + elif cursor.rowcount != -1: + result_type = ResultType.ROW_COUNT result_details = cursor.rowcount self._last_statement_result_details = ExecuteStatement( @@ -155,6 +153,7 @@ def add_execute_statement_for_retry( statement_type=statement_type, sql=sql, args=args, + result_type=result_type, result_details=result_details, ) self._statement_result_details_list.append(self._last_statement_result_details) @@ -175,19 +174,17 @@ def retry_transaction(self): attempt += 1 if attempt > MAX_INTERNAL_RETRIES: raise - self._set_connection_for_retry() - try: for statement_result_details in self._statement_result_details_list: if statement_result_details.cursor in self._cursor_map: cursor = self._cursor_map.get(statement_result_details.cursor) else: - cursor: Cursor = self._connection.cursor() + cursor = self._connection.cursor() cursor._in_retry_mode = True self._cursor_map[statement_result_details.cursor] = cursor try: - self._handle_statement(statement_result_details, cursor) + _handle_statement(statement_result_details, cursor) except Aborted: raise except RetryAborted: @@ -205,33 +202,37 @@ def retry_transaction(self): if delay: time.sleep(delay) - def _handle_statement(self, statement_result_details, cursor): - statement_type = statement_result_details.statement_type + +def _handle_statement(statement_result_details, cursor): + statement_type = statement_result_details.statement_type + if _is_execute_type_statement(statement_type): if statement_type == CursorStatementType.EXECUTE: cursor.execute(statement_result_details.sql, statement_result_details.args) - if ( - type(statement_result_details.result_details) is int - and statement_result_details.result_details != cursor.rowcount - ): - raise RetryAborted(RETRY_ABORTED_ERROR) - elif statement_type == CursorStatementType.EXECUTE_MANY: + else: cursor.executemany( - statement_result_details.sql, - statement_result_details.args, + statement_result_details.sql, statement_result_details.args ) - if ( - type(statement_result_details.result_details) is int - and statement_result_details.result_details != cursor.rowcount - ): - raise RetryAborted(RETRY_ABORTED_ERROR) - elif statement_type == CursorStatementType.FETCH_ALL: + if ( + statement_result_details.result_type == ResultType.ROW_COUNT + and statement_result_details.result_details != cursor.rowcount + ): + raise RetryAborted(RETRY_ABORTED_ERROR) + else: + if statement_type == CursorStatementType.FETCH_ALL: res = cursor.fetchall() - checksum = _get_statement_result_checksum(res) - _compare_checksums(checksum, statement_result_details.result_details) - elif statement_type == CursorStatementType.FETCH_MANY: + else: res = cursor.fetchmany(statement_result_details.size) - checksum = _get_statement_result_checksum(res) - _compare_checksums(checksum, statement_result_details.result_details) + checksum = _get_statement_result_checksum(res) + _compare_checksums(checksum, statement_result_details.result_details) + if statement_result_details.result_type == ResultType.EXCEPTION: + raise RetryAborted(RETRY_ABORTED_ERROR) + + +def _is_execute_type_statement(statement_type): + return statement_type in ( + CursorStatementType.EXECUTE, + CursorStatementType.EXECUTE_MANY, + ) def _get_statement_result_checksum(res_iter): @@ -249,17 +250,24 @@ class CursorStatementType(Enum): FETCH_MANY = 5 +class ResultType(Enum): + # checksum of ResultSet in case of fetch call on query statement + CHECKSUM = (1,) + # None in case of execute call on query statement + NONE = (2,) + # Exception details in case of any statement execution throws exception + EXCEPTION = (3,) + # Total rows updated in case of execute call on DML statement + ROW_COUNT = 4 + + @dataclass class StatementDetails: statement_type: CursorStatementType # The cursor object on which this statement was executed cursor: "Cursor" - # This would be one of - # 1. checksum of ResultSet in case of fetch call on query statement - # 2. Total rows updated in case of DML - # 3. Exception details in case of statement execution throws exception - # 4. None in case of execute calls - result_details: Union[ResultsChecksum, int, Exception, None] + result_type: ResultType + result_details: Any @dataclass diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index a433f52d67..c7e88c1b26 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -649,69 +649,6 @@ def test_execute_sql_abort_retry_multiple_times(self, dbapi_database): got_rows = self._cursor.fetchall() assert len(got_rows) == 2 - @pytest.mark.noautofixt - def test_abort_retry_multiple_cursors(self, shared_instance, dbapi_database): - """Test that retry works when multiple cursors are involved in the transaction.""" - - try: - conn = Connection(shared_instance, dbapi_database) - cur = conn.cursor() - cur.execute( - """ - CREATE TABLE Singers ( - SingerId INT64 NOT NULL, - Name STRING(1024), - ) PRIMARY KEY (SingerId) - """ - ) - cur.execute( - """ - INSERT INTO contacts (contact_id, first_name, last_name, email) - VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') - """ - ) - cur.execute( - """ - INSERT INTO Singers (SingerId, Name) - VALUES (1, 'first-name') - """ - ) - cur.execute( - """ - INSERT INTO contacts (contact_id, first_name, last_name, email) - VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') - """ - ) - cur.execute( - """ - INSERT INTO Singers (SingerId, Name) - VALUES (2, 'first-name') - """ - ) - conn.commit() - - cur1 = conn.cursor() - cur1.execute("SELECT * FROM contacts") - cur2 = conn.cursor() - cur2.execute("SELECT * FROM Singers") - row1 = cur1.fetchone() - row2 = cur2.fetchone() - row3 = cur1.fetchone() - row4 = cur2.fetchone() - dbapi_database._method_abort_interceptor.set_method_to_abort(COMMIT_METHOD) - conn.commit() - dbapi_database._method_abort_interceptor.reset() - assert row1 == (1, "first-name", "last-name", "test.email@domen.ru") - assert row2 == (1, "first-name") - assert row3 == (2, "first-name", "last-name", "test.email@domen.ru") - assert row4 == (2, "first-name") - finally: - # Delete table - table = dbapi_database.table("Singers") - if table.exists(): - op = dbapi_database.update_ddl(["DROP TABLE Singers"]) - op.result() - def test_execute_batch_dml_abort_retry(self, dbapi_database): """Test that when any execute batch dml failed with Abort exception, then the retry succeeds with transaction having insert as well as query @@ -805,6 +742,73 @@ def test_consecutive_aborted_transactions(self, dbapi_database): got_rows = self._cursor.fetchall() assert len(got_rows) == 4 + @pytest.mark.noautofixt + def test_abort_retry_multiple_cursors(self, shared_instance, dbapi_database): + """Test that retry works when multiple cursors are involved in the transaction.""" + + try: + conn = Connection(shared_instance, dbapi_database) + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + Name STRING(1024), + ) PRIMARY KEY (SingerId) + """ + ) + cur.execute( + """ + INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + cur.execute( + """ + INSERT INTO Singers (SingerId, Name) + VALUES (1, 'first-name') + """ + ) + cur.execute( + """ + INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + cur.execute( + """ + INSERT INTO Singers (SingerId, Name) + VALUES (2, 'first-name') + """ + ) + conn.commit() + + cur1 = conn.cursor() + cur1.execute("SELECT * FROM contacts") + cur2 = conn.cursor() + cur2.execute("SELECT * FROM Singers") + row1 = cur1.fetchone() + row2 = cur2.fetchone() + row3 = cur1.fetchone() + row4 = cur2.fetchone() + dbapi_database._method_abort_interceptor.set_method_to_abort(COMMIT_METHOD) + conn.commit() + dbapi_database._method_abort_interceptor.reset() + + assert set([row1, row3]) == set( + [ + (1, "first-name", "last-name", "test.email@domen.ru"), + (2, "first-name", "last-name", "test.email@domen.ru"), + ] + ) + assert set([row2, row4]) == set([(1, "first-name"), (2, "first-name")]) + finally: + # Delete table + table = dbapi_database.table("Singers") + if table.exists(): + op = dbapi_database.update_ddl(["DROP TABLE Singers"]) + op.result() + def test_begin_success_post_commit(self): """Test beginning a new transaction post commiting an existing transaction is possible on a connection, when connection is in autocommit mode.""" diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index b91955bfb7..9735185a5c 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -71,12 +71,10 @@ def test_property_description(self): self.assertIsInstance(cursor.description[0], ColumnInfo) def test_property_rowcount(self): - from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT - connection = self._make_connection(self.INSTANCE, self.DATABASE) cursor = self._make_one(connection) - self.assertEqual(cursor.rowcount, _UNSET_COUNT) + self.assertEqual(cursor.rowcount, None) def test_callproc(self): from google.cloud.spanner_dbapi.exceptions import InterfaceError @@ -505,7 +503,7 @@ def test_executemany(self, mock_client): cursor._itr = iter([1, 2, 3]) with mock.patch( - "google.cloud.spanner_dbapi.cursor.Cursor.execute" + "google.cloud.spanner_dbapi.cursor.Cursor._execute" ) as execute_mock: cursor.executemany(operation, params_seq) @@ -858,8 +856,6 @@ def test_setoutputsize(self): @mock.patch("google.cloud.spanner_dbapi.cursor.PeekIterator") def test_handle_dql(self, MockedPeekIterator): - from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT - connection = self._make_connection(self.INSTANCE, mock.MagicMock()) connection.database.snapshot.return_value.__enter__.return_value = ( mock_snapshot @@ -871,11 +867,10 @@ def test_handle_dql(self, MockedPeekIterator): cursor._handle_DQL("sql", params=None) self.assertEqual(cursor._result_set, _result_set) self.assertEqual(cursor._itr, MockedPeekIterator()) - self.assertEqual(cursor._row_count, _UNSET_COUNT) + self.assertEqual(cursor._row_count, None) @mock.patch("google.cloud.spanner_dbapi.cursor.PeekIterator") def test_handle_dql_priority(self, MockedPeekIterator): - from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT from google.cloud.spanner_v1 import RequestOptions connection = self._make_connection(self.INSTANCE, mock.MagicMock()) @@ -892,7 +887,7 @@ def test_handle_dql_priority(self, MockedPeekIterator): cursor._handle_DQL(sql, params=None) self.assertEqual(cursor._result_set, _result_set) self.assertEqual(cursor._itr, MockedPeekIterator()) - self.assertEqual(cursor._row_count, _UNSET_COUNT) + self.assertEqual(cursor._row_count, None) mock_snapshot.execute_sql.assert_called_with( sql, None, None, request_options=RequestOptions(priority=1) ) diff --git a/tests/unit/spanner_dbapi/test_transaction_helper.py b/tests/unit/spanner_dbapi/test_transaction_helper.py index 65c274f5d7..2173af0dea 100644 --- a/tests/unit/spanner_dbapi/test_transaction_helper.py +++ b/tests/unit/spanner_dbapi/test_transaction_helper.py @@ -26,6 +26,7 @@ ExecuteStatement, CursorStatementType, FetchStatement, + ResultType, ) @@ -55,6 +56,7 @@ def test_retry_transaction_execute(self): cursor=self._mock_cursor, sql=SQL, args=ARGS, + result_type=ResultType.NONE, result_details=None, ) self._under_test._statement_result_details_list.append(execute_statement) @@ -74,6 +76,7 @@ def test_retry_transaction_dml_execute(self): cursor=self._mock_cursor, sql=SQL, args=ARGS, + result_type=ResultType.ROW_COUNT, result_details=update_count, ) self._under_test._statement_result_details_list.append(execute_statement) @@ -94,6 +97,7 @@ def test_retry_transaction_dml_execute_exception(self): cursor=self._mock_cursor, sql=SQL, args=ARGS, + result_type=ResultType.ROW_COUNT, result_details=2, ) self._under_test._statement_result_details_list.append(execute_statement) @@ -114,6 +118,7 @@ def test_retry_transaction_execute_many(self): cursor=self._mock_cursor, sql=SQL, args=ARGS, + result_type=ResultType.NONE, result_details=None, ) self._under_test._statement_result_details_list.append(execute_statement) @@ -133,6 +138,7 @@ def test_retry_transaction_dml_execute_many(self): cursor=self._mock_cursor, sql=SQL, args=ARGS, + result_type=ResultType.ROW_COUNT, result_details=update_count, ) self._under_test._statement_result_details_list.append(execute_statement) @@ -153,6 +159,7 @@ def test_retry_transaction_dml_executemany_exception(self): cursor=self._mock_cursor, sql=SQL, args=ARGS, + result_type=ResultType.ROW_COUNT, result_details=2, ) self._under_test._statement_result_details_list.append(execute_statement) @@ -172,6 +179,7 @@ def test_retry_transaction_fetchall(self): fetch_statement = FetchStatement( cursor=self._mock_cursor, statement_type=CursorStatementType.FETCH_ALL, + result_type=ResultType.CHECKSUM, result_details=_get_checksum(result_row), ) self._under_test._statement_result_details_list.append(fetch_statement) @@ -191,6 +199,7 @@ def test_retry_transaction_fetchall_exception(self): fetch_statement = FetchStatement( cursor=self._mock_cursor, statement_type=CursorStatementType.FETCH_ALL, + result_type=ResultType.CHECKSUM, result_details=_get_checksum(result_row), ) self._under_test._statement_result_details_list.append(fetch_statement) @@ -211,6 +220,7 @@ def test_retry_transaction_fetchmany(self): fetch_statement = FetchStatement( cursor=self._mock_cursor, statement_type=CursorStatementType.FETCH_MANY, + result_type=ResultType.CHECKSUM, result_details=_get_checksum(result_row), size=1, ) @@ -231,6 +241,7 @@ def test_retry_transaction_fetchmany_exception(self): fetch_statement = FetchStatement( cursor=self._mock_cursor, statement_type=CursorStatementType.FETCH_MANY, + result_type=ResultType.CHECKSUM, result_details=_get_checksum(result_row), size=1, ) @@ -255,6 +266,7 @@ def test_retry_transaction_same_exception(self): cursor=self._mock_cursor, sql=SQL, args=ARGS, + result_type=ResultType.EXCEPTION, result_details=exception, ) self._under_test._statement_result_details_list.append(execute_statement) @@ -275,6 +287,7 @@ def test_retry_transaction_different_exception(self): cursor=self._mock_cursor, sql=SQL, args=ARGS, + result_type=ResultType.EXCEPTION, result_details=Exception("Test"), ) self._under_test._statement_result_details_list.append(execute_statement) @@ -296,6 +309,7 @@ def test_retry_transaction_aborted_retry(self): cursor=self._mock_cursor, sql=SQL, args=ARGS, + result_type=ResultType.NONE, result_details=None, ) self._under_test._statement_result_details_list.append(execute_statement) @@ -336,6 +350,7 @@ def test_add_execute_statement_for_retry(self): cursor=self._mock_cursor, sql=sql, args=[], + result_type=ResultType.ROW_COUNT, result_details=rows_inserted, ) self.assertEqual( @@ -367,6 +382,7 @@ def test_add_execute_statement_for_retry_with_exception(self): cursor=self._mock_cursor, sql=sql, args=[], + result_type=ResultType.EXCEPTION, result_details=exception, ) self.assertEqual( @@ -397,6 +413,7 @@ def test_add_execute_statement_for_retry_query_statement(self): cursor=self._mock_cursor, sql=sql, args=[], + result_type=ResultType.NONE, result_details=None, ) self.assertEqual( @@ -428,6 +445,7 @@ def test_add_execute_many_statement_for_retry(self): cursor=self._mock_cursor, sql=sql, args=[], + result_type=ResultType.ROW_COUNT, result_details=rows_inserted, ) self.assertEqual( @@ -450,6 +468,7 @@ def test_add_fetch_statement_for_retry(self): self._under_test._last_statement_result_details = FetchStatement( statement_type=CursorStatementType.FETCH_MANY, cursor=self._mock_cursor, + result_type=ResultType.CHECKSUM, result_details=result_checksum, size=1, ) @@ -476,6 +495,7 @@ def test_add_fetch_statement_for_retry_with_exception(self): self._under_test._last_statement_result_details = FetchStatement( statement_type=CursorStatementType.FETCH_MANY, cursor=self._mock_cursor, + result_type=ResultType.CHECKSUM, result_details=_get_checksum(result_row), size=1, ) @@ -490,6 +510,7 @@ def test_add_fetch_statement_for_retry_with_exception(self): FetchStatement( statement_type=CursorStatementType.FETCH_MANY, cursor=self._mock_cursor, + result_type=ResultType.EXCEPTION, result_details=exception, size=1, ), @@ -509,6 +530,7 @@ def test_add_fetch_statement_for_retry_last_statement_not_exists(self): expected_statement = FetchStatement( statement_type=CursorStatementType.FETCH_MANY, cursor=self._mock_cursor, + result_type=ResultType.CHECKSUM, result_details=_get_checksum(row), size=1, ) @@ -534,6 +556,7 @@ def test_add_fetch_statement_for_retry_fetch_all_statement(self): expected_statement = FetchStatement( statement_type=CursorStatementType.FETCH_ALL, cursor=self._mock_cursor, + result_type=ResultType.CHECKSUM, result_details=_get_checksum(row), ) self.assertEqual( @@ -555,6 +578,7 @@ def test_add_fetch_statement_for_retry_when_last_statement_is_not_fetch(self): cursor=self._mock_cursor, sql=SQL, args=ARGS, + result_type=ResultType.ROW_COUNT, result_details=2, ) self._under_test._last_statement_result_details = execute_statement @@ -568,6 +592,7 @@ def test_add_fetch_statement_for_retry_when_last_statement_is_not_fetch(self): expected_fetch_statement = FetchStatement( statement_type=CursorStatementType.FETCH_MANY, cursor=self._mock_cursor, + result_type=ResultType.CHECKSUM, result_details=_get_checksum(row), size=1, ) From d7ef54f729b267bed47df53f08b3a374e20f75ec Mon Sep 17 00:00:00 2001 From: ankiaga Date: Thu, 4 Jan 2024 13:53:00 +0530 Subject: [PATCH 4/9] Small fix --- .../cloud/spanner_dbapi/transaction_helper.py | 23 +++++++++++++------ .../spanner_dbapi/test_transaction_helper.py | 18 +++++++++------ 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/google/cloud/spanner_dbapi/transaction_helper.py b/google/cloud/spanner_dbapi/transaction_helper.py index b9418db947..a8cd718354 100644 --- a/google/cloud/spanner_dbapi/transaction_helper.py +++ b/google/cloud/spanner_dbapi/transaction_helper.py @@ -143,8 +143,10 @@ def add_execute_statement_for_retry( if exception is not None: result_type = ResultType.EXCEPTION result_details = exception - # True in case of DML statement - elif cursor.rowcount != -1: + elif cursor._batch_dml_rows_count is not None: + result_type = ResultType.BATCH_DML_ROWS_COUNT + result_details = cursor._batch_dml_rows_count + elif cursor._row_count is not None: result_type = ResultType.ROW_COUNT result_details = cursor.rowcount @@ -208,13 +210,18 @@ def _handle_statement(statement_result_details, cursor): if _is_execute_type_statement(statement_type): if statement_type == CursorStatementType.EXECUTE: cursor.execute(statement_result_details.sql, statement_result_details.args) + if ( + statement_result_details.result_type == ResultType.ROW_COUNT + and statement_result_details.result_details != cursor.rowcount + ): + raise RetryAborted(RETRY_ABORTED_ERROR) else: cursor.executemany( statement_result_details.sql, statement_result_details.args ) if ( - statement_result_details.result_type == ResultType.ROW_COUNT - and statement_result_details.result_details != cursor.rowcount + statement_result_details.result_type == ResultType.BATCH_DML_ROWS_COUNT + and statement_result_details.result_details != cursor._batch_dml_rows_count ): raise RetryAborted(RETRY_ABORTED_ERROR) else: @@ -252,13 +259,15 @@ class CursorStatementType(Enum): class ResultType(Enum): # checksum of ResultSet in case of fetch call on query statement - CHECKSUM = (1,) + CHECKSUM = 1 # None in case of execute call on query statement - NONE = (2,) + NONE = 2 # Exception details in case of any statement execution throws exception - EXCEPTION = (3,) + EXCEPTION = 3 # Total rows updated in case of execute call on DML statement ROW_COUNT = 4 + # Total rows updated in case of Batch DML statement execution + BATCH_DML_ROWS_COUNT = 5 @dataclass diff --git a/tests/unit/spanner_dbapi/test_transaction_helper.py b/tests/unit/spanner_dbapi/test_transaction_helper.py index 2173af0dea..282edaf913 100644 --- a/tests/unit/spanner_dbapi/test_transaction_helper.py +++ b/tests/unit/spanner_dbapi/test_transaction_helper.py @@ -154,17 +154,19 @@ def test_retry_transaction_dml_executemany_exception(self): Test retrying a transaction with an executemany DML statement with different row update count than original throws RetryAborted exception. """ + rows_inserted = [3, 4] + self._mock_cursor._batch_dml_rows_count = rows_inserted execute_statement = ExecuteStatement( statement_type=CursorStatementType.EXECUTE_MANY, cursor=self._mock_cursor, sql=SQL, args=ARGS, - result_type=ResultType.ROW_COUNT, - result_details=2, + result_type=ResultType.BATCH_DML_ROWS_COUNT, + result_details=rows_inserted, ) self._under_test._statement_result_details_list.append(execute_statement) run_mock = self._under_test._connection.cursor = mock.Mock() - run_mock().rowcount = 3 + run_mock()._batch_dml_rows_count = [4, 3] with self.assertRaises(RetryAborted): self._under_test.retry_transaction() @@ -341,6 +343,7 @@ def test_add_execute_statement_for_retry(self): sql = "INSERT INTO Table" rows_inserted = 3 self._mock_cursor.rowcount = rows_inserted + self._mock_cursor._batch_dml_rows_count = None self._under_test.add_execute_statement_for_retry( self._mock_cursor, sql, [], None, False ) @@ -401,7 +404,8 @@ def test_add_execute_statement_for_retry_query_statement(self): self._mock_cursor._parsed_statement = ParsedStatement( statement_type=StatementType.QUERY, statement=None ) - self._mock_cursor.rowcount = -1 + self._mock_cursor._row_count = None + self._mock_cursor._batch_dml_rows_count = None sql = "SELECT 1" self._under_test.add_execute_statement_for_retry( @@ -434,8 +438,8 @@ def test_add_execute_many_statement_for_retry(self): ) sql = "INSERT INTO Table" - rows_inserted = 3 - self._mock_cursor.rowcount = rows_inserted + rows_inserted = [3, 4] + self._mock_cursor._batch_dml_rows_count = rows_inserted self._under_test.add_execute_statement_for_retry( self._mock_cursor, sql, [], None, True ) @@ -445,7 +449,7 @@ def test_add_execute_many_statement_for_retry(self): cursor=self._mock_cursor, sql=sql, args=[], - result_type=ResultType.ROW_COUNT, + result_type=ResultType.BATCH_DML_ROWS_COUNT, result_details=rows_inserted, ) self.assertEqual( From 721a9aeeccfa7ecb1a790867135f69c3bd9668d1 Mon Sep 17 00:00:00 2001 From: ankiaga Date: Thu, 4 Jan 2024 18:01:00 +0530 Subject: [PATCH 5/9] Maintaining a map from cursor to last statement added in transaction_helper.py --- .../cloud/spanner_dbapi/transaction_helper.py | 60 ++++++++++--------- .../spanner_dbapi/test_transaction_helper.py | 50 ++++++++++------ 2 files changed, 62 insertions(+), 48 deletions(-) diff --git a/google/cloud/spanner_dbapi/transaction_helper.py b/google/cloud/spanner_dbapi/transaction_helper.py index a8cd718354..bc896009c7 100644 --- a/google/cloud/spanner_dbapi/transaction_helper.py +++ b/google/cloud/spanner_dbapi/transaction_helper.py @@ -44,8 +44,8 @@ def __init__(self, connection: "Connection"): # list of all statements in the same order as executed in original # transaction along with their results self._statement_result_details_list: List[StatementDetails] = [] - # last StatementDetails that was added in the _statement_result_details_list - self._last_statement_result_details: StatementDetails = None + # Map of last StatementDetails that was added to a particular cursor + self._last_statement_details_per_cursor: Dict[Cursor, StatementDetails] = {} # 1-1 map from original cursor object on which transaction ran to the # new cursor object used in the retry self._cursor_map: Dict[Cursor, Cursor] = {} @@ -61,7 +61,7 @@ def reset(self): or aborted """ self._statement_result_details_list = [] - self._last_statement_result_details = None + self._last_statement_details_per_cursor = {} self._cursor_map = {} def add_fetch_statement_for_retry( @@ -82,41 +82,42 @@ def add_fetch_statement_for_retry( """ if not self._connection._client_transaction_started: return + + last_statement_result_details = self._last_statement_details_per_cursor.get( + cursor + ) if ( - self._last_statement_result_details is not None - and self._last_statement_result_details.statement_type + last_statement_result_details is not None + and last_statement_result_details.statement_type == CursorStatementType.FETCH_MANY - and self._last_statement_result_details.cursor == cursor ): if exception is not None: - self._last_statement_result_details.result_type = ResultType.EXCEPTION - self._last_statement_result_details.result_details = exception + last_statement_result_details.result_type = ResultType.EXCEPTION + last_statement_result_details.result_details = exception else: for row in result_rows: - self._last_statement_result_details.result_details.consume_result( - row - ) - self._last_statement_result_details.size += len(result_rows) + last_statement_result_details.result_details.consume_result(row) + last_statement_result_details.size += len(result_rows) else: result_details = _get_statement_result_checksum(result_rows) if is_fetch_all: - self._last_statement_result_details = FetchStatement( - cursor=cursor, - statement_type=CursorStatementType.FETCH_ALL, - result_type=ResultType.CHECKSUM, - result_details=result_details, - ) + statement_type = CursorStatementType.FETCH_ALL + size = None else: - self._last_statement_result_details = FetchStatement( - cursor=cursor, - statement_type=CursorStatementType.FETCH_MANY, - result_type=ResultType.CHECKSUM, - result_details=result_details, - size=len(result_rows), - ) - self._statement_result_details_list.append( - self._last_statement_result_details + statement_type = CursorStatementType.FETCH_MANY + size = len(result_rows) + + last_statement_result_details = FetchStatement( + cursor=cursor, + statement_type=statement_type, + result_type=ResultType.CHECKSUM, + result_details=result_details, + size=size, ) + self._last_statement_details_per_cursor[ + cursor + ] = last_statement_result_details + self._statement_result_details_list.append(last_statement_result_details) def add_execute_statement_for_retry( self, cursor, sql, args, exception, is_execute_many @@ -150,7 +151,7 @@ def add_execute_statement_for_retry( result_type = ResultType.ROW_COUNT result_details = cursor.rowcount - self._last_statement_result_details = ExecuteStatement( + last_statement_result_details = ExecuteStatement( cursor=cursor, statement_type=statement_type, sql=sql, @@ -158,7 +159,8 @@ def add_execute_statement_for_retry( result_type=result_type, result_details=result_details, ) - self._statement_result_details_list.append(self._last_statement_result_details) + self._last_statement_details_per_cursor[cursor] = last_statement_result_details + self._statement_result_details_list.append(last_statement_result_details) def retry_transaction(self): """Retry the aborted transaction. diff --git a/tests/unit/spanner_dbapi/test_transaction_helper.py b/tests/unit/spanner_dbapi/test_transaction_helper.py index 282edaf913..1d50a51825 100644 --- a/tests/unit/spanner_dbapi/test_transaction_helper.py +++ b/tests/unit/spanner_dbapi/test_transaction_helper.py @@ -357,8 +357,8 @@ def test_add_execute_statement_for_retry(self): result_details=rows_inserted, ) self.assertEqual( - self._under_test._last_statement_result_details, - expected_statement_result_details, + self._under_test._last_statement_details_per_cursor, + {self._mock_cursor: expected_statement_result_details}, ) self.assertEqual( self._under_test._statement_result_details_list, @@ -389,8 +389,8 @@ def test_add_execute_statement_for_retry_with_exception(self): result_details=exception, ) self.assertEqual( - self._under_test._last_statement_result_details, - expected_statement_result_details, + self._under_test._last_statement_details_per_cursor, + {self._mock_cursor: expected_statement_result_details}, ) self.assertEqual( self._under_test._statement_result_details_list, @@ -421,8 +421,8 @@ def test_add_execute_statement_for_retry_query_statement(self): result_details=None, ) self.assertEqual( - self._under_test._last_statement_result_details, - expected_statement_result_details, + self._under_test._last_statement_details_per_cursor, + {self._mock_cursor: expected_statement_result_details}, ) self.assertEqual( self._under_test._statement_result_details_list, @@ -453,8 +453,8 @@ def test_add_execute_many_statement_for_retry(self): result_details=rows_inserted, ) self.assertEqual( - self._under_test._last_statement_result_details, - expected_statement_result_details, + self._under_test._last_statement_details_per_cursor, + {self._mock_cursor: expected_statement_result_details}, ) self.assertEqual( self._under_test._statement_result_details_list, @@ -469,25 +469,31 @@ def test_add_fetch_statement_for_retry(self): result_row = ("field1", "field2") result_checksum = _get_checksum(result_row) original_checksum_digest = result_checksum.checksum.digest() - self._under_test._last_statement_result_details = FetchStatement( + last_statement_result_details = FetchStatement( statement_type=CursorStatementType.FETCH_MANY, cursor=self._mock_cursor, result_type=ResultType.CHECKSUM, result_details=result_checksum, size=1, ) + self._under_test._last_statement_details_per_cursor = { + self._mock_cursor: last_statement_result_details + } new_rows = [("field3", "field4"), ("field5", "field6")] self._under_test.add_fetch_statement_for_retry( self._mock_cursor, new_rows, None, False ) + updated_last_statement_result_details = ( + self._under_test._last_statement_details_per_cursor.get(self._mock_cursor) + ) self.assertEqual( - self._under_test._last_statement_result_details.size, + updated_last_statement_result_details.size, 3, ) self.assertNotEqual( - self._under_test._last_statement_result_details.result_details.checksum.digest(), + updated_last_statement_result_details.result_details.checksum.digest(), original_checksum_digest, ) @@ -496,13 +502,16 @@ def test_add_fetch_statement_for_retry_with_exception(self): Test add_fetch_statement_for_retry method with exception """ result_row = ("field1", "field2") - self._under_test._last_statement_result_details = FetchStatement( + fetch_statement = FetchStatement( statement_type=CursorStatementType.FETCH_MANY, cursor=self._mock_cursor, result_type=ResultType.CHECKSUM, result_details=_get_checksum(result_row), size=1, ) + self._under_test._last_statement_details_per_cursor = { + self._mock_cursor: fetch_statement + } exception = Exception("Test") self._under_test.add_fetch_statement_for_retry( @@ -510,7 +519,7 @@ def test_add_fetch_statement_for_retry_with_exception(self): ) self.assertEqual( - self._under_test._last_statement_result_details, + self._under_test._last_statement_details_per_cursor.get(self._mock_cursor), FetchStatement( statement_type=CursorStatementType.FETCH_MANY, cursor=self._mock_cursor, @@ -539,8 +548,8 @@ def test_add_fetch_statement_for_retry_last_statement_not_exists(self): size=1, ) self.assertEqual( - self._under_test._last_statement_result_details, - expected_statement, + self._under_test._last_statement_details_per_cursor, + {self._mock_cursor: expected_statement}, ) self.assertEqual( self._under_test._statement_result_details_list, @@ -564,8 +573,8 @@ def test_add_fetch_statement_for_retry_fetch_all_statement(self): result_details=_get_checksum(row), ) self.assertEqual( - self._under_test._last_statement_result_details, - expected_statement, + self._under_test._last_statement_details_per_cursor, + {self._mock_cursor: expected_statement}, ) self.assertEqual( self._under_test._statement_result_details_list, @@ -585,7 +594,9 @@ def test_add_fetch_statement_for_retry_when_last_statement_is_not_fetch(self): result_type=ResultType.ROW_COUNT, result_details=2, ) - self._under_test._last_statement_result_details = execute_statement + self._under_test._last_statement_details_per_cursor = { + self._mock_cursor: execute_statement + } self._under_test._statement_result_details_list.append(execute_statement) row = ("field3", "field4") @@ -601,7 +612,8 @@ def test_add_fetch_statement_for_retry_when_last_statement_is_not_fetch(self): size=1, ) self.assertEqual( - self._under_test._last_statement_result_details, expected_fetch_statement + self._under_test._last_statement_details_per_cursor, + {self._mock_cursor: expected_fetch_statement}, ) self.assertEqual( self._under_test._statement_result_details_list, From 86db58a872933628d97f0e74d3c51752fc81026d Mon Sep 17 00:00:00 2001 From: ankiaga Date: Tue, 9 Jan 2024 15:39:48 +0530 Subject: [PATCH 6/9] Rolling back the transaction when Aborted exception is thrown from interceptor --- .../cloud/spanner_v1/testing/interceptors.py | 9 +- tests/system/test_dbapi.py | 155 ++++++++++-------- tests/unit/spanner_dbapi/test_connection.py | 37 ----- 3 files changed, 92 insertions(+), 109 deletions(-) diff --git a/google/cloud/spanner_v1/testing/interceptors.py b/google/cloud/spanner_v1/testing/interceptors.py index a439ba4f86..a16ec4a516 100644 --- a/google/cloud/spanner_v1/testing/interceptors.py +++ b/google/cloud/spanner_v1/testing/interceptors.py @@ -39,6 +39,7 @@ def __init__(self): self._method_to_abort = None self._count = 0 self._max_raise_count = 1 + self._connection = None def intercept(self, method, request_or_iterator, call_details): if ( @@ -46,14 +47,20 @@ def intercept(self, method, request_or_iterator, call_details): and call_details.method == self._method_to_abort ): self._count += 1 + if self._connection is not None: + self._connection._transaction.rollback() + self._connection._transaction.rolled_back = False raise Aborted("Thrown from ClientInterceptor for testing") return method(request_or_iterator, call_details) - def set_method_to_abort(self, method_to_abort, max_raise_count=1): + def set_method_to_abort(self, method_to_abort, connection=None, max_raise_count=1): self._method_to_abort = method_to_abort + self._count = 0 self._max_raise_count = max_raise_count + self._connection = connection def reset(self): """Reset the interceptor to the original state.""" self._method_to_abort = None self._count = 0 + self._connection = None diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index c7e88c1b26..64d6d06880 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -21,7 +21,11 @@ from google.cloud._helpers import UTC from google.cloud.spanner_dbapi.connection import Connection, connect -from google.cloud.spanner_dbapi.exceptions import ProgrammingError, OperationalError +from google.cloud.spanner_dbapi.exceptions import ( + ProgrammingError, + OperationalError, + RetryAborted, +) from google.cloud.spanner_v1 import JsonObject from google.cloud.spanner_v1 import gapic_version as package_version from google.api_core.datetime_helpers import DatetimeWithNanoseconds @@ -609,7 +613,9 @@ def test_commit_abort_retry(self, dbapi_database): ) self._cursor.fetchone() self._cursor.fetchmany(2) - dbapi_database._method_abort_interceptor.set_method_to_abort(COMMIT_METHOD) + dbapi_database._method_abort_interceptor.set_method_to_abort( + COMMIT_METHOD, self._conn + ) # called 2 times self._conn.commit() dbapi_database._method_abort_interceptor.reset() @@ -621,6 +627,39 @@ def test_commit_abort_retry(self, dbapi_database): got_rows = self._cursor.fetchall() assert len(got_rows) == 5 + def test_retry_aborted(self, shared_instance, dbapi_database): + """Test that retry fails with RetryAborted error when rows are updated during retry.""" + + conn1 = Connection(shared_instance, dbapi_database) + cursor1 = conn1.cursor() + cursor1.execute( + """ + INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + conn1.commit() + cursor1.execute("SELECT * FROM contacts") + cursor1.fetchall() + + conn2 = Connection(shared_instance, dbapi_database) + cursor2 = conn2.cursor() + cursor2.execute( + """ + UPDATE contacts + SET email = 'test.email_updated@domen.ru' + WHERE contact_id = 1 + """ + ) + conn2.commit() + + dbapi_database._method_abort_interceptor.set_method_to_abort( + COMMIT_METHOD, conn1 + ) + with pytest.raises(RetryAborted): + conn1.commit() + dbapi_database._method_abort_interceptor.reset() + def test_execute_sql_abort_retry_multiple_times(self, dbapi_database): """Test that when execute sql failed 2 times with Abort exception, then the retry succeeds 3rd time.""" @@ -633,14 +672,13 @@ def test_execute_sql_abort_retry_multiple_times(self, dbapi_database): self._cursor.execute("run batch") # aborting method 2 times before succeeding dbapi_database._method_abort_interceptor.set_method_to_abort( - EXECUTE_STREAMING_SQL_METHOD, 2 + EXECUTE_STREAMING_SQL_METHOD, self._conn, 2 ) self._cursor.execute("SELECT * FROM contacts") self._cursor.fetchmany(2) dbapi_database._method_abort_interceptor.reset() self._conn.commit() # Check that all rpcs except commit should be called 3 times the original - print(method_count_interceptor._counts) assert method_count_interceptor._counts[COMMIT_METHOD] == 1 assert method_count_interceptor._counts[EXECUTE_BATCH_DML_METHOD] == 3 assert method_count_interceptor._counts[EXECUTE_STREAMING_SQL_METHOD] == 3 @@ -665,7 +703,7 @@ def test_execute_batch_dml_abort_retry(self, dbapi_database): self._insert_row(2) self._insert_row(3) dbapi_database._method_abort_interceptor.set_method_to_abort( - EXECUTE_BATCH_DML_METHOD, 2 + EXECUTE_BATCH_DML_METHOD, self._conn, 2 ) # called 3 times self._cursor.execute("run batch") @@ -688,7 +726,7 @@ def test_multiple_aborts_in_transaction(self, dbapi_database): # called 3 times self._insert_row(1) dbapi_database._method_abort_interceptor.set_method_to_abort( - EXECUTE_STREAMING_SQL_METHOD + EXECUTE_STREAMING_SQL_METHOD, self._conn ) # called 3 times self._cursor.execute("SELECT * FROM contacts") @@ -699,7 +737,9 @@ def test_multiple_aborts_in_transaction(self, dbapi_database): # called 2 times self._cursor.execute("SELECT * FROM contacts") self._cursor.fetchone() - dbapi_database._method_abort_interceptor.set_method_to_abort(COMMIT_METHOD) + dbapi_database._method_abort_interceptor.set_method_to_abort( + COMMIT_METHOD, self._conn + ) # called 2 times self._conn.commit() dbapi_database._method_abort_interceptor.reset() @@ -720,7 +760,9 @@ def test_consecutive_aborted_transactions(self, dbapi_database): self._insert_row(2) self._cursor.execute("SELECT * FROM contacts") self._cursor.fetchall() - dbapi_database._method_abort_interceptor.set_method_to_abort(COMMIT_METHOD) + dbapi_database._method_abort_interceptor.set_method_to_abort( + COMMIT_METHOD, self._conn + ) self._conn.commit() dbapi_database._method_abort_interceptor.reset() assert method_count_interceptor._counts[COMMIT_METHOD] == 2 @@ -732,7 +774,9 @@ def test_consecutive_aborted_transactions(self, dbapi_database): self._insert_row(4) self._cursor.execute("SELECT * FROM contacts") self._cursor.fetchall() - dbapi_database._method_abort_interceptor.set_method_to_abort(COMMIT_METHOD) + dbapi_database._method_abort_interceptor.set_method_to_abort( + COMMIT_METHOD, self._conn + ) self._conn.commit() dbapi_database._method_abort_interceptor.reset() assert method_count_interceptor._counts[COMMIT_METHOD] == 2 @@ -742,72 +786,41 @@ def test_consecutive_aborted_transactions(self, dbapi_database): got_rows = self._cursor.fetchall() assert len(got_rows) == 4 - @pytest.mark.noautofixt - def test_abort_retry_multiple_cursors(self, shared_instance, dbapi_database): + def test_abort_retry_multiple_cursors(self, dbapi_database): """Test that retry works when multiple cursors are involved in the transaction.""" - try: - conn = Connection(shared_instance, dbapi_database) - cur = conn.cursor() - cur.execute( - """ - CREATE TABLE Singers ( - SingerId INT64 NOT NULL, - Name STRING(1024), - ) PRIMARY KEY (SingerId) - """ - ) - cur.execute( - """ - INSERT INTO contacts (contact_id, first_name, last_name, email) - VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') - """ - ) - cur.execute( - """ - INSERT INTO Singers (SingerId, Name) - VALUES (1, 'first-name') - """ - ) - cur.execute( - """ - INSERT INTO contacts (contact_id, first_name, last_name, email) - VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') - """ - ) - cur.execute( - """ - INSERT INTO Singers (SingerId, Name) - VALUES (2, 'first-name') - """ - ) - conn.commit() + self._insert_row(1) + self._insert_row(2) + self._insert_row(3) + self._insert_row(4) + self._conn.commit() - cur1 = conn.cursor() - cur1.execute("SELECT * FROM contacts") - cur2 = conn.cursor() - cur2.execute("SELECT * FROM Singers") - row1 = cur1.fetchone() - row2 = cur2.fetchone() - row3 = cur1.fetchone() - row4 = cur2.fetchone() - dbapi_database._method_abort_interceptor.set_method_to_abort(COMMIT_METHOD) - conn.commit() - dbapi_database._method_abort_interceptor.reset() + cur1 = self._conn.cursor() + cur1.execute("SELECT * FROM contacts WHERE contact_id IN (1, 2)") + cur2 = self._conn.cursor() + cur2.execute("SELECT * FROM contacts WHERE contact_id IN (3, 4)") + row1 = cur1.fetchone() + row2 = cur2.fetchone() + row3 = cur1.fetchone() + row4 = cur2.fetchone() + dbapi_database._method_abort_interceptor.set_method_to_abort( + COMMIT_METHOD, self._conn + ) + self._conn.commit() + dbapi_database._method_abort_interceptor.reset() - assert set([row1, row3]) == set( - [ - (1, "first-name", "last-name", "test.email@domen.ru"), - (2, "first-name", "last-name", "test.email@domen.ru"), - ] - ) - assert set([row2, row4]) == set([(1, "first-name"), (2, "first-name")]) - finally: - # Delete table - table = dbapi_database.table("Singers") - if table.exists(): - op = dbapi_database.update_ddl(["DROP TABLE Singers"]) - op.result() + assert set([row1, row3]) == set( + [ + (1, "first-name-1", "last-name-1", "test.email@domen.ru"), + (2, "first-name-2", "last-name-2", "test.email@domen.ru"), + ] + ) + assert set([row2, row4]) == set( + [ + (3, "first-name-3", "last-name-3", "test.email@domen.ru"), + (4, "first-name-4", "last-name-4", "test.email@domen.ru"), + ] + ) def test_begin_success_post_commit(self): """Test beginning a new transaction post commiting an existing transaction diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index e2706e1966..eece10c741 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -121,26 +121,6 @@ def test_read_only_connection(self): connection.read_only = False self.assertFalse(connection.read_only) - def test_read_only_not_retried(self): - """ - Testing the unlikely case of a read-only transaction - failed with Aborted exception. In this case the - transaction should not be automatically retried. - """ - from google.api_core.exceptions import Aborted - - connection = self._make_connection(read_only=True) - connection.retry_transaction = mock.Mock() - - cursor = connection.cursor() - cursor._itr = mock.Mock( - __next__=mock.Mock( - side_effect=Aborted("Aborted"), - ) - ) - - connection.retry_transaction.assert_not_called() - @staticmethod def _make_pool(): from google.cloud.spanner_v1.pool import AbstractSessionPool @@ -494,23 +474,6 @@ def test_begin(self): self.assertEqual(self._under_test._transaction_begin_marked, True) - @mock.patch("google.cloud.spanner_v1.Client") - def test_commit_retry_aborted_statements(self, mock_client): - """Check that retried transaction executing the same statements.""" - from google.api_core.exceptions import Aborted - from google.cloud.spanner_dbapi.connection import connect - - connection = connect("test-instance", "test-database") - mock_transaction = mock.Mock() - connection._spanner_transaction_started = True - connection._transaction = mock_transaction - mock_transaction.commit.side_effect = [Aborted("Aborted"), None] - run_mock = connection._transaction_helper = mock.Mock() - - connection.commit() - - assert run_mock.retry_transaction.called - def test_validate_ok(self): connection = self._make_connection() From bad9abebd17f2dd1752004936e0eb83a89452a28 Mon Sep 17 00:00:00 2001 From: ankiaga Date: Wed, 10 Jan 2024 18:04:19 +0530 Subject: [PATCH 7/9] Small change --- google/cloud/spanner_v1/testing/interceptors.py | 1 - tests/system/test_dbapi.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/google/cloud/spanner_v1/testing/interceptors.py b/google/cloud/spanner_v1/testing/interceptors.py index a16ec4a516..a8b015a87d 100644 --- a/google/cloud/spanner_v1/testing/interceptors.py +++ b/google/cloud/spanner_v1/testing/interceptors.py @@ -49,7 +49,6 @@ def intercept(self, method, request_or_iterator, call_details): self._count += 1 if self._connection is not None: self._connection._transaction.rollback() - self._connection._transaction.rolled_back = False raise Aborted("Thrown from ClientInterceptor for testing") return method(request_or_iterator, call_details) diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index 64d6d06880..9b37ed6b54 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -445,7 +445,6 @@ def test_read_timestamp_client_side_autocommit(self): assert self._cursor.description[0].name == "SHOW_READ_TIMESTAMP" assert isinstance(read_timestamp_query_result_1[0][0], DatetimeWithNanoseconds) - time.sleep(0.25) self._cursor.execute("SELECT * FROM contacts") self._cursor.execute("SHOW VARIABLE READ_TIMESTAMP") read_timestamp_query_result_2 = self._cursor.fetchall() @@ -627,7 +626,7 @@ def test_commit_abort_retry(self, dbapi_database): got_rows = self._cursor.fetchall() assert len(got_rows) == 5 - def test_retry_aborted(self, shared_instance, dbapi_database): + def test_retry_aborted_exception(self, shared_instance, dbapi_database): """Test that retry fails with RetryAborted error when rows are updated during retry.""" conn1 = Connection(shared_instance, dbapi_database) From 9fda104bc40ae716ad28270e97a618754fbe97d6 Mon Sep 17 00:00:00 2001 From: ankiaga Date: Wed, 10 Jan 2024 19:50:58 +0530 Subject: [PATCH 8/9] Disabling a test for emulator run --- tests/system/test_dbapi.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index 9b37ed6b54..e080b1258b 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -626,6 +626,8 @@ def test_commit_abort_retry(self, dbapi_database): got_rows = self._cursor.fetchall() assert len(got_rows) == 5 + + @pytest.mark.skipif(_helpers.USE_EMULATOR, reason="Emulator does not concurrent transactions.") def test_retry_aborted_exception(self, shared_instance, dbapi_database): """Test that retry fails with RetryAborted error when rows are updated during retry.""" From a971d33622a337320449bb6a066a14492c9becf9 Mon Sep 17 00:00:00 2001 From: ankiaga Date: Wed, 10 Jan 2024 22:54:37 +0530 Subject: [PATCH 9/9] Reformatting --- tests/system/test_dbapi.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index e080b1258b..c741304b29 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -626,8 +626,10 @@ def test_commit_abort_retry(self, dbapi_database): got_rows = self._cursor.fetchall() assert len(got_rows) == 5 - - @pytest.mark.skipif(_helpers.USE_EMULATOR, reason="Emulator does not concurrent transactions.") + @pytest.mark.skipif( + _helpers.USE_EMULATOR, + reason="Emulator does not support concurrent transactions.", + ) def test_retry_aborted_exception(self, shared_instance, dbapi_database): """Test that retry fails with RetryAborted error when rows are updated during retry."""