Skip to content

Commit 108d965

Browse files
authored
feat: support Partitioned DML (#541)
Adds tests and samples for executing Partitioned DML using SQLAlchemy. Fixes #496
1 parent 0f33c1c commit 108d965

File tree

4 files changed

+106
-7
lines changed

4 files changed

+106
-7
lines changed

samples/partitioned_dml_sample.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
16+
from sqlalchemy import create_engine, text
17+
18+
from sample_helper import run_sample
19+
20+
# Shows how to use Partitioned DML using SQLAlchemy and Spanner.
21+
def partitioned_dml_sample():
22+
engine = create_engine(
23+
"spanner:///projects/sample-project/"
24+
"instances/sample-instance/"
25+
"databases/sample-database",
26+
echo=True,
27+
)
28+
# Get a connection in auto-commit mode.
29+
# Partitioned DML can only be executed in auto-commit mode, as each
30+
# Partitioned DML transaction can only consist of one statement.
31+
with engine.connect().execution_options(isolation_level="AUTOCOMMIT") as connection:
32+
# Set the DML mode to PARTITIONED_NON_ATOMIC.
33+
connection.connection.set_autocommit_dml_mode(
34+
AutocommitDmlMode.PARTITIONED_NON_ATOMIC
35+
)
36+
# Use a bulk update statement to back-fill a column.
37+
lower_bound_rowcount = connection.execute(
38+
text("update venues set active=true where active is null")
39+
).rowcount
40+
# Partitioned DML returns the lower-bound update count.
41+
print("Updated at least ", lower_bound_rowcount, " venue records")
42+
43+
44+
if __name__ == "__main__":
45+
run_sample(partitioned_dml_sample)

test/mockserver_tests/mock_server_test_base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
1415
from sqlalchemy import Engine, create_engine
1516
from sqlalchemy.testing.plugin.plugin_base import fixtures
1617
import google.cloud.spanner_v1.types.type as spanner_type
@@ -35,6 +36,17 @@ def add_result(sql: str, result: ResultSet):
3536
MockServerTestBase.spanner_service.mock_spanner.add_result(sql, result)
3637

3738

39+
def add_update_count(
40+
sql: str, count: int, dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL
41+
):
42+
if dml_mode == AutocommitDmlMode.PARTITIONED_NON_ATOMIC:
43+
stats = dict(row_count_lower_bound=count)
44+
else:
45+
stats = dict(row_count_exact=count)
46+
result = result_set.ResultSet(dict(stats=result_set.ResultSetStats(stats)))
47+
add_result(sql, result)
48+
49+
3850
def add_select1_result():
3951
result = result_set.ResultSet(
4052
dict(

test/mockserver_tests/mock_spanner.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from google.cloud.spanner_v1 import TransactionOptions, ResultSetMetadata
15+
from google.cloud.spanner_v1 import (
16+
TransactionOptions,
17+
ResultSetMetadata,
18+
ExecuteSqlRequest,
19+
)
1620
from google.protobuf import empty_pb2
1721
import test.mockserver_tests.spanner_pb2_grpc as spanner_grpc
1822
import test.mockserver_tests.spanner_database_admin_pb2_grpc as database_admin_grpc
@@ -40,23 +44,25 @@ def get_result(self, sql: str) -> result_set.ResultSet:
4044
return result
4145

4246
def get_result_as_partial_result_sets(
43-
self, sql: str
47+
self, sql: str, started_transaction: transaction.Transaction
4448
) -> [result_set.PartialResultSet]:
4549
result: result_set.ResultSet = self.get_result(sql)
4650
partials = []
4751
first = True
4852
if len(result.rows) == 0:
4953
partial = result_set.PartialResultSet()
50-
partial.metadata = result.metadata
54+
partial.metadata = ResultSetMetadata(result.metadata)
5155
partials.append(partial)
5256
else:
5357
for row in result.rows:
5458
partial = result_set.PartialResultSet()
5559
if first:
56-
partial.metadata = result.metadata
60+
partial.metadata = ResultSetMetadata(result.metadata)
5761
partial.values.extend(row)
5862
partials.append(partial)
5963
partials[len(partials) - 1].stats = result.stats
64+
if started_transaction:
65+
partials[0].metadata.transaction = started_transaction
6066
return partials
6167

6268

@@ -120,9 +126,16 @@ def ExecuteSql(self, request, context):
120126
self._requests.append(request)
121127
return result_set.ResultSet()
122128

123-
def ExecuteStreamingSql(self, request, context):
129+
def ExecuteStreamingSql(self, request: ExecuteSqlRequest, context):
124130
self._requests.append(request)
125-
partials = self.mock_spanner.get_result_as_partial_result_sets(request.sql)
131+
started_transaction = None
132+
if not request.transaction.begin == TransactionOptions():
133+
started_transaction = self.__create_transaction(
134+
request.session, request.transaction.begin
135+
)
136+
partials = self.mock_spanner.get_result_as_partial_result_sets(
137+
request.sql, started_transaction
138+
)
126139
for result in partials:
127140
yield result
128141

test/mockserver_tests/test_basics.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,17 @@
1313
# limitations under the License.
1414

1515
from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest
16-
from sqlalchemy import create_engine, select, MetaData, Table, Column, Integer, String
16+
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
17+
from sqlalchemy import (
18+
create_engine,
19+
select,
20+
MetaData,
21+
Table,
22+
Column,
23+
Integer,
24+
String,
25+
text,
26+
)
1727
from sqlalchemy.testing import eq_, is_instance_of
1828
from google.cloud.spanner_v1 import (
1929
FixedSizePool,
@@ -26,6 +36,7 @@
2636
MockServerTestBase,
2737
add_select1_result,
2838
add_result,
39+
add_update_count,
2940
)
3041

3142

@@ -127,3 +138,21 @@ def test_create_multiple_tables(self):
127138
"\n) PRIMARY KEY (id)",
128139
requests[0].statements[i],
129140
)
141+
142+
def test_partitioned_dml(self):
143+
sql = "UPDATE singers SET checked=true WHERE active = true"
144+
add_update_count(sql, 100, AutocommitDmlMode.PARTITIONED_NON_ATOMIC)
145+
engine = create_engine(
146+
"spanner:///projects/p/instances/i/databases/d",
147+
connect_args={"client": self.client, "pool": PingingPool(size=10)},
148+
)
149+
# TODO: Support autocommit_dml_mode as a connection variable in execution
150+
# options.
151+
with engine.connect().execution_options(
152+
isolation_level="AUTOCOMMIT"
153+
) as connection:
154+
connection.connection.set_autocommit_dml_mode(
155+
AutocommitDmlMode.PARTITIONED_NON_ATOMIC
156+
)
157+
results = connection.execute(text(sql)).rowcount
158+
eq_(100, results)

0 commit comments

Comments
 (0)