Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] Add ONNX export support for equal_scalar operator (#19824)
Browse files Browse the repository at this point in the history
* Allow axis to be an optional parameter to squeeze, since onnx supports it now.

* Add onnx export function for equal_scalar, add unit test.

* Use Constant instead of ConstantOfShape for scalar functions.

* Check for 'True' value for squeeze_axis.

Co-authored-by: Joe Evans <[email protected]>
  • Loading branch information
josephevans and Joe Evans authored Feb 3, 2021
1 parent 193d3db commit 90836ad
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 15 deletions.
61 changes: 46 additions & 15 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,7 +1758,7 @@ def convert_slice_channel(node, **kwargs):

num_outputs = int(attrs.get("num_outputs"))
axis = int(attrs.get("axis", 1))
squeeze_axis = int(attrs.get("squeeze_axis", 0))
squeeze_axis = int(attrs.get("squeeze_axis", 0) in [1, 'True'])

if squeeze_axis == 1 and num_outputs == 1:
node = onnx.helper.make_node(
Expand Down Expand Up @@ -1810,17 +1810,22 @@ def convert_squeeze(node, **kwargs):

axis = attrs.get("axis", None)
if not axis:
raise AttributeError("Squeeze: Missing axis attribute: ONNX currently requires axis to "
"be specified for squeeze operator")
axis = convert_string_to_list(axis)
node = onnx.helper.make_node(
"Squeeze",
input_nodes,
[name],
name=name
)
else:
axis = convert_string_to_list(axis)

node = onnx.helper.make_node(
"Squeeze",
input_nodes,
[name],
axes=axis,
name=name,
)
node = onnx.helper.make_node(
"Squeeze",
input_nodes,
[name],
axes=axis,
name=name,
)
return [node]


Expand Down Expand Up @@ -3141,8 +3146,7 @@ def convert_greater_scalar(node, **kwargs):

tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar])
nodes = [
make_node("Shape", [input_nodes[0]], [name+"_shape"]),
make_node("ConstantOfShape", [name+"_shape"], [name+"_rhs"], value=tensor_value),
make_node("Constant", [], [name+"_rhs"], value=tensor_value),
make_node("Greater", [input_nodes[0], name+"_rhs"], [name+"_gt"]),
make_node("Cast", [name+"_gt"], [name], to=input_type, name=name)
]
Expand Down Expand Up @@ -3171,14 +3175,41 @@ def convert_lesser_scalar(node, **kwargs):

tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar])
nodes = [
make_node("Shape", [input_nodes[0]], [name+"_shape"]),
make_node("ConstantOfShape", [name+"_shape"], [name+"_rhs"], value=tensor_value),
make_node("Constant", [], [name+"_rhs"], value=tensor_value),
make_node("Less", [input_nodes[0], name+"_rhs"], [name+"_lt"]),
make_node("Cast", [name+"_lt"], [name], to=input_type, name=name)
]
return nodes


@mx_op.register("_equal_scalar")
def convert_equal_scalar(node, **kwargs):
"""Map MXNet's equal_scalar operator attributes to onnx.
"""
from onnx.helper import make_node, make_tensor
name, input_nodes, attrs = get_inputs(node, kwargs)

scalar = float(attrs.get('scalar'))
input_type = kwargs['in_type']
dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type]

if str(dtype).startswith('int'):
scalar = int(scalar)
else:
if dtype == 'float16':
# when using float16, we must convert it to np.uint16 view first
# pylint: disable=too-many-function-args
scalar = np.float16(scalar).view(np.uint16)

tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar])
nodes = [
make_node("Constant", [], [name+"_rhs"], value=tensor_value),
make_node("Equal", [input_nodes[0], name+"_rhs"], [name+"_eq"]),
make_node("Cast", [name+"_eq"], [name], to=input_type, name=name)
]
return nodes


@mx_op.register("where")
def convert_where(node, **kwargs):
"""Map MXNet's where operator attributes to onnx's Where
Expand Down
12 changes: 12 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,18 @@ def test_onnx_export_lesser_scalar(tmp_path, dtype, scalar):
op_export_test('_internal._lesser_scalar', M, [x], tmp_path)


@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
@pytest.mark.parametrize("scalar", [0., 0.1, 0.5, 1., 5, 555.])
def test_onnx_export_equal_scalar(tmp_path, dtype, scalar):
if 'int' in dtype:
scalar = int(scalar)
x = mx.nd.arange(0, 12, dtype=dtype).reshape((3, 4))
else:
x = mx.random.uniform(0, 9999, (5,10), dtype=dtype)
M = def_model('_internal._equal_scalar', scalar=scalar)
op_export_test('_internal._equal_scalar', M, [x], tmp_path)


@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
@pytest.mark.parametrize("shape", [(1,1), (3,3), (10,2), (20,30,40)])
def test_onnx_export_where(tmp_path, dtype, shape):
Expand Down

0 comments on commit 90836ad

Please sign in to comment.