1919import math
2020import time
2121import base64
22+ import inspect
2223import threading
2324
2425from google .protobuf .struct_pb2 import ListValue
@@ -739,9 +740,100 @@ def __init__(self, original_callable: Callable):
739740
740741
741742patched = {}
743+ patched_mu = threading .Lock ()
742744
743745
744746def inject_retry_header_control (api ):
747+ return
748+ monkey_patch (type (api ))
749+
750+ memoize_map = dict ()
751+
752+ def monkey_patch (obj ):
753+ return
754+
755+ """
756+ klass = obj
757+ attrs = dir(klass)
758+ for attr_key in attrs:
759+ if attr_key.startswith('_'):
760+ continue
761+
762+ attr_value = getattr(obj, attr_key)
763+ if not callable(attr_value):
764+ continue
765+
766+ signature = inspect.signature(attr_value)
767+ print(attr_key, signature.parameters)
768+
769+ call = attr_value
770+ # Our goal is to replace the runtime pass through.
771+ def wrapped(*args, **kwargs):
772+ print(attr_key, 'called')
773+ return call(*args, **kwargs)
774+
775+ setattr(klass, attr_key, wrapped)
776+
777+ return
778+ """
779+
780+ orig_get_attr = getattr (obj , "__getattribute__" )
781+ def patched_getattribute (obj , key , * args , ** kwargs ):
782+ if key .startswith ('_' ):
783+ return orig_get_attr (obj , key , * args , ** kwargs )
784+
785+ orig_value = orig_get_attr (obj , key , * args , ** kwargs )
786+ if not callable (orig_value ):
787+ return orig_value
788+
789+ map_key = hex (id (key )) + hex (id (obj ))
790+ memoized = memoize_map .get (map_key , None )
791+ if memoized :
792+ print ("memoized_hit" , key , '\033 [35m' , inspect .getsource (orig_value ), '\033 [00m' )
793+ return memoized
794+
795+ signature = inspect .signature (orig_value )
796+ if signature .parameters .get ('metadata' , None ) is None :
797+ return orig_value
798+
799+ print (key , '\033 [34m' , map_key , '\033 [00m' , signature , signature .parameters .get ('metadata' , None ))
800+ counters = dict (attempt = 0 )
801+ def patched_method (* aargs , ** kkwargs ):
802+ counters ['attempt' ] += 1
803+ metadata = kkwargs .get ('metadata' , None )
804+ if not metadata :
805+ return orig_value (* aargs , ** kkwargs )
806+
807+ # 4. Find all the headers that match the target header key.
808+ all_metadata = []
809+ for mkey , value in metadata :
810+ if mkey is REQ_ID_HEADER_KEY :
811+ attempt = counters ['attempt' ]
812+ if attempt > 1 :
813+ # 5. Increment the original_attempt with that of our re-invocation count.
814+ splits = value .split ("." )
815+ print ('\033 [34mkey' , mkey , '\033 [00m' , splits )
816+ hdr_attempt_plus_reinvocation = (
817+ int (splits [- 1 ]) + attempt
818+ )
819+ splits [- 1 ] = str (hdr_attempt_plus_reinvocation )
820+ value = "." .join (splits )
821+
822+ all_metadata .append ((mkey , value ))
823+
824+ kwargs ["metadata" ] = all_metadata
825+ return orig_value (* aargs , ** kkwargs )
826+
827+ memoize_map [map_key ] = patched_method
828+ return patched_method
829+
830+ setattr (obj , '__getattribute__' , patched_getattribute )
831+
832+
833+ def foo (api ):
834+ global patched
835+ global patched_mu
836+
745837 # For each method, add an _attempt value that'll then be
746838 # retrieved for each retry.
747839 # 1. Patch the __getattribute__ method to match items in our manifest.
@@ -753,55 +845,66 @@ def inject_retry_header_control(api):
753845 orig_getattribute = getattr (target , "__getattribute__" )
754846
755847 def patched_getattribute (obj , key , * args , ** kwargs ):
848+ # 1. Skip modifying private and mangled methods.
756849 if key .startswith ("_" ):
757850 return orig_getattribute (obj , key , * args , ** kwargs )
758851
759852 attr = orig_getattribute (obj , key , * args , ** kwargs )
760853
761- # 0. If we already patched it, we can return immediately.
762- if getattr (attr , "_patched" , None ) is not None :
763- return attr
764-
765- # 1. Skip over non-methods.
854+ # 2. Skip over non-methods.
766855 if not callable (attr ):
856+ patched_mu .release ()
767857 return attr
768858
769- # 2. Skip modifying private and mangled methods.
770- mangled_or_private = attr . __name__ . startswith ( "_" )
771- if mangled_or_private :
772- return attr
773-
859+ patched_key = hex ( id ( key )) + hex ( id ( obj ))
860+ patched_mu . acquire ( )
861+ already_patched = patched . get ( patched_key , None )
862+
863+ other_attempts = dict ( attempts = 0 )
774864 # 3. Wrap the callable attribute and then capture its metadata keyed argument.
775865 def wrapped_attr (* args , ** kwargs ):
866+ print ("\033 [31m" , key , "attempt" , other_attempts ['attempts' ], "\033 [00m" )
867+ other_attempts ['attempts' ] += 1
868+
776869 metadata = kwargs .get ("metadata" , [])
777870 if not metadata :
778871 # Increment the reinvocation count.
779872 wrapped_attr ._attempt += 1
780873 return attr (* args , ** kwargs )
781874
875+ print ("\033 [35mwrapped_attr" , key , args , kwargs , 'attempt' , wrapped_attr ._attempt , "\033 [00m" )
876+
782877 # 4. Find all the headers that match the target header key.
783878 all_metadata = []
784- for key , value in metadata :
785- if key is REQ_ID_HEADER_KEY :
786- # 5. Increment the original_attempt with that of our re-invocation count.
787- splits = value .split ("." )
788- hdr_attempt_plus_reinvocation = (
789- int (splits [- 1 ]) + wrapped_attr ._attempt
790- )
791- splits [- 1 ] = str (hdr_attempt_plus_reinvocation )
792- value = "." .join (splits )
793-
794- all_metadata .append ((key , value ))
795-
796- # Increment the reinvocation count.
797- wrapped_attr ._attempt += 1
879+ for mkey , value in metadata :
880+ if mkey is REQ_ID_HEADER_KEY :
881+ if wrapped_attr ._attempt > 0 :
882+ # 5. Increment the original_attempt with that of our re-invocation count.
883+ splits = value .split ("." )
884+ print ('\033 [34mkey' , mkey , '\033 [00m' , splits )
885+ hdr_attempt_plus_reinvocation = (
886+ int (splits [- 1 ]) + wrapped_attr ._attempt
887+ )
888+ splits [- 1 ] = str (hdr_attempt_plus_reinvocation )
889+ value = "." .join (splits )
890+
891+ all_metadata .append ((mkey , value ))
798892
799893 kwargs ["metadata" ] = all_metadata
894+ wrapped_attr ._attempt += 1
895+ print (key , "\033 [36mreplaced_all_metadata" , all_metadata , "\033 [00m" )
800896 return attr (* args , ** kwargs )
801897
802- wrapped_attr ._attempt = 0
803- wrapped_attr ._patched = True
898+ if already_patched :
899+ print ("patched_key \033 [32m" , patched_key , key , "\033 [00m" , already_patched )
900+ setattr (attr , 'patched' , True )
901+ # Increment the reinvocation count.
902+ patched_mu .release ()
903+ return already_patched
904+
905+ patched [patched_key ] = wrapped_attr
906+ setattr (wrapped_attr , '_attempt' , 0 )
907+ patched_mu .release ()
804908 return wrapped_attr
805909
806910 setattr (target , "__getattribute__" , patched_getattribute )
807- patched [hex_id ] = True
0 commit comments