Skip to content

Commit a09961b

Browse files
authored
feat(spanner): add lazy decode to partitioned query (#1411)
1 parent 58e2406 commit a09961b

File tree

4 files changed

+199
-5
lines changed

4 files changed

+199
-5
lines changed

google/cloud/spanner_v1/database.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1532,6 +1532,14 @@ def to_dict(self):
15321532
"transaction_id": snapshot._transaction_id,
15331533
}
15341534

1535+
def __enter__(self):
1536+
"""Begin ``with`` block."""
1537+
return self
1538+
1539+
def __exit__(self, exc_type, exc_val, exc_tb):
1540+
"""End ``with`` block."""
1541+
self.close()
1542+
15351543
@property
15361544
def observability_options(self):
15371545
return getattr(self._database, "observability_options", {})
@@ -1703,6 +1711,7 @@ def process_read_batch(
17031711
*,
17041712
retry=gapic_v1.method.DEFAULT,
17051713
timeout=gapic_v1.method.DEFAULT,
1714+
lazy_decode=False,
17061715
):
17071716
"""Process a single, partitioned read.
17081717
@@ -1717,6 +1726,14 @@ def process_read_batch(
17171726
:type timeout: float
17181727
:param timeout: (Optional) The timeout for this request.
17191728
1729+
:type lazy_decode: bool
1730+
:param lazy_decode:
1731+
(Optional) If this argument is set to ``true``, the iterator
1732+
returns the underlying protobuf values instead of decoded Python
1733+
objects. This reduces the time that is needed to iterate through
1734+
large result sets. The application is responsible for decoding
1735+
the data that is needed.
1736+
17201737
17211738
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
17221739
:returns: a result set instance which can be used to consume rows.
@@ -1844,6 +1861,7 @@ def process_query_batch(
18441861
self,
18451862
batch,
18461863
*,
1864+
lazy_decode: bool = False,
18471865
retry=gapic_v1.method.DEFAULT,
18481866
timeout=gapic_v1.method.DEFAULT,
18491867
):
@@ -1854,6 +1872,13 @@ def process_query_batch(
18541872
one of the mappings returned from an earlier call to
18551873
:meth:`generate_query_batches`.
18561874
1875+
:type lazy_decode: bool
1876+
:param lazy_decode:
1877+
(Optional) If this argument is set to ``true``, the iterator
1878+
returns the underlying protobuf values instead of decoded Python
1879+
objects. This reduces the time that is needed to iterate through
1880+
large result sets.
1881+
18571882
:type retry: :class:`~google.api_core.retry.Retry`
18581883
:param retry: (Optional) The retry settings for this request.
18591884
@@ -1870,6 +1895,7 @@ def process_query_batch(
18701895
return self._get_snapshot().execute_sql(
18711896
partition=batch["partition"],
18721897
**batch["query"],
1898+
lazy_decode=lazy_decode,
18731899
retry=retry,
18741900
timeout=timeout,
18751901
)
@@ -1883,6 +1909,7 @@ def run_partitioned_query(
18831909
max_partitions=None,
18841910
query_options=None,
18851911
data_boost_enabled=False,
1912+
lazy_decode=False,
18861913
):
18871914
"""Start a partitioned query operation to get list of partitions and
18881915
then executes each partition on a separate thread
@@ -1943,7 +1970,7 @@ def run_partitioned_query(
19431970
data_boost_enabled,
19441971
)
19451972
)
1946-
return MergedResultSet(self, partitions, 0)
1973+
return MergedResultSet(self, partitions, 0, lazy_decode=lazy_decode)
19471974

19481975
def process(self, batch):
19491976
"""Process a single, partitioned query or read.

google/cloud/spanner_v1/merged_result_set.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,13 @@ class PartitionExecutor:
3333
rows in the queue
3434
"""
3535

36-
def __init__(self, batch_snapshot, partition_id, merged_result_set):
36+
def __init__(
37+
self, batch_snapshot, partition_id, merged_result_set, lazy_decode=False
38+
):
3739
self._batch_snapshot: BatchSnapshot = batch_snapshot
3840
self._partition_id = partition_id
3941
self._merged_result_set: MergedResultSet = merged_result_set
42+
self._lazy_decode = lazy_decode
4043
self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue
4144

