@@ -1263,25 +1263,32 @@ def _execute_partitioned_dml_helper(
12631263 session = _Session ()
12641264 pool .put (session )
12651265 database = self ._make_one (self .DATABASE_ID , instance , pool = pool )
1266-
1266+
12671267 # Check if multiplexed sessions are enabled for partitioned operations
1268- multiplexed_partitioned_enabled = os .environ .get (
1269- "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS" , "false"
1270- ).lower () == "true"
1271-
1268+ multiplexed_partitioned_enabled = (
1269+ os .environ .get (
1270+ "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS" , "false"
1271+ ).lower ()
1272+ == "true"
1273+ )
1274+
12721275 if multiplexed_partitioned_enabled :
12731276 # When multiplexed sessions are enabled, create a mock multiplexed session
12741277 # that the sessions manager will return
12751278 multiplexed_session = _Session ()
1276- multiplexed_session .name = self .SESSION_NAME # Use the expected session name
1279+ multiplexed_session .name = (
1280+ self .SESSION_NAME
1281+ ) # Use the expected session name
12771282 multiplexed_session .is_multiplexed = True
12781283 # Configure the sessions manager to return the multiplexed session
1279- database ._sessions_manager .get_session = mock .Mock (return_value = multiplexed_session )
1284+ database ._sessions_manager .get_session = mock .Mock (
1285+ return_value = multiplexed_session
1286+ )
12801287 expected_session = multiplexed_session
12811288 else :
12821289 # When multiplexed sessions are disabled, use the regular pool session
12831290 expected_session = session
1284-
1291+
12851292 api = database ._spanner_api = self ._make_spanner_api ()
12861293 api ._method_configs = {"ExecuteStreamingSql" : MethodConfig (retry = Retry ())}
12871294 if retried :
@@ -1446,12 +1453,15 @@ def _execute_partitioned_dml_helper(
14461453 ],
14471454 )
14481455 self .assertEqual (api .execute_streaming_sql .call_count , 1 )
1449-
1456+
14501457 # Verify that the correct session type was used based on environment
14511458 if multiplexed_partitioned_enabled :
14521459 # Verify that sessions_manager.get_session was called with PARTITIONED transaction type
14531460 from google .cloud .spanner_v1 .session_options import TransactionType
1454- database ._sessions_manager .get_session .assert_called_with (TransactionType .PARTITIONED )
1461+
1462+ database ._sessions_manager .get_session .assert_called_with (
1463+ TransactionType .PARTITIONED
1464+ )
14551465 # If multiplexed sessions are not enabled, the regular pool session should be used
14561466
14571467 def test_execute_partitioned_dml_wo_params (self ):
@@ -1542,16 +1552,20 @@ def test_snapshot_defaults(self):
15421552 database = self ._make_one (self .DATABASE_ID , instance , pool = pool )
15431553
15441554 # Check if multiplexed sessions are enabled for read operations
1545- multiplexed_enabled = os .getenv ("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS" ) == "true"
1546-
1555+ multiplexed_enabled = (
1556+ os .getenv ("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS" ) == "true"
1557+ )
1558+
15471559 if multiplexed_enabled :
15481560 # When multiplexed sessions are enabled, configure the sessions manager
15491561 # to return a multiplexed session for read operations
15501562 multiplexed_session = _Session ()
15511563 multiplexed_session .name = self .SESSION_NAME
15521564 multiplexed_session .is_multiplexed = True
15531565 # Override the side_effect to return the multiplexed session
1554- database ._sessions_manager .get_session = mock .Mock (return_value = multiplexed_session )
1566+ database ._sessions_manager .get_session = mock .Mock (
1567+ return_value = multiplexed_session
1568+ )
15551569 expected_session = multiplexed_session
15561570 else :
15571571 expected_session = session
@@ -1588,16 +1602,20 @@ def test_snapshot_w_read_timestamp_and_multi_use(self):
15881602 database = self ._make_one (self .DATABASE_ID , instance , pool = pool )
15891603
15901604 # Check if multiplexed sessions are enabled for read operations
1591- multiplexed_enabled = os .getenv ("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS" ) == "true"
1592-
1605+ multiplexed_enabled = (
1606+ os .getenv ("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS" ) == "true"
1607+ )
1608+
15931609 if multiplexed_enabled :
15941610 # When multiplexed sessions are enabled, configure the sessions manager
15951611 # to return a multiplexed session for read operations
15961612 multiplexed_session = _Session ()
15971613 multiplexed_session .name = self .SESSION_NAME
15981614 multiplexed_session .is_multiplexed = True
15991615 # Override the side_effect to return the multiplexed session
1600- database ._sessions_manager .get_session = mock .Mock (return_value = multiplexed_session )
1616+ database ._sessions_manager .get_session = mock .Mock (
1617+ return_value = multiplexed_session
1618+ )
16011619 expected_session = multiplexed_session
16021620 else :
16031621 expected_session = session
@@ -2557,7 +2575,10 @@ def test__get_session_new(self):
25572575 self .assertIs (batch_txn ._get_session (), session )
25582576 # Verify that sessions_manager.get_session was called with PARTITIONED transaction type
25592577 from google .cloud .spanner_v1 .session_options import TransactionType
2560- database .sessions_manager .get_session .assert_called_once_with (TransactionType .PARTITIONED )
2578+
2579+ database .sessions_manager .get_session .assert_called_once_with (
2580+ TransactionType .PARTITIONED
2581+ )
25612582
25622583 def test__get_snapshot_already (self ):
25632584 database = self ._make_database ()
@@ -3518,17 +3539,25 @@ def __init__(self, name, instance=None):
35183539 self .default_transaction_options = DefaultTransactionOptions ()
35193540 self ._nth_request = AtomicCounter ()
35203541 self ._nth_client_id = _Database .NTH_CLIENT_ID .increment ()
3521-
3542+
35223543 # Mock sessions manager for multiplexed sessions support
35233544 self ._sessions_manager = mock .Mock ()
3524- # Configure get_session to return sessions from the pool
3525- self ._sessions_manager .get_session = mock .Mock (side_effect = lambda tx_type : self ._pool .get () if hasattr (self , '_pool' ) and self ._pool else None )
3526- self ._sessions_manager .put_session = mock .Mock (side_effect = lambda session : self ._pool .put (session ) if hasattr (self , '_pool' ) and self ._pool else None )
3545+ # Configure get_session to return sessions from the pool
3546+ self ._sessions_manager .get_session = mock .Mock (
3547+ side_effect = lambda tx_type : self ._pool .get ()
3548+ if hasattr (self , "_pool" ) and self ._pool
3549+ else None
3550+ )
3551+ self ._sessions_manager .put_session = mock .Mock (
3552+ side_effect = lambda session : self ._pool .put (session )
3553+ if hasattr (self , "_pool" ) and self ._pool
3554+ else None
3555+ )
35273556
35283557 @property
35293558 def sessions_manager (self ):
35303559 """Returns the database sessions manager.
3531-
3560+
35323561 :rtype: Mock
35333562 :returns: The mock sessions manager for this database.
35343563 """
0 commit comments