Skip to content

Commit 0dff8f6

Browse files
committed
feat(spanner): add lazy decode to partitioned query
1 parent a041042 commit 0dff8f6

File tree

4 files changed

+212
-5
lines changed

4 files changed

+212
-5
lines changed

google/cloud/spanner_v1/database.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1532,6 +1532,14 @@ def to_dict(self):
15321532
"transaction_id": snapshot._transaction_id,
15331533
}
15341534

1535+
def __enter__(self):
1536+
"""Begin ``with`` block."""
1537+
return self
1538+
1539+
def __exit__(self, exc_type, exc_val, exc_tb):
1540+
"""End ``with`` block."""
1541+
self.close()
1542+
15351543
@property
15361544
def observability_options(self):
15371545
return getattr(self._database, "observability_options", {})
@@ -1703,6 +1711,7 @@ def process_read_batch(
17031711
*,
17041712
retry=gapic_v1.method.DEFAULT,
17051713
timeout=gapic_v1.method.DEFAULT,
1714+
lazy_decode=False,
17061715
):
17071716
"""Process a single, partitioned read.
17081717
@@ -1844,6 +1853,7 @@ def process_query_batch(
18441853
self,
18451854
batch,
18461855
*,
1856+
lazy_decode: bool = False,
18471857
retry=gapic_v1.method.DEFAULT,
18481858
timeout=gapic_v1.method.DEFAULT,
18491859
):
@@ -1854,6 +1864,13 @@ def process_query_batch(
18541864
one of the mappings returned from an earlier call to
18551865
:meth:`generate_query_batches`.
18561866
1867+
:type lazy_decode: bool
1868+
:param lazy_decode:
1869+
(Optional) If this argument is set to ``true``, the iterator
1870+
returns the underlying protobuf values instead of decoded Python
1871+
objects. This reduces the time that is needed to iterate through
1872+
large result sets.
1873+
18571874
:type retry: :class:`~google.api_core.retry.Retry`
18581875
:param retry: (Optional) The retry settings for this request.
18591876
@@ -1870,6 +1887,7 @@ def process_query_batch(
18701887
return self._get_snapshot().execute_sql(
18711888
partition=batch["partition"],
18721889
**batch["query"],
1890+
lazy_decode=lazy_decode,
18731891
retry=retry,
18741892
timeout=timeout,
18751893
)
@@ -1883,6 +1901,7 @@ def run_partitioned_query(
18831901
max_partitions=None,
18841902
query_options=None,
18851903
data_boost_enabled=False,
1904+
lazy_decode=False,
18861905
):
18871906
"""Start a partitioned query operation to get list of partitions and
18881907
then executes each partition on a separate thread
@@ -1943,7 +1962,7 @@ def run_partitioned_query(
19431962
data_boost_enabled,
19441963
)
19451964
)
1946-
return MergedResultSet(self, partitions, 0)
1965+
return MergedResultSet(self, partitions, 0, lazy_decode=lazy_decode)
19471966

19481967
def process(self, batch):
19491968
"""Process a single, partitioned query or read.

google/cloud/spanner_v1/merged_result_set.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
2121
from google.cloud.spanner_v1.metrics.metrics_capture import MetricsCapture
22+
from google.cloud.spanner_v1._helpers import _get_type_decoder, _parse_nullable
2223

2324
if TYPE_CHECKING:
2425
from google.cloud.spanner_v1.database import BatchSnapshot
@@ -33,10 +34,13 @@ class PartitionExecutor:
3334
rows in the queue
3435
"""
3536

36-
def __init__(self, batch_snapshot, partition_id, merged_result_set):
37+
def __init__(
38+
self, batch_snapshot, partition_id, merged_result_set, lazy_decode=False
39+
):
3740
self._batch_snapshot: BatchSnapshot = batch_snapshot
3841
self._partition_id = partition_id
3942
self._merged_result_set: MergedResultSet = merged_result_set
43+
self._lazy_decode = lazy_decode
4044
self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue
4145

