From d974cba400add86d07beee4ac64121992401709b Mon Sep 17 00:00:00 2001 From: waytrue17 <52505574+waytrue17@users.noreply.github.com> Date: Mon, 15 Feb 2021 15:02:19 -0800 Subject: [PATCH] [v1.x] ONNX export rewrite tile (#19868) * fix tile * fix sanity Co-authored-by: Wei Chu --- .../contrib/onnx/mx2onnx/_op_translations.py | 47 +++++++++---------- tests/python-pytest/onnx/test_operators.py | 8 ++++ 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index e5bbb1f66000..c5e42a06a343 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2262,37 +2262,32 @@ def convert_tile(node, **kwargs): """Map MXNet's Tile operator attributes to onnx's Tile operator and return the created node. """ + from onnx.helper import make_node name, input_nodes, attrs = get_inputs(node, kwargs) - reps_list = convert_string_to_list(attrs["reps"]) - - initializer = kwargs["initializer"] - reps_shape_np = np.array(reps_list, dtype='int64') - data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[reps_shape_np.dtype] - dims = np.shape(reps_shape_np) - - output_shape_name = "reps_attr_tensor" + str(kwargs["idx"]) - tensor_node = onnx.helper.make_tensor_value_info(output_shape_name, data_type, dims) + data = input_nodes[0] + reps = convert_string_to_list(attrs["reps"]) - initializer.append( - onnx.helper.make_tensor( - name=output_shape_name, - data_type=data_type, - dims=dims, - vals=reps_list, - raw=False, - ) - ) + create_tensor([0], name+'_0', kwargs['initializer']) + create_tensor([1], name+'_1', kwargs['initializer']) + create_tensor(reps, name+'_reps', kwargs['initializer'], dtype='int64') + create_tensor([len(reps)], name+'_reps_len', kwargs['initializer']) - input_nodes.append(output_shape_name) - tile_node = onnx.helper.make_node( - "Tile", - input_nodes, - [name], - name=name - ) + nodes = [ + make_node('Shape', [data], [name+'_data_shape']), + make_node('Shape', [name+'_data_shape'], [name+'_data_dim']), + make_node('Max', [name+'_data_dim', name+'_reps_len'], [name+'_max']), + make_node('Sub', [name+'_max', name+'_data_dim'], [name+'_data_diff']), + make_node('Concat', [name+'_data_diff', name+'_0'], [name+'_concat0_out'], axis=0), + make_node('Pad', [name+'_data_shape', name+'_concat0_out', name+'_1'], [name+'_data_shape_pad']), + make_node('Reshape', [data, name+'_data_shape_pad'], [name+'_data']), + make_node('Sub', [name+'_max', name+'_reps_len'], [name+'_reps_diff']), + make_node('Concat', [name+'_reps_diff', name+'_0'], [name+'_concat1_out'], axis=0), + make_node('Pad', [name+'_reps', name+'_concat1_out', name+'_1'], [name+'_reps_pad']), + make_node('Tile', [name+'_data', name+'_reps_pad'], [name], name=name), + ] - return [tensor_node, tile_node] + return nodes @mx_op.register("broadcast_to") diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 6a5cfd7f0b17..90ec8f58b250 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -1131,3 +1131,11 @@ def test_onnx_export_argsort(tmp_path, dtype, axis, is_ascend, dtype_i): kwargs['is_ascend'] = is_ascend M = def_model('argsort', axis=axis, dtype=dtype_i, **kwargs) op_export_test('argsort', M, [A], tmp_path) + + +@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) +@pytest.mark.parametrize('reps', [(2, 3), (2, ), (2, 3, 4)]) +def test_onnx_export_tile(tmp_path, dtype, reps): + x = mx.nd.random.normal(0, 100, (5, 6)).astype(dtype) + M = def_model('tile', reps=reps) + op_export_test('tile', M, [x], tmp_path)