From 05de69b7277a2ccf63d34fd7322ace0a75f4f113 Mon Sep 17 00:00:00 2001 From: waytrue17 <52505574+waytrue17@users.noreply.github.com> Date: Wed, 27 Jan 2021 14:51:58 -0800 Subject: [PATCH] broadcast_like (#19791) Co-authored-by: Wei Chu --- .../contrib/onnx/mx2onnx/_op_translations.py | 42 +++++++++++++++++++ tests/python-pytest/onnx/test_operators.py | 11 +++++ 2 files changed, 53 insertions(+) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 46767980f4c3..0ca1ef50ca2f 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3348,3 +3348,45 @@ def convert_slice_like(node, **kwargs): ] return nodes + + +@mx_op.register("broadcast_like") +def convert_broadcast_like(node, **kwargs): + """Map MXNet's broadcast_like operator attributes to onnx's operator. + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + + lhs = input_nodes[0] + rhs = input_nodes[1] + lhs_axes = convert_string_to_list(str(attrs.get('lhs_axes', 'None'))) + rhs_axes = convert_string_to_list(str(attrs.get('rhs_axes', 'None'))) + + if lhs_axes[0] is None or rhs_axes[0] is None: + nodes = [ + make_node('Shape', [rhs], [name+'_rhs_shape']), + make_node('Expand', [lhs, name+'_rhs_shape'], [name], name=name) + ] + return nodes + + lhs_axes = [[i] for i in lhs_axes] + rhs_axes = [[i] for i in rhs_axes] + + nodes = [ + create_tensor([0], name+'_0', kwargs['initializer']), + create_tensor(lhs_axes, name+'_lhs_axes', kwargs['initializer']), + create_tensor(rhs_axes, name+'_rhs_axes', kwargs['initializer']), + make_node('Shape', [lhs], [name+'_lhs_shape']), + make_node('Shape', [rhs], [name+'_rhs_shape']), + make_node('Shape', [name+'_lhs_shape'], [name+'_lhs_dim']), + make_node('Less', [name+'_lhs_axes', name+'_0'], [name+'_less']), + make_node('Cast', [name+'_less'], [name+'_mask'], to=int(onnx.TensorProto.INT64)), + make_node('Mul', [name+'_mask', name+'_lhs_dim'], [name+'_mul']), + make_node('Add', [name+'_lhs_axes', name+'_mul'], [name+'_lhs_axes_positive']), + make_node('GatherND', [name+'_rhs_shape', name+'_rhs_axes'], [name+'_gather']), + make_node('ScatterND', [name+'_lhs_shape', name+'_lhs_axes_positive', name+'_gather'], + [name+'_scatter']), + make_node('Expand', [lhs, name+'_scatter'], [name], name=name) + ] + + return nodes diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 21d46703f5e8..71629728d14c 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -557,3 +557,14 @@ def test_onnx_export_slice_like(tmp_path, dtype, axes): op_export_test('slice_like_2', M, [x, y2], tmp_path) op_export_test('slice_like_3', M, [x, y3], tmp_path) + +@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) +@pytest.mark.parametrize('lhs_axes', [[1, 3], [3, 1], [-2, -4], [-4, -2]]) +@pytest.mark.parametrize('rhs_axes', [[1, 3], [3, 1], [-2, -4], [-4, -2]]) +def test_onnx_export_broadcast_like(tmp_path, dtype, lhs_axes, rhs_axes): + x = mx.random.normal(0, 10, (2, 1, 1, 1, 6)).astype(dtype) + y = mx.random.normal(0, 10, (2, 3, 4, 5, 6)).astype(dtype) + M1 = def_model('broadcast_like') + op_export_test('broadcast_like1', M1, [x, y], tmp_path) + M2 = def_model('broadcast_like', lhs_axes=lhs_axes, rhs_axes=rhs_axes) + op_export_test('broadcast_like2', M2, [x, y], tmp_path)