Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX export test: Allow multiple inputs in forward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Dec 10, 2018
1 parent 7ec99cd commit d15724c
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions tests/python-pytest/onnx/export/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,33 @@ def get_test_files(name):


def forward_pass(sym, arg, aux, data_names, input_data):
""" Perform forward pass on given data"""
""" Perform forward pass on given data
:param sym: Symbol
:param arg: Arg params
:param aux: Aux params
:param data_names: Input names (list)
:param input_data: Input data (list). If there is only one input,
pass it as a list. For example, if input is [1, 2],
pass input_data=[[1, 2]]
:return: result of forward pass
"""
# create module
mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None)

data_shapes = []
data_forward = []
for idx in range(len(data_names)):
val = input_data[idx]
data_shapes.append((data_names[idx], np.shape(val)))
data_forward.append(mx.nd.array(val))

mod.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)
mod.set_params(arg_params=arg, aux_params=aux,
allow_missing=True, allow_extra=True)

# run inference
batch = namedtuple('Batch', ['data'])
mod.forward(batch([mx.nd.array(input_data)]), is_train=False)
mod.forward(batch(data_forward), is_train=False)

return mod.get_outputs()[0].asnumpy()

Expand Down Expand Up @@ -136,7 +154,7 @@ def test_models(model_name, input_shape, output_shape):
logging.info("Running inference on onnx re-import model in mxnet")
# run test for each test file
for input_data, output_data in zip(inputs, outputs):
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
result = forward_pass(sym, arg_params, aux_params, data_names, [input_data])

# verify the results
npt.assert_equal(result.shape, output_data.shape)
Expand All @@ -156,7 +174,7 @@ def test_model_accuracy(model_name, input_shape):

expected_result= []
for input_data, output_data in zip(inputs, outputs):
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
result = forward_pass(sym, arg_params, aux_params, data_names, [input_data])
expected_result.append(result)

params = {}
Expand All @@ -178,7 +196,7 @@ def test_model_accuracy(model_name, input_shape):

actual_result = []
for input_data, output_data in zip(inputs, outputs):
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)
result = forward_pass(sym, arg_params, aux_params, data_names, [input_data])
actual_result.append(result)

# verify the results
Expand Down Expand Up @@ -235,7 +253,7 @@ def test_square():
converted_model = onnx_mxnet.export_model(square, params, [np.shape(input1)], np.float32, "square.onnx")

sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)
result = forward_pass(sym, arg_params, aux_params, ['input1'], [input1])

numpy_op = np.square(input1)

Expand Down

0 comments on commit d15724c

Please sign in to comment.