diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 843d5e2a2873..39723321923e 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3400,6 +3400,35 @@ def convert_gather_nd(node, **kwargs): return nodes +@mx_op.register('UpSampling') +def convert_upsampling(node, **kwargs): + """Map MXNet's UpSampling operator to onnx. + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + + scale = int(attrs.get('scale', '1')) + sample_type = attrs.get('sample_type') + num_args = int(attrs.get('num_args', '1')) + + if num_args > 1: + raise NotImplementedError('Upsampling conversion does not currently support num_args > 1') + + if sample_type != 'nearest': + raise NotImplementedError('Upsampling conversion does not currently support \ + sample_type != nearest') + + nodes = [ + create_tensor([], name+'_roi', kwargs['initializer'], dtype='float32'), + create_tensor([1, 1, scale, scale], name+'_scales', kwargs['initializer'], + dtype='float32'), + make_node('Resize', [input_nodes[0], name+'_roi', name+'_scales'], [name], mode='nearest', + coordinate_transformation_mode='half_pixel') + ] + + return nodes + + @mx_op.register('SwapAxis') def convert_swapaxis(node, **kwargs): """Map MXNet's SwapAxis operator diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 038541abe953..204fe57f5930 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -595,6 +595,15 @@ def test_onnx_export_gather_nd(tmp_path, dtype): op_export_test('gather_nd2', M2, [x2, y2], tmp_path) +@pytest.mark.parametrize('dtype', ['float16', 'float32']) +@pytest.mark.parametrize('shape', [(3, 4, 5, 6), (1, 1, 1, 1)]) +@pytest.mark.parametrize('scale', [1, 2, 3]) +def test_onnx_export_upsampling(tmp_path, dtype, shape, scale): + A = mx.random.uniform(0, 1, shape).astype(dtype) + M = def_model('UpSampling', scale=scale, sample_type='nearest', num_args=1) + op_export_test('UpSampling', M, [A], tmp_path) + + @pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) @pytest.mark.parametrize('params', [((4, 5, 6), (0, 2)), ((4, 5, 6), (0, 1)), ((1, 2, 3, 4, 1), (0, 4)),