diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index f1af74677bf6..59082c961475 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -179,21 +179,21 @@ def create_const_node(input_name, value, kwargs): initializer.append(tensor_node) return value_node -def create_tensor(shape_list, shape_name, initializer, dtype='int64'): +def create_tensor(tensor_list, tensor_name, initializer, dtype='int64'): """Helper function to create a tensor value node and a initializer tensor node with constant value.""" - shape_np = np.array(shape_list, dtype=dtype) - data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[shape_np.dtype] - dims = np.shape(shape_np) - tensor_node = onnx.helper.make_tensor_value_info(shape_name, data_type, dims) + tensor_np = np.array(tensor_list, dtype=dtype) + data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[tensor_np.dtype] + dims = np.shape(tensor_np) + tensor_node = onnx.helper.make_tensor_value_info(tensor_name, data_type, dims) if dtype == np.float16: - shape_list = shape_np.view(dtype=np.uint16).flatten().tolist() + tensor_list = tensor_np.view(dtype=np.uint16).flatten().tolist() initializer.append( onnx.helper.make_tensor( - name=shape_name, + name=tensor_name, data_type=data_type, dims=dims, - vals=shape_list, + vals=tensor_list, raw=False ) ) @@ -2955,6 +2955,31 @@ def convert_where(node, **kwargs): ] return nodes + +@mx_op.register('_maximum_scalar') +def convert_maximum_scalar(node, **kwargs): + """Map MXNet's _maximum_scalar + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + + input_type = int(kwargs['in_type']) + dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type] + + scalar = None + if 'float' in str(dtype): + scalar = float(attrs.get('scalar', '0')) + else: + scalar = int(attrs.get('scalar', '0')) + + nodes = [ + create_tensor([scalar], name+'_scalar', kwargs['initializer'], dtype=dtype), + make_node('Max', [input_nodes[0], name+'_scalar'], [name], name=name) + ] + + return nodes + + @mx_op.register("_contrib_box_decode") def convert_contrib_box_decode(node, **kwargs): """Map MXNet's _contrib_box_decode operator attributes to onnx's operator. @@ -3049,6 +3074,7 @@ def convert_contrib_AdaptiveAvgPooling2D(node, **kwargs): nodes = [ make_node("GlobalAveragePool", [input_nodes[0]], [name], name=name) ] + return nodes diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 50ce0cbe62a8..1c4629d6f871 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -410,6 +410,15 @@ def test_onnx_export_where(tmp_path, dtype, shape): cond = mx.nd.random.randint(low=0, high=1, shape=shape, dtype='int32') op_export_test('where', M, [cond, x, y], tmp_path) + +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64']) +@pytest.mark.parametrize('shape', [(3, 4, 5), (1, 4, 1, 7)]) +def test_onnx_maximum_scalar(tmp_path, dtype, shape): + x = mx.random.uniform(0, 10, shape).astype(dtype) + M = def_model('maximum', right=5) + op_export_test('_maximum_scalar', M, [x], tmp_path) + + @pytest.mark.parametrize('dtype', ['float16', 'float32']) @pytest.mark.parametrize('fmt', ['corner', 'center']) @pytest.mark.parametrize('clip', [-1., 0., .5, 5.]) @@ -423,6 +432,7 @@ def test_onnx_export_contrib_box_decode(tmp_path, dtype, fmt, clip): M2 = def_model('contrib.box_decode', format=fmt, clip=clip, std0=0.3, std1=1.4, std2=0.5, std3=1.6) op_export_test('contrib_box_decode', M1, [data, anchors], tmp_path) + @pytest.mark.parametrize('dtype', ['float16', 'float32']) def test_onnx_export_contrib_AdaptiveAvgPooling2D(tmp_path, dtype): x = mx.nd.random.uniform(0, 1, (1, 2, 3, 4), dtype=dtype) @@ -452,3 +462,4 @@ def test_onnx_export_reshape_like(tmp_path, dtype): op_export_test('reshape_like3', M3, [x, y], tmp_path) M4 = def_model('reshape_like', lhs_begin=0, lhs_end=None, rhs_begin=1, rhs_end=None) op_export_test('reshape_like4', M4, [x, y], tmp_path) +