Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
_maximum_scalar (#19763)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zha0q1 committed Jan 21, 2021
1 parent 0e59517 commit c86bf0e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
42 changes: 34 additions & 8 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,21 +179,21 @@ def create_const_node(input_name, value, kwargs):
initializer.append(tensor_node)
return value_node

def create_tensor(shape_list, shape_name, initializer, dtype='int64'):
def create_tensor(tensor_list, tensor_name, initializer, dtype='int64'):
"""Helper function to create a tensor value node and a
initializer tensor node with constant value."""
shape_np = np.array(shape_list, dtype=dtype)
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[shape_np.dtype]
dims = np.shape(shape_np)
tensor_node = onnx.helper.make_tensor_value_info(shape_name, data_type, dims)
tensor_np = np.array(tensor_list, dtype=dtype)
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[tensor_np.dtype]
dims = np.shape(tensor_np)
tensor_node = onnx.helper.make_tensor_value_info(tensor_name, data_type, dims)
if dtype == np.float16:
shape_list = shape_np.view(dtype=np.uint16).flatten().tolist()
tensor_list = tensor_np.view(dtype=np.uint16).flatten().tolist()
initializer.append(
onnx.helper.make_tensor(
name=shape_name,
name=tensor_name,
data_type=data_type,
dims=dims,
vals=shape_list,
vals=tensor_list,
raw=False
)
)
Expand Down Expand Up @@ -2955,6 +2955,31 @@ def convert_where(node, **kwargs):
]
return nodes


@mx_op.register('_maximum_scalar')
def convert_maximum_scalar(node, **kwargs):
"""Map MXNet's _maximum_scalar
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

input_type = int(kwargs['in_type'])
dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type]

scalar = None
if 'float' in str(dtype):
scalar = float(attrs.get('scalar', '0'))
else:
scalar = int(attrs.get('scalar', '0'))

nodes = [
create_tensor([scalar], name+'_scalar', kwargs['initializer'], dtype=dtype),
make_node('Max', [input_nodes[0], name+'_scalar'], [name], name=name)
]

return nodes


@mx_op.register("_contrib_box_decode")
def convert_contrib_box_decode(node, **kwargs):
"""Map MXNet's _contrib_box_decode operator attributes to onnx's operator.
Expand Down Expand Up @@ -3049,6 +3074,7 @@ def convert_contrib_AdaptiveAvgPooling2D(node, **kwargs):
nodes = [
make_node("GlobalAveragePool", [input_nodes[0]], [name], name=name)
]

return nodes


Expand Down
11 changes: 11 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,15 @@ def test_onnx_export_where(tmp_path, dtype, shape):
cond = mx.nd.random.randint(low=0, high=1, shape=shape, dtype='int32')
op_export_test('where', M, [cond, x, y], tmp_path)


@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64'])
@pytest.mark.parametrize('shape', [(3, 4, 5), (1, 4, 1, 7)])
def test_onnx_maximum_scalar(tmp_path, dtype, shape):
x = mx.random.uniform(0, 10, shape).astype(dtype)
M = def_model('maximum', right=5)
op_export_test('_maximum_scalar', M, [x], tmp_path)


@pytest.mark.parametrize('dtype', ['float16', 'float32'])
@pytest.mark.parametrize('fmt', ['corner', 'center'])
@pytest.mark.parametrize('clip', [-1., 0., .5, 5.])
Expand All @@ -423,6 +432,7 @@ def test_onnx_export_contrib_box_decode(tmp_path, dtype, fmt, clip):
M2 = def_model('contrib.box_decode', format=fmt, clip=clip, std0=0.3, std1=1.4, std2=0.5, std3=1.6)
op_export_test('contrib_box_decode', M1, [data, anchors], tmp_path)


@pytest.mark.parametrize('dtype', ['float16', 'float32'])
def test_onnx_export_contrib_AdaptiveAvgPooling2D(tmp_path, dtype):
x = mx.nd.random.uniform(0, 1, (1, 2, 3, 4), dtype=dtype)
Expand Down Expand Up @@ -452,3 +462,4 @@ def test_onnx_export_reshape_like(tmp_path, dtype):
op_export_test('reshape_like3', M3, [x, y], tmp_path)
M4 = def_model('reshape_like', lhs_begin=0, lhs_end=None, rhs_begin=1, rhs_end=None)
op_export_test('reshape_like4', M4, [x, y], tmp_path)

0 comments on commit c86bf0e

Please sign in to comment.