Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

ONNX export: Add Flatten before Gemm #13356

Merged
merged 3 commits into from
Dec 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ def convert_fully_connected(node, **kwargs):

fcnode = []

op_name = "flatten_" + str(kwargs["idx"])
flatten_node = onnx.helper.make_node(
'Flatten',
inputs=[input_nodes[0]],
outputs=[op_name],
name=op_name
)

input_nodes[0] = op_name
fcnode.append(flatten_node)

if no_bias:
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')]
bias_name = "bias" + str(kwargs["idx"])
Expand Down
78 changes: 71 additions & 7 deletions tests/python-pytest/onnx/export/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,33 @@ def get_test_files(name):


def forward_pass(sym, arg, aux, data_names, input_data):
""" Perform forward pass on given data"""
""" Perform forward pass on given data
:param sym: Symbol
:param arg: Arg params
:param aux: Aux params
:param data_names: Input names (list)
:param input_data: Input data (list). If there is only one input,
pass it as a list. For example, if input is [1, 2],
pass input_data=[[1, 2]]
:return: result of forward pass
"""
# create module
mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None)

data_shapes = []
data_forward = []
for idx in range(len(data_names)):
val = input_data[idx]
data_shapes.append((data_names[idx], np.shape(val)))
data_forward.append(mx.nd.array(val))

mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
mod.set_params(arg_params=arg, aux_params=aux,
allow_missing=True, allow_extra=True)

# run inference
batch = namedtuple('Batch', ['data'])
mod.forward(batch([mx.nd.array(input_data)]), is_train=False)
mod.forward(batch(data_forward), is_train=False)

return mod.get_outputs()[0].asnumpy()

Expand Down Expand Up @@ -136,7 +154,7 @@ def test_models(model_name, input_shape, output_shape):
logging.info("Running inference on onnx re-import model in mxnet")
# run test for each test file
for input_data, output_data in zip(inputs, outputs):
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
result = forward_pass(sym, arg_params, aux_params, data_names, [input_data])

# verify the results
npt.assert_equal(result.shape, output_data.shape)
Expand All @@ -156,7 +174,7 @@ def test_model_accuracy(model_name, input_shape):

expected_result= []
for input_data, output_data in zip(inputs, outputs):
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
result = forward_pass(sym, arg_params, aux_params, data_names, [input_data])
expected_result.append(result)

params = {}
Expand All @@ -178,7 +196,7 @@ def test_model_accuracy(model_name, input_shape):

actual_result = []
for input_data, output_data in zip(inputs, outputs):
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
result = forward_pass(sym, arg_params, aux_params, data_names, [input_data])
actual_result.append(result)

# verify the results
Expand Down Expand Up @@ -235,13 +253,59 @@ def test_square():
converted_model = onnx_mxnet.export_model(square, params, [np.shape(input1)], np.float32, "square.onnx")

sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)
result = forward_pass(sym, arg_params, aux_params, ['input1'], [input1])

numpy_op = np.square(input1)

npt.assert_almost_equal(result, numpy_op)


@with_seed()
def test_fully_connected():
def random_arrays(*shapes):
"""Generate some random numpy arrays."""
arrays = [np.random.randn(*s).astype("float32")
for s in shapes]
if len(arrays) == 1:
return arrays[0]
return arrays

data_names = ['x', 'w', 'b']

dim_in, dim_out = (3, 4)
input_data = random_arrays((4, dim_in), (dim_out, dim_in), (dim_out,))

ipsym = []
data_shapes = []
data_forward = []
for idx in range(len(data_names)):
val = input_data[idx]
data_shapes.append((data_names[idx], np.shape(val)))
data_forward.append(mx.nd.array(val))
ipsym.append(mx.sym.Variable(data_names[idx]))

op = mx.sym.FullyConnected(data=ipsym[0], weight=ipsym[1], bias=ipsym[2], num_hidden=dim_out, name='FC')

model = mx.mod.Module(op, data_names=data_names, label_names=None)
model.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)

model.init_params()

args, auxs = model.get_params()
params = {}
params.update(args)
params.update(auxs)

converted_model = onnx_mxnet.export_model(op, params, [shape[1] for shape in data_shapes], np.float32, "fc.onnx")

sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)

numpy_op = np.dot(input_data[0], input_data[1].T) + input_data[2]

npt.assert_almost_equal(result, numpy_op)


def test_softmax():
input1 = np.random.rand(1000, 1000).astype("float32")
label1 = np.random.rand(1000)
Expand Down