Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Add ONNX export support for equal_scalar operator #19824

Merged
merged 4 commits into from
Feb 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 46 additions & 15 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,7 +1758,7 @@ def convert_slice_channel(node, **kwargs):

num_outputs = int(attrs.get("num_outputs"))
axis = int(attrs.get("axis", 1))
squeeze_axis = int(attrs.get("squeeze_axis", 0))
squeeze_axis = int(attrs.get("squeeze_axis", 0) in [1, 'True'])

if squeeze_axis == 1 and num_outputs == 1:
node = onnx.helper.make_node(
Expand Down Expand Up @@ -1810,17 +1810,22 @@ def convert_squeeze(node, **kwargs):

axis = attrs.get("axis", None)
if not axis:
raise AttributeError("Squeeze: Missing axis attribute: ONNX currently requires axis to "
"be specified for squeeze operator")
axis = convert_string_to_list(axis)
node = onnx.helper.make_node(
"Squeeze",
input_nodes,
[name],
name=name
)
else:
axis = convert_string_to_list(axis)

node = onnx.helper.make_node(
"Squeeze",
input_nodes,
[name],
axes=axis,
name=name,
)
node = onnx.helper.make_node(
"Squeeze",
input_nodes,
[name],
axes=axis,
name=name,
)
return [node]


Expand Down Expand Up @@ -3141,8 +3146,7 @@ def convert_greater_scalar(node, **kwargs):

tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar])
nodes = [
make_node("Shape", [input_nodes[0]], [name+"_shape"]),
make_node("ConstantOfShape", [name+"_shape"], [name+"_rhs"], value=tensor_value),
make_node("Constant", [], [name+"_rhs"], value=tensor_value),
make_node("Greater", [input_nodes[0], name+"_rhs"], [name+"_gt"]),
make_node("Cast", [name+"_gt"], [name], to=input_type, name=name)
]
Expand Down Expand Up @@ -3171,14 +3175,41 @@ def convert_lesser_scalar(node, **kwargs):

tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar])
nodes = [
make_node("Shape", [input_nodes[0]], [name+"_shape"]),
make_node("ConstantOfShape", [name+"_shape"], [name+"_rhs"], value=tensor_value),
make_node("Constant", [], [name+"_rhs"], value=tensor_value),
make_node("Less", [input_nodes[0], name+"_rhs"], [name+"_lt"]),
make_node("Cast", [name+"_lt"], [name], to=input_type, name=name)
]
return nodes


@mx_op.register("_equal_scalar")
def convert_equal_scalar(node, **kwargs):
"""Map MXNet's equal_scalar operator attributes to onnx.
"""
from onnx.helper import make_node, make_tensor
name, input_nodes, attrs = get_inputs(node, kwargs)

scalar = float(attrs.get('scalar'))
input_type = kwargs['in_type']
dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type]

if str(dtype).startswith('int'):
scalar = int(scalar)
else:
if dtype == 'float16':
# when using float16, we must convert it to np.uint16 view first
# pylint: disable=too-many-function-args
scalar = np.float16(scalar).view(np.uint16)

tensor_value = make_tensor(name+"_scalar", input_type, [1], [scalar])
nodes = [
make_node("Constant", [], [name+"_rhs"], value=tensor_value),
make_node("Equal", [input_nodes[0], name+"_rhs"], [name+"_eq"]),
make_node("Cast", [name+"_eq"], [name], to=input_type, name=name)
]
return nodes


@mx_op.register("where")
def convert_where(node, **kwargs):
"""Map MXNet's where operator attributes to onnx's Where
Expand Down
12 changes: 12 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,18 @@ def test_onnx_export_lesser_scalar(tmp_path, dtype, scalar):
op_export_test('_internal._lesser_scalar', M, [x], tmp_path)


@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
@pytest.mark.parametrize("scalar", [0., 0.1, 0.5, 1., 5, 555.])
def test_onnx_export_equal_scalar(tmp_path, dtype, scalar):
if 'int' in dtype:
scalar = int(scalar)
x = mx.nd.arange(0, 12, dtype=dtype).reshape((3, 4))
else:
x = mx.random.uniform(0, 9999, (5,10), dtype=dtype)
M = def_model('_internal._equal_scalar', scalar=scalar)
op_export_test('_internal._equal_scalar', M, [x], tmp_path)


@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
@pytest.mark.parametrize("shape", [(1,1), (3,3), (10,2), (20,30,40)])
def test_onnx_export_where(tmp_path, dtype, shape):
Expand Down