Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX export: Square operator
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Oct 16, 2018
1 parent a20b496 commit 302651f
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
36 changes: 36 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2185,6 +2185,42 @@ def convert_spacetodepth(node, **kwargs):
)
return [node]

@mx_op.register("square")
def convert_square(node, **kwargs):
"""Map MXNet's square operator attributes to onnx's Pow operator
and return the created node.
"""
onnx = import_onnx_modules()
name = node["name"]
proc_nodes = kwargs["proc_nodes"]
inputs = node["inputs"]

input_node_a_id = kwargs["index_lookup"][inputs[0][0]]
input_node_a = proc_nodes[input_node_a_id].name

initializer = kwargs["initializer"]
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')]

power2_name = "square_tensor" + str(kwargs["idx"])
tensor_node = onnx.helper.make_tensor_value_info(power2_name, data_type, (1,))
initializer.append(
onnx.helper.make_tensor(
name=power2_name,
data_type=data_type,
dims=(1,),
vals=[2],
raw=False,
)
)

node = onnx.helper.make_node(
"Pow",
[input_node_a, power2_name],
[name],
name=name
)
return [tensor_node, node]

@mx_op.register("sum")
def convert_sum(node, **kwargs):
"""Map MXNet's sum operator attributes to onnx's ReduceSum operator
Expand Down
26 changes: 25 additions & 1 deletion tests/python-pytest/onnx/export/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,30 @@ def test_spacetodepth():

npt.assert_almost_equal(output[0], numpy_op)

@with_seed()
def test_square():
input1 = np.random.randint(1, 10, (2, 3)).astype("float32")

ipsym = mx.sym.Variable("input1")
square = mx.sym.square(data=ipsym)
model = mx.mod.Module(symbol=square, data_names=['input1'], label_names=None)
model.bind(for_training=False, data_shapes=[('input1', np.shape(input1))], label_shapes=None)
model.init_params()

args, auxs = model.get_params()
params = {}
params.update(args)
params.update(auxs)

converted_model = onnx_mxnet.export_model(square, params, [np.shape(input1)], np.float32, "square.onnx")

sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)

numpy_op = np.square(input1)

npt.assert_almost_equal(result, numpy_op)

if __name__ == '__main__':
test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000))
test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))
Expand All @@ -224,4 +248,4 @@ def test_spacetodepth():
test_model_accuracy("inception_v1", (1, 3, 224, 224))
test_model_accuracy("inception_v2", (1, 3, 224, 224))

unittest.main()
unittest.main()

0 comments on commit 302651f

Please sign in to comment.