Skip to content

Commit 2bfdcbe

Browse files
[Relay] Convert negative axes to positive when importing ONNX Unsqueeze (#13846)
1 parent 56771a8 commit 2bfdcbe

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2540,6 +2540,8 @@ class Unsqueeze(OnnxOpConverter):
25402540
def run_calculation(cls, tensor, axes):
25412541
axes = sorted(axes)
25422542
for axis in axes:
2543+
if axis < 0 and isinstance(tensor, _expr.Var):
2544+
axis = len(tensor.type_annotation.concrete_shape) + len(axes) + axis
25432545
tensor = _op.expand_dims(tensor, axis=axis, num_newaxis=1)
25442546
return tensor
25452547

@@ -2558,6 +2560,7 @@ def _impl_v13(cls, inputs, attr, params):
25582560
num_new_axis = int(infer_type(inputs[1]).checked_type.shape[0])
25592561
axes = relay.sort(inputs[1])
25602562
axes = relay.split(axes, num_new_axis).astuple()
2563+
rank_output = rank_input + num_new_axis
25612564
result = inputs[0]
25622565

25632566
# TODO (AndrewZhaoLuo): investigate performance issues with consecutive
@@ -2567,7 +2570,7 @@ def _impl_v13(cls, inputs, attr, params):
25672570
# Unpack scalar
25682571
axis = relay.reshape(axis, [])
25692572
axis = relay.where(
2570-
axis >= relay.const(0, "int64"), axis, axis + relay.const(rank_input, "int64")
2573+
axis >= relay.const(0, "int64"), axis, axis + relay.const(rank_output, "int64")
25712574
)
25722575
result = _op.expand_dims(result, axis)
25732576
return result

tests/python/frontend/onnx/test_forward.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,45 @@ def test_unsqueeze(target, dev):
676676
verify_with_ort(model, [in_shape], target=target, dev=dev, opset=11)
677677

678678

679+
@tvm.testing.parametrize_targets
680+
def test_unsqueeze_with_neg_axes(target, dev):
681+
def verify_unsqueeze_with_neg_axes(opset=11):
682+
in_shape = (2, 3, 4)
683+
axis = (-2, -1)
684+
out_shape = (2, 3, 4, 1, 1)
685+
if opset < 13:
686+
y = helper.make_node("Unsqueeze", ["in"], ["out"], axes=list(axis))
687+
nodes = [y]
688+
else:
689+
axes = np.array(list(axis)).astype(np.int64)
690+
axes = helper.make_node(
691+
"Constant",
692+
inputs=[],
693+
outputs=["axes"],
694+
value=onnx.helper.make_tensor(
695+
name="const_axes",
696+
data_type=onnx.TensorProto.INT64,
697+
dims=axes.shape,
698+
vals=axes.flatten().astype(int),
699+
),
700+
)
701+
y = helper.make_node("Unsqueeze", ["in", "axes"], ["out"])
702+
nodes = [axes, y]
703+
704+
graph = helper.make_graph(
705+
nodes,
706+
"squeeze_test",
707+
inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
708+
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))],
709+
)
710+
711+
model = helper.make_model(graph, producer_name="squeeze_test")
712+
verify_with_ort(model, [in_shape], target=target, dev=dev, opset=opset)
713+
714+
verify_unsqueeze_with_neg_axes()
715+
verify_unsqueeze_with_neg_axes(opset=13)
716+
717+
679718
@tvm.testing.parametrize_targets
680719
def test_gather(target, dev):
681720
"""test_gather"""

0 commit comments

Comments
 (0)