Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX export: Test for fully connected
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Nov 29, 2018
1 parent 38ef4c7 commit d70ff8a
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions tests/python-pytest/onnx/export/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,53 @@ def test_square():

npt.assert_almost_equal(result, numpy_op)


@with_seed()
def test_fully_connected():
def random_arrays(*shapes):
"""Generate some random numpy arrays."""
arrays = [np.random.randn(*s).astype("float32")
for s in shapes]
if len(arrays) == 1:
return arrays[0]
return arrays

data_names = ['x', 'w', 'b']

dim_in, dim_out = (3, 4)
input_data = random_arrays((4, dim_in), (dim_out, dim_in), (dim_out,))

ipsym = []
data_shapes = []
data_forward = []
for idx in range(len(data_names)):
val = input_data[idx]
data_shapes.append((data_names[idx], np.shape(val)))
data_forward.append(mx.nd.array(val))
ipsym.append(mx.sym.Variable(data_names[idx]))

op = mx.sym.FullyConnected(data=ipsym[0], weight=ipsym[1], bias=ipsym[2], num_hidden=dim_out, name='FC')

model = mx.mod.Module(op, data_names=data_names, label_names=None)
model.bind(for_training=False, data_shapes=data_shapes, label_shapes=None)

model.init_params()

args, auxs = model.get_params()
params = {}
params.update(args)
params.update(auxs)

converted_model = onnx_mxnet.export_model(op, params, [shape[1] for shape in data_shapes], np.float32, "fc.onnx")

sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
result = forward_pass(sym, arg_params, aux_params, data_names, input_data)

numpy_op = np.dot(input_data[0], input_data[1].T) + input_data[2]

npt.assert_almost_equal(result, numpy_op)


@with_seed()
def test_comparison_ops():
"""Test greater, lesser, equal"""
Expand Down

0 comments on commit d70ff8a

Please sign in to comment.