diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index ffd31317e9f5..eb385786bc32 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2540,6 +2540,8 @@ class Unsqueeze(OnnxOpConverter): def run_calculation(cls, tensor, axes): axes = sorted(axes) for axis in axes: + if axis < 0 and isinstance(tensor, _expr.Var): + axis = len(tensor.type_annotation.concrete_shape) + len(axes) + axis tensor = _op.expand_dims(tensor, axis=axis, num_newaxis=1) return tensor @@ -2558,6 +2560,7 @@ def _impl_v13(cls, inputs, attr, params): num_new_axis = int(infer_type(inputs[1]).checked_type.shape[0]) axes = relay.sort(inputs[1]) axes = relay.split(axes, num_new_axis).astuple() + rank_output = rank_input + num_new_axis result = inputs[0] # TODO (AndrewZhaoLuo): investigate performance issues with consecutive @@ -2567,7 +2570,7 @@ def _impl_v13(cls, inputs, attr, params): # Unpack scalar axis = relay.reshape(axis, []) axis = relay.where( - axis >= relay.const(0, "int64"), axis, axis + relay.const(rank_input, "int64") + axis >= relay.const(0, "int64"), axis, axis + relay.const(rank_output, "int64") ) result = _op.expand_dims(result, axis) return result diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index f5b5f7c65cb5..ee16d039d741 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -676,6 +676,45 @@ def test_unsqueeze(target, dev): verify_with_ort(model, [in_shape], target=target, dev=dev, opset=11) +@tvm.testing.parametrize_targets +def test_unsqueeze_with_neg_axes(target, dev): + def verify_unsqueeze_with_neg_axes(opset=11): + in_shape = (2, 3, 4) + axis = (-2, -1) + out_shape = (2, 3, 4, 1, 1) + if opset < 13: + y = helper.make_node("Unsqueeze", ["in"], ["out"], axes=list(axis)) + nodes = [y] + else: + axes = np.array(list(axis)).astype(np.int64) + axes = helper.make_node( + "Constant", + inputs=[], + outputs=["axes"], + value=onnx.helper.make_tensor( + name="const_axes", + data_type=onnx.TensorProto.INT64, + dims=axes.shape, + vals=axes.flatten().astype(int), + ), + ) + y = helper.make_node("Unsqueeze", ["in", "axes"], ["out"]) + nodes = [axes, y] + + graph = helper.make_graph( + nodes, + "squeeze_test", + inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + ) + + model = helper.make_model(graph, producer_name="squeeze_test") + verify_with_ort(model, [in_shape], target=target, dev=dev, opset=opset) + + verify_unsqueeze_with_neg_axes() + verify_unsqueeze_with_neg_axes(opset=13) + + @tvm.testing.parametrize_targets def test_gather(target, dev): """test_gather"""