Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,7 @@ system_tests/local_test_setup
# Make sure a generated file isn't accidentally committed.
pylintrc
pylintrc.test


# Ignore coverage files
.coverage*
2 changes: 1 addition & 1 deletion google/cloud/spanner_dbapi/transaction_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode
from google.cloud.spanner_dbapi.exceptions import RetryAborted
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1._helpers import _get_retry_delay

if TYPE_CHECKING:
from google.cloud.spanner_dbapi import Connection, Cursor
Expand Down
77 changes: 75 additions & 2 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper

from google.api_core import datetime_helpers
from google.api_core.exceptions import Aborted
from google.cloud._helpers import _date_from_iso8601_date
from google.cloud.spanner_v1 import TypeCode
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import JsonObject
from google.cloud.spanner_v1.request_id_header import with_request_id
from google.rpc.error_details_pb2 import RetryInfo

import random

# Validation error messages
NUMERIC_MAX_SCALE_ERR_MSG = (
Expand Down Expand Up @@ -466,13 +470,19 @@ def _retry(
delay=2,
allowed_exceptions=None,
beforeNextRetry=None,
deadline=None,
):
"""
Retry a function with a specified number of retries, delay between retries, and list of allowed exceptions.
Retry a specified function with different logic based on the type of exception raised.

If the exception is of type google.api_core.exceptions.Aborted,
apply an alternate retry strategy that relies on the provided deadline value instead of a fixed number of retries.
For all other exceptions, retry the function up to a specified number of times.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API is not really logical. I would suggest splitting this into two separate functions:

  1. Keep the current retry function as-is.
  2. Add a new function _retry_on_aborted_exception that handles that specific case.

In the current form, the API is quite 'magical' and hard to understand. What is for example the definition of this function if you call it with Aborted as one of the allowed exceptions? Will it use the specific logic for Aborted in all cases? Or only if you have also supplied a deadline? What is the meaning of retry_count if you use to it retry Aborted errors? etc...

Copy link
Contributor Author

@aakashanandg aakashanandg Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the exception is of type Aborted, it will activate the custom retry strategy. However, this will only occur if the user has listed this exception in the allowed_exceptions map and provided a deadline value. If either condition is missing, the exception will not be retried. For the batch API use case, we will specifically allow this exception to be retried.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant that the _retry function and the _retry_on_aborted_exception should be completely separated. I don't really see any advantage of combining them, as the actual code that can be shared is minimal, and the API surface of this function is not logical.

E.g. if you have defined Aborted as a retryable exception, but you forget to supply a deadline, then all of a sudden it is not retriable. Also, deadline is only used if you add Aborted as a possible retryable error, and is otherwise ignored if you only supply other error codes. Same with retry_count; it is only used for non-Aborted errors. The fact that there are many combinations of input arguments that don't make any sense, is an indication that the function itself should be split.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification. I've implemented the new retry logic as suggested, separating the _retry and _retry_on_aborted_exception functions. This ensures clearer logic, as combining them led to confusing combinations of parameters that didn't make sense. Now, the retry logic for non-Aborted and Aborted exceptions is more distinct and easier to manage.


Args:
func: The function to be retried.
retry_count: The maximum number of times to retry the function.
deadline: This will be used in case of Aborted transactions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove, this is not relevant anymore

delay: The delay in seconds between retries.
allowed_exceptions: A tuple of exceptions that are allowed to occur without triggering a retry.
Passing allowed_exceptions as None will lead to retrying for all exceptions.
Expand All @@ -481,13 +491,21 @@ def _retry(
The result of the function if it is successful, or raises the last exception if all retries fail.
"""
retries = 0
while retries <= retry_count:
while True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks existing use cases that rely on this function to stop retrying after N retries.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the check for retries < retry_count is already in place for generic retries. This ensures that the while loop terminates early and an exception is raised once the retry count is exceeded. So, in my opinion, this logic should work correctly for generic retries as well.

if retries > 0 and beforeNextRetry:
beforeNextRetry(retries, delay)

try:
return func()
except Exception as exc:
if isinstance(exc, Aborted) and deadline is not None:
if (
allowed_exceptions is not None
and allowed_exceptions.get(exc.__class__) is not None
):
retries += 1
_delay_until_retry(exc, deadline=deadline, attempts=retries)
continue
if (
allowed_exceptions is None or exc.__class__ in allowed_exceptions
) and retries < retry_count:
Expand Down Expand Up @@ -529,6 +547,61 @@ def _metadata_with_leader_aware_routing(value, **kw):
return ("x-goog-spanner-route-to-leader", str(value).lower())


def _delay_until_retry(exc, deadline, attempts):
"""Helper for :meth:`Session.run_in_transaction`.

Detect retryable abort, and impose server-supplied delay.

:type exc: :class:`google.api_core.exceptions.Aborted`
:param exc: exception for aborted transaction

:type deadline: float
:param deadline: maximum timestamp to continue retrying the transaction.

:type attempts: int
:param attempts: number of call retries
"""

cause = exc.errors[0]
now = time.time()
if now >= deadline:
raise

delay = _get_retry_delay(cause, attempts)
print(now, delay, deadline)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I extracted these methods to make them more generic, allowing other clients to reuse the logic instead of it being tightly coupled with the session object.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant: Remove the print(...) line. We should not print debug info in non-test code (and normally also not in test code).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted. Removed from subsequent commits.

if delay is not None:
if now + delay > deadline:
raise

time.sleep(delay)


def _get_retry_delay(cause, attempts):
"""Helper for :func:`_delay_until_retry`.

:type exc: :class:`grpc.Call`
:param exc: exception for aborted transaction

:rtype: float
:returns: seconds to wait before retrying the transaction.

:type attempts: int
:param attempts: number of call retries
"""
if hasattr(cause, "trailing_metadata"):
metadata = dict(cause.trailing_metadata())
else:
metadata = {}
retry_info_pb = metadata.get("google.rpc.retryinfo-bin")
if retry_info_pb is not None:
retry_info = RetryInfo()
retry_info.ParseFromString(retry_info_pb)
nanos = retry_info.retry_delay.nanos
return retry_info.retry_delay.seconds + nanos / 1.0e9

return 2**attempts + random.random()


class AtomicCounter:
def __init__(self, start_value=0):
self.__lock = threading.Lock()
Expand Down
32 changes: 29 additions & 3 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
from google.cloud.spanner_v1._helpers import _retry
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
from google.api_core.exceptions import InternalServerError
from google.api_core.exceptions import Aborted
import time

DEFAULT_RETRY_TIMEOUT_SECS = 30


class _BatchBase(_SessionWrapper):
Expand Down Expand Up @@ -162,6 +166,7 @@ def commit(
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
**kwargs,
):
"""Commit mutations to the database.

Expand Down Expand Up @@ -227,9 +232,16 @@ def commit(
request=request,
metadata=metadata,
)
deadline = time.time() + kwargs.get(
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
allowed_exceptions={
InternalServerError: _check_rst_stream_error,
Aborted: no_op_handler,
},
deadline=deadline,
)
self.committed = response.commit_timestamp
self.commit_stats = response.commit_stats
Expand Down Expand Up @@ -293,7 +305,9 @@ def group(self):
self._mutation_groups.append(mutation_group)
return MutationGroup(self._session, mutation_group.mutations)

def batch_write(self, request_options=None, exclude_txn_from_change_streams=False):
def batch_write(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_write is a bit different. I don't think we should include it in this PR, as it is a non-atomic, streaming operation, that probably needs different error handling than 'just retry if it fails with an aborted error'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. In that case, we can bypass the retry behavior for this operation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also remove the **kwargs addition again from this PR. It would just be confusing if that is added in this PR, when it is not relevant to the actual change in this PR.

self, request_options=None, exclude_txn_from_change_streams=False, **kwargs
):
"""Executes batch_write.

:type request_options:
Expand Down Expand Up @@ -346,9 +360,16 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
request=request,
metadata=metadata,
)
deadline = time.time() + kwargs.get(
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
allowed_exceptions={
InternalServerError: _check_rst_stream_error,
Aborted: no_op_handler,
},
deadline=deadline,
)
self.committed = True
return response
Expand All @@ -372,3 +393,8 @@ def _make_write_pb(table, columns, values):
return Mutation.Write(
table=table, columns=columns, values=_make_list_value_pbs(values)
)


