From d289ef3e4bbc6dbea37aec7f22f0720c06423254 Mon Sep 17 00:00:00 2001 From: Roshani Nagmote Date: Thu, 6 Dec 2018 11:52:01 -0800 Subject: [PATCH] Adding test for softmaxoutput --- .../contrib/onnx/mx2onnx/_op_translations.py | 2 +- .../onnx/export/mxnet_export_test.py | 22 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 86767a667128..e605e824be43 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -705,7 +705,7 @@ def convert_softmax_output(node, **kwargs): softmax_node = onnx.helper.make_node( "Softmax", - [input1.output[0]], + [input1.name], [name], axis=1, name=name diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py index 6b858f05e24f..22db0d637a3a 100644 --- a/tests/python-pytest/onnx/export/mxnet_export_test.py +++ b/tests/python-pytest/onnx/export/mxnet_export_test.py @@ -241,6 +241,28 @@ def test_square(): npt.assert_almost_equal(result, numpy_op) + +def test_softmax(): + input1 = np.random.rand(1000, 1000).astype("float32") + label1 = np.random.rand(1000) + input_nd = mx.nd.array(input1) + label_nd = mx.nd.array(label1) + + ipsym = mx.sym.Variable("ipsym") + label = mx.sym.Variable('label') + sym = mx.sym.SoftmaxOutput(data=ipsym, label=label, ignore_label=0, use_ignore=False) + ex = sym.bind(ctx=mx.cpu(0), args={'ipsym': input_nd, 'label': label_nd}) + ex.forward(is_train=True) + softmax_out = ex.outputs[0].asnumpy() + + converted_model = onnx_mxnet.export_model(sym, {}, [(1000, 1000), (1000,)], np.float32, "softmaxop.onnx") + + sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model) + result = forward_pass(sym, arg_params, aux_params, ['ipsym'], input1) + + # Comparing result of forward pass before using onnx export, import + npt.assert_almost_equal(result, softmax_out) + @with_seed() def test_comparison_ops(): """Test greater, lesser, equal"""