diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 30e8188a8312..997aa6240e9e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2420,6 +2420,8 @@ def _impl_v1(cls, inputs, attr, params): axis += ndim if axis == 0: reshape_shape = [-1] + elif axis == ndim - 1: + return _op.nn.softmax(inputs[0], axis=axis) else: axis_val = [in_shape[i] for i in range(axis)] reshape_shape = [np.prod(axis_val)] + [-1] diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 41123a254825..55f8f6c48255 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1589,26 +1589,45 @@ def test_upsample3d_trilinear(target, dev): tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5) +# TODO: Fix softmax with dynamic input on cuda and enable this test +@tvm.testing.known_failing_targets("cuda") @tvm.testing.parametrize_targets def test_softmax(target, dev): - def verify_softmax(inshape, axis): + def verify_softmax(inshape, axis, opset=None, dynamic=False): opname = "Softmax" - indata = np.random.uniform(size=inshape).astype(np.float32) outshape = inshape - y = helper.make_node(opname, ["in"], ["out"]) + node_list = [] + input_node_list = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(inshape))] + output_node_list = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outshape))] + input_list = [np.random.uniform(size=inshape).astype(np.float32)] + softmax_inputs = ["in"] + + if dynamic: + input_node_list.append( + helper.make_tensor_value_info("shape", TensorProto.INT64, [len(inshape)]) + ) + input_list.append(np.asarray(inshape)) + reshape_node = helper.make_node("Reshape", ["in", "shape"], ["dynamic_in"]) + softmax_inputs[0] = "dynamic_in" + node_list += [reshape_node] + + y = helper.make_node(opname, softmax_inputs, ["out"]) if axis is not None: axis_attr = helper.make_attribute("axis", axis) y.attribute.append(axis_attr) + node_list.append(y) graph = helper.make_graph( - [y], + node_list, opname + "_test", - inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(outshape))], + inputs=input_node_list, + outputs=output_node_list, ) model = helper.make_model(graph, producer_name=opname + "_test") - verify_with_ort_with_inputs(model, [indata], target=target, dev=dev) + verify_with_ort_with_inputs( + model, input_list, use_vm=True, opset=opset, target=target, dev=dev + ) verify_softmax((1, 10), None) verify_softmax((1, 10), 1) @@ -1616,6 +1635,10 @@ def verify_softmax(inshape, axis): verify_softmax((1, 2, 3, 10), 2) verify_softmax((1, 2, 3, 4, 10), 3) verify_softmax((1, 2, 3, 4, 10), 4) + verify_softmax((1, 10), -1, dynamic=True) + verify_softmax((1, 2, 3, 10), -1, dynamic=True) + verify_softmax((1, 10), -1, opset=8, dynamic=True) + verify_softmax((1, 2, 3, 10), -1, opset=8, dynamic=True) @tvm.testing.parametrize_targets