Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] Onnx export support for batch_dot (#19775)
Browse files Browse the repository at this point in the history
* batch_dot

* remove print

* update default values

* fix sanity

* split usecases

* fix sanity

* refactor constant nodes

* remove dtype

* remove in_type

Co-authored-by: Wei Chu <[email protected]>
  • Loading branch information
waytrue17 and Wei Chu authored Feb 2, 2021
1 parent 5637325 commit 6c025e0
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 0 deletions.
101 changes: 101 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3615,3 +3615,104 @@ def convert_broadcast_like(node, **kwargs):
]

return nodes


@mx_op.register("batch_dot")
def convert_batch_dot(node, **kwargs):
"""Map MXNet's batch_dot operator attributes to onnx's operator.
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

lhs = input_nodes[0]
rhs = input_nodes[1]
transpose_a = str(attrs.get('transpose_a', 'False'))
transpose_b = str(attrs.get('transpose_b', 'False'))
perm = [0, 2, 1]

if transpose_a == 'False' and transpose_b == 'False':
nodes = [
make_node('MatMul', [lhs, rhs], [name]),
]
return nodes

nodes = [
create_tensor([-2], name+'_-2', kwargs['initializer']),
create_tensor([-1], name+'_-1', kwargs['initializer']),
create_tensor([0], name+'_0', kwargs['initializer']),
create_tensor([100], name+'_100', kwargs['initializer']),
]

if transpose_a != 'False' and transpose_b == 'False':
nodes += [
make_node('Shape', [lhs], [name+'_lhs_shape']),
make_node('Shape', [name+'_lhs_shape'], [name+'_lhs_dim']),
make_node('Slice', [name+'_lhs_shape', name+'_0', name+'_-2'],
[name+'_lhs_slice0']),
make_node('Slice', [name+'_lhs_shape', name+'_-2', name+'_100'],
[name+'_lhs_slice1']),
make_node('Concat', [name+'_-1', name+'_lhs_slice1'], [name+'_lhs_concat1'], axis=0),
make_node('Reshape', [lhs, name+'_lhs_concat1'], [name+'_lhs_3d']),
make_node('Transpose', [name+'_lhs_3d'], [name+'_lhs_3d_transpose'], perm=perm),
make_node('Shape', [name+'_lhs_3d_transpose'], [name+'_lhs_shape_3d']),
make_node('Slice', [name+'_lhs_shape_3d', name+'_-2', name+'_100'],
[name+'_lhs_slice2']),
make_node('Concat', [name+'_lhs_slice0', name+'_lhs_slice2'], [name+'_lhs_concat2'], axis=0),
make_node('Reshape', [name+'_lhs_3d_transpose', name+'_lhs_concat2'], [name+'_lhs']),
make_node('MatMul', [name+'_lhs', rhs], [name]),
]

elif transpose_a == 'False' and transpose_b != 'False':
nodes += [
make_node('Shape', [rhs], [name+'_rhs_shape']),
make_node('Shape', [name+'_rhs_shape'], [name+'_rhs_dim']),
make_node('Slice', [name+'_rhs_shape', name+'_0', name+'_-2'],
[name+'_rhs_slice0']),
make_node('Slice', [name+'_rhs_shape', name+'_-2', name+'_100'],
[name+'_rhs_slice1']),
make_node('Concat', [name+'_-1', name+'_rhs_slice1'], [name+'_rhs_concat1'], axis=0),
make_node('Reshape', [rhs, name+'_rhs_concat1'], [name+'_rhs_3d']),
make_node('Transpose', [name+'_rhs_3d'], [name+'_rhs_3d_transpose'], perm=perm),
make_node('Shape', [name+'_rhs_3d_transpose'], [name+'_rhs_shape_3d']),
make_node('Slice', [name+'_rhs_shape_3d', name+'_-2', name+'_100'],
[name+'_rhs_slice2']),
make_node('Concat', [name+'_rhs_slice0', name+'_rhs_slice2'], [name+'_rhs_concat2'], axis=0),
make_node('Reshape', [name+'_rhs_3d_transpose', name+'_rhs_concat2'], [name+'_rhs']),
make_node('MatMul', [lhs, name+'_rhs'], [name]),
]

else:
nodes += [
make_node('Shape', [lhs], [name+'_lhs_shape']),
make_node('Shape', [name+'_lhs_shape'], [name+'_lhs_dim']),
make_node('Slice', [name+'_lhs_shape', name+'_0', name+'_-2'],
[name+'_lhs_slice0']),
make_node('Slice', [name+'_lhs_shape', name+'_-2', name+'_100'],
[name+'_lhs_slice1']),
make_node('Concat', [name+'_-1', name+'_lhs_slice1'], [name+'_lhs_concat1'], axis=0),
make_node('Reshape', [lhs, name+'_lhs_concat1'], [name+'_lhs_3d']),
make_node('Transpose', [name+'_lhs_3d'], [name+'_lhs_3d_transpose'], perm=perm),
make_node('Shape', [name+'_lhs_3d_transpose'], [name+'_lhs_shape_3d']),
make_node('Slice', [name+'_lhs_shape_3d', name+'_-2', name+'_100'],
[name+'_lhs_slice2']),
make_node('Concat', [name+'_lhs_slice0', name+'_lhs_slice2'], [name+'_lhs_concat2'], axis=0),
make_node('Reshape', [name+'_lhs_3d_transpose', name+'_lhs_concat2'], [name+'_lhs']),

make_node('Shape', [rhs], [name+'_rhs_shape']),
make_node('Shape', [name+'_rhs_shape'], [name+'_rhs_dim']),
make_node('Slice', [name+'_rhs_shape', name+'_0', name+'_-2'],
[name+'_rhs_slice0']),
make_node('Slice', [name+'_rhs_shape', name+'_-2', name+'_100'],
[name+'_rhs_slice1']),
make_node('Concat', [name+'_-1', name+'_rhs_slice1'], [name+'_rhs_concat1'], axis=0),
make_node('Reshape', [rhs, name+'_rhs_concat1'], [name+'_rhs_3d']),
make_node('Transpose', [name+'_rhs_3d'], [name+'_rhs_3d_transpose'], perm=perm),
make_node('Shape', [name+'_rhs_3d_transpose'], [name+'_rhs_shape_3d']),
make_node('Slice', [name+'_rhs_shape_3d', name+'_-2', name+'_100'],
[name+'_rhs_slice2']),
make_node('Concat', [name+'_rhs_slice0', name+'_rhs_slice2'], [name+'_rhs_concat2'], axis=0),
make_node('Reshape', [name+'_rhs_3d_transpose', name+'_rhs_concat2'], [name+'_rhs']),
make_node('MatMul', [name+'_lhs', name+'_rhs'], [name]),
]

return nodes
13 changes: 13 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,3 +686,16 @@ def test_onnx_export_broadcast_like(tmp_path, dtype, lhs_axes, rhs_axes):
op_export_test('broadcast_like1', M1, [x, y], tmp_path)
M2 = def_model('broadcast_like', lhs_axes=lhs_axes, rhs_axes=rhs_axes)
op_export_test('broadcast_like2', M2, [x, y], tmp_path)

@pytest.mark.parametrize('dtype', ['float32', 'float64'])
@pytest.mark.parametrize('transpose_a', [True, False])
@pytest.mark.parametrize('transpose_b', [True, False])
def test_onnx_export_batch_dot(tmp_path, dtype, transpose_a, transpose_b):
x1 = mx.nd.random.normal(0, 10, (2, 3, 4, 5, 6), dtype=dtype)
y1 = mx.nd.random.normal(0, 10, (2, 3, 4, 6, 5), dtype=dtype)
M1 = def_model('batch_dot')
op_export_test('batch_dot1', M1, [x1, y1], tmp_path)
x2 = mx.nd.random.normal(0, 10, (2, 3, 4, 5, 5), dtype=dtype)
y2 = mx.nd.random.normal(0, 10, (2, 3, 4, 5, 5), dtype=dtype)
M2 = def_model('batch_dot', transpose_a=transpose_a, transpose_b=transpose_b)
op_export_test('batch_dot2', M2, [x2, y2], tmp_path)

0 comments on commit 6c025e0

Please sign in to comment.