Skip to content

Commit 3c5c5ad

Browse files
committed
more fixes
1 parent 8976025 commit 3c5c5ad

File tree

6 files changed

+287
-241
lines changed

6 files changed

+287
-241
lines changed

tests/mockserver_tests/mock_server_test_base.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
SpannerServicer,
4242
start_mock_server,
4343
)
44+
from tests._helpers import is_multiplexed_enabled
45+
from google.cloud.spanner_v1.database_sessions_manager import TransactionType
4446

4547

4648
# Creates an aborted status with the smallest possible retry delay.
@@ -228,3 +230,79 @@ def database(self) -> Database:
228230
enable_interceptors_in_tests=True,
229231
)
230232
return self._database
233+
234+
def assert_requests_sequence(self, requests, expected_types, transaction_type, allow_multiple_batch_create=True):
235+
"""Assert that the requests sequence matches the expected types, accounting for multiplexed sessions and retries.
236+
237+
Args:
238+
requests: List of requests from spanner_service.requests
239+
expected_types: List of expected request types (excluding session creation requests)
240+
transaction_type: TransactionType enum value to check multiplexed session status
241+
allow_multiple_batch_create: If True, skip all leading BatchCreateSessionsRequest and one optional CreateSessionRequest
242+
"""
243+
from google.cloud.spanner_v1 import BatchCreateSessionsRequest, CreateSessionRequest
244+
mux_enabled = is_multiplexed_enabled(transaction_type)
245+
idx = 0
246+
# Skip all leading BatchCreateSessionsRequest (for retries)
247+
if allow_multiple_batch_create:
248+
while idx < len(requests) and isinstance(requests[idx], BatchCreateSessionsRequest):
249+
idx += 1
250+
# For multiplexed, optionally skip a CreateSessionRequest
251+
if mux_enabled and idx < len(requests) and isinstance(requests[idx], CreateSessionRequest):
252+
idx += 1
253+
else:
254+
if mux_enabled:
255+
self.assertTrue(isinstance(requests[idx], BatchCreateSessionsRequest),
256+
f"Expected BatchCreateSessionsRequest at index {idx}, got {type(requests[idx])}")
257+
idx += 1
258+
self.assertTrue(isinstance(requests[idx], CreateSessionRequest),
259+
f"Expected CreateSessionRequest at index {idx}, got {type(requests[idx])}")
260+
idx += 1
261+
else:
262+
self.assertTrue(isinstance(requests[idx], BatchCreateSessionsRequest),
263+
f"Expected BatchCreateSessionsRequest at index {idx}, got {type(requests[idx])}")
264+
idx += 1
265+
# Check the rest of the expected request types
266+
for expected_type in expected_types:
267+
self.assertTrue(isinstance(requests[idx], expected_type),
268+
f"Expected {expected_type} at index {idx}, got {type(requests[idx])}")
269+
idx += 1
270+
self.assertEqual(idx, len(requests),
271+
f"Expected {idx} requests, got {len(requests)}")
272+
273+
def adjust_request_id_sequence(self, expected_segments, requests, transaction_type):
274+
"""Adjust expected request ID sequence numbers based on actual session creation requests.
275+
276+
Args:
277+
expected_segments: List of expected (method, (sequence_numbers)) tuples
278+
requests: List of actual requests from spanner_service.requests
279+
transaction_type: TransactionType enum value to check multiplexed session status
280+
281+
Returns:
282+
List of adjusted expected segments with corrected sequence numbers
283+
"""
284+
from google.cloud.spanner_v1 import BatchCreateSessionsRequest, CreateSessionRequest, ExecuteSqlRequest, BeginTransactionRequest
285+
286+
# Count session creation requests that come before the first non-session request
287+
session_requests_before = 0
288+
for req in requests:
289+
if isinstance(req, (BatchCreateSessionsRequest, CreateSessionRequest)):
290+
session_requests_before += 1
291+
elif isinstance(req, (ExecuteSqlRequest, BeginTransactionRequest)):
292+
break
293+
294+
# For multiplexed sessions, we expect 2 session requests (BatchCreateSessions + CreateSession)
295+
# For non-multiplexed, we expect 1 session request (BatchCreateSessions)
296+
mux_enabled = is_multiplexed_enabled(transaction_type)
297+
expected_session_requests = 2 if mux_enabled else 1
298+
extra_session_requests = session_requests_before - expected_session_requests
299+
300+
# Adjust sequence numbers based on extra session requests
301+
adjusted_segments = []
302+
for method, seq_nums in expected_segments:
303+
# Adjust the sequence number (5th element in the tuple)
304+
adjusted_seq_nums = list(seq_nums)
305+
adjusted_seq_nums[4] += extra_session_requests
306+
adjusted_segments.append((method, tuple(adjusted_seq_nums)))
307+
308+
return adjusted_segments

