Skip to content

Commit 9d4d942

Browse files
committed
Monkey patch updates
1 parent 06c12a2 commit 9d4d942

File tree

6 files changed

+151
-38
lines changed

6 files changed

+151
-38
lines changed

google/cloud/spanner_v1/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
from google.cloud.spanner_v1.pool import FixedSizePool
7676
from google.cloud.spanner_v1.pool import PingingPool
7777
from google.cloud.spanner_v1.pool import TransactionPingingPool
78+
from google.cloud.spanner_v1._helpers import monkey_patch
7879

7980

8081
COMMIT_TIMESTAMP = "spanner.commit_timestamp()"
@@ -83,6 +84,8 @@
8384
``(allow_commit_timestamp=true)`` in the schema.
8485
"""
8586

87+
monkey_patch(Transaction)
88+
8689

8790
__all__ = (
8891
# google.cloud.spanner_v1

google/cloud/spanner_v1/_helpers.py

Lines changed: 130 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import math
2020
import time
2121
import base64
22+
import inspect
2223
import threading
2324

2425
from google.protobuf.struct_pb2 import ListValue
@@ -739,9 +740,100 @@ def __init__(self, original_callable: Callable):
739740

740741

741742
patched = {}
743+
patched_mu = threading.Lock()
742744

743745

744746
def inject_retry_header_control(api):
747+
return
748+
monkey_patch(type(api))
749+
750+
memoize_map = dict()
751+
752+
def monkey_patch(obj):
753+
return
754+
755+
"""
756+
klass = obj
757+
attrs = dir(klass)
758+
for attr_key in attrs:
759+
if attr_key.startswith('_'):
760+
continue
761+
762+
attr_value = getattr(obj, attr_key)
763+
if not callable(attr_value):
764+
continue
765+
766+
signature = inspect.signature(attr_value)
767+
print(attr_key, signature.parameters)
768+
769+
call = attr_value
770+
# Our goal is to replace the runtime pass through.
771+
def wrapped(*args, **kwargs):
772+
print(attr_key, 'called')
773+
return call(*args, **kwargs)
774+
775+
setattr(klass, attr_key, wrapped)
776+
777+
return
778+
"""
779+
780+
orig_get_attr = getattr(obj, "__getattribute__")
781+
def patched_getattribute(obj, key, *args, **kwargs):
782+
if key.startswith('_'):
783+
return orig_get_attr(obj, key, *args, **kwargs)
784+
785+
orig_value = orig_get_attr(obj, key, *args, **kwargs)
786+
if not callable(orig_value):
787+
return orig_value
788+
789+
map_key = hex(id(key)) + hex(id(obj))
790+
memoized = memoize_map.get(map_key, None)
791+
if memoized:
792+
print("memoized_hit", key, '\033[35m', inspect.getsource(orig_value), '\033[00m')
793+
return memoized
794+
795+
signature = inspect.signature(orig_value)
796+
if signature.parameters.get('metadata', None) is None:
797+
return orig_value
798+
799+
print(key, '\033[34m', map_key, '\033[00m', signature, signature.parameters.get('metadata', None))
800+
counters = dict(attempt=0)
801+
def patched_method(*aargs, **kkwargs):
802+
counters['attempt'] += 1
803+
metadata = kkwargs.get('metadata', None)
804+
if not metadata:
805+
return orig_value(*aargs, **kkwargs)
806+
807+
# 4. Find all the headers that match the target header key.
808+
all_metadata = []
809+
for mkey, value in metadata:
810+
if mkey is REQ_ID_HEADER_KEY:
811+
attempt = counters['attempt']
812+
if attempt > 1:
813+
# 5. Increment the original_attempt with that of our re-invocation count.
814+
splits = value.split(".")
815+
print('\033[34mkey', mkey, '\033[00m', splits)
816+
hdr_attempt_plus_reinvocation = (
817+
int(splits[-1]) + attempt
818+
)
819+
splits[-1] = str(hdr_attempt_plus_reinvocation)
820+
value = ".".join(splits)
821+
822+
all_metadata.append((mkey, value))
823+
824+
kwargs["metadata"] = all_metadata
825+
return orig_value(*aargs, **kkwargs)
826+
827+
memoize_map[map_key] = patched_method
828+
return patched_method
829+
830+
setattr(obj, '__getattribute__', patched_getattribute)
831+
832+
833+
def foo(api):
834+
global patched
835+
global patched_mu
836+
745837
# For each method, add an _attempt value that'll then be
746838
# retrieved for each retry.
747839
# 1. Patch the __getattribute__ method to match items in our manifest.
@@ -753,55 +845,66 @@ def inject_retry_header_control(api):
753845
orig_getattribute = getattr(target, "__getattribute__")
754846

755847
def patched_getattribute(obj, key, *args, **kwargs):
848+
# 1. Skip modifying private and mangled methods.
756849
if key.startswith("_"):
757850
return orig_getattribute(obj, key, *args, **kwargs)
758851

759852
attr = orig_getattribute(obj, key, *args, **kwargs)
760853

761-
# 0. If we already patched it, we can return immediately.
762-
if getattr(attr, "_patched", None) is not None:
763-
return attr
764-
765-
# 1. Skip over non-methods.
854+
# 2. Skip over non-methods.
766855
if not callable(attr):
856+
patched_mu.release()
767857
return attr
768858

769-
# 2. Skip modifying private and mangled methods.
770-
mangled_or_private = attr.__name__.startswith("_")
771-
if mangled_or_private:
772-
return attr
773-
859+
patched_key = hex(id(key)) + hex(id(obj))
860+
patched_mu.acquire()
861+
already_patched = patched.get(patched_key, None)
862+
863+
other_attempts = dict(attempts=0)
774864
# 3. Wrap the callable attribute and then capture its metadata keyed argument.
775865
def wrapped_attr(*args, **kwargs):
866+
print("\033[31m", key, "attempt", other_attempts['attempts'], "\033[00m")
867+
other_attempts['attempts'] += 1
868+
776869
metadata = kwargs.get("metadata", [])
777870
if not metadata:
778871
# Increment the reinvocation count.
779872
wrapped_attr._attempt += 1
780873
return attr(*args, **kwargs)
781874

875+
print("\033[35mwrapped_attr", key, args, kwargs, 'attempt', wrapped_attr._attempt, "\033[00m")
876+
782877
# 4. Find all the headers that match the target header key.
783878
all_metadata = []
784-
for key, value in metadata:
785-
if key is REQ_ID_HEADER_KEY:
786-
# 5. Increment the original_attempt with that of our re-invocation count.
787-
splits = value.split(".")
788-
hdr_attempt_plus_reinvocation = (
789-
int(splits[-1]) + wrapped_attr._attempt
790-
)
791-
splits[-1] = str(hdr_attempt_plus_reinvocation)
792-
value = ".".join(splits)
793-
794-
all_metadata.append((key, value))
795-
796-
# Increment the reinvocation count.
797-
wrapped_attr._attempt += 1
879+
for mkey, value in metadata:
880+
if mkey is REQ_ID_HEADER_KEY:
881+
if wrapped_attr._attempt > 0:
882+
# 5. Increment the original_attempt with that of our re-invocation count.
883+
splits = value.split(".")
884+
print('\033[34mkey', mkey, '\033[00m', splits)
885+
hdr_attempt_plus_reinvocation = (
886+
int(splits[-1]) + wrapped_attr._attempt
887+
)
888+
splits[-1] = str(hdr_attempt_plus_reinvocation)
889+
value = ".".join(splits)
890+
891+
all_metadata.append((mkey, value))
798892

799893
kwargs["metadata"] = all_metadata
894+
wrapped_attr._attempt += 1
895+
print(key, "\033[36mreplaced_all_metadata", all_metadata, "\033[00m")
800896
return attr(*args, **kwargs)
801897

802-
wrapped_attr._attempt = 0
803-
wrapped_attr._patched = True
898+
if already_patched:
899+
print("patched_key \033[32m", patched_key, key, "\033[00m", already_patched)
900+
setattr(attr, 'patched', True)
901+
# Increment the reinvocation count.
902+
patched_mu.release()
903+
return already_patched
904+
905+
patched[patched_key] = wrapped_attr
906+
setattr(wrapped_attr, '_attempt', 0)
907+
patched_mu.release()
804908
return wrapped_attr
805909

806910
setattr(target, "__getattribute__", patched_getattribute)
807-
patched[hex_id] = True

google/cloud/spanner_v1/database.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
_metadata_with_prefix,
5656
_metadata_with_leader_aware_routing,
5757
_metadata_with_request_id,
58-
inject_retry_header_control,
58+
monkey_patch,
5959
)
6060
from google.cloud.spanner_v1.batch import Batch
6161
from google.cloud.spanner_v1.batch import MutationGroups
@@ -438,7 +438,7 @@ def spanner_api(self):
438438
if not api:
439439
return api
440440

441-
inject_retry_header_control(api)
441+
monkey_patch(api)
442442
return api
443443

444444
def __generate_spanner_api(self):
@@ -813,6 +813,7 @@ def execute_pdml():
813813
def _next_nth_request(self):
814814
if self._instance and self._instance._client:
815815
return self._instance._client._next_nth_request
816+
raise Exception("returning 1 for next_nth_request")
816817
return 1
817818

818819
@property

google/cloud/spanner_v1/pool.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def bind(self, database):
249249
attempt = 1
250250
returned_session_count = 0
251251
while not self._sessions.full():
252+
print("fixedPool.batchCreateSessions")
252253
request.session_count = requested_session_count - self._sessions.qsize()
253254
add_span_event(
254255
span,
@@ -562,6 +563,7 @@ def bind(self, database):
562563
) as span, MetricsCapture():
563564
returned_session_count = 0
564565
while returned_session_count < self.size:
566+
print("pingingPool.batchCreateSessions")
565567
resp = api.batch_create_sessions(
566568
request=request,
567569
metadata=database.metadata_with_request_id(

google/cloud/spanner_v1/snapshot.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -591,21 +591,24 @@ def execute_sql(
591591
directed_read_options=directed_read_options,
592592
)
593593

594-
restart = functools.partial(
595-
api.execute_streaming_sql,
596-
request=request,
597-
metadata=metadata,
598-
retry=retry,
599-
timeout=timeout,
600-
)
594+
def wrapped_restart(*args, **kwargs):
595+
restart = functools.partial(
596+
api.execute_streaming_sql,
597+
request=request,
598+
metadata=kwargs.get('metadata', metadata),
599+
retry=retry,
600+
timeout=timeout,
601+
)
602+
return restart(*args, **kwargs)
603+
601604
trace_attributes = {"db.statement": sql}
602605
observability_options = getattr(database, "observability_options", None)
603606

604607
if self._transaction_id is None:
605608
# lock is added to handle the inline begin for first rpc
606609
with self._lock:
607610
return self._get_streamed_result_set(
608-
restart,
611+
wrapped_restart,
609612
request,
610613
metadata,
611614
trace_attributes,
@@ -615,7 +618,7 @@ def execute_sql(
615618
)
616619
else:
617620
return self._get_streamed_result_set(
618-
restart,
621+
wrapped_restart,
619622
request,
620623
metadata,
621624
trace_attributes,

tests/mockserver_tests/test_request_id_header.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def test_unary_retryable_error(self):
310310
]
311311

312312
print("got_unaries", got_unary_segments)
313+
print("got_stream", got_stream_segments)
313314
assert got_unary_segments == want_unary_segments
314315
assert got_stream_segments == want_stream_segments
315316

0 commit comments

Comments
 (0)