Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Onnx export support for batch_dot #19775

Merged
merged 9 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3555,3 +3555,104 @@ def convert_broadcast_like(node, **kwargs):
]

return nodes


@mx_op.register("batch_dot")
def convert_batch_dot(node, **kwargs):
"""Map MXNet's batch_dot operator attributes to onnx's operator.
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

lhs = input_nodes[0]
rhs = input_nodes[1]
transpose_a = str(attrs.get('transpose_a', 'False'))
transpose_b = str(attrs.get('transpose_b', 'False'))
perm = [0, 2, 1]

if transpose_a == 'False' and transpose_b == 'False':
nodes = [
make_node('MatMul', [lhs, rhs], [name]),
]
return nodes

nodes = [
create_tensor([-2], name+'_-2', kwargs['initializer']),
create_tensor([-1], name+'_-1', kwargs['initializer']),
create_tensor([0], name+'_0', kwargs['initializer']),
create_tensor([100], name+'_100', kwargs['initializer']),
]

if transpose_a != 'False' and transpose_b == 'False':
nodes += [
make_node('Shape', [lhs], [name+'_lhs_shape']),
make_node('Shape', [name+'_lhs_shape'], [name+'_lhs_dim']),
make_node('Slice', [name+'_lhs_shape', name+'_0', name+'_-2'],
[name+'_lhs_slice0']),
make_node('Slice', [name+'_lhs_shape', name+'_-2', name+'_100'],
[name+'_lhs_slice1']),
make_node('Concat', [name+'_-1', name+'_lhs_slice1'], [name+'_lhs_concat1'], axis=0),
make_node('Reshape', [lhs, name+'_lhs_concat1'], [name+'_lhs_3d']),
make_node('Transpose', [name+'_lhs_3d'], [name+'_lhs_3d_transpose'], perm=perm),
make_node('Shape', [name+'_lhs_3d_transpose'], [name+'_lhs_shape_3d']),
make_node('Slice', [name+'_lhs_shape_3d', name+'_-2', name+'_100'],
[name+'_lhs_slice2']),
make_node('Concat', [name+'_lhs_slice0', name+'_lhs_slice2'], [name+'_lhs_concat2'], axis=0),
make_node('Reshape', [name+'_lhs_3d_transpose', name+'_lhs_concat2'], [name+'_lhs']),
make_node('MatMul', [name+'_lhs', rhs], [name]),
]

elif transpose_a == 'False' and transpose_b != 'False':
nodes += [
make_node('Shape', [rhs], [name+'_rhs_shape']),
make_node('Shape', [name+'_rhs_shape'], [name+'_rhs_dim']),
make_node('Slice', [name+'_rhs_shape', name+'_0', name+'_-2'],
[name+'_rhs_slice0']),
make_node('Slice', [name+'_rhs_shape', name+'_-2', name+'_100'],
[name+'_rhs_slice1']),
make_node('Concat', [name+'_-1', name+'_rhs_slice1'], [name+'_rhs_concat1'], axis=0),
make_node('Reshape', [rhs, name+'_rhs_concat1'], [name+'_rhs_3d']),
make_node('Transpose', [name+'_rhs_3d'], [name+'_rhs_3d_transpose'], perm=perm),
make_node('Shape', [name+'_rhs_3d_transpose'], [name+'_rhs_shape_3d']),
make_node('Slice', [name+'_rhs_shape_3d', name+'_-2', name+'_100'],
[name+'_rhs_slice2']),
make_node('Concat', [name+'_rhs_slice0', name+'_rhs_slice2'], [name+'_rhs_concat2'], axis=0),
make_node('Reshape', [name+'_rhs_3d_transpose', name+'_rhs_concat2'], [name+'_rhs']),
make_node('MatMul', [lhs, name+'_rhs'], [name]),
]

else:
nodes += [
make_node('Shape', [lhs], [name+'_lhs_shape']),
make_node('Shape', [name+'_lhs_shape'], [name+'_lhs_dim']),
make_node('Slice', [name+'_lhs_shape', name+'_0', name+'_-2'],
[name+'_lhs_slice0']),
make_node('Slice', [name+'_lhs_shape', name+'_-2', name+'_100'],
[name+'_lhs_slice1']),
make_node('Concat', [name+'_-1', name+'_lhs_slice1'], [name+'_lhs_concat1'], axis=0),
make_node('Reshape', [lhs, name+'_lhs_concat1'], [name+'_lhs_3d']),
make_node('Transpose', [name+'_lhs_3d'], [name+'_lhs_3d_transpose'], perm=perm),
make_node('Shape', [name+'_lhs_3d_transpose'], [name+'_lhs_shape_3d']),
make_node('Slice', [name+'_lhs_shape_3d', name+'_-2', name+'_100'],
[name+'_lhs_slice2']),
make_node('Concat', [name+'_lhs_slice0', name+'_lhs_slice2'], [name+'_lhs_concat2'], axis=0),
make_node('Reshape', [name+'_lhs_3d_transpose', name+'_lhs_concat2'], [name+'_lhs']),

make_node('Shape', [rhs], [name+'_rhs_shape']),
make_node('Shape', [name+'_rhs_shape'], [name+'_rhs_dim']),
make_node('Slice', [name+'_rhs_shape', name+'_0', name+'_-2'],
[name+'_rhs_slice0']),
make_node('Slice', [name+'_rhs_shape', name+'_-2', name+'_100'],
[name+'_rhs_slice1']),
make_node('Concat', [name+'_-1', name+'_rhs_slice1'], [name+'_rhs_concat1'], axis=0),
make_node('Reshape', [rhs, name+'_rhs_concat1'], [name+'_rhs_3d']),
make_node('Transpose', [name+'_rhs_3d'], [name+'_rhs_3d_transpose'], perm=perm),
make_node('Shape', [name+'_rhs_3d_transpose'], [name+'_rhs_shape_3d']),
make_node('Slice', [name+'_rhs_shape_3d', name+'_-2', name+'_100'],
[name+'_rhs_slice2']),
make_node('Concat', [name+'_rhs_slice0', name+'_rhs_slice2'], [name+'_rhs_concat2'], axis=0),
make_node('Reshape', [name+'_rhs_3d_transpose', name+'_rhs_concat2'], [name+'_rhs']),
make_node('MatMul', [name+'_lhs', name+'_rhs'], [name]),
]

return nodes
13 changes: 13 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,3 +644,16 @@ def test_onnx_export_broadcast_like(tmp_path, dtype, lhs_axes, rhs_axes):
op_export_test('broadcast_like1', M1, [x, y], tmp_path)
M2 = def_model('broadcast_like', lhs_axes=lhs_axes, rhs_axes=rhs_axes)
op_export_test('broadcast_like2', M2, [x, y], tmp_path)

@pytest.mark.parametrize('dtype', ['float32', 'float64'])
@pytest.mark.parametrize('transpose_a', [True, False])
@pytest.mark.parametrize('transpose_b', [True, False])
def test_onnx_export_batch_dot(tmp_path, dtype, transpose_a, transpose_b):
x1 = mx.nd.random.normal(0, 10, (2, 3, 4, 5, 6), dtype=dtype)
y1 = mx.nd.random.normal(0, 10, (2, 3, 4, 6, 5), dtype=dtype)
M1 = def_model('batch_dot')
op_export_test('batch_dot1', M1, [x1, y1], tmp_path)
x2 = mx.nd.random.normal(0, 10, (2, 3, 4, 5, 5), dtype=dtype)
y2 = mx.nd.random.normal(0, 10, (2, 3, 4, 5, 5), dtype=dtype)
M2 = def_model('batch_dot', transpose_a=transpose_a, transpose_b=transpose_b)
op_export_test('batch_dot2', M2, [x2, y2], tmp_path)