Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
broadcast_like
Browse files Browse the repository at this point in the history
  • Loading branch information
Wei Chu committed Jan 26, 2021
1 parent 49b26b6 commit c075a7c
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
42 changes: 42 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3275,3 +3275,45 @@ def convert_gather_nd(node, **kwargs):
]

return nodes


@mx_op.register("broadcast_like")
def convert_broadcast_like(node, **kwargs):
"""Map MXNet's broadcast_like operator attributes to onnx's operator.
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

lhs = input_nodes[0]
rhs = input_nodes[1]
lhs_axes = convert_string_to_list(str(attrs.get('lhs_axes', 'None')))
rhs_axes = convert_string_to_list(str(attrs.get('rhs_axes', 'None')))

if lhs_axes[0] is None or rhs_axes[0] is None:
nodes = [
make_node('Shape', [rhs], [name+'_rhs_shape']),
make_node('Expand', [lhs, name+'_rhs_shape'], [name], name=name)
]
return nodes

lhs_axes = [[i] for i in lhs_axes]
rhs_axes = [[i] for i in rhs_axes]

nodes = [
create_tensor([0], name+'_0', kwargs['initializer']),
create_tensor(lhs_axes, name+'_lhs_axes', kwargs['initializer']),
create_tensor(rhs_axes, name+'_rhs_axes', kwargs['initializer']),
make_node('Shape', [lhs], [name+'_lhs_shape']),
make_node('Shape', [rhs], [name+'_rhs_shape']),
make_node('Shape', [name+'_lhs_shape'], [name+'_lhs_dim']),
make_node('Less', [name+'_lhs_axes', name+'_0'], [name+'_less']),
make_node('Cast', [name+'_less'], [name+'_mask'], to=int(onnx.TensorProto.INT64)),
make_node('Mul', [name+'_mask', name+'_lhs_dim'], [name+'_mul']),
make_node('Add', [name+'_lhs_axes', name+'_mul'], [name+'_lhs_axes_positive']),
make_node('GatherND', [name+'_rhs_shape', name+'_rhs_axes'], [name+'_gather']),
make_node('ScatterND', [name+'_lhs_shape', name+'_lhs_axes_positive', name+'_gather'],
[name+'_scatter']),
make_node('Expand', [lhs, name+'_scatter'], [name], name=name)
]

return nodes
11 changes: 11 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,14 @@ def test_onnx_export_gather_nd(tmp_path, dtype):
M2 = def_model('gather_nd')
op_export_test('gather_nd2', M2, [x2, y2], tmp_path)


@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64'])
@pytest.mark.parametrize('lhs_axes', [[1, 3], [3, 1], [-2, -4], [-4, -2]])
@pytest.mark.parametrize('rhs_axes', [[1, 3], [3, 1], [-2, -4], [-4, -2]])
def test_onnx_export_broadcast_like(tmp_path, dtype, lhs_axes, rhs_axes):
x = mx.random.normal(0, 10, (2, 1, 1, 1, 6)).astype(dtype)
y = mx.random.normal(0, 10, (2, 3, 4, 5, 6)).astype(dtype)
M1 = def_model('broadcast_like')
op_export_test('broadcast_like1', M1, [x, y], tmp_path)
M2 = def_model('broadcast_like', lhs_axes=lhs_axes, rhs_axes=rhs_axes)
op_export_test('broadcast_like2', M2, [x, y], tmp_path)

0 comments on commit c075a7c

Please sign in to comment.