Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def _make_params_pb(params, param_types):
:raises ValueError:
If ``params`` is None but ``param_types`` is not None.
"""
if params is not None:
if params:
return Struct(
fields={key: _make_value_pb(value) for key, value in params.items()}
)
Expand Down
45 changes: 45 additions & 0 deletions tests/mockserver_tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import unittest

from google.cloud.spanner_admin_database_v1.types import spanner_database_admin
from google.cloud.spanner_dbapi import Connection
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
from google.cloud.spanner_v1.testing.mock_spanner import (
start_mock_server,
Expand All @@ -29,6 +31,8 @@
FixedSizePool,
BatchCreateSessionsRequest,
ExecuteSqlRequest,
BeginTransactionRequest,
TransactionOptions,
)
from google.cloud.spanner_v1.database import Database
from google.cloud.spanner_v1.instance import Instance
Expand Down Expand Up @@ -62,6 +66,10 @@ def tearDownClass(cls):
TestBasics.server.stop(grace=None)
TestBasics.server = None

def teardown_method(self, *args, **kwargs):
TestBasics.spanner_service.clear_requests()
TestBasics.database_admin_service.clear_requests()

def _add_select1_result(self):
result = result_set.ResultSet(
dict(
Expand All @@ -88,6 +96,19 @@ def _add_select1_result(self):
result.rows.extend(["1"])
TestBasics.spanner_service.mock_spanner.add_result("select 1", result)

def add_update_count(
self,
sql: str,
count: int,
dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL,
):
if dml_mode == AutocommitDmlMode.PARTITIONED_NON_ATOMIC:
stats = dict(row_count_lower_bound=count)
else:
stats = dict(row_count_exact=count)
result = result_set.ResultSet(dict(stats=result_set.ResultSetStats(stats)))
TestBasics.spanner_service.mock_spanner.add_result(sql, result)

@property
def client(self) -> Client:
if self._client is None:
Expand Down Expand Up @@ -145,3 +166,27 @@ def test_create_table(self):
)
operation = database_admin_api.update_database_ddl(request)
operation.result(1)

# TODO: Move this to a separate class once the mock server test setup has
# been re-factored to use a base class for the boiler plate code.
def test_dbapi_partitioned_dml(self):
sql = "UPDATE singers SET foo='bar' WHERE active = true"
self.add_update_count(sql, 100, AutocommitDmlMode.PARTITIONED_NON_ATOMIC)
connection = Connection(self.instance, self.database)
connection.autocommit = True
connection.set_autocommit_dml_mode(AutocommitDmlMode.PARTITIONED_NON_ATOMIC)
with connection.cursor() as cursor:
# Note: SQLAlchemy uses [] as the list of parameters for statements
# with no parameters.
cursor.execute(sql, [])
self.assertEqual(100, cursor.rowcount)

requests = self.spanner_service.requests
self.assertEqual(3, len(requests), msg=requests)
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
begin_request: BeginTransactionRequest = requests[1]
self.assertEqual(
TransactionOptions(dict(partitioned_dml={})), begin_request.options
)
Loading