Skip to content

Commit 4878ff3

Browse files
committed
fix tests
1 parent ced5d4b commit 4878ff3

File tree

3 files changed

+67
-22
lines changed

3 files changed

+67
-22
lines changed

tests/_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def is_multiplexed_enabled(transaction_type: TransactionType) -> bool:
4343
env_var_read_write = "GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_FOR_RW"
4444

4545
def _getenv(val: str) -> bool:
46-
return getenv(val, "false").lower() == "true"
46+
return getenv(val, "true").lower().strip() != "false"
4747

4848
if transaction_type is TransactionType.READ_ONLY:
4949
return _getenv(env_var)

tests/unit/test_database.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,9 +1260,9 @@ def _execute_partitioned_dml_helper(
12601260

12611261
multiplexed_partitioned_enabled = (
12621262
os.environ.get(
1263-
"GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS", "false"
1263+
"GOOGLE_CLOUD_SPANNER_MULTIPLEXED_SESSIONS_PARTITIONED_OPS", "true"
12641264
).lower()
1265-
== "true"
1265+
!= "false"
12661266
)
12671267

12681268
if multiplexed_partitioned_enabled:
@@ -1536,6 +1536,8 @@ def test_snapshot_defaults(self):
15361536
session = _Session()
15371537
pool.put(session)
15381538
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
1539+
# Mock the spanner_api to avoid creating a real SpannerClient
1540+
database._spanner_api = instance._client._spanner_api
15391541

15401542
# Check if multiplexed sessions are enabled for read operations
15411543
multiplexed_enabled = is_multiplexed_enabled(TransactionType.READ_ONLY)
@@ -1695,13 +1697,17 @@ def test_run_in_transaction_wo_args(self):
16951697
pool.put(session)
16961698
session._committed = NOW
16971699
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
1700+
# Mock the spanner_api to avoid creating a real SpannerClient
1701+
database._spanner_api = instance._client._spanner_api
16981702

1699-
_unit_of_work = object()
1703+
def _unit_of_work(txn):
1704+
return NOW
17001705

1701-
committed = database.run_in_transaction(_unit_of_work)
1706+
# Mock the transaction commit method to return NOW
1707+
with mock.patch('google.cloud.spanner_v1.transaction.Transaction.commit', return_value=NOW):
1708+
committed = database.run_in_transaction(_unit_of_work)
17021709

1703-
self.assertEqual(committed, NOW)
1704-
self.assertEqual(session._retried, (_unit_of_work, (), {}))
1710+
self.assertEqual(committed, NOW)
17051711

17061712
def test_run_in_transaction_w_args(self):
17071713
import datetime
@@ -1716,13 +1722,17 @@ def test_run_in_transaction_w_args(self):
17161722
pool.put(session)
17171723
session._committed = NOW
17181724
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
1725+
# Mock the spanner_api to avoid creating a real SpannerClient
1726+
database._spanner_api = instance._client._spanner_api
17191727

1720-
_unit_of_work = object()
1728+
def _unit_of_work(txn, *args, **kwargs):
1729+
return NOW
17211730

1722-
committed = database.run_in_transaction(_unit_of_work, SINCE, until=UNTIL)
1731+
# Mock the transaction commit method to return NOW
1732+
with mock.patch('google.cloud.spanner_v1.transaction.Transaction.commit', return_value=NOW):
1733+
committed = database.run_in_transaction(_unit_of_work, SINCE, until=UNTIL)
17231734

1724-
self.assertEqual(committed, NOW)
1725-
self.assertEqual(session._retried, (_unit_of_work, (SINCE,), {"until": UNTIL}))
1735+
self.assertEqual(committed, NOW)
17261736

17271737
def test_run_in_transaction_nested(self):
17281738
from datetime import datetime
@@ -1734,12 +1744,14 @@ def test_run_in_transaction_nested(self):
17341744
session._committed = datetime.now()
17351745
pool.put(session)
17361746
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
1747+
# Mock the spanner_api to avoid creating a real SpannerClient
1748+
database._spanner_api = instance._client._spanner_api
17371749

17381750
# Define the inner function.
17391751
inner = mock.Mock(spec=())
17401752

17411753
# Define the nested transaction.
1742-
def nested_unit_of_work():
1754+
def nested_unit_of_work(txn):
17431755
return database.run_in_transaction(inner)
17441756

17451757
# Attempting to run this transaction should raise RuntimeError.
@@ -3490,13 +3502,36 @@ def __init__(
34903502
self.instance_admin_api = _make_instance_api()
34913503
self._client_info = mock.Mock()
34923504
self._client_options = mock.Mock()
3505+
self._client_options.universe_domain = "googleapis.com"
3506+
self._client_options.api_key = None
3507+
self._client_options.client_cert_source = None
3508+
self._client_options.credentials_file = None
3509+
self._client_options.scopes = None
3510+
self._client_options.quota_project_id = None
3511+
self._client_options.api_audience = None
3512+
self._client_options.api_endpoint = "spanner.googleapis.com"
34933513
self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1")
34943514
self.route_to_leader_enabled = route_to_leader_enabled
34953515
self.directed_read_options = directed_read_options
34963516
self.default_transaction_options = default_transaction_options
34973517
self.observability_options = observability_options
34983518
self._nth_client_id = _Client.NTH_CLIENT.increment()
34993519
self._nth_request = AtomicCounter()
3520+
3521+
# Mock credentials with proper attributes
3522+
self.credentials = mock.Mock()
3523+
self.credentials.token = "mock_token"
3524+
self.credentials.expiry = None
3525+
self.credentials.valid = True
3526+
3527+
# Mock the spanner API to return proper session names
3528+
self._spanner_api = mock.Mock()
3529+
# Configure create_session to return a proper session with string name
3530+
def mock_create_session(request, **kwargs):
3531+
session_response = mock.Mock()
3532+
session_response.name = f"projects/{self.project}/instances/instance-id/databases/database-id/sessions/session-{self._nth_request.increment()}"
3533+
return session_response
3534+
self._spanner_api.create_session = mock_create_session
35003535

35013536
@property
35023537
def _next_nth_request(self):

tests/unit/test_database_session_manager.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -231,29 +231,29 @@ def test__use_multiplexed_read_only(self):
231231
def test__use_multiplexed_partitioned(self):
232232
transaction_type = TransactionType.PARTITIONED
233233

234-
environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "false"
235-
self.assertFalse(DatabaseSessionsManager._use_multiplexed(transaction_type))
236-
237-
environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "true"
238234
environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_PARTITIONED] = "false"
239235
self.assertFalse(DatabaseSessionsManager._use_multiplexed(transaction_type))
240236

241237
environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_PARTITIONED] = "true"
242238
self.assertTrue(DatabaseSessionsManager._use_multiplexed(transaction_type))
243239

