Skip to content

Commit bb64ff5

Browse files
1 parent 3c5c5ad commit bb64ff5

File tree

5 files changed

+128
-39
lines changed

5 files changed

+128
-39
lines changed

‎tests/mockserver_tests/mock_server_test_base.py‎

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -231,78 +231,108 @@ def database(self) -> Database:
231231
)
232232
return self._database
233233

234-
def assert_requests_sequence(self, requests, expected_types, transaction_type, allow_multiple_batch_create=True):
234+
def assert_requests_sequence(
235+
self,
236+
requests,
237+
expected_types,
238+
transaction_type,
239+
allow_multiple_batch_create=True,
240+
):
235241
"""Assert that the requests sequence matches the expected types, accounting for multiplexed sessions and retries.
236-
242+
237243
Args:
238244
requests: List of requests from spanner_service.requests
239245
expected_types: List of expected request types (excluding session creation requests)
240246
transaction_type: TransactionType enum value to check multiplexed session status
241247
allow_multiple_batch_create: If True, skip all leading BatchCreateSessionsRequest and one optional CreateSessionRequest
242248
"""
243-
from google.cloud.spanner_v1 import BatchCreateSessionsRequest, CreateSessionRequest
249+
from google.cloud.spanner_v1 import (
250+
BatchCreateSessionsRequest,
251+
CreateSessionRequest,
252+
)
253+
244254
mux_enabled = is_multiplexed_enabled(transaction_type)
245255
idx = 0
246256
# Skip all leading BatchCreateSessionsRequest (for retries)
247257
if allow_multiple_batch_create:
248-
while idx < len(requests) and isinstance(requests[idx], BatchCreateSessionsRequest):
258+
while idx < len(requests) and isinstance(
259+
requests[idx], BatchCreateSessionsRequest
260+
):
249261
idx += 1
250262
# For multiplexed, optionally skip a CreateSessionRequest
251-
if mux_enabled and idx < len(requests) and isinstance(requests[idx], CreateSessionRequest):
263+
if (
264+
mux_enabled
265+
and idx < len(requests)
266+
and isinstance(requests[idx], CreateSessionRequest)
267+
):
252268
idx += 1
253269
else:
254270
if mux_enabled:
255-
self.assertTrue(isinstance(requests[idx], BatchCreateSessionsRequest),
256-
f"Expected BatchCreateSessionsRequest at index {idx}, got {type(requests[idx])}")
271+
self.assertTrue(
272+
isinstance(requests[idx], BatchCreateSessionsRequest),
273+
f"Expected BatchCreateSessionsRequest at index {idx}, got {type(requests[idx])}",
274+
)
257275
idx += 1
258-
self.assertTrue(isinstance(requests[idx], CreateSessionRequest),
259-
f"Expected CreateSessionRequest at index {idx}, got {type(requests[idx])}")
276+
self.assertTrue(
277+
isinstance(requests[idx], CreateSessionRequest),
278+
f"Expected CreateSessionRequest at index {idx}, got {type(requests[idx])}",
279+
)
260280
idx += 1
261281
else:
262-
self.assertTrue(isinstance(requests[idx], BatchCreateSessionsRequest),
263-
f"Expected BatchCreateSessionsRequest at index {idx}, got {type(requests[idx])}")
282+
self.assertTrue(
283+
isinstance(requests[idx], BatchCreateSessionsRequest),
284+
f"Expected BatchCreateSessionsRequest at index {idx}, got {type(requests[idx])}",
285+
)
264286
idx += 1
265287
# Check the rest of the expected request types
266288
for expected_type in expected_types:
267-
self.assertTrue(isinstance(requests[idx], expected_type),
268-
f"Expected {expected_type} at index {idx}, got {type(requests[idx])}")
289+
self.assertTrue(
290+
isinstance(requests[idx], expected_type),
291+
f"Expected {expected_type} at index {idx}, got {type(requests[idx])}",
292+
)
269293
idx += 1
270-
self.assertEqual(idx, len(requests),
271-
f"Expected {idx} requests, got {len(requests)}")
294+
self.assertEqual(
295+
idx, len(requests), f"Expected {idx} requests, got {len(requests)}"
296+
)
272297

