@@ -122,6 +122,8 @@ class _ProgramState:
122122 # Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference,
123123 # and should be copied to Program.backend_delegate_data.
124124 backend_delegate_data : List [BackendDelegateInlineData ] = field (default_factory = list )
125+ # Delegate cache that is used across all entry points. Key is the hash of the delegated payload.
126+ backend_delegate_data_cache : Dict [str , int ] = field (default_factory = dict )
125127
126128 # Constants are optionally stored in external files.
127129 # Aggregate unique external constants into one buffer.
@@ -144,7 +146,8 @@ class _EmitterState:
144146 operators : List [Operator ]
145147 delegates : List [BackendDelegate ]
146148 operator_cache : Dict [Tuple [str , str ], int ]
147- delegate_cache : Dict [bytes , int ]
149+ # delegate_cache: (key: hash(delegated_payload), value: index in delegates)
150+ delegate_cache : Dict [str , int ]
148151 emit_stacktrace : bool
149152
150153 spec2id_dict : Dict [TensorSpec , int ] = field (default_factory = dict )
@@ -1073,8 +1076,8 @@ def _emit_delegate(
10731076 """Emit the delegates inputs and outputs as specified by the schema, then emit the
10741077 delegate's blob."""
10751078 processed_bytes = lowered_module .processed_bytes
1076-
1077- delegate_index = self .emitter_state .delegate_cache .get (processed_bytes )
1079+ hashed = hashlib . sha256 ( processed_bytes ). hexdigest ()
1080+ delegate_index = self .emitter_state .delegate_cache .get (hashed )
10781081 delegate_ret = None
10791082
10801083 if isinstance (self .node .meta ["spec" ], list ):
@@ -1112,10 +1115,16 @@ def _emit_delegate(
11121115 if delegate_index is None :
11131116 # Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
11141117 # present.
1115- data_index : int = len ( self . program_state . backend_delegate_data )
1116- self . program_state . backend_delegate_data . append (
1117- BackendDelegateInlineData ( data = processed_bytes )
1118+ hashed = hashlib . sha256 ( processed_bytes ). hexdigest ( )
1119+ data_index : Optional [ int ] = (
1120+ self . program_state . backend_delegate_data_cache . get ( hashed )
11181121 )
1122+ if data_index is None :
1123+ data_index = len (self .program_state .backend_delegate_data )
1124+ self .program_state .backend_delegate_data_cache [hashed ] = data_index
1125+ self .program_state .backend_delegate_data .append (
1126+ BackendDelegateInlineData (data = processed_bytes )
1127+ )
11191128
11201129 backend_delegate = BackendDelegate (
11211130 id = lowered_module .backend_id ,
@@ -1126,7 +1135,7 @@ def _emit_delegate(
11261135 )
11271136 delegate_index = len (self .emitter_state .delegate_cache )
11281137 self .emitter_state .delegates .append (backend_delegate )
1129- self .emitter_state .delegate_cache [processed_bytes ] = delegate_index
1138+ self .emitter_state .delegate_cache [hashed ] = delegate_index
11301139
11311140 # TODO(angelayi) Will need to emit the kwargs too, in the correct order according to the
11321141 # function's spec and with default arguments. This requires us to store the function's spec
0 commit comments