def no_op_handler(exc):
# No-op (does nothing)
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this and just pass in a lambda where a no-op handler is needed (if it is needed at all after we separate the normal retry function from the aborted retry function)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can use a no-op lambda for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the redundant code as this is no longer required with the new implementation.

10 changes: 9 additions & 1 deletion google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ def batch(
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
**kw,
):
"""Return an object which wraps a batch.

Expand Down Expand Up @@ -805,7 +806,11 @@ def batch(
:returns: new wrapper
"""
return BatchCheckout(
self, request_options, max_commit_delay, exclude_txn_from_change_streams
self,
request_options,
max_commit_delay,
exclude_txn_from_change_streams,
**kw,
)

def mutation_groups(self):
Expand Down Expand Up @@ -1166,6 +1171,7 @@ def __init__(
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
**kw,
):
self._database = database
self._session = self._batch = None
Expand All @@ -1177,6 +1183,7 @@ def __init__(
self._request_options = request_options
self._max_commit_delay = max_commit_delay
self._exclude_txn_from_change_streams = exclude_txn_from_change_streams
self._kw = kw

def __enter__(self):
"""Begin ``with`` block."""
Expand All @@ -1197,6 +1204,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
request_options=self._request_options,
max_commit_delay=self._max_commit_delay,
exclude_txn_from_change_streams=self._exclude_txn_from_change_streams,
**self._kw,
)
finally:
if self._database.log_commit_stats and self._batch.commit_stats:
Expand Down
58 changes: 2 additions & 56 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
"""Wrapper for Cloud Spanner Session objects."""

