Skip to content

Commit ac31c75

Browse files
committed
feat(spanner): add lazy decode to partitioned query
1 parent a041042 commit ac31c75

File tree

4 files changed

+218
-5
lines changed

4 files changed

+218
-5
lines changed

google/cloud/spanner_v1/database.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1532,6 +1532,14 @@ def to_dict(self):
15321532
"transaction_id": snapshot._transaction_id,
15331533
}
15341534

1535+
def __enter__(self):
1536+
"""Begin ``with`` block."""
1537+
return self
1538+
1539+
def __exit__(self, exc_type, exc_val, exc_tb):
1540+
"""End ``with`` block."""
1541+
self.close()
1542+
15351543
@property
15361544
def observability_options(self):
15371545
return getattr(self._database, "observability_options", {})
@@ -1703,6 +1711,7 @@ def process_read_batch(
17031711
*,
17041712
retry=gapic_v1.method.DEFAULT,
17051713
timeout=gapic_v1.method.DEFAULT,
1714+
lazy_decode=False,
17061715
):
17071716
"""Process a single, partitioned read.
17081717
@@ -1844,6 +1853,7 @@ def process_query_batch(
18441853
self,
18451854
batch,
18461855
*,
1856+
lazy_decode: bool = False,
18471857
retry=gapic_v1.method.DEFAULT,
18481858
timeout=gapic_v1.method.DEFAULT,
18491859
):
@@ -1854,6 +1864,13 @@ def process_query_batch(
18541864
one of the mappings returned from an earlier call to
18551865
:meth:`generate_query_batches`.
18561866
1867+
:type lazy_decode: bool
1868+
:param lazy_decode:
1869+
(Optional) If this argument is set to ``true``, the iterator
1870+
returns the underlying protobuf values instead of decoded Python
1871+
objects. This reduces the time that is needed to iterate through
1872+
large result sets.
1873+
18571874
:type retry: :class:`~google.api_core.retry.Retry`
18581875
:param retry: (Optional) The retry settings for this request.
18591876
@@ -1870,6 +1887,7 @@ def process_query_batch(
18701887
return self._get_snapshot().execute_sql(
18711888
partition=batch["partition"],
18721889
**batch["query"],
1890+
lazy_decode=lazy_decode,
18731891
retry=retry,
18741892
timeout=timeout,
18751893
)
@@ -1883,6 +1901,7 @@ def run_partitioned_query(
18831901
max_partitions=None,
18841902
query_options=None,
18851903
data_boost_enabled=False,
1904+
lazy_decode=False,
18861905
):
18871906
"""Start a partitioned query operation to get list of partitions and
18881907
then executes each partition on a separate thread
@@ -1943,7 +1962,7 @@ def run_partitioned_query(
19431962
data_boost_enabled,
19441963
)
19451964
)
1946-
return MergedResultSet(self, partitions, 0)
1965+
return MergedResultSet(self, partitions, 0, lazy_decode=lazy_decode)
19471966

19481967
def process(self, batch):
19491968
"""Process a single, partitioned query or read.

google/cloud/spanner_v1/merged_result_set.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
2121
from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture
22+
from google.cloud.spanner_v1._helpers import _get_type_decoder, _parse_nullable
23+
from google.cloud.spanner_v1.streamed import StreamedResultSet
2224

2325
if TYPE_CHECKING:
2426
from google.cloud.spanner_v1.database import BatchSnapshot
@@ -33,10 +35,11 @@ class PartitionExecutor:
3335
rows in the queue
3436
"""
3537

36-
def __init__(self, batch_snapshot, partition_id, merged_result_set):
38+
def __init__(self, batch_snapshot, partition_id, merged_result_set, lazy_decode=False):
3739
self._batch_snapshot: BatchSnapshot = batch_snapshot
3840
self._partition_id = partition_id
3941
self._merged_result_set: MergedResultSet = merged_result_set
42+
self._lazy_decode = lazy_decode
4043
self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue
4144

