55
55
{ ( subgraph_index, tensor_index ) : var_handle_id }
56
56
"""
57
57
58
- ReadVarOps = List [Dict [int , model_facade .Operator ]]
59
-
60
- ConcatOps = List [Dict [int , model_facade .Operator ]]
58
+ PendingOps = List [Dict [TensorIndex , model_facade .Operator ]]
59
+ """PendingOps
60
+ [ { output_tensor_index : operator }]
61
+ """
61
62
62
63
63
64
class Context :
@@ -73,8 +74,7 @@ def __init__(self, model: model_facade.Model) -> None:
73
74
self ._subgraph_processed : List [bool ] = [False ] * len (model .subgraphs )
74
75
self ._subgraph_modified_vars : List [VarHandles ] = [set ()] * len (
75
76
model .subgraphs )
76
- self ._subgraph_read_var_ops : ReadVarOps = [{}] * len (model .subgraphs )
77
- self ._subgraph_concat_ops : ConcatOps = [{}] * len (model .subgraphs )
77
+ self ._pending_ops : PendingOps = [{}] * len (model .subgraphs )
78
78
self ._var_handles_by_name : VarHandleByName = {}
79
79
self ._var_handles_by_tensor : VarHandleByTensor = {}
80
80
self ._current_var_handle_id : VarHandleId = 0
@@ -116,47 +116,35 @@ def set_subgraph_var_handles(self, subgraph_index: SubgraphIndex,
116
116
handles : VarHandles ) -> None :
117
117
self ._subgraph_modified_vars [subgraph_index ] = handles
118
118
119
- def add_read_var_op (self , op : model_facade .Operator ) -> None :
120
- assert op .builtin_opcode == tflite . BuiltinOperator . READ_VARIABLE
119
+ def add_pending_op (self , op : model_facade .Operator ) -> None :
120
+ assert len ( op .outputs_indices ) == 1
121
121
key : TensorIndex = op .outputs_indices [0 ]
122
- self ._subgraph_read_var_ops [op .subgraph .index ][key ] = op
122
+ assert self ._pending_ops [op .subgraph .index ].get (key ) is None
123
+ self ._pending_ops [op .subgraph .index ][key ] = op
123
124
124
- def remove_read_var_op (self , op : model_facade .Operator ) -> None :
125
- assert op .builtin_opcode == tflite .BuiltinOperator .READ_VARIABLE
125
+ def remove_pending_op (self , op : model_facade .Operator ) -> None :
126
126
key : TensorIndex = op .outputs_indices [0 ]
127
- del self ._subgraph_read_var_ops [op .subgraph .index ][key ]
127
+ assert self ._pending_ops [op .subgraph .index ][key ].index == op .index
128
+ del self ._pending_ops [op .subgraph .index ][key ]
128
129
129
- def get_read_var_op_by_tensor (
130
+ def get_pending_op (
130
131
self , tensor_index : TensorIndex ,
131
132
subgraph_index : SubgraphIndex ) -> model_facade .Operator | None :
132
- return self ._subgraph_read_var_ops [subgraph_index ].get (tensor_index , None )
133
+ return self ._pending_ops [subgraph_index ].get (tensor_index , None )
133
134
134
135
def get_read_var_op_by_handle (
135
136
self , resource_tensor_index : TensorIndex ,
136
137
subgraph_index : SubgraphIndex ) -> List [model_facade .Operator ]:
137
138
result : List [model_facade .Operator ] = []
138
139
var_handle_id = self .get_var_handle (subgraph_index , resource_tensor_index )
139
- for op in self ._subgraph_read_var_ops [subgraph_index ].values ():
140
+ for op in self ._pending_ops [subgraph_index ].values ():
141
+ if op .builtin_opcode != tflite .BuiltinOperator .READ_VARIABLE :
142
+ continue
140
143
if self .get_var_handle (op .subgraph .index ,
141
144
op .inputs_indices [0 ]) == var_handle_id :
142
145
result .append (op )
143
146
return result
144
147
145
- def add_concat_op (self , op : model_facade .Operator ) -> None :
146
- assert op .builtin_opcode == tflite .BuiltinOperator .CONCATENATION
147
- key : TensorIndex = op .outputs_indices [0 ]
148
- self ._subgraph_concat_ops [op .subgraph .index ][key ] = op
149
-
150
- def remove_concat_op (self , op : model_facade .Operator ) -> None :
151
- assert op .builtin_opcode == tflite .BuiltinOperator .CONCATENATION
152
- key : TensorIndex = op .outputs_indices [0 ]
153
- del self ._subgraph_concat_ops [op .subgraph .index ][key ]
154
-
155
- def get_concat_op_by_tensor (
156
- self , tensor_index : TensorIndex ,
157
- subgraph_index : SubgraphIndex ) -> model_facade .Operator | None :
158
- return self ._subgraph_concat_ops [subgraph_index ].get (tensor_index , None )
159
-
160
148
def create_var_handle (self , container_name : str | None , resource_name : str ,
161
149
subgraph_index : SubgraphIndex ,
162
150
resource_tensor_index : TensorIndex ) -> VarHandleId :
@@ -201,27 +189,15 @@ def process_operator_var_handle(context: Context) -> VarHandles:
201
189
202
190
def process_operator_assign_variable (context : Context ) -> VarHandles :
203
191
assign_op = context .current_op ()
204
- pending_concat_op = context .get_concat_op_by_tensor (
205
- assign_op .inputs_indices [1 ], assign_op .subgraph .index )
206
- assert pending_concat_op is None
207
- read_var_op = context .get_read_var_op_by_tensor (assign_op .inputs_indices [1 ],
208
- assign_op .subgraph .index )
209
- if read_var_op is not None :
210
- context .append_to_reordered_operations (read_var_op )
211
- context .remove_read_var_op (read_var_op )
212
-
213
192
for read_var_op in context .get_read_var_op_by_handle (
214
193
assign_op .inputs_indices [0 ], assign_op .subgraph .index ):
215
194
context .append_to_reordered_operations (read_var_op )
216
- context .remove_read_var_op (read_var_op )
195
+ context .remove_pending_op (read_var_op )
217
196
218
- context .append_to_reordered_operations (assign_op )
219
- return set ()
220
-
221
-
222
- def process_operator_read_variable (context : Context ) -> VarHandles :
223
- context .add_read_var_op (context .current_op ())
224
- return set ()
197
+ process_pending_ops (context )
198
+ var_handle_id = context .get_var_handle (assign_op .subgraph .index ,
199
+ assign_op .inputs_indices [0 ])
200
+ return set ([var_handle_id ])
225
201
226
202
227
203
def process_operator_call_once (context : Context ) -> VarHandles :
@@ -239,45 +215,42 @@ def process_operator_while(context: Context) -> VarHandles:
239
215
return set ()
240
216
241
217
242
- def process_operator_concatenation (context : Context ) -> VarHandles :
243
- context .add_concat_op (context .current_op ())
218
+ def process_operator_as_pending (context : Context ) -> VarHandles :
219
+ context .add_pending_op (context .current_op ())
244
220
return set ()
245
221
246
222
223
+ def process_pending_ops (context : Context ) -> None :
224
+ op = context .current_op ()
225
+ for tensor_input in op .inputs_indices :
226
+ pending_op = context .get_pending_op (tensor_input , op .subgraph .index )
227
+ if pending_op is not None :
228
+ context .remove_pending_op (pending_op )
229
+ context .push_current_op (pending_op )
230
+ process_pending_ops (context )
231
+ context .pop_current_op ()
232
+
233
+ context .append_to_reordered_operations (op )
234
+
235
+
247
236
def process_operator (context : Context ) -> VarHandles :
248
237
op = context .current_op ()
249
238
if op .builtin_opcode == tflite .BuiltinOperator .VAR_HANDLE :
250
239
return process_operator_var_handle (context )
251
240
elif op .builtin_opcode == tflite .BuiltinOperator .ASSIGN_VARIABLE :
252
241
return process_operator_assign_variable (context )
253
242
elif op .builtin_opcode == tflite .BuiltinOperator .READ_VARIABLE :
254
- return process_operator_read_variable (context )
243
+ return process_operator_as_pending (context )
255
244
elif op .builtin_opcode == tflite .BuiltinOperator .CONCATENATION :
256
- return process_operator_concatenation (context )
245
+ return process_operator_as_pending (context )
246
+ elif op .builtin_opcode == tflite .BuiltinOperator .IF :
247
+ return process_operator_if (context )
248
+ elif op .builtin_opcode == tflite .BuiltinOperator .WHILE :
249
+ return process_operator_while (context )
250
+ elif op .builtin_opcode == tflite .BuiltinOperator .CALL_ONCE :
251
+ return process_operator_call_once (context )
257
252
else :
258
- for tensor_input in op .inputs_indices :
259
- concat_op = context .get_concat_op_by_tensor (tensor_input ,
260
- op .subgraph .index )
261
- if concat_op is not None :
262
- for concat_tensor_input in concat_op .inputs_indices :
263
- pending_concat_op = context .get_concat_op_by_tensor (
264
- concat_tensor_input , op .subgraph .index )
265
- assert pending_concat_op is None
266
- read_var_op = context .get_read_var_op_by_tensor (
267
- concat_tensor_input , op .subgraph .index )
268
- if read_var_op is not None :
269
- context .append_to_reordered_operations (read_var_op )
270
- context .remove_read_var_op (read_var_op )
271
-
272
- context .append_to_reordered_operations (concat_op )
273
- context .remove_concat_op (concat_op )
274
-
275
- read_var_op = context .get_read_var_op_by_tensor (tensor_input ,
276
- op .subgraph .index )
277
- if read_var_op is not None :
278
- context .append_to_reordered_operations (read_var_op )
279
- context .remove_read_var_op (read_var_op )
280
- context .append_to_reordered_operations (op )
253
+ process_pending_ops (context )
281
254
282
255
return set ()
283
256
0 commit comments