From ea33835803c51e43ad29b6cafa03d222241a6ead Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Thu, 17 Dec 2020 12:55:21 -0800 Subject: [PATCH] add test --- tests/python-pytest/onnx/test_operators.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 0ed1188dc9f6..b3e73d7a0732 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -51,7 +51,8 @@ def export_to_onnx(model, model_name, inputs): def onnx_rt(onnx_file, inputs): sess = rt.InferenceSession(onnx_file) - input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy()) for i in range(len(inputs))) + dtype_0 = inputs[0].asnumpy().dtype + input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy().astype(dtype_0)) for i in range(len(inputs))) pred = sess.run(None, input_dict)[0] return pred @@ -205,3 +206,15 @@ def test_onnx_export_fully_connected(tmp_path, dtype, num_hidden, no_bias, flatt if not no_bias: args.append(mx.nd.random.uniform(0,1,(num_hidden,))) op_export_test('FullyConnected', M, args, tmp_path) + +@pytest.mark.parametrize('dtype', ['float32', 'float64']) +def test_onnx_export_softmax(tmp_path, dtype): + x = mx.nd.random.uniform(0, 1, (2, 3, 4), dtype=dtype) + M1 = def_model('softmax') + op_export_test('softmax_1', M1, [x], tmp_path) + M2 = def_model('softmax', use_length=True, axis=0) + l2 = mx.nd.array([[2,0,3,1],[1,3,2,0], [0,0,0,1]], dtype=int) + op_export_test('softmax_2', M1, [x, l2], tmp_path) + M3 = def_model('softmax', use_length=True, axis=-1, temperature=0.5) + l3 = mx.nd.array([[2,0,1],[0,0,0]], dtype=int) + op_export_test('softmax_3', M1, [x, l3], tmp_path)