tests/mockserver_tests/test_aborted_transaction.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from google.api_core import exceptions
3434
from test_utils import retry
35+
from google.cloud.spanner_v1.database_sessions_manager import TransactionType
3536

3637
retry_maybe_aborted_txn = retry.RetryErrors(
3738
exceptions.Aborted, max_tries=5, delay=0, backoff=1
@@ -46,29 +47,23 @@ def test_run_in_transaction_commit_aborted(self):
4647
# time that the transaction tries to commit. It will then be retried
4748
# and succeed.
4849
self.database.run_in_transaction(_insert_mutations)
49-
50-
# Verify that the transaction was retried.
5150
requests = self.spanner_service.requests
52-
self.assertEqual(5, len(requests), msg=requests)
53-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
54-
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
55-
self.assertTrue(isinstance(requests[2], CommitRequest))
56-
# The transaction is aborted and retried.
57-
self.assertTrue(isinstance(requests[3], BeginTransactionRequest))
58-
self.assertTrue(isinstance(requests[4], CommitRequest))
51+
self.assert_requests_sequence(
52+
requests,
53+
[BeginTransactionRequest, CommitRequest, BeginTransactionRequest, CommitRequest],
54+
TransactionType.READ_WRITE,
55+
)
5956

6057
def test_run_in_transaction_update_aborted(self):
6158
add_update_count("update my_table set my_col=1 where id=2", 1)
6259
add_error(SpannerServicer.ExecuteSql.__name__, aborted_status())
6360
self.database.run_in_transaction(_execute_update)
64-
65-
# Verify that the transaction was retried.
6661
requests = self.spanner_service.requests
67-
self.assertEqual(4, len(requests), msg=requests)
68-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
69-
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
70-
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
71-
self.assertTrue(isinstance(requests[3], CommitRequest))
62+
self.assert_requests_sequence(
63+
requests,
64+
[ExecuteSqlRequest, ExecuteSqlRequest, CommitRequest],
65+
TransactionType.READ_WRITE,
66+
)
7267

7368
def test_run_in_transaction_query_aborted(self):
7469
add_single_result(
@@ -79,28 +74,24 @@ def test_run_in_transaction_query_aborted(self):
7974
)
8075
add_error(SpannerServicer.ExecuteStreamingSql.__name__, aborted_status())
8176
self.database.run_in_transaction(_execute_query)
82-
83-
# Verify that the transaction was retried.
8477
requests = self.spanner_service.requests
85-
self.assertEqual(4, len(requests), msg=requests)
86-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
87-
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
88-
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
89-
self.assertTrue(isinstance(requests[3], CommitRequest))
78+
self.assert_requests_sequence(
79+
requests,
80+
[ExecuteSqlRequest, ExecuteSqlRequest, CommitRequest],
81+
TransactionType.READ_WRITE,
82+
)
9083

9184
def test_run_in_transaction_batch_dml_aborted(self):
9285
add_update_count("update my_table set my_col=1 where id=1", 1)
9386
add_update_count("update my_table set my_col=1 where id=2", 1)
9487
add_error(SpannerServicer.ExecuteBatchDml.__name__, aborted_status())
9588
self.database.run_in_transaction(_execute_batch_dml)
96-
97-
# Verify that the transaction was retried.
9889
requests = self.spanner_service.requests
99-
self.assertEqual(4, len(requests), msg=requests)
100-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
101-
self.assertTrue(isinstance(requests[1], ExecuteBatchDmlRequest))
102-
self.assertTrue(isinstance(requests[2], ExecuteBatchDmlRequest))
103-
self.assertTrue(isinstance(requests[3], CommitRequest))
90+
self.assert_requests_sequence(
91+
requests,
92+
[ExecuteBatchDmlRequest, ExecuteBatchDmlRequest, CommitRequest],
93+
TransactionType.READ_WRITE,
94+
)
10495

10596
def test_batch_commit_aborted(self):
10697
# Add an Aborted error for the Commit method on the mock server.
@@ -117,14 +108,12 @@ def test_batch_commit_aborted(self):
117108
(5, "David", "Lomond"),
118109
],
119110
)
120-
121-
# Verify that the transaction was retried.
122111
requests = self.spanner_service.requests
123-
self.assertEqual(3, len(requests), msg=requests)
124-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
125-
self.assertTrue(isinstance(requests[1], CommitRequest))
126-
# The transaction is aborted and retried.
127-
self.assertTrue(isinstance(requests[2], CommitRequest))
112+
self.assert_requests_sequence(
113+
requests,
114+
[CommitRequest, CommitRequest],
115+
TransactionType.READ_WRITE,
116+
)
128117

