Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
more test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
Roshrini committed Dec 19, 2018
1 parent 14d0579 commit 24cd6db
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ def convert_blockgrad(node, **kwargs):
""" Skip operator """
return create_basic_op_node('ConstantFill', node, kwargs)

@mx_op.register("make_loss")
@mx_op.register("MakeLoss")
def convert_makeloss(node, **kwargs):
""" Skip operator """
return create_basic_op_node('ConstantFill', node, kwargs)
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/contrib/onnx/onnx2mx/_translation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _fix_broadcast(op_name, inputs, broadcast_axis, proto_obj):
assert len(list(inputs)) == 2

input0_shape = get_input_shape(inputs[0], proto_obj)
#creating reshape shape
# creating reshape shape
reshape_shape = list(len(input0_shape) * (1,))
reshape_shape[broadcast_axis] = -1
reshape_shape = tuple(reshape_shape)
Expand Down
21 changes: 21 additions & 0 deletions tests/python-pytest/onnx/export/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,27 @@ def test_scalarops():
out = ipsym / 2
_test_scalar_op(input1, out, np_out)

def test_makeloss():
v1 = mx.nd.array([1, 2])
v2 = mx.nd.array([0, 1])
a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
sym = mx.sym.MakeLoss(b + a)
ex = sym.bind(ctx=mx.cpu(0), args={'a': v1, 'b': v2})
ex.forward(is_train=True)
makeloss_out = ex.outputs[0].asnumpy()

converted_model = onnx_mxnet.export_model(sym, {}, [(2,),(2,)], np.float32, "makelossop.onnx")

# sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
# result = forward_pass(sym, arg_params, aux_params, ['ipsym'], input1)
#
# # Comparing result of forward pass before using onnx export, import
# npt.assert_almost_equal(result, makeloss_out)
#
# executor = loss.simple_bind(ctx=mx.cpu(0), a=(2,), b=(2,))
# executor.forward(is_train=True, a=v1, b=v2)
# print(executor.outputs[0])

@with_seed()
def test_comparison_ops():
Expand Down

0 comments on commit 24cd6db

Please sign in to comment.