Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion onnxruntime/python/tools/quantization/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,14 @@ def get_op_version(op_type, model):
return opset_import.version
raise RuntimeError(f"Model does not contain a version for '{op_type}'.")

def insert_nodes(tensor_name, new_nodes):
index = next(
(i for i, x in enumerate(self.model.graph.node) if tensor_name in x.input), len(self.model.graph.node)
)
for node in new_nodes:
self.model.graph.node.insert(index, node)
index += 1

def add_reduce_min_max(tensor_name, reduce_op_name):
# When doing ReduceMax/ReduceMin, ORT can't reduce on dim with value of 0 if 'keepdims' is false.
# To make the code simple, we always let keepdims to be 1.
Expand Down Expand Up @@ -396,7 +404,7 @@ def add_reduce_min_max(tensor_name, reduce_op_name):
reduce_node.input.append(reduce_axes_name)
self.model.graph.initializer.append(reduce_axes)

self.model.graph.node.extend([reduce_node, reshape_node])
insert_nodes(tensor_name, [reduce_node, reshape_node])
self.model.graph.output.append(helper.make_tensor_value_info(reduce_output, onnx_type, [None]))

for tensor in tensors:
Expand Down
Loading