Skip to content
Merged
6 changes: 6 additions & 0 deletions google/cloud/spanner_v1/streamed.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
self._column_info = column_info # Column information
self._field_decoders = None
self._lazy_decode = lazy_decode # Return protobuf values
self._done = False

@property
def fields(self):
Expand Down Expand Up @@ -159,11 +160,16 @@ def _consume_next(self):

self._merge_values(values)

if response_pb.last:
self._done = True

def __iter__(self):
while True:
iter_rows, self._rows[:] = self._rows[:], ()
while iter_rows:
yield iter_rows.pop(0)
if self._done:
return
try:
self._consume_next()
except StopIteration:
Expand Down
26 changes: 21 additions & 5 deletions google/cloud/spanner_v1/testing/mock_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,17 @@
class MockSpanner:
def __init__(self):
self.results = {}
self.execute_streaming_sql_results = {}
self.errors = {}

def add_result(self, sql: str, result: result_set.ResultSet):
self.results[sql.lower().strip()] = result

def add_execute_streaming_sql_results(
self, sql: str, partial_result_sets: list[result_set.PartialResultSet]
):
self.execute_streaming_sql_results[sql.lower().strip()] = partial_result_sets

def get_result(self, sql: str) -> result_set.ResultSet:
result = self.results.get(sql.lower().strip())
if result is None:
Expand All @@ -55,9 +61,20 @@ def pop_error(self, context):
if error:
context.abort_with_status(error)

def get_result_as_partial_result_sets(
def get_execute_streaming_sql_results(
self, sql: str, started_transaction: transaction.Transaction
) -> [result_set.PartialResultSet]:
) -> list[result_set.PartialResultSet]:
if self.execute_streaming_sql_results.get(sql.lower().strip()):
partials = self.execute_streaming_sql_results[sql.lower().strip()]
else:
partials = self.get_result_as_partial_result_sets(sql)
if started_transaction:
partials[0].metadata.transaction = started_transaction
return partials

