Skip to content

Commit 6450b4d

Browse files
committed
more fixes
1 parent 8976025 commit 6450b4d

File tree

6 files changed

+364
-229
lines changed

6 files changed

+364
-229
lines changed

tests/mockserver_tests/mock_server_test_base.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
SpannerServicer,
4242
start_mock_server,
4343
)
44+
from tests._helpers import is_multiplexed_enabled
45+
from google.cloud.spanner_v1.database_sessions_manager import TransactionType
4446

4547

4648
# Creates an aborted status with the smallest possible retry delay.
@@ -228,3 +230,109 @@ def database(self) -> Database:
228230
enable_interceptors_in_tests=True,
229231
)
230232
return self._database
233+
234+
def assert_requests_sequence(
235+
self,
236+
requests,
237+
expected_types,
238+
transaction_type,
239+
allow_multiple_batch_create=True,
240+
):
241+
"""Assert that the requests sequence matches the expected types, accounting for multiplexed sessions and retries.
242+
243+
Args:
244+
requests: List of requests from spanner_service.requests
245+
expected_types: List of expected request types (excluding session creation requests)
246+
transaction_type: TransactionType enum value to check multiplexed session status
247+
allow_multiple_batch_create: If True, skip all leading BatchCreateSessionsRequest and one optional CreateSessionRequest
248+
"""
249+
from google.cloud.spanner_v1 import (
250+
BatchCreateSessionsRequest,
251+
CreateSessionRequest,
252+
)
253+
254+
mux_enabled = is_multiplexed_enabled(transaction_type)
255+
idx = 0
256+
# Skip all leading BatchCreateSessionsRequest (for retries)
257+
if allow_multiple_batch_create:
258+
while idx < len(requests) and isinstance(
259+
requests[idx], BatchCreateSessionsRequest
260+
):
261+
idx += 1
262+
# For multiplexed, optionally skip a CreateSessionRequest
263+
if (
264+
mux_enabled
265+
and idx < len(requests)
266+
and isinstance(requests[idx], CreateSessionRequest)
267+
):
268+
idx += 1
269+
else:
270+
if mux_enabled:
271+
self.assertTrue(
272+
isinstance(requests[idx], BatchCreateSessionsRequest),
273+
f"Expected BatchCreateSessionsRequest at index {idx}, got {type(requests[idx])}",
274+
)
275+
idx += 1
276+
self.assertTrue(
277+
isinstance(requests[idx], CreateSessionRequest),
278+
f"Expected CreateSessionRequest at index {idx}, got {type(requests[idx])}",
279+
)
280+
idx += 1
281+
else:
282+
self.assertTrue(
283+
isinstance(requests[idx], BatchCreateSessionsRequest),
284+
f"Expected BatchCreateSessionsRequest at index {idx}, got {type(requests[idx])}",
285+
)
286+
idx += 1
287+
# Check the rest of the expected request types
288+
for expected_type in expected_types:
289+
self.assertTrue(
290+
isinstance(requests[idx], expected_type),
291+
f"Expected {expected_type} at index {idx}, got {type(requests[idx])}",
292+
)
293+
idx += 1
294+
self.assertEqual(
295+
idx, len(requests), f"Expected {idx} requests, got {len(requests)}"
296+
)
297+
298+
def adjust_request_id_sequence(self, expected_segments, requests, transaction_type):
299+
"""Adjust expected request ID sequence numbers based on actual session creation requests.
300+
301+
Args:
302+
expected_segments: List of expected (method, (sequence_numbers)) tuples
303+
requests: List of actual requests from spanner_service.requests
304+
transaction_type: TransactionType enum value to check multiplexed session status
305+
306+
Returns:
307+
List of adjusted expected segments with corrected sequence numbers
308+
"""
309+
from google.cloud.spanner_v1 import (
310+
BatchCreateSessionsRequest,
311+
CreateSessionRequest,
312+
ExecuteSqlRequest,
313+
BeginTransactionRequest,
314+
)
315+
316+
# Count session creation requests that come before the first non-session request
317+
session_requests_before = 0
318+
for req in requests:
319+
if isinstance(req, (BatchCreateSessionsRequest, CreateSessionRequest)):
320+
session_requests_before += 1
321+
elif isinstance(req, (ExecuteSqlRequest, BeginTransactionRequest)):
322+
break
323+
324+
# For multiplexed sessions, we expect 2 session requests (BatchCreateSessions + CreateSession)
325+
# For non-multiplexed, we expect 1 session request (BatchCreateSessions)
326+
mux_enabled = is_multiplexed_enabled(transaction_type)
327+
expected_session_requests = 2 if mux_enabled else 1
328+
extra_session_requests = session_requests_before - expected_session_requests
329+
330+
# Adjust sequence numbers based on extra session requests
331+
adjusted_segments = []
332+
for method, seq_nums in expected_segments:
333+
# Adjust the sequence number (5th element in the tuple)
334+
adjusted_seq_nums = list(seq_nums)
335+
adjusted_seq_nums[4] += extra_session_requests
336+
adjusted_segments.append((method, tuple(adjusted_seq_nums)))
337+
338+
return adjusted_segments

