Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,18 @@ def _restart_on_unavailable(
)

request.transaction = transaction_selector
iterator = None

with trace_call(
trace_name, session, attributes, observability_options=observability_options
):
iterator = method(request=request)
while True:
try:
if iterator is None:
with trace_call(
trace_name,
session,
attributes,
observability_options=observability_options,
):
iterator = method(request=request)
for item in iterator:
item_buffer.append(item)
# Setting the transaction id because the transaction begin was inlined for first rpc.
Expand Down
21 changes: 21 additions & 0 deletions tests/mockserver_tests/mock_server_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,27 @@ def aborted_status() -> _Status:
return status


# Creates an UNAVAILABLE status with the smallest possible retry delay.
def unavailable_status() -> _Status:
error = status_pb2.Status(
code=code_pb2.UNAVAILABLE,
message="Service unavailable.",
)
retry_info = RetryInfo(retry_delay=Duration(seconds=0, nanos=1))
status = _Status(
code=code_to_grpc_status_code(error.code),
details=error.message,
trailing_metadata=(
("grpc-status-details-bin", error.SerializeToString()),
(
"google.rpc.retryinfo-bin",
retry_info.SerializeToString(),
),
),
)
return status


def add_error(method: str, error: status_pb2.Status):
MockServerTestBase.spanner_service.mock_spanner.add_error(method, error)

Expand Down
22 changes: 22 additions & 0 deletions tests/mockserver_tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
BeginTransactionRequest,
TransactionOptions,
)
from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer

from tests.mockserver_tests.mock_server_test_base import (
MockServerTestBase,
add_select1_result,
add_update_count,
add_error,
unavailable_status,
)


Expand Down Expand Up @@ -85,3 +88,22 @@ def test_dbapi_partitioned_dml(self):
self.assertEqual(
TransactionOptions(dict(partitioned_dml={})), begin_request.options
)

def test_execute_streaming_sql_unavailable(self):
add_select1_result()
# Add an UNAVAILABLE error that is returned the first time the
# ExecuteStreamingSql RPC is called.
add_error(SpannerServicer.ExecuteStreamingSql.__name__, unavailable_status())
with self.database.snapshot() as snapshot:
results = snapshot.execute_sql("select 1")
result_list = []
for row in results:
result_list.append(row)
self.assertEqual(1, row[0])
self.assertEqual(1, len(result_list))
requests = self.spanner_service.requests
self.assertEqual(3, len(requests), msg=requests)
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
# The ExecuteStreamingSql call should be retried.
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
Loading