129118
@retry_maybe_aborted_txn
130119
def test_retry_helper(self):

tests/mockserver_tests/test_basics.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer
2727
from google.cloud.spanner_v1.transaction import Transaction
28+
from google.cloud.spanner_v1.database_sessions_manager import TransactionType
2829

2930
from tests.mockserver_tests.mock_server_test_base import (
3031
MockServerTestBase,
@@ -36,6 +37,7 @@
3637
unavailable_status,
3738
add_execute_streaming_sql_results,
3839
)
40+
from tests._helpers import is_multiplexed_enabled
3941

4042

4143
class TestBasics(MockServerTestBase):
@@ -49,9 +51,11 @@ def test_select1(self):
4951
self.assertEqual(1, row[0])
5052
self.assertEqual(1, len(result_list))
5153
requests = self.spanner_service.requests
52-
self.assertEqual(2, len(requests), msg=requests)
53-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
54-
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
54+
self.assert_requests_sequence(
55+
requests,
56+
[ExecuteSqlRequest],
57+
TransactionType.READ_ONLY,
58+
)
5559

5660
def test_create_table(self):
5761
database_admin_api = self.client.database_admin_api
@@ -84,13 +88,21 @@ def test_dbapi_partitioned_dml(self):
8488
# with no parameters.
8589
cursor.execute(sql, [])
8690
self.assertEqual(100, cursor.rowcount)
87-
8891
requests = self.spanner_service.requests
89-
self.assertEqual(3, len(requests), msg=requests)
90-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
91-
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
92-
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
93-
begin_request: BeginTransactionRequest = requests[1]
92+
self.assert_requests_sequence(
93+
requests,
94+
[BeginTransactionRequest, ExecuteSqlRequest],
95+
TransactionType.PARTITIONED,
96+
allow_multiple_batch_create=True,
97+
)
98+
# Find the first BeginTransactionRequest after session creation
99+
idx = 0
100+
from google.cloud.spanner_v1 import BatchCreateSessionsRequest, CreateSessionRequest
101+
while idx < len(requests) and isinstance(requests[idx], BatchCreateSessionsRequest):
102+
idx += 1
103+
if is_multiplexed_enabled(TransactionType.PARTITIONED) and idx < len(requests) and isinstance(requests[idx], CreateSessionRequest):
104+
idx += 1
105+
begin_request: BeginTransactionRequest = requests[idx]
94106
self.assertEqual(
95107
TransactionOptions(dict(partitioned_dml={})), begin_request.options
96108
)
@@ -106,11 +118,12 @@ def test_batch_create_sessions_unavailable(self):
106118
self.assertEqual(1, row[0])
107119
self.assertEqual(1, len(result_list))
108120
requests = self.spanner_service.requests
109-
self.assertEqual(3, len(requests), msg=requests)
110-
# The BatchCreateSessions call should be retried.
111-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
112-
self.assertTrue(isinstance(requests[1], BatchCreateSessionsRequest))
113-
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
121+
self.assert_requests_sequence(
122+
requests,
123+
[ExecuteSqlRequest],
124+
TransactionType.READ_ONLY,
125+
allow_multiple_batch_create=True,
126+
)
114127

115128
def test_execute_streaming_sql_unavailable(self):
116129
add_select1_result()
@@ -125,11 +138,11 @@ def test_execute_streaming_sql_unavailable(self):
125138
self.assertEqual(1, row[0])
126139
self.assertEqual(1, len(result_list))
127140
requests = self.spanner_service.requests
128-
self.assertEqual(3, len(requests), msg=requests)
129-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
130-
# The ExecuteStreamingSql call should be retried.
131-
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
132-
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
141+
self.assert_requests_sequence(
142+
requests,
143+
[ExecuteSqlRequest, ExecuteSqlRequest],
144+
TransactionType.READ_ONLY,
145+
)
133146

134147
def test_last_statement_update(self):
135148
sql = "update my_table set my_col=1 where id=2"
@@ -199,9 +212,11 @@ def test_execute_streaming_sql_last_field(self):
199212
count += 1
200213
self.assertEqual(3, len(result_list))
201214
requests = self.spanner_service.requests
202-
self.assertEqual(2, len(requests), msg=requests)
203-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
204-
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
215+
self.assert_requests_sequence(
216+
requests,
217+
[ExecuteSqlRequest],
218+
TransactionType.READ_ONLY,
219+
)
205220

206221

207222
def _execute_query(transaction: Transaction, sql: str):

0 commit comments

Comments
 (0)