from functools import total_ordering
import random
import time
from datetime import datetime

from google.api_core.exceptions import Aborted
from google.api_core.exceptions import GoogleAPICallError
from google.api_core.exceptions import NotFound
from google.api_core.gapic_v1 import method
from google.rpc.error_details_pb2 import RetryInfo
from google.cloud.spanner_v1._helpers import _delay_until_retry
from google.cloud.spanner_v1._helpers import _get_retry_delay

from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import CreateSessionRequest
Expand Down Expand Up @@ -554,57 +554,3 @@ def run_in_transaction(self, func, *args, **kw):
extra={"commit_stats": txn.commit_stats},
)
return return_value


# Rational: this function factors out complex shared deadline / retry
# handling from two `except:` clauses.
def _delay_until_retry(exc, deadline, attempts):
"""Helper for :meth:`Session.run_in_transaction`.

Detect retryable abort, and impose server-supplied delay.

:type exc: :class:`google.api_core.exceptions.Aborted`
:param exc: exception for aborted transaction

:type deadline: float
:param deadline: maximum timestamp to continue retrying the transaction.

:type attempts: int
:param attempts: number of call retries
"""
cause = exc.errors[0]

now = time.time()

if now >= deadline:
raise

delay = _get_retry_delay(cause, attempts)
if delay is not None:
if now + delay > deadline:
raise

time.sleep(delay)


def _get_retry_delay(cause, attempts):
"""Helper for :func:`_delay_until_retry`.

:type exc: :class:`grpc.Call`
:param exc: exception for aborted transaction

:rtype: float
:returns: seconds to wait before retrying the transaction.

:type attempts: int
:param attempts: number of call retries
"""
metadata = dict(cause.trailing_metadata())
retry_info_pb = metadata.get("google.rpc.retryinfo-bin")
if retry_info_pb is not None:
retry_info = RetryInfo()
retry_info.ParseFromString(retry_info_pb)
nanos = retry_info.retry_delay.nanos
return retry_info.retry_delay.seconds + nanos / 1.0e9

return 2**attempts + random.random()
Loading
Loading