Skip to content

Commit 15d245f

Browse files
1 parent bbbc41e commit 15d245f

File tree

3 files changed

+64
-34
lines changed

3 files changed

+64
-34
lines changed

‎google/cloud/spanner_v1/database.py‎

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def __init__(
201201

202202
self._pool = pool
203203
pool.bind(self)
204-
204+
205205
# Initialize session options and sessions manager for multiplexed session support
206206
self.session_options = SessionOptions()
207207
self._sessions_manager = DatabaseSessionsManager(self, pool)
@@ -766,9 +766,11 @@ def execute_pdml():
766766
observability_options=self.observability_options,
767767
) as span, MetricsCapture():
768768
from google.cloud.spanner_v1.session_options import TransactionType
769-
769+
770770
# Use sessions manager for partitioned DML operations
771-
session = self._sessions_manager.get_session(TransactionType.PARTITIONED)
771+
session = self._sessions_manager.get_session(
772+
TransactionType.PARTITIONED
773+
)
772774
try:
773775
add_span_event(span, "Starting BeginTransaction")
774776
txn = api.begin_transaction(
@@ -1255,7 +1257,7 @@ def observability_options(self):
12551257
@property
12561258
def sessions_manager(self):
12571259
"""Returns the database sessions manager.
1258-
1260+
12591261
:rtype: :class:`~google.cloud.spanner_v1.database_sessions_manager.DatabaseSessionsManager`
12601262
:returns: The sessions manager for this database.
12611263
"""
@@ -1312,7 +1314,7 @@ def __init__(
13121314
def __enter__(self):
13131315
"""Begin ``with`` block."""
13141316
from google.cloud.spanner_v1.session_options import TransactionType
1315-
1317+
13161318
current_span = get_current_span()
13171319
session = self._session = self._database.sessions_manager.get_session(
13181320
TransactionType.READ_WRITE
@@ -1370,7 +1372,7 @@ def __init__(self, database):
13701372
def __enter__(self):
13711373
"""Begin ``with`` block."""
13721374
from google.cloud.spanner_v1.session_options import TransactionType
1373-
1375+
13741376
session = self._session = self._database.sessions_manager.get_session(
13751377
TransactionType.READ_WRITE
13761378
)
@@ -1413,7 +1415,7 @@ def __init__(self, database, **kw):
14131415
def __enter__(self):
14141416
"""Begin ``with`` block."""
14151417
from google.cloud.spanner_v1.session_options import TransactionType
1416-
1418+
14171419
session = self._session = self._database.sessions_manager.get_session(
14181420
TransactionType.READ_ONLY
14191421
)
@@ -1508,7 +1510,7 @@ def _get_session(self):
15081510
"""
15091511
if self._session is None:
15101512
from google.cloud.spanner_v1.session_options import TransactionType
1511-
1513+
15121514
# Use sessions manager for partition operations
15131515
session = self._session = self._database.sessions_manager.get_session(
15141516
TransactionType.PARTITIONED

‎google/cloud/spanner_v1/session_options.py‎

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,8 @@ def use_multiplexed(self, transaction_type: TransactionType) -> bool:
6969
"""
7070

7171
if transaction_type is TransactionType.READ_ONLY:
72-
return (
73-
self._is_multiplexed_enabled[transaction_type]
74-
and self._getenv(self.ENV_VAR_ENABLE_MULTIPLEXED)
72+
return self._is_multiplexed_enabled[transaction_type] and self._getenv(
73+
self.ENV_VAR_ENABLE_MULTIPLEXED
7574
)
7675

7776
elif transaction_type is TransactionType.PARTITIONED:
@@ -130,4 +129,4 @@ def _getenv(name: str) -> bool:
130129
considered false.
131130
"""
132131
env_var = os.getenv(name, "").lower().strip()
133-
return env_var in ["1", "true"]
132+
return env_var in ["1", "true"]

‎tests/unit/test_database.py‎

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,25 +1263,32 @@ def _execute_partitioned_dml_helper(
12631263
session = _Session()
12641264
pool.put(session)
12651265
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
1266-
1266+
12671267
# Check if multiplexed sessions are enabled for partitioned operations
1268-
multiplexed_partitioned_enabled = os.environ.get(
1269-
"GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS", "false"
1270-
).lower() == "true"
1271-
1268+
multiplexed_partitioned_enabled = (
1269+
os.environ.get(
1270+
"GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS", "false"
1271+
).lower()
1272+
== "true"
1273+
)
1274+
12721275
if multiplexed_partitioned_enabled:
12731276
# When multiplexed sessions are enabled, create a mock multiplexed session
12741277
# that the sessions manager will return
12751278
multiplexed_session = _Session()
1276-
multiplexed_session.name = self.SESSION_NAME # Use the expected session name
1279+
multiplexed_session.name = (
1280+
self.SESSION_NAME
1281+
) # Use the expected session name
12771282
multiplexed_session.is_multiplexed = True
12781283
# Configure the sessions manager to return the multiplexed session
1279-
database._sessions_manager.get_session = mock.Mock(return_value=multiplexed_session)
1284+
database._sessions_manager.get_session = mock.Mock(
1285+
return_value=multiplexed_session
1286+
)
12801287
expected_session = multiplexed_session
12811288
else:
12821289
# When multiplexed sessions are disabled, use the regular pool session
12831290
expected_session = session
1284-
1291+
12851292
api = database._spanner_api = self._make_spanner_api()
12861293
api._method_configs = {"ExecuteStreamingSql": MethodConfig(retry=Retry())}
12871294
if retried:
@@ -1446,12 +1453,15 @@ def _execute_partitioned_dml_helper(
14461453
],
14471454
)
14481455
self.assertEqual(api.execute_streaming_sql.call_count, 1)
1449-
1456+
14501457
# Verify that the correct session type was used based on environment
14511458
if multiplexed_partitioned_enabled:
14521459
# Verify that sessions_manager.get_session was called with PARTITIONED transaction type
14531460
from google.cloud.spanner_v1.session_options import TransactionType
1454-
database._sessions_manager.get_session.assert_called_with(TransactionType.PARTITIONED)
1461+
1462+
database._sessions_manager.get_session.assert_called_with(
1463+
TransactionType.PARTITIONED
1464+
)
14551465
# If multiplexed sessions are not enabled, the regular pool session should be used
14561466

14571467
def test_execute_partitioned_dml_wo_params(self):
@@ -1542,16 +1552,20 @@ def test_snapshot_defaults(self):
15421552
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
15431553

15441554
# Check if multiplexed sessions are enabled for read operations
1545-
multiplexed_enabled = os.getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS") == "true"
1546-
1555+
multiplexed_enabled = (
1556+
os.getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS") == "true"
1557+
)
1558+
15471559
if multiplexed_enabled:
15481560
# When multiplexed sessions are enabled, configure the sessions manager
15491561
# to return a multiplexed session for read operations
15501562
multiplexed_session = _Session()
15511563
multiplexed_session.name = self.SESSION_NAME
15521564
multiplexed_session.is_multiplexed = True
15531565
# Override the side_effect to return the multiplexed session
1554-
database._sessions_manager.get_session = mock.Mock(return_value=multiplexed_session)
1566+
database._sessions_manager.get_session = mock.Mock(
1567+
return_value=multiplexed_session
1568+
)
15551569
expected_session = multiplexed_session
15561570
else:
15571571
expected_session = session
@@ -1588,16 +1602,20 @@ def test_snapshot_w_read_timestamp_and_multi_use(self):
15881602
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
15891603

15901604
# Check if multiplexed sessions are enabled for read operations
1591-
multiplexed_enabled = os.getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS") == "true"
1592-
1605+
multiplexed_enabled = (
1606+
os.getenv("GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS") == "true"
1607+
)
1608+
15931609
if multiplexed_enabled:
15941610
# When multiplexed sessions are enabled, configure the sessions manager
15951611
# to return a multiplexed session for read operations
15961612
multiplexed_session = _Session()
15971613
multiplexed_session.name = self.SESSION_NAME
15981614
multiplexed_session.is_multiplexed = True
15991615
# Override the side_effect to return the multiplexed session
1600-
database._sessions_manager.get_session = mock.Mock(return_value=multiplexed_session)
1616+
database._sessions_manager.get_session = mock.Mock(
1617+
return_value=multiplexed_session
1618+
)
16011619
expected_session = multiplexed_session
16021620
else:
16031621
expected_session = session
@@ -2557,7 +2575,10 @@ def test__get_session_new(self):
25572575
self.assertIs(batch_txn._get_session(), session)
25582576
# Verify that sessions_manager.get_session was called with PARTITIONED transaction type
25592577
from google.cloud.spanner_v1.session_options import TransactionType
2560-
database.sessions_manager.get_session.assert_called_once_with(TransactionType.PARTITIONED)
2578+
2579+
database.sessions_manager.get_session.assert_called_once_with(
2580+
TransactionType.PARTITIONED
2581+
)
25612582

25622583
def test__get_snapshot_already(self):
25632584
database = self._make_database()
@@ -3518,17 +3539,25 @@ def __init__(self, name, instance=None):
35183539
self.default_transaction_options = DefaultTransactionOptions()
35193540
self._nth_request = AtomicCounter()
35203541
self._nth_client_id = _Database.NTH_CLIENT_ID.increment()
3521-
3542+
35223543
# Mock sessions manager for multiplexed sessions support
35233544
self._sessions_manager = mock.Mock()
3524-
# Configure get_session to return sessions from the pool
3525-
self._sessions_manager.get_session = mock.Mock(side_effect=lambda tx_type: self._pool.get() if hasattr(self, '_pool') and self._pool else None)
3526-
self._sessions_manager.put_session = mock.Mock(side_effect=lambda session: self._pool.put(session) if hasattr(self, '_pool') and self._pool else None)
3545+
# Configure get_session to return sessions from the pool
3546+
self._sessions_manager.get_session = mock.Mock(
3547+
side_effect=lambda tx_type: self._pool.get()
3548+
if hasattr(self, "_pool") and self._pool
3549+
else None
3550+
)
3551+
self._sessions_manager.put_session = mock.Mock(
3552+
side_effect=lambda session: self._pool.put(session)
3553+
if hasattr(self, "_pool") and self._pool
3554+
else None
3555+
)
35273556

35283557
@property
35293558
def sessions_manager(self):
35303559
"""Returns the database sessions manager.
3531-
3560+
35323561
:rtype: Mock
35333562
:returns: The mock sessions manager for this database.
35343563
"""

0 commit comments

Comments
 (0)