@@ -122,6 +122,8 @@ class _ProgramState:
122122 # Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference,
123123 # and should be copied to Program.backend_delegate_data.
124124 backend_delegate_data : List [BackendDelegateInlineData ] = field (default_factory = list )
125+ # Delegate cache that is used across all entry points. Key is the hash of the delegated payload.
126+ backend_delegate_data_cache : Dict [str , int ] = field (default_factory = dict )
125127
126128 # Constants are optionally stored in external files.
127129 # Aggregate unique external constants into one buffer.
@@ -144,7 +146,8 @@ class _EmitterState:
144146 operators : List [Operator ]
145147 delegates : List [BackendDelegate ]
146148 operator_cache : Dict [Tuple [str , str ], int ]
147- delegate_cache : Dict [bytes , int ]
149+ # delegate_cache: (key: hash(delegated_payload), value: index in delegates)
150+ delegate_cache : Dict [str , int ]
148151 emit_stacktrace : bool
149152
150153 spec2id_dict : Dict [TensorSpec , int ] = field (default_factory = dict )
@@ -1073,8 +1076,8 @@ def _emit_delegate(
10731076 """Emit the delegates inputs and outputs as specified by the schema, then emit the
10741077 delegate's blob."""
10751078 processed_bytes = lowered_module .processed_bytes
1076-
1077- delegate_index = self .emitter_state .delegate_cache .get (processed_bytes )
1079+ hashed = hashlib . sha256 ( processed_bytes ). hexdigest ()
1080+ delegate_index = self .emitter_state .delegate_cache .get (hashed )
10781081 delegate_ret = None
10791082
10801083 if isinstance (self .node .meta ["spec" ], list ):
@@ -1112,10 +1115,14 @@ def _emit_delegate(
11121115 if delegate_index is None :
11131116 # Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
11141117 # present.
1115- data_index : int = len (self .program_state .backend_delegate_data )
1116- self .program_state .backend_delegate_data .append (
1117- BackendDelegateInlineData (data = processed_bytes )
1118- )
1118+ hashed = hashlib .sha256 (processed_bytes ).hexdigest ()
1119+ data_index : Optional [int ] = self .program_state .backend_delegate_data_cache .get (hashed )
1120+ if data_index is None :
1121+ data_index = len (self .program_state .backend_delegate_data )
1122+ self .program_state .backend_delegate_data_cache [hashed ] = data_index
1123+ self .program_state .backend_delegate_data .append (
1124+ BackendDelegateInlineData (data = processed_bytes )
1125+ )
11191126
11201127 backend_delegate = BackendDelegate (
11211128 id = lowered_module .backend_id ,
@@ -1126,7 +1133,7 @@ def _emit_delegate(
11261133 )
11271134 delegate_index = len (self .emitter_state .delegate_cache )
11281135 self .emitter_state .delegates .append (backend_delegate )
1129- self .emitter_state .delegate_cache [processed_bytes ] = delegate_index
1136+ self .emitter_state .delegate_cache [hashed ] = delegate_index
11301137
11311138 # TODO(angelayi) Will need to emit the kwargs too, in the correct order according to the
11321139 # function's spec and with default arguments. This requires us to store the function's spec
0 commit comments