From 739f181c2e3bbedc43e79fda7a46419840ad48ad Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Fri, 7 Dec 2018 10:20:54 -0800 Subject: [PATCH] ONNX export: Test for fully connected --- .../onnx/export/mxnet_export_test.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py index a8b99e644129..b4fa4b12c781 100644 --- a/tests/python-pytest/onnx/export/mxnet_export_test.py +++ b/tests/python-pytest/onnx/export/mxnet_export_test.py @@ -260,6 +260,52 @@ def test_square(): npt.assert_almost_equal(result, numpy_op) +@with_seed() +def test_fully_connected(): + def random_arrays(*shapes): + """Generate some random numpy arrays.""" + arrays = [np.random.randn(*s).astype("float32") + for s in shapes] + if len(arrays) == 1: + return arrays[0] + return arrays + + data_names = ['x', 'w', 'b'] + + dim_in, dim_out = (3, 4) + input_data = random_arrays((4, dim_in), (dim_out, dim_in), (dim_out,)) + + ipsym = [] + data_shapes = [] + data_forward = [] + for idx in range(len(data_names)): + val = input_data[idx] + data_shapes.append((data_names[idx], np.shape(val))) + data_forward.append(mx.nd.array(val)) + ipsym.append(mx.sym.Variable(data_names[idx])) + + op = mx.sym.FullyConnected(data=ipsym[0], weight=ipsym[1], bias=ipsym[2], num_hidden=dim_out, name='FC') + + model = mx.mod.Module(op, data_names=data_names, label_names=None) + model.bind(for_training=False, data_shapes=data_shapes, label_shapes=None) + + model.init_params() + + args, auxs = model.get_params() + params = {} + params.update(args) + params.update(auxs) + + converted_model = onnx_mxnet.export_model(op, params, [shape[1] for shape in data_shapes], np.float32, "fc.onnx") + + sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model) + result = forward_pass(sym, arg_params, aux_params, data_names, input_data) + + numpy_op = np.dot(input_data[0], input_data[1].T) + input_data[2] + + npt.assert_almost_equal(result, numpy_op) + + def test_softmax(): input1 = np.random.rand(1000, 1000).astype("float32") label1 = np.random.rand(1000)