Skip to content

Commit eb2c3d2

Browse files
committed
chore(x-goog-request-id): commit testing scaffold
This change commits the scaffolding for which testing will be used. This is a carve out of PRs #1264 and #1364, meant to make those changes lighter and much easier to review then merge. Updates #1261
1 parent e064474 commit eb2c3d2

File tree

5 files changed

+104
-7
lines changed

5 files changed

+104
-7
lines changed

google/cloud/spanner_v1/request_id_header.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,6 @@ def generate_rand_uint64():
3737

3838
def with_request_id(client_id, channel_id, nth_request, attempt, other_metadata=[]):
3939
req_id = f"{REQ_ID_VERSION}.{REQ_RAND_PROCESS_ID}.{client_id}.{channel_id}.{nth_request}.{attempt}"
40-
all_metadata = other_metadata.copy()
40+
all_metadata = (other_metadata or []).copy()
4141
all_metadata.append((REQ_ID_HEADER_KEY, req_id))
4242
return all_metadata

google/cloud/spanner_v1/testing/database_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from google.cloud.spanner_v1.testing.interceptors import (
2626
MethodCountInterceptor,
2727
MethodAbortInterceptor,
28+
XGoogRequestIDHeaderInterceptor,
2829
)
2930

3031

@@ -34,6 +35,8 @@ class TestDatabase(Database):
3435
currently, and we don't want to make changes in the Database class for
3536
testing purpose as this is a hack to use interceptors in tests."""
3637

38+
_interceptors = []
39+
3740
def __init__(
3841
self,
3942
database_id,
@@ -74,6 +77,8 @@ def spanner_api(self):
7477
client_options = client._client_options
7578
if self._instance.emulator_host is not None:
7679
channel = grpc.insecure_channel(self._instance.emulator_host)
80+
self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor()
81+
self._interceptors.append(self._x_goog_request_id_interceptor)
7782
channel = grpc.intercept_channel(channel, *self._interceptors)
7883
transport = SpannerGrpcTransport(channel=channel)
7984
self._spanner_api = SpannerClient(
@@ -110,3 +115,7 @@ def _create_spanner_client_for_tests(self, client_options, credentials):
110115
client_options=client_options,
111116
transport=transport,
112117
)
118+
119+
def reset(self):
120+
if self._x_goog_request_id_interceptor:
121+
self._x_goog_request_id_interceptor.reset()

google/cloud/spanner_v1/testing/interceptors.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515
from collections import defaultdict
16+
import threading
17+
1618
from grpc_interceptor import ClientInterceptor
1719
from google.api_core.exceptions import Aborted
1820

@@ -63,3 +65,67 @@ def reset(self):
6365
self._method_to_abort = None
6466
self._count = 0
6567
self._connection = None
68+
69+
70+
X_GOOG_REQUEST_ID = "x-goog-spanner-request-id"
71+
72+
73+
class XGoogRequestIDHeaderInterceptor(ClientInterceptor):
74+
def __init__(self):
75+
self._unary_req_segments = []
76+
self._stream_req_segments = []
77+
self.__lock = threading.Lock()
78+
79+
def intercept(self, method, request_or_iterator, call_details):
80+
metadata = call_details.metadata
81+
x_goog_request_id = None
82+
for key, value in metadata:
83+
if key == X_GOOG_REQUEST_ID:
84+
x_goog_request_id = value
85+
break
86+
87+
if not x_goog_request_id:
88+
raise Exception(
89+
f"Missing {X_GOOG_REQUEST_ID} header in {call_details.method}"
90+
)
91+
92+
response_or_iterator = method(request_or_iterator, call_details)
93+
streaming = getattr(response_or_iterator, "__iter__", None) is not None
94+
with self.__lock:
95+
if streaming:
96+
self._stream_req_segments.append(
97+
(call_details.method, parse_request_id(x_goog_request_id))
98+
)
99+
else:
100+
self._unary_req_segments.append(
101+
(call_details.method, parse_request_id(x_goog_request_id))
102+
)
103+
104+
return response_or_iterator
105+
106+
@property
107+
def unary_request_ids(self):
108+
return self._unary_req_segments
109+
110+
@property
111+
def stream_request_ids(self):
112+
return self._stream_req_segments
113+
114+
def reset(self):
115+
self._stream_req_segments.clear()
116+
self._unary_req_segments.clear()
117+
118+
119+
def parse_request_id(request_id_str):
120+
splits = request_id_str.split(".")
121+
version, rand_process_id, client_id, channel_id, nth_request, nth_attempt = list(
122+
map(lambda v: int(v), splits)
123+
)
124+
return (
125+
version,
126+
rand_process_id,
127+
client_id,
128+
channel_id,
129+
nth_request,
130+
nth_attempt,
131+
)

google/cloud/spanner_v1/testing/mock_spanner.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
from google.cloud.spanner_v1 import (
2323
TransactionOptions,
2424
ResultSetMetadata,
25-
ExecuteSqlRequest,
26-
ExecuteBatchDmlRequest,
2725
)
2826
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
2927
import google.cloud.spanner_v1.testing.spanner_database_admin_pb2_grpc as database_admin_grpc
@@ -107,6 +105,7 @@ def CreateSession(self, request, context):
107105

108106
def BatchCreateSessions(self, request, context):
109107
self._requests.append(request)
108+
self.mock_spanner.pop_error(context)
110109
sessions = []
111110
for i in range(request.session_count):
112111
sessions.append(
@@ -186,9 +185,7 @@ def BeginTransaction(self, request, context):
186185
self._requests.append(request)
187186
return self.__create_transaction(request.session, request.options)
188187

189-
def __maybe_create_transaction(
190-
self, request: ExecuteSqlRequest | ExecuteBatchDmlRequest
191-
):
188+
def __maybe_create_transaction(self, request):
192189
started_transaction = None
193190
if not request.transaction.begin == TransactionOptions():
194191
started_transaction = self.__create_transaction(

tests/mockserver_tests/mock_server_test_base.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
start_mock_server,
2121
SpannerServicer,
2222
)
23+
from google.cloud.spanner_v1.client import Client
2324
import google.cloud.spanner_v1.types.type as spanner_type
2425
import google.cloud.spanner_v1.types.result_set as result_set
2526
from google.api_core.client_options import ClientOptions
@@ -78,6 +79,27 @@ def unavailable_status() -> _Status:
7879
return status
7980

8081

82+
# Creates an UNAVAILABLE status with the smallest possible retry delay.
83+
def unavailable_status() -> _Status:
84+
error = status_pb2.Status(
85+
code=code_pb2.UNAVAILABLE,
86+
message="Service unavailable.",
87+
)
88+
retry_info = RetryInfo(retry_delay=Duration(seconds=0, nanos=1))
89+
status = _Status(
90+
code=code_to_grpc_status_code(error.code),
91+
details=error.message,
92+
trailing_metadata=(
93+
("grpc-status-details-bin", error.SerializeToString()),
94+
(
95+
"google.rpc.retryinfo-bin",
96+
retry_info.SerializeToString(),
97+
),
98+
),
99+
)
100+
return status
101+
102+
81103
def add_error(method: str, error: status_pb2.Status):
82104
MockServerTestBase.spanner_service.mock_spanner.add_error(method, error)
83105

@@ -153,6 +175,7 @@ def setup_class(cls):
153175
def teardown_class(cls):
154176
if MockServerTestBase.server is not None:
155177
MockServerTestBase.server.stop(grace=None)
178+
Client.NTH_CLIENT.reset()
156179
MockServerTestBase.server = None
157180

158181
def setup_method(self, *args, **kwargs):
@@ -186,6 +209,8 @@ def instance(self) -> Instance:
186209
def database(self) -> Database:
187210
if self._database is None:
188211
self._database = self.instance.database(
189-
"test-database", pool=FixedSizePool(size=10)
212+
"test-database",
213+
pool=FixedSizePool(size=10),
214+
enable_interceptors_in_tests=True,
190215
)
191216
return self._database

0 commit comments

Comments
 (0)