diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 9a6f290cdf80..b47899be066c 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1598,6 +1598,7 @@ def convert_floor(node, **kwargs): """ return create_basic_op_node('Floor', node, kwargs) + # Changing shape and type. @mx_op.register("Reshape") def convert_reshape(node, **kwargs): @@ -1612,6 +1613,31 @@ def convert_reshape(node, **kwargs): reverse = attrs.get('reverse', 'False') targ_shape = convert_string_to_list(attrs["shape"]) + # In general -2, -3, -4 in the target shape are not supoorted, but there are + # a few special cases that we can convert to supported scenarios + + # If -2 and -3 are not used, then we can just remove the -4 + if -4 in targ_shape and -3 not in targ_shape and -2 not in targ_shape and reverse != 'True': + targ_shape = [i for i in targ_shape if i != -4] + + if targ_shape == [-3, 0] and reverse != 'True': + targ_shape = [-1, 0] + reverse = 'True' + + if targ_shape == [0, 0, -3, -3] and reverse != 'True': + nodes = [ + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Split', [name+'_shape'], [name+'_dim0', name+'_dim1', name+'_dim2', + name+'_dim3', name+'_dim4', name+'_dim5'], + axis=0), + make_node('Mul', [name+'_dim2', name+'_dim3'], [name+'_mul_1']), + make_node('Mul', [name+'_dim4', name+'_dim5'], [name+'_mul_2']), + make_node('Concat', [name+'_dim0', name+'_dim1', name+'_mul_1', name+'_mul_2'], + [name+'_shape_new'], axis=0), + make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) + ] + return nodes + not_supported_shape = [-2, -3, -4] for val in targ_shape: if val in not_supported_shape: diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 74e850be995e..bacfbdac6000 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -216,6 +216,27 @@ def test_onnx_export_reshape(tmp_path, dtype): op_export_test('reshape_3', M3, [x], tmp_path) +@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64']) +def test_onnx_export_reshape_special_cases(tmp_path, dtype): + x1 = mx.nd.ones((8, 9), dtype=dtype) + M1 = def_model('reshape', shape=(0, -4, 1, -1)) + op_export_test('reshape_spec_1', M1, [x1], tmp_path) + + x2 = mx.nd.ones((8, 9, 10), dtype=dtype) + + M2 = def_model('reshape', shape=(0, -4, 3, -1, 10)) + op_export_test('reshape_spec_2', M2, [x2], tmp_path) + M3 = def_model('reshape', shape=(-4, 2, -1, 10, 9)) + op_export_test('reshape_spec_3', M3, [x2], tmp_path) + + M4 = def_model('reshape', shape=(-3, 0)) + op_export_test('reshape_spec_4', M4, [x2], tmp_path) + + x3 = mx.nd.ones((1, 2, 3, 4, 5, 6), dtype=dtype) + M5 = def_model('reshape', shape=(0, 0, -3, -3)) + op_export_test('reshape_spec_5', M5, [x3], tmp_path) + + @pytest.mark.parametrize('dtype', ['int32', 'int64']) def test_onnx_export_embedding(tmp_path, dtype): x = mx.nd.array([[ 1., 3.],