Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] ONNX export rewrite tile (#19868)
Browse files Browse the repository at this point in the history
* fix tile

* fix sanity

Co-authored-by: Wei Chu <[email protected]>
  • Loading branch information
waytrue17 and Wei Chu committed Feb 15, 2021
1 parent 26afc44 commit d974cba
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 26 deletions.
47 changes: 21 additions & 26 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2262,37 +2262,32 @@ def convert_tile(node, **kwargs):
"""Map MXNet's Tile operator attributes to onnx's Tile
operator and return the created node.
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

reps_list = convert_string_to_list(attrs["reps"])

initializer = kwargs["initializer"]
reps_shape_np = np.array(reps_list, dtype='int64')
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[reps_shape_np.dtype]
dims = np.shape(reps_shape_np)

output_shape_name = "reps_attr_tensor" + str(kwargs["idx"])
tensor_node = onnx.helper.make_tensor_value_info(output_shape_name, data_type, dims)
data = input_nodes[0]
reps = convert_string_to_list(attrs["reps"])

initializer.append(
onnx.helper.make_tensor(
name=output_shape_name,
data_type=data_type,
dims=dims,
vals=reps_list,
raw=False,
)
)
create_tensor([0], name+'_0', kwargs['initializer'])
create_tensor([1], name+'_1', kwargs['initializer'])
create_tensor(reps, name+'_reps', kwargs['initializer'], dtype='int64')
create_tensor([len(reps)], name+'_reps_len', kwargs['initializer'])

input_nodes.append(output_shape_name)
tile_node = onnx.helper.make_node(
"Tile",
input_nodes,
[name],
name=name
)
nodes = [
make_node('Shape', [data], [name+'_data_shape']),
make_node('Shape', [name+'_data_shape'], [name+'_data_dim']),
make_node('Max', [name+'_data_dim', name+'_reps_len'], [name+'_max']),
make_node('Sub', [name+'_max', name+'_data_dim'], [name+'_data_diff']),
make_node('Concat', [name+'_data_diff', name+'_0'], [name+'_concat0_out'], axis=0),
make_node('Pad', [name+'_data_shape', name+'_concat0_out', name+'_1'], [name+'_data_shape_pad']),
make_node('Reshape', [data, name+'_data_shape_pad'], [name+'_data']),
make_node('Sub', [name+'_max', name+'_reps_len'], [name+'_reps_diff']),
make_node('Concat', [name+'_reps_diff', name+'_0'], [name+'_concat1_out'], axis=0),
make_node('Pad', [name+'_reps', name+'_concat1_out', name+'_1'], [name+'_reps_pad']),
make_node('Tile', [name+'_data', name+'_reps_pad'], [name], name=name),
]

return [tensor_node, tile_node]
return nodes


@mx_op.register("broadcast_to")
Expand Down
8 changes: 8 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,3 +1131,11 @@ def test_onnx_export_argsort(tmp_path, dtype, axis, is_ascend, dtype_i):
kwargs['is_ascend'] = is_ascend
M = def_model('argsort', axis=axis, dtype=dtype_i, **kwargs)
op_export_test('argsort', M, [A], tmp_path)


@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
@pytest.mark.parametrize('reps', [(2, 3), (2, ), (2, 3, 4)])
def test_onnx_export_tile(tmp_path, dtype, reps):
x = mx.nd.random.normal(0, 100, (5, 6)).astype(dtype)
M = def_model('tile', reps=reps)
op_export_test('tile', M, [x], tmp_path)

0 comments on commit d974cba

Please sign in to comment.