Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] Add onnx export operator for minimum_scalar. (#19888)
Browse files Browse the repository at this point in the history
* Add onnx export operator for minimum_scalar.

* Fix function name.

Co-authored-by: Joe Evans <[email protected]>
  • Loading branch information
josephevans and Joe Evans authored Feb 15, 2021
1 parent 58e7d0e commit 26afc44
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
22 changes: 22 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3293,6 +3293,28 @@ def convert_maximum_scalar(node, **kwargs):

return nodes

@mx_op.register('_minimum_scalar')
def convert_minimum_scalar(node, **kwargs):
"""Map MXNet's _minimum_scalar
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

input_type = int(kwargs['in_type'])
dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type]

scalar = None
if 'float' in str(dtype):
scalar = float(attrs.get('scalar', '0'))
else:
scalar = int(attrs.get('scalar', '0'))

nodes = [
create_tensor([scalar], name+'_scalar', kwargs['initializer'], dtype=dtype),
make_node('Min', [input_nodes[0], name+'_scalar'], [name], name=name)
]

return nodes

@mx_op.register("_contrib_box_decode")
def convert_contrib_box_decode(node, **kwargs):
Expand Down
10 changes: 10 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,16 @@ def test_onnx_maximum_scalar(tmp_path, dtype, shape):
op_export_test('_maximum_scalar', M, [x], tmp_path)


# opset 8 Min only supports float types
# opset 12 and up suppots float and int
@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64'])
@pytest.mark.parametrize('shape', [(3, 4, 5), (1, 4, 1, 7)])
def test_onnx_minimum_scalar(tmp_path, dtype, shape):
x = mx.random.uniform(0, 10, shape).astype(dtype)
M = def_model('minimum', right=5)
op_export_test('_minimum_scalar', M, [x], tmp_path)


@pytest.mark.parametrize('dtype', ['float16', 'float32'])
@pytest.mark.parametrize('fmt', ['corner', 'center'])
@pytest.mark.parametrize('clip', [-1., 0., .5, 5.])
Expand Down

0 comments on commit 26afc44

Please sign in to comment.