4245
def run(self):
@@ -52,7 +55,9 @@ def run(self):
5255
def __run(self):
5356
results = None
5457
try:
55-
results = self._batch_snapshot.process_query_batch(self._partition_id)
58+
results = self._batch_snapshot.process_query_batch(
59+
self._partition_id, lazy_decode=self._lazy_decode
60+
)
5661
for row in results:
5762
if self._merged_result_set._metadata is None:
5863
self._set_metadata(results)
@@ -94,7 +99,7 @@ class MergedResultSet:
9499
records in the MergedResultSet is not guaranteed.
95100
"""
96101

97-
def __init__(self, batch_snapshot, partition_ids, max_parallelism):
102+
def __init__(self, batch_snapshot, partition_ids, max_parallelism, lazy_decode=False):
98103
self._exception = None
99104
self._metadata = None
100105
self.metadata_event = Event()
@@ -110,7 +115,7 @@ def __init__(self, batch_snapshot, partition_ids, max_parallelism):
110115
partition_executors = []
111116
for partition_id in partition_ids:
112117
partition_executors.append(
113-
PartitionExecutor(batch_snapshot, partition_id, self)
118+
PartitionExecutor(batch_snapshot, partition_id, self, lazy_decode)
114119
)
115120
executor = ThreadPoolExecutor(max_workers=parallelism)
116121
for partition_executor in partition_executors:
@@ -144,3 +149,43 @@ def metadata(self):
144149
def stats(self):
145150
# TODO: Implement
146151
return None
152+
153+
def decode_row(self, row: []) -> []:
154+
"""Decodes a row from protobuf values to Python objects. This function
155+
should only be called for result sets that use ``lazy_decoding=True``.
156+
The array that is returned by this function is the same as the array
157+
that would have been returned by the rows iterator if ``lazy_decoding=False``.
158+
159+
:returns: an array containing the decoded values of all the columns in the given row
160+
"""
161+
if not isinstance(row, (list, tuple)):
162+
raise TypeError("row must be an array of protobuf values")
163+
decoders = self._decoders
164+
return [
165+
_parse_nullable(row[index], decoders[index]) for index in range(len(row))
166+
]
167+
168+
def decode_column(self, row: [], column_index: int):
169+
"""Decodes a column from a protobuf value to a Python object. This function
170+
should only be called for result sets that use ``lazy_decoding=True``.
171+
The object that is returned by this function is the same as the object
172+
that would have been returned by the rows iterator if ``lazy_decoding=False``.
173+
174+
:returns: the decoded column value
175+
"""
176+
if not isinstance(row, (list, tuple)):
177+
raise TypeError("row must be an array of protobuf values")
178+
decoders = self._decoders
179+
return _parse_nullable(row[column_index], decoders[column_index])
180+
181+
@property
182+
def _decoders(self):
183+
if self.metadata is None:
184+
raise ValueError("iterator not started")
185+
# This logic is borrowed from StreamedResultSet._decoders.
186+
# We assume `column_info` is None as it's not available here and
187+
# is only used for custom decoders, which are not typical for this path.
188+
return [
189+
_get_type_decoder(field.type_, field.name, None)
190+
for field in self.metadata.row_type.fields
191+
]

tests/unit/test_database.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3141,6 +3141,7 @@ def test_process_query_batch(self):
31413141
params=params,
31423142
param_types=param_types,
31433143
partition=token,
3144+
lazy_decode=False,
31443145
retry=gapic_v1.method.DEFAULT,
31453146
timeout=gapic_v1.method.DEFAULT,
31463147
)
@@ -3170,6 +3171,7 @@ def test_process_query_batch_w_retry_timeout(self):
31703171
params=params,
31713172
param_types=param_types,
31723173
partition=token,
3174+
lazy_decode=False,
31733175
retry=retry,
31743176
timeout=2.0,
31753177
)
@@ -3193,11 +3195,24 @@ def test_process_query_batch_w_directed_read_options(self):
31933195
snapshot.execute_sql.assert_called_once_with(
31943196
sql=sql,
31953197
partition=token,
3198+
lazy_decode=False,
31963199
retry=gapic_v1.method.DEFAULT,
31973200
timeout=gapic_v1.method.DEFAULT,
31983201
directed_read_options=DIRECTED_READ_OPTIONS,
31993202
)
32003203

3204+
def test_context_manager(self):
3205+
database = self._make_database()
3206+
batch_txn = self._make_one(database)
3207+
session = batch_txn._session = self._make_session()
3208+
session.is_multiplexed = False
3209+
3210+
with batch_txn:
3211+
pass
3212+
3213+
session.delete.assert_called_once_with()
3214+
3215+
32013216
def test_close_wo_session(self):
32023217
database = self._make_database()
32033218
batch_txn = self._make_one(database)
@@ -3292,6 +3307,7 @@ def test_process_w_query_batch(self):
32923307
params=params,
32933308
param_types=param_types,
32943309
partition=token,
3310+
lazy_decode=False,
32953311
retry=gapic_v1.method.DEFAULT,
32963312
timeout=gapic_v1.method.DEFAULT,
32973313
)
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import mock
18+
19+
20+
class TestMergedResultSet(unittest.TestCase):
21+
def _get_target_class(self):
22+
from google.cloud.spanner_v1.merged_result_set import MergedResultSet
23+
24+
return MergedResultSet
25+
26+
def _make_one(self, *args, **kwargs):
27+
# We are not creating a real MergedResultSet, just an object
28+
# that can be used for testing the decode methods.
29+
# The constructor starts threads, which we want to avoid.
30+
klass = self._get_target_class()
31+
obj = super(klass, klass).__new__(klass)
32+
# Manually initialize attributes that __init__ would set.
33+
from threading import Event, Lock
34+
35+
obj.metadata_event = Event()
36+
obj.metadata_lock = Lock()
37+
obj._metadata = None
38+
return obj
39+
@staticmethod
40+
def _make_value(value):
41+
from google.cloud.spanner_v1._helpers import _make_value_pb
42+
43+
return _make_value_pb(value)
44+
45+
@staticmethod
46+
def _make_scalar_field(name, type_):
47+
from google.cloud.spanner_v1 import StructType
48+
from google.cloud.spanner_v1 import Type
49+
50+
return StructType.Field(name=name, type_=Type(code=type_))
51+
52+
@staticmethod
53+
def _make_result_set_metadata(fields=()):
54+
from google.cloud.spanner_v1 import ResultSetMetadata
55+
from google.cloud.spanner_v1 import StructType
56+
57+
metadata = ResultSetMetadata(row_type=StructType(fields=[]))
58+
for field in fields:
59+
metadata.row_type.fields.append(field)
60+
return metadata
61+
62+
def test_decoders_property_no_metadata(self):
63+
merged = self._make_one()
64+
merged._metadata = None
65+
merged.metadata_event.set()
66+
with self.assertRaises(ValueError):
67+
getattr(merged, "_decoders")
68+
69+
def test_decoders_property_with_metadata(self):
70+
from google.cloud.spanner_v1 import TypeCode
71+
72+
merged = self._make_one()
73+
fields = [
74+
self._make_scalar_field("full_name", TypeCode.STRING),
75+
self._make_scalar_field("age", TypeCode.INT64),
76+
]
77+
merged._metadata = self._make_result_set_metadata(fields)
78+
merged.metadata_event.set()
79+
80+
decoders = merged._decoders
81+
self.assertEqual(len(decoders), 2)
82+
self.assertTrue(callable(decoders[0]))
83+
self.assertTrue(callable(decoders[1]))
84+
85+
def test_decode_row(self):
86+
from google.cloud.spanner_v1 import TypeCode
87+
88+
merged = self._make_one()
89+
fields = [
90+
self._make_scalar_field("full_name", TypeCode.STRING),
91+
self._make_scalar_field("age", TypeCode.INT64),
92+
]
93+
merged._metadata = self._make_result_set_metadata(fields)
94+
merged.metadata_event.set()
95+
96+
raw_row = [self._make_value("Phred"), self._make_value(42)]
97+
decoded_row = merged.decode_row(raw_row)
98+
99+
self.assertEqual(decoded_row, ["Phred", 42])
100+
101+
def test_decode_row_type_error(self):
102+
merged = self._make_one()
103+
# The _decoders property requires metadata, even for a type error check.
104+
merged._metadata = self._make_result_set_metadata()
105+
merged.metadata_event.set()
106+
with self.assertRaises(TypeError):
107+
merged.decode_row("not a list")
108+
109+
def test_decode_column(self):
110+
from google.cloud.spanner_v1 import TypeCode
111+
112+
merged = self._make_one()
113+
fields = [
114+
self._make_scalar_field("full_name", TypeCode.STRING),
115+
self._make_scalar_field("age", TypeCode.INT64),
116+
]
117+
merged._metadata = self._make_result_set_metadata(fields)
118+
merged.metadata_event.set()
119+
120+
raw_row = [self._make_value("Phred"), self._make_value(42)]
121+
decoded_name = merged.decode_column(raw_row, 0)
122+
decoded_age = merged.decode_column(raw_row, 1)
123+
124+
self.assertEqual(decoded_name, "Phred")
125+
self.assertEqual(decoded_age, 42)
126+
127+
def test_decode_column_type_error(self):
128+
merged = self._make_one()
129+
# The _decoders property requires metadata, even for a type error check.
130+
merged._metadata = self._make_result_set_metadata()
131+
merged.metadata_event.set()
132+
with self.assertRaises(TypeError):
133+
merged.decode_column("not a list", 0)

0 commit comments

Comments
 (0)