tests/mockserver_tests/test_aborted_transaction.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from google.api_core import exceptions
3434
from test_utils import retry
35+
from google.cloud.spanner_v1.database_sessions_manager import TransactionType
3536

3637
retry_maybe_aborted_txn = retry.RetryErrors(
3738
exceptions.Aborted, max_tries=5, delay=0, backoff=1
@@ -46,29 +47,28 @@ def test_run_in_transaction_commit_aborted(self):
4647
# time that the transaction tries to commit. It will then be retried
4748
# and succeed.
4849
self.database.run_in_transaction(_insert_mutations)
49-
50-
# Verify that the transaction was retried.
5150
requests = self.spanner_service.requests
52-
self.assertEqual(5, len(requests), msg=requests)
53-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
54-
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
55-
self.assertTrue(isinstance(requests[2], CommitRequest))
56-
# The transaction is aborted and retried.
57-
self.assertTrue(isinstance(requests[3], BeginTransactionRequest))
58-
self.assertTrue(isinstance(requests[4], CommitRequest))
51+
self.assert_requests_sequence(
52+
requests,
53+
[
54+
BeginTransactionRequest,
55+
CommitRequest,
56+
BeginTransactionRequest,
57+
CommitRequest,
58+
],
59+
TransactionType.READ_WRITE,
60+
)
5961

6062
def test_run_in_transaction_update_aborted(self):
6163
add_update_count("update my_table set my_col=1 where id=2", 1)
6264
add_error(SpannerServicer.ExecuteSql.__name__, aborted_status())
6365
self.database.run_in_transaction(_execute_update)
64-
65-
# Verify that the transaction was retried.
6666
requests = self.spanner_service.requests
67-
self.assertEqual(4, len(requests), msg=requests)
68-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
69-
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
70-
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
71-
self.assertTrue(isinstance(requests[3], CommitRequest))
67+
self.assert_requests_sequence(
68+
requests,
69+
[ExecuteSqlRequest, ExecuteSqlRequest, CommitRequest],
70+
TransactionType.READ_WRITE,
71+
)
7272

7373
def test_run_in_transaction_query_aborted(self):
7474
add_single_result(
@@ -79,28 +79,24 @@ def test_run_in_transaction_query_aborted(self):
7979
)
8080
add_error(SpannerServicer.ExecuteStreamingSql.__name__, aborted_status())
8181
self.database.run_in_transaction(_execute_query)
82-
83-
# Verify that the transaction was retried.
8482
requests = self.spanner_service.requests
85-
self.assertEqual(4, len(requests), msg=requests)
86-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
87-
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
88-
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
89-
self.assertTrue(isinstance(requests[3], CommitRequest))
83+
self.assert_requests_sequence(
84+
requests,
85+
[ExecuteSqlRequest, ExecuteSqlRequest, CommitRequest],
86+
TransactionType.READ_WRITE,
87+
)
9088

9189
def test_run_in_transaction_batch_dml_aborted(self):
9290
add_update_count("update my_table set my_col=1 where id=1", 1)
9391
add_update_count("update my_table set my_col=1 where id=2", 1)
9492
add_error(SpannerServicer.ExecuteBatchDml.__name__, aborted_status())
9593
self.database.run_in_transaction(_execute_batch_dml)
96-
97-
# Verify that the transaction was retried.
9894
requests = self.spanner_service.requests
99-
self.assertEqual(4, len(requests), msg=requests)
100-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
101-
self.assertTrue(isinstance(requests[1], ExecuteBatchDmlRequest))
102-
self.assertTrue(isinstance(requests[2], ExecuteBatchDmlRequest))
103-
self.assertTrue(isinstance(requests[3], CommitRequest))
95+
self.assert_requests_sequence(
96+
requests,
97+
[ExecuteBatchDmlRequest, ExecuteBatchDmlRequest, CommitRequest],
98+
TransactionType.READ_WRITE,
99+
)
104100

105101
def test_batch_commit_aborted(self):
106102
# Add an Aborted error for the Commit method on the mock server.
@@ -117,14 +113,12 @@ def test_batch_commit_aborted(self):
117113
(5, "David", "Lomond"),
118114
],
119115
)
120-
121-
# Verify that the transaction was retried.
122116
requests = self.spanner_service.requests
123-
self.assertEqual(3, len(requests), msg=requests)
124-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
125-
self.assertTrue(isinstance(requests[1], CommitRequest))
126-
# The transaction is aborted and retried.
127-
self.assertTrue(isinstance(requests[2], CommitRequest))
117+
self.assert_requests_sequence(
118+
requests,
119+
[CommitRequest, CommitRequest],
120+
TransactionType.READ_WRITE,
121+
)
128122

