Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions google/cloud/spanner_dbapi/batch_dml_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
for statement in statements:
statements_tuple.append(statement.get_tuple())
if not connection._client_transaction_started:
res = connection.database.run_in_transaction(_do_batch_update, statements_tuple)
res = connection.database.run_in_transaction(
_do_batch_update_autocommit, statements_tuple
)
many_result_set.add_iter(res)
cursor._row_count = sum([max(val, 0) for val in res])
else:
Expand All @@ -113,10 +115,10 @@ def run_batch_dml(cursor: "Cursor", statements: List[Statement]):
connection._transaction_helper.retry_transaction()


def _do_batch_update(transaction, statements):
def _do_batch_update_autocommit(transaction, statements):
from google.cloud.spanner_dbapi import OperationalError

status, res = transaction.batch_update(statements)
status, res = transaction.batch_update(statements, last_statement=True)
if status.code == ABORTED:
raise Aborted(status.message)
elif status.code != OK:
Expand Down
5 changes: 4 additions & 1 deletion google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,10 @@ def _do_execute_update_in_autocommit(self, transaction, sql, params):
self.connection._transaction = transaction
self.connection._snapshot = None
self._result_set = transaction.execute_sql(
sql, params=params, param_types=get_param_types(params)
sql,
params=params,
param_types=get_param_types(params),
last_statement=True,
)
self._itr = PeekIterator(self._result_set)
self._row_count = None
Expand Down
15 changes: 15 additions & 0 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def execute_sql(
query_mode=None,
query_options=None,
request_options=None,
last_statement=False,
partition=None,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
Expand Down Expand Up @@ -432,6 +433,19 @@ def execute_sql(
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.

:type last_statement: bool
:param last_statement:
If set to true, this option marks the end of the transaction. The
transaction should be committed or aborted after this statement
executes, and attempts to execute any other requests against this
transaction (including reads and queries) will be rejected. Mixing
mutations with statements that are marked as the last statement is
not allowed.
For DML statements, setting this option may cause some error
reporting to be deferred until commit time (e.g. validation of
unique constraints). Given this, successful execution of a DML
statement should not be assumed until the transaction commits.

:type partition: bytes
:param partition: (Optional) one of the partition tokens returned
from :meth:`partition_query`.
Expand Down Expand Up @@ -536,6 +550,7 @@ def execute_sql(
seqno=self._execute_sql_count,
query_options=query_options,
request_options=request_options,
last_statement=last_statement,
data_boost_enabled=data_boost_enabled,
directed_read_options=directed_read_options,
)
Expand Down
30 changes: 30 additions & 0 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def execute_update(
query_mode=None,
query_options=None,
request_options=None,
last_statement=False,
*,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
Expand Down Expand Up @@ -385,6 +386,19 @@ def execute_update(
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.

:type last_statement: bool
:param last_statement:
If set to true, this option marks the end of the transaction. The
transaction should be committed or aborted after this statement
executes, and attempts to execute any other requests against this
transaction (including reads and queries) will be rejected. Mixing
mutations with statements that are marked as the last statement is
not allowed.
For DML statements, setting this option may cause some error
reporting to be deferred until commit time (e.g. validation of
unique constraints). Given this, successful execution of a DML
statement should not be assumed until the transaction commits.

:type retry: :class:`~google.api_core.retry.Retry`
:param retry: (Optional) The retry settings for this request.

Expand Down Expand Up @@ -433,6 +447,7 @@ def execute_update(
query_options=query_options,
seqno=seqno,
request_options=request_options,
last_statement=last_statement,
)

method = functools.partial(
Expand Down Expand Up @@ -478,6 +493,7 @@ def batch_update(
self,
statements,
request_options=None,
last_statement=False,
*,
retry=gapic_v1.method.DEFAULT,
timeout=gapic_v1.method.DEFAULT,
Expand All @@ -502,6 +518,19 @@ def batch_update(
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.

:type last_statement: bool
:param last_statement:
If set to true, this option marks the end of the transaction. The
transaction should be committed or aborted after this statement
executes, and attempts to execute any other requests against this
transaction (including reads and queries) will be rejected. Mixing
mutations with statements that are marked as the last statement is
not allowed.
For DML statements, setting this option may cause some error
reporting to be deferred until commit time (e.g. validation of
unique constraints). Given this, successful execution of a DML
statement should not be assumed until the transaction commits.

:type retry: :class:`~google.api_core.retry.Retry`
:param retry: (Optional) The retry settings for this request.

Expand Down Expand Up @@ -558,6 +587,7 @@ def batch_update(
statements=parsed,
seqno=seqno,
request_options=request_options,
last_statements=last_statement,
)

method = functools.partial(
Expand Down
57 changes: 57 additions & 0 deletions tests/mockserver_tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
ExecuteSqlRequest,
BeginTransactionRequest,
TransactionOptions,
ExecuteBatchDmlRequest,
TypeCode,
)
from google.cloud.spanner_v1.transaction import Transaction
from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer

from tests.mockserver_tests.mock_server_test_base import (
Expand All @@ -29,6 +32,7 @@
add_update_count,
add_error,
unavailable_status,
add_single_result,
)


Expand Down Expand Up @@ -107,3 +111,56 @@ def test_execute_streaming_sql_unavailable(self):
# The ExecuteStreamingSql call should be retried.
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))

def test_last_statement_update(self):
sql = "update my_table set my_col=1 where id=2"
add_update_count(sql, 1)
self.database.run_in_transaction(
lambda transaction: transaction.execute_update(sql, last_statement=True)
)
requests = list(
filter(
lambda msg: isinstance(msg, ExecuteSqlRequest),
self.spanner_service.requests,
)
)
self.assertEqual(1, len(requests), msg=requests)
self.assertTrue(requests[0].last_statement, requests[0])

def test_last_statement_batch_update(self):
sql = "update my_table set my_col=1 where id=2"
add_update_count(sql, 1)
self.database.run_in_transaction(
lambda transaction: transaction.batch_update(
[sql, sql], last_statement=True
)
)
requests = list(
filter(
lambda msg: isinstance(msg, ExecuteBatchDmlRequest),
self.spanner_service.requests,
)
)
self.assertEqual(1, len(requests), msg=requests)
self.assertTrue(requests[0].last_statements, requests[0])

def test_last_statement_query(self):
sql = "insert into my_table (value) values ('One') then return id"
add_single_result(sql, "c", TypeCode.INT64, [("1",)])
self.database.run_in_transaction(
lambda transaction: _execute_query(transaction, sql)
)
requests = list(
filter(
lambda msg: isinstance(msg, ExecuteSqlRequest),
self.spanner_service.requests,
)
)
self.assertEqual(1, len(requests), msg=requests)
self.assertTrue(requests[0].last_statement, requests[0])


def _execute_query(transaction: Transaction, sql: str):
rows = transaction.execute_sql(sql, last_statement=True)
for _ in rows:
pass
127 changes: 127 additions & 0 deletions tests/mockserver_tests/test_dbapi_autocommit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright 2025 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud.spanner_dbapi import Connection
from google.cloud.spanner_v1 import (
ExecuteSqlRequest,
TypeCode,
CommitRequest,
ExecuteBatchDmlRequest,
)
from tests.mockserver_tests.mock_server_test_base import (
MockServerTestBase,
add_single_result,
add_update_count,
)


class TestDbapiAutoCommit(MockServerTestBase):
@classmethod
def setup_class(cls):
super().setup_class()
add_single_result(
"select name from singers", "name", TypeCode.STRING, [("Some Singer",)]
)
add_update_count("insert into singers (id, name) values (1, 'Some Singer')", 1)

def test_select_autocommit(self):
connection = Connection(self.instance, self.database)
connection.autocommit = True
with connection.cursor() as cursor:
cursor.execute("select name from singers")
result_list = cursor.fetchall()
for _ in result_list:
pass
requests = list(
filter(
lambda msg: isinstance(msg, ExecuteSqlRequest),
self.spanner_service.requests,
)
)
self.assertEqual(1, len(requests))
self.assertFalse(requests[0].last_statement, requests[0])
self.assertIsNotNone(requests[0].transaction, requests[0])
self.assertIsNotNone(requests[0].transaction.single_use, requests[0])
self.assertTrue(requests[0].transaction.single_use.read_only, requests[0])

def test_dml_autocommit(self):
connection = Connection(self.instance, self.database)
connection.autocommit = True
with connection.cursor() as cursor:
cursor.execute("insert into singers (id, name) values (1, 'Some Singer')")
self.assertEqual(1, cursor.rowcount)
requests = list(
filter(
lambda msg: isinstance(msg, ExecuteSqlRequest),
self.spanner_service.requests,
)
)
self.assertEqual(1, len(requests))
self.assertTrue(requests[0].last_statement, requests[0])
commit_requests = list(
filter(
lambda msg: isinstance(msg, CommitRequest),
self.spanner_service.requests,
)
)
self.assertEqual(1, len(commit_requests))

def test_executemany_autocommit(self):
connection = Connection(self.instance, self.database)
connection.autocommit = True
with connection.cursor() as cursor:
cursor.executemany(
"insert into singers (id, name) values (1, 'Some Singer')", [(), ()]
)
self.assertEqual(2, cursor.rowcount)
requests = list(
filter(
lambda msg: isinstance(msg, ExecuteBatchDmlRequest),
self.spanner_service.requests,
)
)
self.assertEqual(1, len(requests))
self.assertTrue(requests[0].last_statements, requests[0])
commit_requests = list(
filter(
lambda msg: isinstance(msg, CommitRequest),
self.spanner_service.requests,
)
)
self.assertEqual(1, len(commit_requests))

def test_batch_dml_autocommit(self):
connection = Connection(self.instance, self.database)
connection.autocommit = True
with connection.cursor() as cursor:
cursor.execute("start batch dml")
cursor.execute("insert into singers (id, name) values (1, 'Some Singer')")
cursor.execute("insert into singers (id, name) values (1, 'Some Singer')")
cursor.execute("run batch")
self.assertEqual(2, cursor.rowcount)
requests = list(
filter(
lambda msg: isinstance(msg, ExecuteBatchDmlRequest),
self.spanner_service.requests,
)
)
self.assertEqual(1, len(requests))
self.assertTrue(requests[0].last_statements, requests[0])
commit_requests = list(
filter(
lambda msg: isinstance(msg, CommitRequest),
self.spanner_service.requests,
)
)
self.assertEqual(1, len(commit_requests))
12 changes: 8 additions & 4 deletions tests/unit/spanner_dbapi/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def test_do_batch_update(self):
("DELETE FROM table WHERE col1 = @a0", {"a0": 1}, {"a0": INT64}),
("DELETE FROM table WHERE col1 = @a0", {"a0": 2}, {"a0": INT64}),
("DELETE FROM table WHERE col1 = @a0", {"a0": 3}, {"a0": INT64}),
]
],
last_statement=True,
)
self.assertEqual(cursor._row_count, 3)

Expand Down Expand Up @@ -539,7 +540,8 @@ def test_executemany_delete_batch_autocommit(self):
("DELETE FROM table WHERE col1 = @a0", {"a0": 1}, {"a0": INT64}),
("DELETE FROM table WHERE col1 = @a0", {"a0": 2}, {"a0": INT64}),
("DELETE FROM table WHERE col1 = @a0", {"a0": 3}, {"a0": INT64}),
]
],
last_statement=True,
)

def test_executemany_update_batch_autocommit(self):
Expand Down Expand Up @@ -582,7 +584,8 @@ def test_executemany_update_batch_autocommit(self):
{"a0": 3, "a1": "c"},
{"a0": INT64, "a1": STRING},
),
]
],
last_statement=True,
)

def test_executemany_insert_batch_non_autocommit(self):
Expand Down Expand Up @@ -659,7 +662,8 @@ def test_executemany_insert_batch_autocommit(self):
{"a0": 5, "a1": 6, "a2": 7, "a3": 8},
{"a0": INT64, "a1": INT64, "a2": INT64, "a3": INT64},
),
]
],
last_statement=True,
)
transaction.commit.assert_called_once()

Expand Down
Loading