Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] Onnx Reshpe support for special caes (#19804)
Browse files Browse the repository at this point in the history
* add special cases to Reshape

* one more special case

* Update _op_translations.py
  • Loading branch information
Zha0q1 committed Feb 2, 2021
1 parent 6c025e0 commit 787416b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
26 changes: 26 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,6 +1605,7 @@ def convert_floor(node, **kwargs):
"""
return create_basic_op_node('Floor', node, kwargs)


# Changing shape and type.
@mx_op.register("Reshape")
def convert_reshape(node, **kwargs):
Expand All @@ -1619,6 +1620,31 @@ def convert_reshape(node, **kwargs):
reverse = attrs.get('reverse', 'False')
targ_shape = convert_string_to_list(attrs["shape"])

# In general -2, -3, -4 in the target shape are not supoorted, but there are
# a few special cases that we can convert to supported scenarios

# If -2 and -3 are not used, then we can just remove the -4
if -4 in targ_shape and -3 not in targ_shape and -2 not in targ_shape and reverse != 'True':
targ_shape = [i for i in targ_shape if i != -4]

if targ_shape == [-3, 0] and reverse != 'True':
targ_shape = [-1, 0]
reverse = 'True'

if targ_shape == [0, 0, -3, -3] and reverse != 'True':
nodes = [
make_node('Shape', [input_nodes[0]], [name+'_shape']),
make_node('Split', [name+'_shape'], [name+'_dim0', name+'_dim1', name+'_dim2',
name+'_dim3', name+'_dim4', name+'_dim5'],
axis=0),
make_node('Mul', [name+'_dim2', name+'_dim3'], [name+'_mul_1']),
make_node('Mul', [name+'_dim4', name+'_dim5'], [name+'_mul_2']),
make_node('Concat', [name+'_dim0', name+'_dim1', name+'_mul_1', name+'_mul_2'],
[name+'_shape_new'], axis=0),
make_node('Reshape', [input_nodes[0], name+'_shape_new'], [name], name=name)
]
return nodes

not_supported_shape = [-2, -3, -4]
for val in targ_shape:
if val in not_supported_shape:
Expand Down
21 changes: 21 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,27 @@ def test_onnx_export_reshape(tmp_path, dtype):
op_export_test('reshape_3', M3, [x], tmp_path)


@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
def test_onnx_export_reshape_special_cases(tmp_path, dtype):
x1 = mx.nd.ones((8, 9), dtype=dtype)
M1 = def_model('reshape', shape=(0, -4, 1, -1))
op_export_test('reshape_spec_1', M1, [x1], tmp_path)

x2 = mx.nd.ones((8, 9, 10), dtype=dtype)

M2 = def_model('reshape', shape=(0, -4, 3, -1, 10))
op_export_test('reshape_spec_2', M2, [x2], tmp_path)
M3 = def_model('reshape', shape=(-4, 2, -1, 10, 9))
op_export_test('reshape_spec_3', M3, [x2], tmp_path)

M4 = def_model('reshape', shape=(-3, 0))
op_export_test('reshape_spec_4', M4, [x2], tmp_path)

x3 = mx.nd.ones((1, 2, 3, 4, 5, 6), dtype=dtype)
M5 = def_model('reshape', shape=(0, 0, -3, -3))
op_export_test('reshape_spec_5', M5, [x3], tmp_path)


@pytest.mark.parametrize('dtype', ['int32', 'int64'])
def test_onnx_export_embedding(tmp_path, dtype):
x = mx.nd.array([[ 1., 3.],
Expand Down

0 comments on commit 787416b

Please sign in to comment.