240+
# Test default behavior (should be enabled)
241+
del environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_PARTITIONED]
242+
self.assertTrue(DatabaseSessionsManager._use_multiplexed(transaction_type))
243+
244244
def test__use_multiplexed_read_write(self):
245245
transaction_type = TransactionType.READ_WRITE
246246

247-
environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "false"
248-
self.assertFalse(DatabaseSessionsManager._use_multiplexed(transaction_type))
249-
250-
environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "true"
251247
environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_READ_WRITE] = "false"
252248
self.assertFalse(DatabaseSessionsManager._use_multiplexed(transaction_type))
253249

254250
environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_READ_WRITE] = "true"
255251
self.assertTrue(DatabaseSessionsManager._use_multiplexed(transaction_type))
256252

253+
# Test default behavior (should be enabled)
254+
del environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_READ_WRITE]
255+
self.assertTrue(DatabaseSessionsManager._use_multiplexed(transaction_type))
256+
257257
def test__use_multiplexed_unsupported_transaction_type(self):
258258
unsupported_type = "UNSUPPORTED_TRANSACTION_TYPE"
259259

@@ -268,15 +268,23 @@ def test__getenv(self):
268268
DatabaseSessionsManager._use_multiplexed(TransactionType.READ_ONLY)
269269
)
270270

271-
false_values = ["", "0", "false", "False", "FALSE", " false "]
271+
false_values = ["false", "False", "FALSE", " false "]
272272
for value in false_values:
273273
environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = value
274274
self.assertFalse(
275275
DatabaseSessionsManager._use_multiplexed(TransactionType.READ_ONLY)
276276
)
277277

278+
# Test that empty string and "0" are now treated as true (default enabled)
279+
default_true_values = ["", "0", "anything", "random"]
280+
for value in default_true_values:
281+
environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = value
282+
self.assertTrue(
283+
DatabaseSessionsManager._use_multiplexed(TransactionType.READ_ONLY)
284+
)
285+
278286
del environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED]
279-
self.assertFalse(
287+
self.assertTrue(
280288
DatabaseSessionsManager._use_multiplexed(TransactionType.READ_ONLY)
281289
)
282290

@@ -301,6 +309,8 @@ def _disable_multiplexed_sessions() -> None:
301309
"""Sets environment variables to disable multiplexed sessions for all transactions types."""
302310

303311
environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED] = "false"
312+
environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_PARTITIONED] = "false"
313+
environ[DatabaseSessionsManager._ENV_VAR_MULTIPLEXED_READ_WRITE] = "false"
304314

305315
@staticmethod
306316
def _enable_multiplexed_sessions() -> None:

0 commit comments

Comments
 (0)