273298
def adjust_request_id_sequence(self, expected_segments, requests, transaction_type):
274299
"""Adjust expected request ID sequence numbers based on actual session creation requests.
275-
300+
276301
Args:
277302
expected_segments: List of expected (method, (sequence_numbers)) tuples
278303
requests: List of actual requests from spanner_service.requests
279304
transaction_type: TransactionType enum value to check multiplexed session status
280-
305+
281306
Returns:
282307
List of adjusted expected segments with corrected sequence numbers
283308
"""
284-
from google.cloud.spanner_v1 import BatchCreateSessionsRequest, CreateSessionRequest, ExecuteSqlRequest, BeginTransactionRequest
285-
309+
from google.cloud.spanner_v1 import (
310+
BatchCreateSessionsRequest,
311+
CreateSessionRequest,
312+
ExecuteSqlRequest,
313+
BeginTransactionRequest,
314+
)
315+
286316
# Count session creation requests that come before the first non-session request
287317
session_requests_before = 0
288318
for req in requests:
289319
if isinstance(req, (BatchCreateSessionsRequest, CreateSessionRequest)):
290320
session_requests_before += 1
291321
elif isinstance(req, (ExecuteSqlRequest, BeginTransactionRequest)):
292322
break
293-
323+
294324
# For multiplexed sessions, we expect 2 session requests (BatchCreateSessions + CreateSession)
295325
# For non-multiplexed, we expect 1 session request (BatchCreateSessions)
296326
mux_enabled = is_multiplexed_enabled(transaction_type)
297327
expected_session_requests = 2 if mux_enabled else 1
298328
extra_session_requests = session_requests_before - expected_session_requests
299-
329+
300330
# Adjust sequence numbers based on extra session requests
301331
adjusted_segments = []
302332
for method, seq_nums in expected_segments:
303333
# Adjust the sequence number (5th element in the tuple)
304334
adjusted_seq_nums = list(seq_nums)
305335
adjusted_seq_nums[4] += extra_session_requests
306336
adjusted_segments.append((method, tuple(adjusted_seq_nums)))
307-
337+
308338
return adjusted_segments

‎tests/mockserver_tests/test_aborted_transaction.py‎

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@ def test_run_in_transaction_commit_aborted(self):
5050
requests = self.spanner_service.requests
5151
self.assert_requests_sequence(
5252
requests,
53-
[BeginTransactionRequest, CommitRequest, BeginTransactionRequest, CommitRequest],
53+
[
54+
BeginTransactionRequest,
55+
CommitRequest,
56+
BeginTransactionRequest,
57+
CommitRequest,
58+
],
5459
TransactionType.READ_WRITE,
5560
)
5661

