@@ -231,78 +231,108 @@ def database(self) -> Database:
231231 )
232232 return self ._database
233233
234- def assert_requests_sequence (self , requests , expected_types , transaction_type , allow_multiple_batch_create = True ):
234+ def assert_requests_sequence (
235+ self ,
236+ requests ,
237+ expected_types ,
238+ transaction_type ,
239+ allow_multiple_batch_create = True ,
240+ ):
235241 """Assert that the requests sequence matches the expected types, accounting for multiplexed sessions and retries.
236-
242+
237243 Args:
238244 requests: List of requests from spanner_service.requests
239245 expected_types: List of expected request types (excluding session creation requests)
240246 transaction_type: TransactionType enum value to check multiplexed session status
241247 allow_multiple_batch_create: If True, skip all leading BatchCreateSessionsRequest and one optional CreateSessionRequest
242248 """
243- from google .cloud .spanner_v1 import BatchCreateSessionsRequest , CreateSessionRequest
249+ from google .cloud .spanner_v1 import (
250+ BatchCreateSessionsRequest ,
251+ CreateSessionRequest ,
252+ )
253+
244254 mux_enabled = is_multiplexed_enabled (transaction_type )
245255 idx = 0
246256 # Skip all leading BatchCreateSessionsRequest (for retries)
247257 if allow_multiple_batch_create :
248- while idx < len (requests ) and isinstance (requests [idx ], BatchCreateSessionsRequest ):
258+ while idx < len (requests ) and isinstance (
259+ requests [idx ], BatchCreateSessionsRequest
260+ ):
249261 idx += 1
250262 # For multiplexed, optionally skip a CreateSessionRequest
251- if mux_enabled and idx < len (requests ) and isinstance (requests [idx ], CreateSessionRequest ):
263+ if (
264+ mux_enabled
265+ and idx < len (requests )
266+ and isinstance (requests [idx ], CreateSessionRequest )
267+ ):
252268 idx += 1
253269 else :
254270 if mux_enabled :
255- self .assertTrue (isinstance (requests [idx ], BatchCreateSessionsRequest ),
256- f"Expected BatchCreateSessionsRequest at index { idx } , got { type (requests [idx ])} " )
271+ self .assertTrue (
272+ isinstance (requests [idx ], BatchCreateSessionsRequest ),
273+ f"Expected BatchCreateSessionsRequest at index { idx } , got { type (requests [idx ])} " ,
274+ )
257275 idx += 1
258- self .assertTrue (isinstance (requests [idx ], CreateSessionRequest ),
259- f"Expected CreateSessionRequest at index { idx } , got { type (requests [idx ])} " )
276+ self .assertTrue (
277+ isinstance (requests [idx ], CreateSessionRequest ),
278+ f"Expected CreateSessionRequest at index { idx } , got { type (requests [idx ])} " ,
279+ )
260280 idx += 1
261281 else :
262- self .assertTrue (isinstance (requests [idx ], BatchCreateSessionsRequest ),
263- f"Expected BatchCreateSessionsRequest at index { idx } , got { type (requests [idx ])} " )
282+ self .assertTrue (
283+ isinstance (requests [idx ], BatchCreateSessionsRequest ),
284+ f"Expected BatchCreateSessionsRequest at index { idx } , got { type (requests [idx ])} " ,
285+ )
264286 idx += 1
265287 # Check the rest of the expected request types
266288 for expected_type in expected_types :
267- self .assertTrue (isinstance (requests [idx ], expected_type ),
268- f"Expected { expected_type } at index { idx } , got { type (requests [idx ])} " )
289+ self .assertTrue (
290+ isinstance (requests [idx ], expected_type ),
291+ f"Expected { expected_type } at index { idx } , got { type (requests [idx ])} " ,
292+ )
269293 idx += 1
270- self .assertEqual (idx , len (requests ),
271- f"Expected { idx } requests, got { len (requests )} " )
294+ self .assertEqual (
295+ idx , len (requests ), f"Expected { idx } requests, got { len (requests )} "
296+ )
272297
273298 def adjust_request_id_sequence (self , expected_segments , requests , transaction_type ):
274299 """Adjust expected request ID sequence numbers based on actual session creation requests.
275-
300+
276301 Args:
277302 expected_segments: List of expected (method, (sequence_numbers)) tuples
278303 requests: List of actual requests from spanner_service.requests
279304 transaction_type: TransactionType enum value to check multiplexed session status
280-
305+
281306 Returns:
282307 List of adjusted expected segments with corrected sequence numbers
283308 """
284- from google .cloud .spanner_v1 import BatchCreateSessionsRequest , CreateSessionRequest , ExecuteSqlRequest , BeginTransactionRequest
285-
309+ from google .cloud .spanner_v1 import (
310+ BatchCreateSessionsRequest ,
311+ CreateSessionRequest ,
312+ ExecuteSqlRequest ,
313+ BeginTransactionRequest ,
314+ )
315+
286316 # Count session creation requests that come before the first non-session request
287317 session_requests_before = 0
288318 for req in requests :
289319 if isinstance (req , (BatchCreateSessionsRequest , CreateSessionRequest )):
290320 session_requests_before += 1
291321 elif isinstance (req , (ExecuteSqlRequest , BeginTransactionRequest )):
292322 break
293-
323+
294324 # For multiplexed sessions, we expect 2 session requests (BatchCreateSessions + CreateSession)
295325 # For non-multiplexed, we expect 1 session request (BatchCreateSessions)
296326 mux_enabled = is_multiplexed_enabled (transaction_type )
297327 expected_session_requests = 2 if mux_enabled else 1
298328 extra_session_requests = session_requests_before - expected_session_requests
299-
329+
300330 # Adjust sequence numbers based on extra session requests
301331 adjusted_segments = []
302332 for method , seq_nums in expected_segments :
303333 # Adjust the sequence number (5th element in the tuple)
304334 adjusted_seq_nums = list (seq_nums )
305335 adjusted_seq_nums [4 ] += extra_session_requests
306336 adjusted_segments .append ((method , tuple (adjusted_seq_nums )))
307-
337+
308338 return adjusted_segments
0 commit comments