diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 0d20c76240bd..15624b6c3a22 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -232,6 +232,17 @@ def convert_fully_connected(node, **kwargs): fcnode = [] + op_name = "flatten_" + str(kwargs["idx"]) + flatten_node = onnx.helper.make_node( + 'Flatten', + inputs=[input_nodes[0]], + outputs=[op_name], + name=op_name + ) + + input_nodes[0] = op_name + fcnode.append(flatten_node) + if no_bias: data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')] bias_name = "bias" + str(kwargs["idx"]) diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py index 22db0d637a3a..b4fa4b12c781 100644 --- a/tests/python-pytest/onnx/export/mxnet_export_test.py +++ b/tests/python-pytest/onnx/export/mxnet_export_test.py @@ -94,15 +94,33 @@ def get_test_files(name): def forward_pass(sym, arg, aux, data_names, input_data): - """ Perform forward pass on given data""" + """ Perform forward pass on given data + :param sym: Symbol + :param arg: Arg params + :param aux: Aux params + :param data_names: Input names (list) + :param input_data: Input data (list). If there is only one input, + pass it as a list. For example, if input is [1, 2], + pass input_data=[[1, 2]] + :return: result of forward pass + """ # create module mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None) - mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None) + + data_shapes = [] + data_forward = [] + for idx in range(len(data_names)): + val = input_data[idx] + data_shapes.append((data_names[idx], np.shape(val))) + data_forward.append(mx.nd.array(val)) + + mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None) mod.set_params(arg_params=arg, aux_params=aux, allow_missing=True, allow_extra=True) + # run inference batch = namedtuple('Batch', ['data']) - mod.forward(batch([mx.nd.array(input_data)]), is_train=False) + mod.forward(batch(data_forward), is_train=False) return mod.get_outputs()[0].asnumpy() @@ -136,7 +154,7 @@ def test_models(model_name, input_shape, output_shape): logging.info("Running inference on onnx re-import model in mxnet") # run test for each test file for input_data, output_data in zip(inputs, outputs): - result = forward_pass(sym, arg_params, aux_params, data_names, input_data) + result = forward_pass(sym, arg_params, aux_params, data_names, [input_data]) # verify the results npt.assert_equal(result.shape, output_data.shape) @@ -156,7 +174,7 @@ def test_model_accuracy(model_name, input_shape): expected_result= [] for input_data, output_data in zip(inputs, outputs): - result = forward_pass(sym, arg_params, aux_params, data_names, input_data) + result = forward_pass(sym, arg_params, aux_params, data_names, [input_data]) expected_result.append(result) params = {} @@ -178,7 +196,7 @@ def test_model_accuracy(model_name, input_shape): actual_result = [] for input_data, output_data in zip(inputs, outputs): - result = forward_pass(sym, arg_params, aux_params, data_names, input_data) + result = forward_pass(sym, arg_params, aux_params, data_names, [input_data]) actual_result.append(result) # verify the results @@ -235,13 +253,59 @@ def test_square(): converted_model = onnx_mxnet.export_model(square, params, [np.shape(input1)], np.float32, "square.onnx") sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model) - result = forward_pass(sym, arg_params, aux_params, ['input1'], input1) + result = forward_pass(sym, arg_params, aux_params, ['input1'], [input1]) numpy_op = np.square(input1) npt.assert_almost_equal(result, numpy_op) +@with_seed() +def test_fully_connected(): + def random_arrays(*shapes): + """Generate some random numpy arrays.""" + arrays = [np.random.randn(*s).astype("float32") + for s in shapes] + if len(arrays) == 1: + return arrays[0] + return arrays + + data_names = ['x', 'w', 'b'] + + dim_in, dim_out = (3, 4) + input_data = random_arrays((4, dim_in), (dim_out, dim_in), (dim_out,)) + + ipsym = [] + data_shapes = [] + data_forward = [] + for idx in range(len(data_names)): + val = input_data[idx] + data_shapes.append((data_names[idx], np.shape(val))) + data_forward.append(mx.nd.array(val)) + ipsym.append(mx.sym.Variable(data_names[idx])) + + op = mx.sym.FullyConnected(data=ipsym[0], weight=ipsym[1], bias=ipsym[2], num_hidden=dim_out, name='FC') + + model = mx.mod.Module(op, data_names=data_names, label_names=None) + model.bind(for_training=False, data_shapes=data_shapes, label_shapes=None) + + model.init_params() + + args, auxs = model.get_params() + params = {} + params.update(args) + params.update(auxs) + + converted_model = onnx_mxnet.export_model(op, params, [shape[1] for shape in data_shapes], np.float32, "fc.onnx") + + sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model) + result = forward_pass(sym, arg_params, aux_params, data_names, input_data) + + numpy_op = np.dot(input_data[0], input_data[1].T) + input_data[2] + + npt.assert_almost_equal(result, numpy_op) + + def test_softmax(): input1 = np.random.rand(1000, 1000).astype("float32") label1 = np.random.rand(1000)