4246
def run(self):
@@ -52,7 +56,9 @@ def run(self):
5256
def __run(self):
5357
results = None
5458
try:
55-
results = self._batch_snapshot.process_query_batch(self._partition_id)
59+
results = self._batch_snapshot.process_query_batch(
60+
self._partition_id, lazy_decode=self._lazy_decode
61+
)
5662
for row in results:
5763
if self._merged_result_set._metadata is None:
5864
self._set_metadata(results)
@@ -94,7 +100,9 @@ class MergedResultSet:
94100
records in the MergedResultSet is not guaranteed.
95101
"""
96102

97-
def __init__(self, batch_snapshot, partition_ids, max_parallelism):
103+
def __init__(
104+
self, batch_snapshot, partition_ids, max_parallelism, lazy_decode=False
105+
):
98106
self._exception = None
99107
self._metadata = None
100108
self.metadata_event = Event()
@@ -110,7 +118,7 @@ def __init__(self, batch_snapshot, partition_ids, max_parallelism):
110118
partition_executors = []
111119
for partition_id in partition_ids:
112120
partition_executors.append(
113-
PartitionExecutor(batch_snapshot, partition_id, self)
121+
PartitionExecutor(batch_snapshot, partition_id, self, lazy_decode)
114122
)
115123
executor = ThreadPoolExecutor(max_workers=parallelism)
116124
for partition_executor in partition_executors:
@@ -144,3 +152,40 @@ def metadata(self):
144152
def stats(self):
145153
# TODO: Implement
146154
return None
155+
156+
def decode_row(self, row: []) -> []:
157+
"""Decodes a row from protobuf values to Python objects. This function
158+
should only be called for result sets that use ``lazy_decoding=True``.
159+
The array that is returned by this function is the same as the array
160+
that would have been returned by the rows iterator if ``lazy_decoding=False``.
161+
162+
:returns: an array containing the decoded values of all the columns in the given row
163+
"""
164+
if not isinstance(row, (list, tuple)):
165+
raise TypeError("row must be an array of protobuf values")
166+
decoders = self._decoders
167+
return [
168+
_parse_nullable(row[index], decoders[index]) for index in range(len(row))
169+
]
170+
171+
def decode_column(self, row: [], column_index: int):
172+
"""Decodes a column from a protobuf value to a Python object. This function
173+
should only be called for result sets that use ``lazy_decoding=True``.
174+
The object that is returned by this function is the same as the object
175+
that would have been returned by the rows iterator if ``lazy_decoding=False``.
176+
177+
:returns: the decoded column value
178+
"""
179+
if not isinstance(row, (list, tuple)):
180+
raise TypeError("row must be an array of protobuf values")
181+
decoders = self._decoders
182+
return _parse_nullable(row[column_index], decoders[column_index])
183+
184+
@property
185+
def _decoders(self):
186+
if self.metadata is None:
187+
raise ValueError("iterator not started")
188+
return [
189+
_get_type_decoder(field.type_, field.name, None)
190+
for field in self.metadata.row_type.fields
191+
]

tests/unit/test_database.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3141,6 +3141,7 @@ def test_process_query_batch(self):
31413141
params=params,
31423142
param_types=param_types,
31433143
partition=token,
3144+
lazy_decode=False,
31443145
retry=gapic_v1.method.DEFAULT,
31453146
timeout=gapic_v1.method.DEFAULT,
31463147
)
@@ -3170,6 +3171,7 @@ def test_process_query_batch_w_retry_timeout(self):
31703171
params=params,
31713172
param_types=param_types,
31723173
partition=token,
3174+
lazy_decode=False,
31733175
retry=retry,
31743176
timeout=2.0,
31753177
)
@@ -3193,11 +3195,23 @@ def test_process_query_batch_w_directed_read_options(self):
31933195
snapshot.execute_sql.assert_called_once_with(
31943196
sql=sql,
31953197
partition=token,
3198+
lazy_decode=False,
31963199
retry=gapic_v1.method.DEFAULT,
31973200
timeout=gapic_v1.method.DEFAULT,
31983201
directed_read_options=DIRECTED_READ_OPTIONS,
31993202
)
32003203

3204+
def test_context_manager(self):
3205+
database = self._make_database()
3206+
batch_txn = self._make_one(database)
3207+
session = batch_txn._session = self._make_session()
3208+
session.is_multiplexed = False
3209+
3210+
with batch_txn:
3211+
pass
3212+
3213+
session.delete.assert_called_once_with()
3214+
32013215
def test_close_wo_session(self):
32023216
database = self._make_database()
32033217
batch_txn = self._make_one(database)
@@ -3292,6 +3306,7 @@ def test_process_w_query_batch(self):
32923306
params=params,
32933307
param_types=param_types,
32943308
partition=token,
3309+
lazy_decode=False,
32953310
retry=gapic_v1.method.DEFAULT,
32963311
timeout=gapic_v1.method.DEFAULT,
32973312
)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
18+
class TestMergedResultSet(unittest.TestCase):
19+
def _get_target_class(self):
20+
from google.cloud.spanner_v1.merged_result_set import MergedResultSet
21+
22+
return MergedResultSet
23+
24+
def _make_one(self, *args, **kwargs):
25+
klass = self._get_target_class()
26+
obj = super(klass, klass).__new__(klass)
27+
from threading import Event, Lock
28+
29+
obj.metadata_event = Event()
30+
obj.metadata_lock = Lock()
31+
obj._metadata = None
32+
return obj
33+
34+
@staticmethod
35+
def _make_value(value):
36+
from google.cloud.spanner_v1._helpers import _make_value_pb
37+
38+
return _make_value_pb(value)
39+
40+
@staticmethod
41+
def _make_scalar_field(name, type_):
42+
from google.cloud.spanner_v1 import StructType
43+
from google.cloud.spanner_v1 import Type
44+
45+
return StructType.Field(name=name, type_=Type(code=type_))
46+
47+
@staticmethod
48+
def _make_result_set_metadata(fields=()):
49+
from google.cloud.spanner_v1 import ResultSetMetadata
50+
from google.cloud.spanner_v1 import StructType
51+
52+
metadata = ResultSetMetadata(row_type=StructType(fields=[]))
53+
for field in fields:
54+
metadata.row_type.fields.append(field)
55+
return metadata
56+
57+
def test_decoders_property_no_metadata(self):
58+
merged = self._make_one()
59+
merged._metadata = None
60+
merged.metadata_event.set()
61+
with self.assertRaises(ValueError):
62+
getattr(merged, "_decoders")
63+
64+
def test_decoders_property_with_metadata(self):
65+
from google.cloud.spanner_v1 import TypeCode
66+
67+
merged = self._make_one()
68+
fields = [
69+
self._make_scalar_field("full_name", TypeCode.STRING),
70+
self._make_scalar_field("age", TypeCode.INT64),
71+
]
72+
merged._metadata = self._make_result_set_metadata(fields)
73+
merged.metadata_event.set()
74+
75+
decoders = merged._decoders
76+
self.assertEqual(len(decoders), 2)
77+
self.assertTrue(callable(decoders[0]))
78+
self.assertTrue(callable(decoders[1]))
79+
80+
def test_decode_row(self):
81+
from google.cloud.spanner_v1 import TypeCode
82+
83+
merged = self._make_one()
84+
fields = [
85+
self._make_scalar_field("full_name", TypeCode.STRING),
86+
self._make_scalar_field("age", TypeCode.INT64),
87+
]
88+
merged._metadata = self._make_result_set_metadata(fields)
89+
merged.metadata_event.set()
90+
91+
raw_row = [self._make_value("Phred"), self._make_value(42)]
92+
decoded_row = merged.decode_row(raw_row)
93+
94+
self.assertEqual(decoded_row, ["Phred", 42])
95+
96+
def test_decode_row_type_error(self):
97+
merged = self._make_one()
98+
# The _decoders property requires metadata, even for a type error check.
99+
merged._metadata = self._make_result_set_metadata()
100+
merged.metadata_event.set()
101+
with self.assertRaises(TypeError):
102+
merged.decode_row("not a list")
103+
104+
def test_decode_column(self):
105+
from google.cloud.spanner_v1 import TypeCode
106+
107+
merged = self._make_one()
108+
fields = [
109+
self._make_scalar_field("full_name", TypeCode.STRING),
110+
self._make_scalar_field("age", TypeCode.INT64),
111+
]
112+
merged._metadata = self._make_result_set_metadata(fields)
113+
merged.metadata_event.set()
114+
115+
raw_row = [self._make_value("Phred"), self._make_value(42)]
116+
decoded_name = merged.decode_column(raw_row, 0)
117+
decoded_age = merged.decode_column(raw_row, 1)
118+
119+
self.assertEqual(decoded_name, "Phred")
120+
self.assertEqual(decoded_age, 42)
121+
122+
def test_decode_column_type_error(self):
123+
merged = self._make_one()
124+
# The _decoders property requires metadata, even for a type error check.
125+
merged._metadata = self._make_result_set_metadata()
126+
merged.metadata_event.set()
127+
with self.assertRaises(TypeError):
128+
merged.decode_column("not a list", 0)

0 commit comments

Comments
 (0)