Skip to content

Commit f651c88

Browse files
committed
single pending ops queue
process pending ops recursively
1 parent cfd9890 commit f651c88

File tree

1 file changed

+46
-73
lines changed

1 file changed

+46
-73
lines changed

tensorflow/lite/micro/compression/relocate_ops.py

+46-73
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@
5555
{ ( subgraph_index, tensor_index ) : var_handle_id }
5656
"""
5757

58-
ReadVarOps = List[Dict[int, model_facade.Operator]]
59-
60-
ConcatOps = List[Dict[int, model_facade.Operator]]
58+
PendingOps = List[Dict[TensorIndex, model_facade.Operator]]
59+
"""PendingOps
60+
[ { output_tensor_index : operator }]
61+
"""
6162

6263

6364
class Context:
@@ -73,8 +74,7 @@ def __init__(self, model: model_facade.Model) -> None:
7374
self._subgraph_processed: List[bool] = [False] * len(model.subgraphs)
7475
self._subgraph_modified_vars: List[VarHandles] = [set()] * len(
7576
model.subgraphs)
76-
self._subgraph_read_var_ops: ReadVarOps = [{}] * len(model.subgraphs)
77-
self._subgraph_concat_ops: ConcatOps = [{}] * len(model.subgraphs)
77+
self._pending_ops: PendingOps = [{}] * len(model.subgraphs)
7878
self._var_handles_by_name: VarHandleByName = {}
7979
self._var_handles_by_tensor: VarHandleByTensor = {}
8080
self._current_var_handle_id: VarHandleId = 0
@@ -116,47 +116,35 @@ def set_subgraph_var_handles(self, subgraph_index: SubgraphIndex,
116116
handles: VarHandles) -> None:
117117
self._subgraph_modified_vars[subgraph_index] = handles
118118

119-
def add_read_var_op(self, op: model_facade.Operator) -> None:
120-
assert op.builtin_opcode == tflite.BuiltinOperator.READ_VARIABLE
119+
def add_pending_op(self, op: model_facade.Operator) -> None:
120+
assert len(op.outputs_indices) == 1
121121
key: TensorIndex = op.outputs_indices[0]
122-
self._subgraph_read_var_ops[op.subgraph.index][key] = op
122+
assert self._pending_ops[op.subgraph.index].get(key) is None
123+
self._pending_ops[op.subgraph.index][key] = op
123124

124-
def remove_read_var_op(self, op: model_facade.Operator) -> None:
125-
assert op.builtin_opcode == tflite.BuiltinOperator.READ_VARIABLE
125+
def remove_pending_op(self, op: model_facade.Operator) -> None:
126126
key: TensorIndex = op.outputs_indices[0]
127-
del self._subgraph_read_var_ops[op.subgraph.index][key]
127+
assert self._pending_ops[op.subgraph.index][key].index == op.index
128+
del self._pending_ops[op.subgraph.index][key]
128129

129-
def get_read_var_op_by_tensor(
130+
def get_pending_op(
130131
self, tensor_index: TensorIndex,
131132
subgraph_index: SubgraphIndex) -> model_facade.Operator | None:
132-
return self._subgraph_read_var_ops[subgraph_index].get(tensor_index, None)
133+
return self._pending_ops[subgraph_index].get(tensor_index, None)
133134

134135
def get_read_var_op_by_handle(
135136
self, resource_tensor_index: TensorIndex,
136137
subgraph_index: SubgraphIndex) -> List[model_facade.Operator]:
137138
result: List[model_facade.Operator] = []
138139
var_handle_id = self.get_var_handle(subgraph_index, resource_tensor_index)
139-
for op in self._subgraph_read_var_ops[subgraph_index].values():
140+
for op in self._pending_ops[subgraph_index].values():
141+
if op.builtin_opcode != tflite.BuiltinOperator.READ_VARIABLE:
142+
continue
140143
if self.get_var_handle(op.subgraph.index,
141144
op.inputs_indices[0]) == var_handle_id:
142145
result.append(op)
143146
return result
144147

145-
def add_concat_op(self, op: model_facade.Operator) -> None:
146-
assert op.builtin_opcode == tflite.BuiltinOperator.CONCATENATION
147-
key: TensorIndex = op.outputs_indices[0]
148-
self._subgraph_concat_ops[op.subgraph.index][key] = op
149-
150-
def remove_concat_op(self, op: model_facade.Operator) -> None:
151-
assert op.builtin_opcode == tflite.BuiltinOperator.CONCATENATION
152-
key: TensorIndex = op.outputs_indices[0]
153-
del self._subgraph_concat_ops[op.subgraph.index][key]
154-
155-
def get_concat_op_by_tensor(
156-
self, tensor_index: TensorIndex,
157-
subgraph_index: SubgraphIndex) -> model_facade.Operator | None:
158-
return self._subgraph_concat_ops[subgraph_index].get(tensor_index, None)
159-
160148
def create_var_handle(self, container_name: str | None, resource_name: str,
161149
subgraph_index: SubgraphIndex,
162150
resource_tensor_index: TensorIndex) -> VarHandleId:
@@ -201,27 +189,15 @@ def process_operator_var_handle(context: Context) -> VarHandles:
201189

202190
def process_operator_assign_variable(context: Context) -> VarHandles:
203191
assign_op = context.current_op()
204-
pending_concat_op = context.get_concat_op_by_tensor(
205-
assign_op.inputs_indices[1], assign_op.subgraph.index)
206-
assert pending_concat_op is None
207-
read_var_op = context.get_read_var_op_by_tensor(assign_op.inputs_indices[1],
208-
assign_op.subgraph.index)
209-
if read_var_op is not None:
210-
context.append_to_reordered_operations(read_var_op)
211-
context.remove_read_var_op(read_var_op)
212-
213192
for read_var_op in context.get_read_var_op_by_handle(
214193
assign_op.inputs_indices[0], assign_op.subgraph.index):
215194
context.append_to_reordered_operations(read_var_op)
216-
context.remove_read_var_op(read_var_op)
195+
context.remove_pending_op(read_var_op)
217196

218-
context.append_to_reordered_operations(assign_op)
219-
return set()
220-
221-
222-
def process_operator_read_variable(context: Context) -> VarHandles:
223-
context.add_read_var_op(context.current_op())
224-
return set()
197+
process_pending_ops(context)
198+
var_handle_id = context.get_var_handle(assign_op.subgraph.index,
199+
assign_op.inputs_indices[0])
200+
return set([var_handle_id])
225201

226202

227203
def process_operator_call_once(context: Context) -> VarHandles:
@@ -239,45 +215,42 @@ def process_operator_while(context: Context) -> VarHandles:
239215
return set()
240216

241217

242-
def process_operator_concatenation(context: Context) -> VarHandles:
243-
context.add_concat_op(context.current_op())
218+
def process_operator_as_pending(context: Context) -> VarHandles:
219+
context.add_pending_op(context.current_op())
244220
return set()
245221

246222

223+
def process_pending_ops(context: Context) -> None:
224+
op = context.current_op()
225+
for tensor_input in op.inputs_indices:
226+
pending_op = context.get_pending_op(tensor_input, op.subgraph.index)
227+
if pending_op is not None:
228+
context.remove_pending_op(pending_op)
229+
context.push_current_op(pending_op)
230+
process_pending_ops(context)
231+
context.pop_current_op()
232+
233+
context.append_to_reordered_operations(op)
234+
235+
247236
def process_operator(context: Context) -> VarHandles:
248237
op = context.current_op()
249238
if op.builtin_opcode == tflite.BuiltinOperator.VAR_HANDLE:
250239
return process_operator_var_handle(context)
251240
elif op.builtin_opcode == tflite.BuiltinOperator.ASSIGN_VARIABLE:
252241
return process_operator_assign_variable(context)
253242
elif op.builtin_opcode == tflite.BuiltinOperator.READ_VARIABLE:
254-
return process_operator_read_variable(context)
243+
return process_operator_as_pending(context)
255244
elif op.builtin_opcode == tflite.BuiltinOperator.CONCATENATION:
256-
return process_operator_concatenation(context)
245+
return process_operator_as_pending(context)
246+
elif op.builtin_opcode == tflite.BuiltinOperator.IF:
247+
return process_operator_if(context)
248+
elif op.builtin_opcode == tflite.BuiltinOperator.WHILE:
249+
return process_operator_while(context)
250+
elif op.builtin_opcode == tflite.BuiltinOperator.CALL_ONCE:
251+
return process_operator_call_once(context)
257252
else:
258-
for tensor_input in op.inputs_indices:
259-
concat_op = context.get_concat_op_by_tensor(tensor_input,
260-
op.subgraph.index)
261-
if concat_op is not None:
262-
for concat_tensor_input in concat_op.inputs_indices:
263-
pending_concat_op = context.get_concat_op_by_tensor(
264-
concat_tensor_input, op.subgraph.index)
265-
assert pending_concat_op is None
266-
read_var_op = context.get_read_var_op_by_tensor(
267-
concat_tensor_input, op.subgraph.index)
268-
if read_var_op is not None:
269-
context.append_to_reordered_operations(read_var_op)
270-
context.remove_read_var_op(read_var_op)
271-
272-
context.append_to_reordered_operations(concat_op)
273-
context.remove_concat_op(concat_op)
274-
275-
read_var_op = context.get_read_var_op_by_tensor(tensor_input,
276-
op.subgraph.index)
277-
if read_var_op is not None:
278-
context.append_to_reordered_operations(read_var_op)
279-
context.remove_read_var_op(read_var_op)
280-
context.append_to_reordered_operations(op)
253+
process_pending_ops(context)
281254

282255
return set()
283256

0 commit comments

Comments
 (0)