@@ -576,6 +576,7 @@ def _retry(
576576
577577
578578def _check_rst_stream_error (exc ):
579+ print ("\033 [31mrst_" , exc , "\033 [00m" )
579580 resumable_error = (
580581 any (
581582 resumable_message in exc .message
@@ -589,6 +590,11 @@ def _check_rst_stream_error(exc):
589590 raise
590591
591592
593+ def _check_unavailable (exc ):
594+ print ("\033 [31mcheck_unavailable" , exc , "\033 [00m" )
595+ raise
596+
597+
592598def _metadata_with_leader_aware_routing (value , ** kw ):
593599 """Create RPC metadata containing a leader aware routing header
594600
@@ -763,96 +769,164 @@ def __init__(self, original_callable: Callable):
763769
764770
765771def inject_retry_header_control (api ):
766- return
767- monkey_patch (type (api ))
772+ # monkey_patch(type(api))
773+ # monkey_patch(api)
774+ pass
768775
769- memoize_map = dict ()
770776
771- def monkey_patch (obj ):
772- return
777+ def monkey_patch (typ ):
778+ keys = dir (typ )
779+ attempts = dict ()
780+ for key in keys :
781+ if key .startswith ("_" ):
782+ continue
773783
774- """
775- klass = obj
776- attrs = dir(klass)
777- for attr_key in attrs:
778- if attr_key.startswith('_'):
784+ if key != "batch_create_sessions" :
779785 continue
780786
781- attr_value = getattr(obj, attr_key)
782- if not callable(attr_value):
787+ fn = getattr (typ , key )
788+
789+ signature = inspect .signature (fn )
790+ if signature .parameters .get ("metadata" , None ) is None :
783791 continue
784792
785- signature = inspect.signature(attr_value)
786- print(attr_key, signature.parameters)
793+ print ("fn.__call__" , inspect .getsource (fn ))
787794
788- call = attr_value
789- # Our goal is to replace the runtime pass through.
790- def wrapped(*args, **kwargs):
791- print(attr_key, 'called')
792- return call(*args, **kwargs)
795+ def as_proxy (db , * args , ** kwargs ):
796+ print ("db_key" , hex (id (db )))
797+ print ("as_proxy" , args , kwargs )
798+ metadata = kwargs .get ("metadata" , None )
799+ if not metadata :
800+ return fn (db , * args , ** kwargs )
793801
794- setattr(klass, attr_key, wrapped)
802+ hash_key = hex (id (db )) + "." + hex (id (key ))
803+ attempts .setdefault (hash_key , 0 )
804+ attempts [hash_key ] += 1
805+ # 4. Find all the headers that match the target header key.
806+ all_metadata = []
807+ for mkey , value in metadata :
808+ if mkey is not REQ_ID_HEADER_KEY :
809+ continue
795810
796- return
797- """
811+ splits = value .split ("." )
812+ # 5. Increment the original_attempt with that of our re-invocation count.
813+ print ("\033 [34mkey" , mkey , "\033 [00m" , splits )
814+ hdr_attempt_plus_reinvocation = int (splits [- 1 ]) + attempts [hash_key ]
815+ splits [- 1 ] = str (hdr_attempt_plus_reinvocation )
816+ value = "." .join (splits )
798817
818+ all_metadata .append ((mkey , value ))
819+
820+ kwargs ["metadata" ] = all_metadata
821+ return fn (db , * args , ** kwargs )
822+
823+ setattr (typ , key , as_proxy )
824+
825+
826+ def alt_foo ():
827+ memoize_map = dict ()
799828 orig_get_attr = getattr (obj , "__getattribute__" )
829+ hex_orig = hex (id (orig_get_attr ))
830+ hex_patched = None
831+
800832 def patched_getattribute (obj , key , * args , ** kwargs ):
801- if key .startswith ('_' ):
833+ if key .startswith ("_" ):
802834 return orig_get_attr (obj , key , * args , ** kwargs )
803835
804- orig_value = orig_get_attr (obj , key , * args , ** kwargs )
805- if not callable (orig_value ):
806- return orig_value
836+ if key != "batch_create_sessions" :
837+ return orig_get_attr (obj , key , * args , ** kwargs )
807838
808839 map_key = hex (id (key )) + hex (id (obj ))
809840 memoized = memoize_map .get (map_key , None )
810841 if memoized :
811- print ("memoized_hit" , key , '\033 [35m' , inspect .getsource (orig_value ), '\033 [00m' )
842+ if False :
843+ print (
844+ "memoized_hit" ,
845+ key ,
846+ "\033 [35m" ,
847+ inspect .getsource (orig_value ),
848+ "\033 [00m" ,
849+ )
850+ print ("memoized_hit" , key , "\033 [35m" , map_key , "\033 [00m" )
812851 return memoized
813852
853+ orig_value = orig_get_attr (obj , key , * args , ** kwargs )
854+ if not callable (orig_value ):
855+ return orig_value
856+
814857 signature = inspect .signature (orig_value )
815- if signature .parameters .get (' metadata' , None ) is None :
858+ if signature .parameters .get (" metadata" , None ) is None :
816859 return orig_value
817860
818- print (key , '\033 [34m' , map_key , '\033 [00m' , signature , signature .parameters .get ('metadata' , None ))
861+ if False :
862+ print (
863+ key ,
864+ "\033 [34m" ,
865+ map_key ,
866+ "\033 [00m" ,
867+ signature ,
868+ signature .parameters .get ("metadata" , None ),
869+ )
870+
871+ if False :
872+ stack = inspect .stack ()
873+ ends = stack [- 50 :- 20 ]
874+ for i , st in enumerate (ends ):
875+ print (i , st .filename , st .lineno )
876+
877+ print (
878+ "\033 [33mmonkey patching now\033 [00m" ,
879+ key ,
880+ "hex_orig" ,
881+ hex_orig ,
882+ "hex_patched" ,
883+ hex_patched ,
884+ )
819885 counters = dict (attempt = 0 )
886+
820887 def patched_method (* aargs , ** kkwargs ):
821- counters ['attempt' ] += 1
822- metadata = kkwargs .get ('metadata' , None )
888+ counters ["attempt" ] += 1
889+ print ("counters" , counters )
890+ metadata = kkwargs .get ("metadata" , None )
823891 if not metadata :
824892 return orig_value (* aargs , ** kkwargs )
825893
826894 # 4. Find all the headers that match the target header key.
827895 all_metadata = []
828896 for mkey , value in metadata :
829897 if mkey is REQ_ID_HEADER_KEY :
830- attempt = counters [' attempt' ]
898+ attempt = counters [" attempt" ]
831899 if attempt > 1 :
832900 # 5. Increment the original_attempt with that of our re-invocation count.
833901 splits = value .split ("." )
834- print ('\033 [34mkey' , mkey , '\033 [00m' , splits )
835- hdr_attempt_plus_reinvocation = (
836- int (splits [- 1 ]) + attempt
837- )
902+ print ("\033 [34mkey" , mkey , "\033 [00m" , splits )
903+ hdr_attempt_plus_reinvocation = int (splits [- 1 ]) + attempt
838904 splits [- 1 ] = str (hdr_attempt_plus_reinvocation )
839905 value = "." .join (splits )
840906
841907 all_metadata .append ((mkey , value ))
842908
843909 kwargs ["metadata" ] = all_metadata
844- return orig_value (* aargs , ** kkwargs )
910+
911+ try :
912+ return orig_value (* aargs , ** kkwargs )
913+
914+ except (InternalServerError , ServiceUnavailable ) as exc :
915+ print ("caught this exception, incrementing" , exc )
916+ counters ["attempt" ] += 1
917+ raise exc
845918
846919 memoize_map [map_key ] = patched_method
847920 return patched_method
848921
849- setattr (obj , '__getattribute__' , patched_getattribute )
922+ hex_patched = hex (id (patched_getattribute ))
923+ setattr (obj , "__getattribute__" , patched_getattribute )
850924
851925
852926def foo (api ):
853927 global patched
854928 global patched_mu
855-
929+
856930 # For each method, add an _attempt value that'll then be
857931 # retrieved for each retry.
858932 # 1. Patch the __getattribute__ method to match items in our manifest.
@@ -878,20 +952,29 @@ def patched_getattribute(obj, key, *args, **kwargs):
878952 patched_key = hex (id (key )) + hex (id (obj ))
879953 patched_mu .acquire ()
880954 already_patched = patched .get (patched_key , None )
881-
955+
882956 other_attempts = dict (attempts = 0 )
957+
883958 # 3. Wrap the callable attribute and then capture its metadata keyed argument.
884959 def wrapped_attr (* args , ** kwargs ):
885- print ("\033 [31m" , key , "attempt" , other_attempts [' attempts' ], "\033 [00m" )
886- other_attempts [' attempts' ] += 1
960+ print ("\033 [31m" , key , "attempt" , other_attempts [" attempts" ], "\033 [00m" )
961+ other_attempts [" attempts" ] += 1
887962
888963 metadata = kwargs .get ("metadata" , [])
889964 if not metadata :
890965 # Increment the reinvocation count.
891966 wrapped_attr ._attempt += 1
892967 return attr (* args , ** kwargs )
893968
894- print ("\033 [35mwrapped_attr" , key , args , kwargs , 'attempt' , wrapped_attr ._attempt , "\033 [00m" )
969+ print (
970+ "\033 [35mwrapped_attr" ,
971+ key ,
972+ args ,
973+ kwargs ,
974+ "attempt" ,
975+ wrapped_attr ._attempt ,
976+ "\033 [00m" ,
977+ )
895978
896979 # 4. Find all the headers that match the target header key.
897980 all_metadata = []
@@ -900,7 +983,7 @@ def wrapped_attr(*args, **kwargs):
900983 if wrapped_attr ._attempt > 0 :
901984 # 5. Increment the original_attempt with that of our re-invocation count.
902985 splits = value .split ("." )
903- print (' \033 [34mkey' , mkey , ' \033 [00m' , splits )
986+ print (" \033 [34mkey" , mkey , " \033 [00m" , splits )
904987 hdr_attempt_plus_reinvocation = (
905988 int (splits [- 1 ]) + wrapped_attr ._attempt
906989 )
@@ -916,13 +999,13 @@ def wrapped_attr(*args, **kwargs):
916999
9171000 if already_patched :
9181001 print ("patched_key \033 [32m" , patched_key , key , "\033 [00m" , already_patched )
919- setattr (attr , ' patched' , True )
1002+ setattr (attr , " patched" , True )
9201003 # Increment the reinvocation count.
9211004 patched_mu .release ()
9221005 return already_patched
9231006
9241007 patched [patched_key ] = wrapped_attr
925- setattr (wrapped_attr , ' _attempt' , 0 )
1008+ setattr (wrapped_attr , " _attempt" , 0 )
9261009 patched_mu .release ()
9271010 return wrapped_attr
9281011
0 commit comments