Skip to content

Commit

Permalink
ONNX export: Add Flatten before Gemm (apache#13356)
Browse files Browse the repository at this point in the history
* Add Flatten before Gemm

* ONNX export test: Allow multiple inputs in forward pass

* ONNX export: Test for fully connected
  • Loading branch information
vandanavk authored and haohuw committed Jun 23, 2019
1 parent e341368 commit aac961f
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 7 deletions.
11 changes: 11 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ def convert_fully_connected(node, **kwargs):

fcnode = []

op_name = "flatten_" + str(kwargs["idx"])
flatten_node = onnx.helper.make_node(
'Flatten',
inputs=[input_nodes[0]],
outputs=[op_name],
name=op_name
)

input_nodes[0] = op_name
fcnode.append(flatten_node)

if no_bias:
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')]
bias_name = "bias" + str(kwargs["idx"])
Expand Down
78 changes: 71 additions & 7 deletions tests/python-pytest/onnx/export/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,33 @@ def get_test_files(name):


def forward_pass(sym, arg, aux, data_names, input_data):
""" Perform forward pass on given data"""
""" Perform forward pass on given data
:param sym: Symbol
:param arg: Arg params
:param aux: Aux params
:param data_names: Input names (list)
:param input_data: Input data (list). If there is only one input,
pass it as a list. For example, if input is [1, 2],
pass input_data=[[1, 2]]
:return: result of forward pass
"""
# create module
mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None)

data_shapes = []
data_forward = []
for idx in range(len(data_names)):
val = input_data[idx]
data_shapes.append((data_names[idx], np.shape(val)))
data_forward.append(mx.nd.array(val))

mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
mod.set_params(arg_params=arg, aux_params=aux,
allow_missing=True, allow_extra=True)

# run inference
batch = namedtuple('Batch', ['data'])
mod.forward(batch([mx.nd.array(input_data)]), is_train=False)
mod.forward(batch(data_forward), is_train=False)

return mod.get_outputs()[0].asnumpy()

Expand Down Expand Up @@ -136,7 +154,7 @@ def test_models(model_name, input_shape, output_shape):
logging.info("Running inference on onnx re-import model in mxnet")
# run test for each test file
for input_data, output_data in zip(inputs, outputs):
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
result = forward_pass(sym, arg_params, aux_params, data_names, [input_data])

# verify the results
npt.assert_equal(result.shape, output_data.shape)
Expand All @@ -156,7 +174,7 @@ def test_model_accuracy(model_name, input_shape):

expected_result= []
for input_data, output_data in zip(inputs, outputs):
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
result = forward_pass(sym, arg_params, aux_params, data_names, [input_data])
expected_result.append(result)

params = {}
Expand All @@ -178,7 +196,7 @@ def test_model_accuracy(model_name, input_shape):

actual_result = []
for input_data, output_data in zip(inputs, outputs):
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
result = forward_pass(sym, arg_params, aux_params, data_names, [input_data])
actual_result.append(result)

# verify the results
Expand Down Expand Up @@ -235,13 +253,59 @@ def test_square():
converted_model = onnx_mxnet.export_model(square, params, [np.shape(input1)], np.float32, "square.onnx")

sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)
result = forward_pass(sym, arg_params, aux_params, ['input1'], [input1])

numpy_op = np.square(input1)

npt.assert_almost_equal(result, numpy_op)


@with_seed()
def test_fully_connected():
def random_arrays(*shapes):
"""Generate some random numpy arrays."""
arrays = [np.random.randn(*s).astype("float32")
for s in shapes]
if len(arrays) == 1:
return arrays[0]
return arrays

data_names = ['x', 'w', 'b']

dim_in, dim_out = (3, 4)
input_data = random_arrays((4, dim_in), (dim_out, dim_in), (dim_out,))

ipsym = []
data_shapes = []
data_forward = []
for idx in range(len(data_names)):
val = input_data[idx]
data_shapes.append((data_names[idx], np.shape(val)))
data_forward.append(mx.nd.array(val))
ipsym.append(mx.sym.Variable(data_names[idx]))

op = mx.sym.FullyConnected(data=ipsym[0], weight=ipsym[1], bias=ipsym[2], num_hidden=dim_out, name='FC')

model = mx.mod.Module(op, data_names=data_names, label_names=None)
model.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)

model.init_params()

args, auxs = model.get_params()
params = {}
params.update(args)
params.update(auxs)

converted_model = onnx_mxnet.export_model(op, params, [shape[1] for shape in data_shapes], np.float32, "fc.onnx")

sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)

numpy_op = np.dot(input_data[0], input_data[1].T) + input_data[2]

npt.assert_almost_equal(result, numpy_op)


def test_softmax():
input1 = np.random.rand(1000, 1000).astype("float32")
label1 = np.random.rand(1000)
Expand Down

0 comments on commit aac961f

Please sign in to comment.