129123
@retry_maybe_aborted_txn
130124
def test_retry_helper(self):

tests/mockserver_tests/test_basics.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer
2727
from google.cloud.spanner_v1.transaction import Transaction
28+
from google.cloud.spanner_v1.database_sessions_manager import TransactionType
2829

2930
from tests.mockserver_tests.mock_server_test_base import (
3031
MockServerTestBase,
@@ -36,6 +37,7 @@
3637
unavailable_status,
3738
add_execute_streaming_sql_results,
3839
)
40+
from tests._helpers import is_multiplexed_enabled
3941

4042

4143
class TestBasics(MockServerTestBase):
@@ -49,9 +51,11 @@ def test_select1(self):
4951
self.assertEqual(1, row[0])
5052
self.assertEqual(1, len(result_list))
5153
requests = self.spanner_service.requests
52-
self.assertEqual(2, len(requests), msg=requests)
53-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
54-
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
54+
self.assert_requests_sequence(
55+
requests,
56+
[ExecuteSqlRequest],
57+
TransactionType.READ_ONLY,
58+
)
5559

5660
def test_create_table(self):
5761
database_admin_api = self.client.database_admin_api
@@ -84,13 +88,31 @@ def test_dbapi_partitioned_dml(self):
8488
# with no parameters.
8589
cursor.execute(sql, [])
8690
self.assertEqual(100, cursor.rowcount)
87-
8891
requests = self.spanner_service.requests
89-
self.assertEqual(3, len(requests), msg=requests)
90-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
91-
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
92-
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
93-
begin_request: BeginTransactionRequest = requests[1]
92+
self.assert_requests_sequence(
93+
requests,
94+
[BeginTransactionRequest, ExecuteSqlRequest],
95+
TransactionType.PARTITIONED,
96+
allow_multiple_batch_create=True,
97+
)
98+
# Find the first BeginTransactionRequest after session creation
99+
idx = 0
100+
from google.cloud.spanner_v1 import (
101+
BatchCreateSessionsRequest,
102+
CreateSessionRequest,
103+
)
104+
105+
while idx < len(requests) and isinstance(
106+
requests[idx], BatchCreateSessionsRequest
107+
):
108+
idx += 1
109+
if (
110+
is_multiplexed_enabled(TransactionType.PARTITIONED)
111+
and idx < len(requests)
112+
and isinstance(requests[idx], CreateSessionRequest)
113+
):
114+
idx += 1
115+
begin_request: BeginTransactionRequest = requests[idx]
94116
self.assertEqual(
95117
TransactionOptions(dict(partitioned_dml={})), begin_request.options
96118
)
@@ -106,11 +128,12 @@ def test_batch_create_sessions_unavailable(self):
106128
self.assertEqual(1, row[0])
107129
self.assertEqual(1, len(result_list))
108130
requests = self.spanner_service.requests
109-
self.assertEqual(3, len(requests), msg=requests)
110-
# The BatchCreateSessions call should be retried.
111-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
112-
self.assertTrue(isinstance(requests[1], BatchCreateSessionsRequest))
113-
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
131+
self.assert_requests_sequence(
132+
requests,
133+
[ExecuteSqlRequest],
134+
TransactionType.READ_ONLY,
135+
allow_multiple_batch_create=True,
136+
)
114137

115138
def test_execute_streaming_sql_unavailable(self):
116139
add_select1_result()
@@ -125,11 +148,11 @@ def test_execute_streaming_sql_unavailable(self):
125148
self.assertEqual(1, row[0])
126149
self.assertEqual(1, len(result_list))
127150
requests = self.spanner_service.requests
128-
self.assertEqual(3, len(requests), msg=requests)
129-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
130-
# The ExecuteStreamingSql call should be retried.
131-
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
132-
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
151+
self.assert_requests_sequence(
152+
requests,
153+
[ExecuteSqlRequest, ExecuteSqlRequest],
154+
TransactionType.READ_ONLY,
155+
)
133156

134157
def test_last_statement_update(self):
135158
sql = "update my_table set my_col=1 where id=2"
@@ -199,9 +222,11 @@ def test_execute_streaming_sql_last_field(self):
199222
count += 1
200223
self.assertEqual(3, len(result_list))
201224
requests = self.spanner_service.requests
202-
self.assertEqual(2, len(requests), msg=requests)
203-
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
204-
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
225+
self.assert_requests_sequence(
226+
requests,
227+
[ExecuteSqlRequest],
228+
TransactionType.READ_ONLY,
229+
)
205230

206231

207232
def _execute_query(transaction: Transaction, sql: str):

0 commit comments

Comments
 (0)