Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
scalar ops added
Browse files Browse the repository at this point in the history
  • Loading branch information
Roshrini committed Jun 14, 2018
1 parent 6a7f412 commit 4e46acd
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 12 deletions.
65 changes: 53 additions & 12 deletions python/mxnet/contrib/onnx/_export/op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,18 +837,13 @@ def convert_flatten(node, **kwargs):
return [flatten_node]


# Convert scalar value into node and pass it as input to mul_node
@mx_op.register("_mul_scalar")
def convert_mul_scalar(node, **kwargs):
"""Map MXNet's _mul_scalar operator attributes to onnx's Mul operator.
Creates a new node for the input scalar value, adds it to the initializer
and return multiple created nodes.
"""
def scalar_op_helper(node, op_name, **kwargs):
"""Helper function for scalar arithmetic operations"""
helper, numpy_helper, mapping = import_onnx_modules()
name = node["name"]
proc_nodes = kwargs["proc_nodes"]
inputs = node["inputs"]
scalar_mul_value = [int(node.get("attrs", {}).get("scalar", 1))]
scalar_value = [float(node.get("attrs", {}).get("scalar", 1))]

input_name_id = kwargs["index_lookup"][inputs[0][0]]
input_node = proc_nodes[input_name_id].name
Expand All @@ -859,13 +854,20 @@ def convert_mul_scalar(node, **kwargs):
# and create a new initializer
for i in initializer:
if i.name == input_node:
new_initializer = scalar_mul_value[0] * numpy_helper.to_array(i)
if op_name == 'Mul':
new_initializer = numpy_helper.to_array(i) * scalar_value[0]
elif op_name == 'Sub':
new_initializer = numpy_helper.to_array(i) - scalar_value[0]
elif op_name == 'Add':
new_initializer = numpy_helper.to_array(i) + scalar_value[0]
elif op_name == 'Div':
new_initializer = numpy_helper.to_array(i) / scalar_value[0]
flag = False
break

# else create a new tensor of the scalar value, add it in initializer
if flag is True:
np_arr = np.array(scalar_mul_value)
np_arr = np.array(scalar_value)
data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np_arr.dtype]
dims = np.shape(np_arr)

Expand All @@ -877,13 +879,13 @@ def convert_mul_scalar(node, **kwargs):
name=scalar_op_name,
data_type=data_type,
dims=dims,
vals=scalar_mul_value,
vals=scalar_value,
raw=False,
)
)

mul_node = helper.make_node(
"Mul",
op_name,
[input_node, scalar_op_name],
[name],
name=name
Expand All @@ -908,6 +910,45 @@ def convert_mul_scalar(node, **kwargs):
)
return [tensor_node]

# Convert scalar value into node and pass it as input to mul_node
@mx_op.register("_mul_scalar")
def convert_mul_scalar(node, **kwargs):
"""Map MXNet's _mul_scalar operator attributes to onnx's Mul operator.
Creates a new node for the input scalar value, adds it to the initializer
and return multiple created nodes.
"""
return scalar_op_helper(node, 'Mul', **kwargs)


# Convert scalar value into node and pass it as input to mul_node
@mx_op.register("_minus_scalar")
def convert_minus_scalar(node, **kwargs):
"""Map MXNet's _minus_scalar operator attributes to onnx's Minus operator.
Creates a new node for the input scalar value, adds it to the initializer
and return multiple created nodes.
"""
return scalar_op_helper(node, 'Sub', **kwargs)


# Convert scalar value into node and pass it as input to mul_node
@mx_op.register("_plus_scalar")
def convert_add_scalar(node, **kwargs):
"""Map MXNet's _plus_scalar operator attributes to onnx's Add operator.
Creates a new node for the input scalar value, adds it to the initializer
and return multiple created nodes.
"""
return scalar_op_helper(node, 'Add', **kwargs)

# Convert scalar value into node and pass it as input to mul_node
@mx_op.register("_div_scalar")
def convert_div_scalar(node, **kwargs):
"""Map MXNet's _div_scalar operator attributes to onnx's Div operator.
Creates a new node for the input scalar value, adds it to the initializer
and return multiple created nodes.
"""
return scalar_op_helper(node, 'Div', **kwargs)


# Sorting and Searching
@mx_op.register("argmax")
def convert_argmax(node, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/contrib/onnx/_import/import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
'Squeeze' : squeeze,
'Unsqueeze' : unsqueeze,
'Flatten' : flatten,
'Identity' : identity,
#Powers
'Reciprocal' : reciprocal,
'Sqrt' : squareroot,
Expand Down
1 change: 1 addition & 0 deletions tests/python-pytest/onnx/export/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
'test_ceil',
'test_floor',
'test_concat',
'test_identity',
'test_sigmoid',
'test_relu',
'test_constant_pad',
Expand Down
1 change: 1 addition & 0 deletions tests/python-pytest/onnx/import/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
'test_ceil',
'test_floor',
'test_concat',
'test_identity',
'test_sigmoid',
'test_relu',
'test_constant_pad',
Expand Down

0 comments on commit 4e46acd

Please sign in to comment.