def get_result_as_partial_result_sets(
self, sql: str
) -> list[result_set.PartialResultSet]:
result: result_set.ResultSet = self.get_result(sql)
partials = []
first = True
Expand All @@ -70,11 +87,10 @@ def get_result_as_partial_result_sets(
partial = result_set.PartialResultSet()
if first:
partial.metadata = ResultSetMetadata(result.metadata)
first = False
partial.values.extend(row)
partials.append(partial)
partials[len(partials) - 1].stats = result.stats
if started_transaction:
partials[0].metadata.transaction = started_transaction
return partials


Expand Down Expand Up @@ -149,7 +165,7 @@ def ExecuteStreamingSql(self, request, context):
self._requests.append(request)
self.mock_spanner.pop_error(context)
started_transaction = self.__maybe_create_transaction(request)
partials = self.mock_spanner.get_result_as_partial_result_sets(
partials = self.mock_spanner.get_execute_streaming_sql_results(
request.sql, started_transaction
)
for result in partials:
Expand Down
66 changes: 51 additions & 15 deletions tests/mockserver_tests/mock_server_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,34 @@

import unittest

from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
from google.cloud.spanner_v1.testing.mock_spanner import (
start_mock_server,
SpannerServicer,
)
import google.cloud.spanner_v1.types.type as spanner_type
import google.cloud.spanner_v1.types.result_set as result_set
import grpc
from google.api_core.client_options import ClientOptions
from google.auth.credentials import AnonymousCredentials
from google.cloud.spanner_v1 import Client, TypeCode, FixedSizePool
from google.cloud.spanner_v1.database import Database
from google.cloud.spanner_v1.instance import Instance
import grpc
from google.rpc import code_pb2
from google.rpc import status_pb2
from google.rpc.error_details_pb2 import RetryInfo
from google.cloud.spanner_v1 import Type

from google.cloud.spanner_v1 import StructType
from google.cloud.spanner_v1._helpers import _make_value_pb

from google.cloud.spanner_v1 import PartialResultSet
from google.protobuf.duration_pb2 import Duration
from google.rpc import code_pb2, status_pb2

from google.rpc.error_details_pb2 import RetryInfo
from grpc_status._common import code_to_grpc_status_code
from grpc_status.rpc_status import _Status

import google.cloud.spanner_v1.types.result_set as result_set
import google.cloud.spanner_v1.types.type as spanner_type
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
from google.cloud.spanner_v1 import Client, FixedSizePool, ResultSetMetadata, TypeCode
from google.cloud.spanner_v1.database import Database
from google.cloud.spanner_v1.instance import Instance
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
from google.cloud.spanner_v1.testing.mock_spanner import (
SpannerServicer,
start_mock_server,
)


# Creates an aborted status with the smallest possible retry delay.
def aborted_status() -> _Status:
Expand All @@ -57,6 +64,27 @@ def aborted_status() -> _Status:
return status


def _make_partial_result_sets(
fields: list[tuple[str, TypeCode]], results: list[dict]
) -> list[result_set.PartialResultSet]:
partial_result_sets = []
for result in results:
partial_result_set = PartialResultSet()
if len(partial_result_sets) == 0:
# setting the metadata
metadata = ResultSetMetadata(row_type=StructType(fields=[]))
for field in fields:
metadata.row_type.fields.append(
StructType.Field(name=field[0], type_=Type(code=field[1]))
)
partial_result_set.metadata = metadata
for value in result["values"]:
partial_result_set.values.append(_make_value_pb(value))
partial_result_set.last = result.get("last") or False
partial_result_sets.append(partial_result_set)
return partial_result_sets


# Creates an UNAVAILABLE status with the smallest possible retry delay.
def unavailable_status() -> _Status:
error = status_pb2.Status(
Expand Down Expand Up @@ -101,6 +129,14 @@ def add_select1_result():
add_single_result("select 1", "c", TypeCode.INT64, [("1",)])


def add_execute_streaming_sql_results(
sql: str, partial_result_sets: list[result_set.PartialResultSet]
):
MockServerTestBase.spanner_service.mock_spanner.add_execute_streaming_sql_results(
sql, partial_result_sets
)


def add_single_result(
sql: str, column_name: str, type_code: spanner_type.TypeCode, row
):
Expand Down
39 changes: 35 additions & 4 deletions tests/mockserver_tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,32 @@
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
from google.cloud.spanner_v1 import (
BatchCreateSessionsRequest,
ExecuteSqlRequest,
BeginTransactionRequest,
TransactionOptions,
ExecuteBatchDmlRequest,
ExecuteSqlRequest,
TransactionOptions,
TypeCode,
)
from google.cloud.spanner_v1.transaction import Transaction
from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer
from google.cloud.spanner_v1.transaction import Transaction

from tests.mockserver_tests.mock_server_test_base import (
MockServerTestBase,
_make_partial_result_sets,
add_select1_result,
add_single_result,
add_update_count,
add_error,
unavailable_status,
add_single_result,
add_execute_streaming_sql_results,
)


class TestBasics(MockServerTestBase):
def setUp(self):
super().setUp()
super().setup_class()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to call setup_class() for each test method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me remove this. This is needed only when I am running individual tests in PyCharm


def test_select1(self):
add_select1_result()
with self.database.snapshot() as snapshot:
Expand Down Expand Up @@ -176,6 +182,31 @@ def test_last_statement_query(self):
self.assertEqual(1, len(requests), msg=requests)
self.assertTrue(requests[0].last_statement, requests[0])

def test_execute_streaming_sql_last_field(self):
partial_result_sets = _make_partial_result_sets(
[("ID", TypeCode.INT64), ("NAME", TypeCode.STRING)],
[
{"values": ["1", "ABC", "2", "DEF"]},
{"values": ["3", "GHI"], "last": True},
],
)

sql = "select * from my_table"
add_execute_streaming_sql_results(sql, partial_result_sets)
count = 1
with self.database.snapshot() as snapshot:
results = snapshot.execute_sql(sql)
result_list = []
for row in results:
result_list.append(row)
self.assertEqual(count, row[0])
count += 1
self.assertEqual(3, len(result_list))
requests = self.spanner_service.requests
self.assertEqual(2, len(requests), msg=requests)
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))


def _execute_query(transaction: Transaction, sql: str):
rows = transaction.execute_sql(sql, last_statement=True)
Expand Down
38 changes: 36 additions & 2 deletions tests/unit/test_streamed.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ def _make_result_set_stats(query_plan=None, **kw):

@staticmethod
def _make_partial_result_set(
values, metadata=None, stats=None, chunked_value=False
values, metadata=None, stats=None, chunked_value=False, last=False
):
from google.cloud.spanner_v1 import PartialResultSet

results = PartialResultSet(
metadata=metadata, stats=stats, chunked_value=chunked_value
metadata=metadata, stats=stats, chunked_value=chunked_value, last=last
)
for v in values:
results.values.append(v)
Expand Down Expand Up @@ -164,6 +164,40 @@ def test__merge_chunk_bool(self):
with self.assertRaises(Unmergeable):
streamed._merge_chunk(chunk)

def test__PartialResultSetWithLastFlag(self):
from google.cloud.spanner_v1 import TypeCode

fields = [
self._make_scalar_field("ID", TypeCode.INT64),
self._make_scalar_field("NAME", TypeCode.STRING),
]
for length in range(4, 6):
metadata = self._make_result_set_metadata(fields)
result_sets = [
self._make_partial_result_set(
[self._make_value(0), "google_0"], metadata=metadata
)
]
for i in range(1, 5):
bares = [i]
values = [
[self._make_value(bare), "google_" + str(bare)] for bare in bares
]
result_sets.append(
self._make_partial_result_set(
*values, metadata=metadata, last=(i == length - 1)
)
)

iterator = _MockCancellableIterator(*result_sets)
streamed = self._make_one(iterator)
count = 0
for row in streamed:
self.assertEqual(row[0], count)
self.assertEqual(row[1], "google_" + str(count))
count += 1
self.assertEqual(count, length)

def test__merge_chunk_numeric(self):
from google.cloud.spanner_v1 import TypeCode

Expand Down