From 4e46acd99b24ba922cac8e7c5752c1579bbb470c Mon Sep 17 00:00:00 2001 From: Roshani Nagmote Date: Wed, 13 Jun 2018 17:18:20 -0700 Subject: [PATCH] scalar ops added --- .../contrib/onnx/_export/op_translations.py | 65 +++++++++++++++---- .../contrib/onnx/_import/import_helper.py | 1 + .../onnx/export/onnx_backend_test.py | 1 + tests/python-pytest/onnx/import/test_cases.py | 1 + 4 files changed, 56 insertions(+), 12 deletions(-) diff --git a/python/mxnet/contrib/onnx/_export/op_translations.py b/python/mxnet/contrib/onnx/_export/op_translations.py index effbe17b5ffc..b33d7a72e7f6 100644 --- a/python/mxnet/contrib/onnx/_export/op_translations.py +++ b/python/mxnet/contrib/onnx/_export/op_translations.py @@ -837,18 +837,13 @@ def convert_flatten(node, **kwargs): return [flatten_node] -# Convert scalar value into node and pass it as input to mul_node -@mx_op.register("_mul_scalar") -def convert_mul_scalar(node, **kwargs): - """Map MXNet's _mul_scalar operator attributes to onnx's Mul operator. - Creates a new node for the input scalar value, adds it to the initializer - and return multiple created nodes. - """ +def scalar_op_helper(node, op_name, **kwargs): + """Helper function for scalar arithmetic operations""" helper, numpy_helper, mapping = import_onnx_modules() name = node["name"] proc_nodes = kwargs["proc_nodes"] inputs = node["inputs"] - scalar_mul_value = [int(node.get("attrs", {}).get("scalar", 1))] + scalar_value = [float(node.get("attrs", {}).get("scalar", 1))] input_name_id = kwargs["index_lookup"][inputs[0][0]] input_node = proc_nodes[input_name_id].name @@ -859,13 +854,20 @@ def convert_mul_scalar(node, **kwargs): # and create a new initializer for i in initializer: if i.name == input_node: - new_initializer = scalar_mul_value[0] * numpy_helper.to_array(i) + if op_name == 'Mul': + new_initializer = numpy_helper.to_array(i) * scalar_value[0] + elif op_name == 'Sub': + new_initializer = numpy_helper.to_array(i) - scalar_value[0] + elif op_name == 'Add': + new_initializer = numpy_helper.to_array(i) + scalar_value[0] + elif op_name == 'Div': + new_initializer = numpy_helper.to_array(i) / scalar_value[0] flag = False break # else create a new tensor of the scalar value, add it in initializer if flag is True: - np_arr = np.array(scalar_mul_value) + np_arr = np.array(scalar_value) data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np_arr.dtype] dims = np.shape(np_arr) @@ -877,13 +879,13 @@ def convert_mul_scalar(node, **kwargs): name=scalar_op_name, data_type=data_type, dims=dims, - vals=scalar_mul_value, + vals=scalar_value, raw=False, ) ) mul_node = helper.make_node( - "Mul", + op_name, [input_node, scalar_op_name], [name], name=name @@ -908,6 +910,45 @@ def convert_mul_scalar(node, **kwargs): ) return [tensor_node] +# Convert scalar value into node and pass it as input to mul_node +@mx_op.register("_mul_scalar") +def convert_mul_scalar(node, **kwargs): + """Map MXNet's _mul_scalar operator attributes to onnx's Mul operator. + Creates a new node for the input scalar value, adds it to the initializer + and return multiple created nodes. + """ + return scalar_op_helper(node, 'Mul', **kwargs) + + +# Convert scalar value into node and pass it as input to mul_node +@mx_op.register("_minus_scalar") +def convert_minus_scalar(node, **kwargs): + """Map MXNet's _minus_scalar operator attributes to onnx's Minus operator. + Creates a new node for the input scalar value, adds it to the initializer + and return multiple created nodes. + """ + return scalar_op_helper(node, 'Sub', **kwargs) + + +# Convert scalar value into node and pass it as input to mul_node +@mx_op.register("_plus_scalar") +def convert_add_scalar(node, **kwargs): + """Map MXNet's _plus_scalar operator attributes to onnx's Add operator. + Creates a new node for the input scalar value, adds it to the initializer + and return multiple created nodes. + """ + return scalar_op_helper(node, 'Add', **kwargs) + +# Convert scalar value into node and pass it as input to mul_node +@mx_op.register("_div_scalar") +def convert_div_scalar(node, **kwargs): + """Map MXNet's _div_scalar operator attributes to onnx's Div operator. + Creates a new node for the input scalar value, adds it to the initializer + and return multiple created nodes. + """ + return scalar_op_helper(node, 'Div', **kwargs) + + # Sorting and Searching @mx_op.register("argmax") def convert_argmax(node, **kwargs): diff --git a/python/mxnet/contrib/onnx/_import/import_helper.py b/python/mxnet/contrib/onnx/_import/import_helper.py index c8d452167299..0b8564e43fd8 100644 --- a/python/mxnet/contrib/onnx/_import/import_helper.py +++ b/python/mxnet/contrib/onnx/_import/import_helper.py @@ -85,6 +85,7 @@ 'Squeeze' : squeeze, 'Unsqueeze' : unsqueeze, 'Flatten' : flatten, + 'Identity' : identity, #Powers 'Reciprocal' : reciprocal, 'Sqrt' : squareroot, diff --git a/tests/python-pytest/onnx/export/onnx_backend_test.py b/tests/python-pytest/onnx/export/onnx_backend_test.py index 0f2301edd8ea..803d290b9c69 100644 --- a/tests/python-pytest/onnx/export/onnx_backend_test.py +++ b/tests/python-pytest/onnx/export/onnx_backend_test.py @@ -48,6 +48,7 @@ 'test_ceil', 'test_floor', 'test_concat', + 'test_identity', 'test_sigmoid', 'test_relu', 'test_constant_pad', diff --git a/tests/python-pytest/onnx/import/test_cases.py b/tests/python-pytest/onnx/import/test_cases.py index 8e6dc443bbaa..f96e4ba6889b 100644 --- a/tests/python-pytest/onnx/import/test_cases.py +++ b/tests/python-pytest/onnx/import/test_cases.py @@ -31,6 +31,7 @@ 'test_ceil', 'test_floor', 'test_concat', + 'test_identity', 'test_sigmoid', 'test_relu', 'test_constant_pad',