4245
def run(self):
@@ -52,7 +55,9 @@ def run(self):
5255
def __run(self):
5356
results = None
5457
try:
55-
results = self._batch_snapshot.process_query_batch(self._partition_id)
58+
results = self._batch_snapshot.process_query_batch(
59+
self._partition_id, lazy_decode=self._lazy_decode
60+
)
5661
for row in results:
5762
if self._merged_result_set._metadata is None:
5863
self._set_metadata(results)
@@ -75,6 +80,7 @@ def _set_metadata(self, results, is_exception=False):
7580
try:
7681
if not is_exception:
7782
self._merged_result_set._metadata = results.metadata
83+
self._merged_result_set._result_set = results
7884
finally:
7985
self._merged_result_set.metadata_lock.release()
8086
self._merged_result_set.metadata_event.set()
@@ -94,7 +100,10 @@ class MergedResultSet:
94100
records in the MergedResultSet is not guaranteed.
95101
"""
96102

97-
def __init__(self, batch_snapshot, partition_ids, max_parallelism):
103+
def __init__(
104+
self, batch_snapshot, partition_ids, max_parallelism, lazy_decode=False
105+
):
106+
self._result_set = None
98107
self._exception = None
99108
self._metadata = None
100109
self.metadata_event = Event()
@@ -110,7 +119,7 @@ def __init__(self, batch_snapshot, partition_ids, max_parallelism):
110119
partition_executors = []
111120
for partition_id in partition_ids:
112121
partition_executors.append(
113-
PartitionExecutor(batch_snapshot, partition_id, self)
122+
PartitionExecutor(batch_snapshot, partition_id, self, lazy_decode)
114123
)
115124
executor = ThreadPoolExecutor(max_workers=parallelism)
116125
for partition_executor in partition_executors:
@@ -144,3 +153,27 @@ def metadata(self):
144153
def stats(self):
145154
# TODO: Implement
146155
return None
156+
157+
def decode_row(self, row: []) -> []:
158+
"""Decodes a row from protobuf values to Python objects. This function
159+
should only be called for result sets that use ``lazy_decoding=True``.
160+
The array that is returned by this function is the same as the array
161+
that would have been returned by the rows iterator if ``lazy_decoding=False``.
162+
163+
:returns: an array containing the decoded values of all the columns in the given row
164+
"""
165+
if self._result_set is None:
166+
raise ValueError("iterator not started")
167+
return self._result_set.decode_row(row)
168+
169+
def decode_column(self, row: [], column_index: int):
170+
"""Decodes a column from a protobuf value to a Python object. This function
171+
should only be called for result sets that use ``lazy_decoding=True``.
172+
The object that is returned by this function is the same as the object
173+
that would have been returned by the rows iterator if ``lazy_decoding=False``.
174+
175+
:returns: the decoded column value
176+
"""
177+
if self._result_set is None:
178+
raise ValueError("iterator not started")
179+
return self._result_set.decode_column(row, column_index)

tests/unit/test_database.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3141,6 +3141,7 @@ def test_process_query_batch(self):
31413141
params=params,
31423142
param_types=param_types,
31433143
partition=token,
3144+
lazy_decode=False,
31443145
retry=gapic_v1.method.DEFAULT,
31453146
timeout=gapic_v1.method.DEFAULT,
31463147
)
@@ -3170,6 +3171,7 @@ def test_process_query_batch_w_retry_timeout(self):
31703171
params=params,
31713172
param_types=param_types,
31723173
partition=token,
3174+
lazy_decode=False,
31733175
retry=retry,
31743176
timeout=2.0,
31753177
)
@@ -3193,11 +3195,23 @@ def test_process_query_batch_w_directed_read_options(self):
31933195
snapshot.execute_sql.assert_called_once_with(
31943196
sql=sql,
31953197
partition=token,
3198+
lazy_decode=False,
31963199
retry=gapic_v1.method.DEFAULT,
31973200
timeout=gapic_v1.method.DEFAULT,
31983201
directed_read_options=DIRECTED_READ_OPTIONS,
31993202
)
32003203

3204+
def test_context_manager(self):
3205+
database = self._make_database()
3206+
batch_txn = self._make_one(database)
3207+
session = batch_txn._session = self._make_session()
3208+
session.is_multiplexed = False
3209+
3210+
with batch_txn:
3211+
pass
3212+
3213+
session.delete.assert_called_once_with()
3214+
32013215
def test_close_wo_session(self):
32023216
database = self._make_database()
32033217
batch_txn = self._make_one(database)
@@ -3292,6 +3306,7 @@ def test_process_w_query_batch(self):
32923306
params=params,
32933307
param_types=param_types,
32943308
partition=token,
3309+
lazy_decode=False,
32953310
retry=gapic_v1.method.DEFAULT,
32963311
timeout=gapic_v1.method.DEFAULT,
32973312
)
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2025 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import mock
18+
from google.cloud.spanner_v1.streamed import StreamedResultSet
19+
20+
21+
class TestMergedResultSet(unittest.TestCase):
22+
def _get_target_class(self):
23+
from google.cloud.spanner_v1.merged_result_set import MergedResultSet
24+
25+
return MergedResultSet
26+
27+
def _make_one(self, *args, **kwargs):
28+
klass = self._get_target_class()
29+
obj = super(klass, klass).__new__(klass)
30+
from threading import Event, Lock
31+
32+
obj.metadata_event = Event()
33+
obj.metadata_lock = Lock()
34+
obj._metadata = None
35+
obj._result_set = None
36+
return obj
37+
38+
@staticmethod
39+
def _make_value(value):
40+
from google.cloud.spanner_v1._helpers import _make_value_pb
41+
42+
return _make_value_pb(value)
43+
44+
@staticmethod
45+
def _make_scalar_field(name, type_):
46+
from google.cloud.spanner_v1 import StructType
47+
from google.cloud.spanner_v1 import Type
48+
49+
return StructType.Field(name=name, type_=Type(code=type_))
50+
51+
@staticmethod
52+
def _make_result_set_metadata(fields=()):
53+
from google.cloud.spanner_v1 import ResultSetMetadata
54+
from google.cloud.spanner_v1 import StructType
55+
56+
metadata = ResultSetMetadata(row_type=StructType(fields=[]))
57+
for field in fields:
58+
metadata.row_type.fields.append(field)
59+
return metadata
60+
61+
def test_stats_property(self):
62+
merged = self._make_one()
63+
# The property is currently not implemented, so it should just return None.
64+
self.assertIsNone(merged.stats)
65+
66+
def test_decode_row(self):
67+
merged = self._make_one()
68+
69+
merged._result_set = mock.create_autospec(StreamedResultSet, instance=True)
70+
merged._result_set.decode_row.return_value = ["Phred", 42]
71+
72+
raw_row = [self._make_value("Phred"), self._make_value(42)]
73+
decoded_row = merged.decode_row(raw_row)
74+
75+
self.assertEqual(decoded_row, ["Phred", 42])
76+
merged._result_set.decode_row.assert_called_once_with(raw_row)
77+
78+
def test_decode_row_no_result_set(self):
79+
merged = self._make_one()
80+
merged._result_set = None
81+
with self.assertRaisesRegex(ValueError, "iterator not started"):
82+
merged.decode_row([])
83+
84+
def test_decode_row_type_error(self):
85+
merged = self._make_one()
86+
merged._result_set = mock.create_autospec(StreamedResultSet, instance=True)
87+
merged._result_set.decode_row.side_effect = TypeError
88+
89+
with self.assertRaises(TypeError):
90+
merged.decode_row("not a list")
91+
92+
def test_decode_column(self):
93+
merged = self._make_one()
94+
merged._result_set = mock.create_autospec(StreamedResultSet, instance=True)
95+
merged._result_set.decode_column.side_effect = ["Phred", 42]
96+
97+
raw_row = [self._make_value("Phred"), self._make_value(42)]
98+
decoded_name = merged.decode_column(raw_row, 0)
99+
decoded_age = merged.decode_column(raw_row, 1)
100+
101+
self.assertEqual(decoded_name, "Phred")
102+
self.assertEqual(decoded_age, 42)
103+
merged._result_set.decode_column.assert_has_calls(
104+
[mock.call(raw_row, 0), mock.call(raw_row, 1)]
105+
)
106+
107+
def test_decode_column_no_result_set(self):
108+
merged = self._make_one()
109+
merged._result_set = None
110+
with self.assertRaisesRegex(ValueError, "iterator not started"):
111+
merged.decode_column([], 0)
112+
113+
def test_decode_column_type_error(self):
114+
merged = self._make_one()
115+
merged._result_set = mock.create_autospec(StreamedResultSet, instance=True)
116+
merged._result_set.decode_column.side_effect = TypeError
117+
118+
with self.assertRaises(TypeError):
119+
merged.decode_column("not a list", 0)

0 commit comments

Comments
 (0)