From 08b83b64f54ee5934b74fcb3826c6f66c3d5e141 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Sat, 30 Jan 2021 00:22:23 +0000 Subject: [PATCH 1/3] add special cases to Reshape --- .../contrib/onnx/mx2onnx/_op_translations.py | 12 ++++++++++++ tests/python-pytest/onnx/test_operators.py | 17 +++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 9a6f290cdf80..c6925579c777 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1598,6 +1598,7 @@ def convert_floor(node, **kwargs): """ return create_basic_op_node('Floor', node, kwargs) + # Changing shape and type. @mx_op.register("Reshape") def convert_reshape(node, **kwargs): @@ -1612,6 +1613,17 @@ def convert_reshape(node, **kwargs): reverse = attrs.get('reverse', 'False') targ_shape = convert_string_to_list(attrs["shape"]) + # In general -2, -3, -4 in the target shape are not supoorted, but there are + # a few special cases that we can convert to supported scenarios + + # If -2 and -4 are not used, then we can just remove the -4 + if -4 in targ_shape and -3 not in targ_shape and -2 not in targ_shape: + targ_shape = [i for i in targ_shape if i != -4] + + if targ_shape == [-3, 0]: + targ_shape = [-1, 0] + reverse = 'True' + not_supported_shape = [-2, -3, -4] for val in targ_shape: if val in not_supported_shape: diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 74e850be995e..eafc7e583121 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -216,6 +216,23 @@ def test_onnx_export_reshape(tmp_path, dtype): op_export_test('reshape_3', M3, [x], tmp_path) +@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64']) +def test_onnx_export_reshape_special_cases(tmp_path, dtype): + x1 = mx.nd.ones((8, 9), dtype=dtype) + M1 = def_model('reshape', shape=(0, -4, 1, -1)) + op_export_test('reshape_spec_1', M1, [x1], tmp_path) + + x2 = mx.nd.ones((8, 9, 10), dtype=dtype) + + M2 = def_model('reshape', shape=(0, -4, 3, -1, 10)) + op_export_test('reshape_spec_2', M2, [x2], tmp_path) + M3 = def_model('reshape', shape=(-4, 2, -1, 10, 9)) + op_export_test('reshape_spec_3', M3, [x2], tmp_path) + + M4 = def_model('reshape', shape=(-3, 0)) + op_export_test('reshape_spec_4', M4, [x2], tmp_path) + + @pytest.mark.parametrize('dtype', ['int32', 'int64']) def test_onnx_export_embedding(tmp_path, dtype): x = mx.nd.array([[ 1., 3.], From 9ddce679ef1c4838c12a62dcd6c77a1e282a7bc1 Mon Sep 17 00:00:00 2001 From: zha0q1 Date: Sat, 30 Jan 2021 01:14:32 +0000 Subject: [PATCH 2/3] one more special case --- .../contrib/onnx/mx2onnx/_op_translations.py | 19 ++++++++++++++++--- tests/python-pytest/onnx/test_operators.py | 4 ++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index c6925579c777..5290ae6daa23 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1616,14 +1616,27 @@ def convert_reshape(node, **kwargs): # In general -2, -3, -4 in the target shape are not supoorted, but there are # a few special cases that we can convert to supported scenarios - # If -2 and -4 are not used, then we can just remove the -4 - if -4 in targ_shape and -3 not in targ_shape and -2 not in targ_shape: + # If -2 and -3 are not used, then we can just remove the -4 + if -4 in targ_shape and -3 not in targ_shape and -2 not in targ_shape and reverse != 'True': targ_shape = [i for i in targ_shape if i != -4] - if targ_shape == [-3, 0]: + if targ_shape == [-3, 0] and reverse != 'True': targ_shape = [-1, 0] reverse = 'True' + if targ_shape == [0, 0, -3, -3] and reverse != 'True': + nodes = [ + make_node('Shape', [input_nodes[0]], [name+'_shape']), + make_node('Split', [name+'_shape'], [name+'_dim0', name+'_dim1', name+'_dim2', + name+'_dim3', name+'_dim4', name+'_dim5'], axis=0), + make_node('Mul', [name+'_dim2', name+'_dim3'], [name+'_mul_1']), + make_node('Mul', [name+'_dim4', name+'_dim5'], [name+'_mul_2']), + make_node('Concat', [name+'_dim0', name+'_dim1', name+'_mul_1', name+'_mul_2'], + [name+'_shape_new'], axis=0), + make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name) + ] + return nodes + not_supported_shape = [-2, -3, -4] for val in targ_shape: if val in not_supported_shape: diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index eafc7e583121..bacfbdac6000 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -232,6 +232,10 @@ def test_onnx_export_reshape_special_cases(tmp_path, dtype): M4 = def_model('reshape', shape=(-3, 0)) op_export_test('reshape_spec_4', M4, [x2], tmp_path) + x3 = mx.nd.ones((1, 2, 3, 4, 5, 6), dtype=dtype) + M5 = def_model('reshape', shape=(0, 0, -3, -3)) + op_export_test('reshape_spec_5', M5, [x3], tmp_path) + @pytest.mark.parametrize('dtype', ['int32', 'int64']) def test_onnx_export_embedding(tmp_path, dtype): From f637ff9e3e3489104611436fb69699b868e52724 Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Fri, 29 Jan 2021 20:19:45 -0800 Subject: [PATCH 3/3] Update _op_translations.py --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 5290ae6daa23..b47899be066c 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -1628,7 +1628,8 @@ def convert_reshape(node, **kwargs): nodes = [ make_node('Shape', [input_nodes[0]], [name+'_shape']), make_node('Split', [name+'_shape'], [name+'_dim0', name+'_dim1', name+'_dim2', - name+'_dim3', name+'_dim4', name+'_dim5'], axis=0), + name+'_dim3', name+'_dim4', name+'_dim5'], + axis=0), make_node('Mul', [name+'_dim2', name+'_dim3'], [name+'_mul_1']), make_node('Mul', [name+'_dim4', name+'_dim5'], [name+'_mul_2']), make_node('Concat', [name+'_dim0', name+'_dim1', name+'_mul_1', name+'_mul_2'],