‎tests/mockserver_tests/test_basics.py‎

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,20 @@ def test_dbapi_partitioned_dml(self):
9797
)
9898
# Find the first BeginTransactionRequest after session creation
9999
idx = 0
100-
from google.cloud.spanner_v1 import BatchCreateSessionsRequest, CreateSessionRequest
101-
while idx < len(requests) and isinstance(requests[idx], BatchCreateSessionsRequest):
100+
from google.cloud.spanner_v1 import (
101+
BatchCreateSessionsRequest,
102+
CreateSessionRequest,
103+
)
104+
105+
while idx < len(requests) and isinstance(
106+
requests[idx], BatchCreateSessionsRequest
107+
):
102108
idx += 1
103-
if is_multiplexed_enabled(TransactionType.PARTITIONED) and idx < len(requests) and isinstance(requests[idx], CreateSessionRequest):
109+
if (
110+
is_multiplexed_enabled(TransactionType.PARTITIONED)
111+
and idx < len(requests)
112+
and isinstance(requests[idx], CreateSessionRequest)
113+
):
104114
idx += 1
105115
begin_request: BeginTransactionRequest = requests[idx]
106116
self.assertEqual(

‎tests/mockserver_tests/test_request_id_header.py‎

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def test_snapshot_execute_sql(self):
5959
CHANNEL_ID = self.database._channel_id
6060
got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers()
6161
# Filter out CreateSessionRequest unary segments for comparison
62-
filtered_unary_segments = [seg for seg in got_unary_segments if not seg[0].endswith("/CreateSession")]
62+
filtered_unary_segments = [
63+
seg for seg in got_unary_segments if not seg[0].endswith("/CreateSession")
64+
]
6365
want_unary_segments = [
6466
(
6567
"/google.spanner.v1.Spanner/BatchCreateSessions",
@@ -76,7 +78,14 @@ def test_snapshot_execute_sql(self):
7678
want_stream_segments = [
7779
(
7880
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
79-
(1, REQ_RAND_PROCESS_ID, NTH_CLIENT, CHANNEL_ID, 1 + session_requests_before, 1),
81+
(
82+
1,
83+
REQ_RAND_PROCESS_ID,
84+
NTH_CLIENT,
85+
CHANNEL_ID,
86+
1 + session_requests_before,
87+
1,
88+
),
8089
)
8190
]
8291
assert filtered_unary_segments == want_unary_segments
@@ -89,6 +98,7 @@ def test_snapshot_read_concurrent(self):
8998
rows = snapshot.execute_sql("select 1")
9099
for row in rows:
91100
_ = row
101+
92102
def select1():
93103
with db.snapshot() as snapshot:
94104
rows = snapshot.execute_sql("select 1")
@@ -97,6 +107,7 @@ def select1():
97107
self.assertEqual(1, row[0])
98108
res_list.append(row)
99109
self.assertEqual(1, len(res_list))
110+
100111
n = 10
101112
threads = []
102113
for i in range(n):
@@ -110,7 +121,9 @@ def select1():
110121
# Allow for an extra request due to multiplexed session creation
111122
expected_min = 2 + n
112123
expected_max = expected_min + 1
113-
assert expected_min <= len(requests) <= expected_max, f"Expected {expected_min} or {expected_max} requests, got {len(requests)}: {requests}"
124+
assert (
125+
expected_min <= len(requests) <= expected_max
126+
), f"Expected {expected_min} or {expected_max} requests, got {len(requests)}: {requests}"
114127
client_id = db._nth_client_id
115128
channel_id = db._channel_id
116129
got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers()
@@ -132,7 +145,14 @@ def select1():
132145
want_stream_segments = [
133146
(
134147
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
135-
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, session_requests_before + i, 1),
148+
(
149+
1,
150+
REQ_RAND_PROCESS_ID,
151+
client_id,
152+
channel_id,
153+
session_requests_before + i,
154+
1,
155+
),
136156
)
137157
for i in range(1, n + 2)
138158
]
@@ -173,7 +193,9 @@ def test_database_execute_partitioned_dml_request_id(self):
173193
NTH_CLIENT = self.database._nth_client_id
174194
CHANNEL_ID = self.database._channel_id
175195
# Allow for extra unary segments due to session creation
176-
filtered_unary_segments = [seg for seg in got_unary_segments if not seg[0].endswith("/CreateSession")]
196+
filtered_unary_segments = [
197+
seg for seg in got_unary_segments if not seg[0].endswith("/CreateSession")
198+
]
177199
# Find the actual sequence number for BeginTransaction
178200
begin_txn_seq = None
179201
for seg in filtered_unary_segments:

‎tests/mockserver_tests/test_tags.py‎

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,25 @@ def test_select_read_write_transaction_with_transaction_tag(self):
149149
requests = self.spanner_service.requests
150150
self.assert_requests_sequence(
151151
requests,
152-
[BeginTransactionRequest, ExecuteSqlRequest, ExecuteSqlRequest, CommitRequest],
152+
[
153+
BeginTransactionRequest,
154+
ExecuteSqlRequest,
155+
ExecuteSqlRequest,
156+
CommitRequest,
157+
],
153158
TransactionType.READ_WRITE,
154159
)
155160
mux_enabled = is_multiplexed_enabled(TransactionType.READ_WRITE)
156161
tag_idx = 3 if mux_enabled else 2
157-
self.assertEqual("my_transaction_tag", requests[tag_idx].request_options.transaction_tag)
158-
self.assertEqual("my_transaction_tag", requests[tag_idx + 1].request_options.transaction_tag)
159-
self.assertEqual("my_transaction_tag", requests[tag_idx + 2].request_options.transaction_tag)
162+
self.assertEqual(
163+
"my_transaction_tag", requests[tag_idx].request_options.transaction_tag
164+
)
165+
self.assertEqual(
166+
"my_transaction_tag", requests[tag_idx + 1].request_options.transaction_tag
167+
)
168+
self.assertEqual(
169+
"my_transaction_tag", requests[tag_idx + 2].request_options.transaction_tag
170+
)
160171

161172
def test_select_read_write_transaction_with_transaction_and_request_tag(self):
162173
connection = Connection(self.instance, self.database)
@@ -170,16 +181,27 @@ def test_select_read_write_transaction_with_transaction_and_request_tag(self):
170181
requests = self.spanner_service.requests
171182
self.assert_requests_sequence(
172183
requests,
173-
[BeginTransactionRequest, ExecuteSqlRequest, ExecuteSqlRequest, CommitRequest],
184+
[
185+
BeginTransactionRequest,
186+
ExecuteSqlRequest,
187+
ExecuteSqlRequest,
188+
CommitRequest,
189+
],
174190
TransactionType.READ_WRITE,
175191
)
176192
mux_enabled = is_multiplexed_enabled(TransactionType.READ_WRITE)
177193
tag_idx = 3 if mux_enabled else 2
178-
self.assertEqual("my_transaction_tag", requests[tag_idx].request_options.transaction_tag)
194+
self.assertEqual(
195+
"my_transaction_tag", requests[tag_idx].request_options.transaction_tag
196+
)
179197
self.assertEqual("my_tag1", requests[tag_idx].request_options.request_tag)
180-
self.assertEqual("my_transaction_tag", requests[tag_idx + 1].request_options.transaction_tag)
198+
self.assertEqual(
199+
"my_transaction_tag", requests[tag_idx + 1].request_options.transaction_tag
200+
)
181201
self.assertEqual("my_tag2", requests[tag_idx + 1].request_options.request_tag)
182-
self.assertEqual("my_transaction_tag", requests[tag_idx + 2].request_options.transaction_tag)
202+
self.assertEqual(
203+
"my_transaction_tag", requests[tag_idx + 2].request_options.transaction_tag
204+
)
183205

184206
def test_request_tag_is_cleared(self):
185207
connection = Connection(self.instance, self.database)

0 commit comments

Comments
 (0)