Skip to content

Commit 259a78b

Browse files
authored
test: add test to verify that transactions are retried (#1267)
1 parent a6811af commit 259a78b

File tree

3 files changed

+94
-0
lines changed

3 files changed

+94
-0
lines changed

google/cloud/spanner_v1/testing/mock_spanner.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import base64
15+
import inspect
1516
import grpc
1617
from concurrent import futures
1718

1819
from google.protobuf import empty_pb2
20+
from grpc_status.rpc_status import _Status
1921
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
2022
import google.cloud.spanner_v1.testing.spanner_database_admin_pb2_grpc as database_admin_grpc
2123
import google.cloud.spanner_v1.testing.spanner_pb2_grpc as spanner_grpc
@@ -28,6 +30,7 @@
2830
class MockSpanner:
2931
def __init__(self):
3032
self.results = {}
33+
self.errors = {}
3134

3235
def add_result(self, sql: str, result: result_set.ResultSet):
3336
self.results[sql.lower().strip()] = result
@@ -38,6 +41,15 @@ def get_result(self, sql: str) -> result_set.ResultSet:
3841
raise ValueError(f"No result found for {sql}")
3942
return result
4043

44+
def add_error(self, method: str, error: _Status):
45+
self.errors[method] = error
46+
47+
def pop_error(self, context):
48+
name = inspect.currentframe().f_back.f_code.co_name
49+
error: _Status | None = self.errors.pop(name, None)
50+
if error:
51+
context.abort_with_status(error)
52+
4153
def get_result_as_partial_result_sets(
4254
self, sql: str
4355
) -> [result_set.PartialResultSet]:
@@ -174,6 +186,7 @@ def __create_transaction(
174186

175187
def Commit(self, request, context):
176188
self._requests.append(request)
189+
self.mock_spanner.pop_error(context)
177190
tx = self.transactions[request.transaction_id]
178191
if tx is None:
179192
raise ValueError(f"Transaction not found: {request.transaction_id}")

tests/mockserver_tests/mock_server_test_base.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,37 @@
2828
from google.cloud.spanner_v1.database import Database
2929
from google.cloud.spanner_v1.instance import Instance
3030
import grpc
31+
from google.rpc import code_pb2
32+
from google.rpc import status_pb2
33+
from google.rpc.error_details_pb2 import RetryInfo
34+
from google.protobuf.duration_pb2 import Duration
35+
from grpc_status._common import code_to_grpc_status_code
36+
from grpc_status.rpc_status import _Status
37+
38+
39+
# Creates an aborted status with the smallest possible retry delay.
40+
def aborted_status() -> _Status:
41+
error = status_pb2.Status(
42+
code=code_pb2.ABORTED,
43+
message="Transaction was aborted.",
44+
)
45+
retry_info = RetryInfo(retry_delay=Duration(seconds=0, nanos=1))
46+
status = _Status(
47+
code=code_to_grpc_status_code(error.code),
48+
details=error.message,
49+
trailing_metadata=(
50+
("grpc-status-details-bin", error.SerializeToString()),
51+
(
52+
"google.rpc.retryinfo-bin",
53+
retry_info.SerializeToString(),
54+
),
55+
),
56+
)
57+
return status
58+
59+
60+
def add_error(method: str, error: status_pb2.Status):
61+
MockServerTestBase.spanner_service.mock_spanner.add_error(method, error)
3162

3263

3364
def add_result(sql: str, result: result_set.ResultSet):
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.cloud.spanner_v1 import (
16+
BatchCreateSessionsRequest,
17+
BeginTransactionRequest,
18+
CommitRequest,
19+
)
20+
from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer
21+
from google.cloud.spanner_v1.transaction import Transaction
22+
from tests.mockserver_tests.mock_server_test_base import (
23+
MockServerTestBase,
24+
add_error,
25+
aborted_status,
26+
)
27+
28+
29+
class TestAbortedTransaction(MockServerTestBase):
30+
def test_run_in_transaction_commit_aborted(self):
31+
# Add an Aborted error for the Commit method on the mock server.
32+
add_error(SpannerServicer.Commit.__name__, aborted_status())
33+
# Run a transaction. The Commit method will return Aborted the first
34+
# time that the transaction tries to commit. It will then be retried
35+
# and succeed.
36+
self.database.run_in_transaction(_insert_mutations)
37+
38+
# Verify that the transaction was retried.
39+
requests = self.spanner_service.requests
40+
self.assertEqual(5, len(requests), msg=requests)
41+
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
42+
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
43+
self.assertTrue(isinstance(requests[2], CommitRequest))
44+
# The transaction is aborted and retried.
45+
self.assertTrue(isinstance(requests[3], BeginTransactionRequest))
46+
self.assertTrue(isinstance(requests[4], CommitRequest))
47+
48+
49+
def _insert_mutations(transaction: Transaction):
50+
transaction.insert("my_table", ["col1", "col2"], ["value1", "value2"])

0 commit comments

Comments
 (0)