From 08585fe680cde8f311423e778c4ba8f3cb47adac Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Thu, 21 Jan 2021 15:19:23 -0800 Subject: [PATCH 1/9] batch_dot --- .../contrib/onnx/mx2onnx/_op_translations.py | 75 +++++++++++++++++++ tests/python-pytest/onnx/test_operators.py | 13 ++++ 2 files changed, 88 insertions(+) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 39723321923e..c05810efddc2 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3555,3 +3555,78 @@ def convert_broadcast_like(node, **kwargs): ] return nodes + + +@mx_op.register("batch_dot") +def convert_batch_dot(node, **kwargs): + """Map MXNet's gather_ND operator attributes to onnx's operator. + """ + from onnx.helper import make_node + name, input_nodes, attrs = get_inputs(node, kwargs) + + lhs = input_nodes[0] + rhs = input_nodes[1] + input_type = kwargs['in_type'] + dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type] + transpose_a = str(attrs.get('transpose_a', '0')) + transpose_b = str(attrs.get('transpose_b', '0')) + perm = [0, 2, 1] + + nodes = [ + create_tensor([-2], name+'_-2', kwargs['initializer']), + create_tensor([-1], name+'_-1', kwargs['initializer']), + create_tensor([0], name+'_0', kwargs['initializer']), + create_tensor([2], name+'_2', kwargs['initializer']), + create_tensor([100], name+'_100', kwargs['initializer']), + create_tensor([0], name+'_0f', kwargs['initializer'], dtype=dtype), + ] + + if transpose_a in ['0', 'False']: + nodes += [ + make_node('Add', [lhs, name+'_0f'], [name+'_lhs']), + ] + else: + nodes += [ + make_node('Shape', [lhs], [name+'_lhs_shape']), + make_node('Shape', [name+'_lhs_shape'], [name+'_lhs_dim']), + # make_node('Sub', [name+'_lhs_dim', name+'_2'], [name+'_lhs_sub']), + make_node('Slice', [name+'_lhs_shape', name+'_0', name+'_-2'], + [name+'_lhs_slice0']), + make_node('Slice', [name+'_lhs_shape', name+'_-2', name+'_100'], + [name+'_lhs_slice1']), + make_node('Concat', [name+'_-1', name+'_lhs_slice1'], [name+'_lhs_concat1'], axis=0), + make_node('Reshape', [lhs, name+'_lhs_concat1'], [name+'_lhs_3d']), + make_node('Transpose', [name+'_lhs_3d'], [name+'_lhs_3d_transpose'], perm=perm), + make_node('Shape', [name+'_lhs_3d_transpose'], [name+'_lhs_shape_3d']), + make_node('Slice', [name+'_lhs_shape_3d', name+'_-2', name+'_100'], + [name+'_lhs_slice2']), + make_node('Concat', [name+'_lhs_slice0', name+'_lhs_slice2'], [name+'_lhs_concat2'], axis=0), + make_node('Reshape', [name+'_lhs_3d_transpose', name+'_lhs_concat2'], [name+'_lhs']), + ] + + if transpose_b in ['0', 'False']: + nodes += [ + make_node('Add', [rhs, name+'_0f'], [name+'_rhs']), + ] + else: + nodes += [ + make_node('Shape', [rhs], [name+'_rhs_shape']), + make_node('Shape', [name+'_rhs_shape'], [name+'_rhs_dim']), + make_node('Slice', [name+'_rhs_shape', name+'_0', name+'_-2'], + [name+'_rhs_slice0']), + make_node('Slice', [name+'_rhs_shape', name+'_-2', name+'_100'], + [name+'_rhs_slice1']), + make_node('Concat', [name+'_-1', name+'_rhs_slice1'], [name+'_rhs_concat1'], axis=0), + make_node('Reshape', [rhs, name+'_rhs_concat1'], [name+'_rhs_3d']), + make_node('Transpose', [name+'_rhs_3d'], [name+'_rhs_3d_transpose'], perm=perm), + make_node('Shape', [name+'_rhs_3d_transpose'], [name+'_rhs_shape_3d']), + make_node('Slice', [name+'_rhs_shape_3d', name+'_-2', name+'_100'], + [name+'_rhs_slice2']), + make_node('Concat', [name+'_rhs_slice0', name+'_rhs_slice2'], [name+'_rhs_concat2'], axis=0), + make_node('Reshape', [name+'_rhs_3d_transpose', name+'_rhs_concat2'], [name+'_rhs']), + ] + + nodes += [ + make_node('MatMul', [name+'_lhs', name+'_rhs'], [name]), + ] + return nodes diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index 204fe57f5930..9686103791c9 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -644,3 +644,16 @@ def test_onnx_export_broadcast_like(tmp_path, dtype, lhs_axes, rhs_axes): op_export_test('broadcast_like1', M1, [x, y], tmp_path) M2 = def_model('broadcast_like', lhs_axes=lhs_axes, rhs_axes=rhs_axes) op_export_test('broadcast_like2', M2, [x, y], tmp_path) + +@pytest.mark.parametrize('dtype', ['float32', 'float64']) +@pytest.mark.parametrize('transpose_a', [True, False]) +@pytest.mark.parametrize('transpose_b', [True, False]) +def test_onnx_export_batch_dot(tmp_path, dtype, transpose_a, transpose_b): + x1 = mx.nd.random.normal(0, 10, (2, 3, 4, 5, 6), dtype=dtype) + y1 = mx.nd.random.normal(0, 10, (2, 3, 4, 6, 5), dtype=dtype) + M1 = def_model('batch_dot') + op_export_test('batch_dot1', M1, [x1, y1], tmp_path) + x2 = mx.nd.random.normal(0, 10, (2, 3, 4, 5, 5), dtype=dtype) + y2 = mx.nd.random.normal(0, 10, (2, 3, 4, 5, 5), dtype=dtype) + M2 = def_model('batch_dot', transpose_a=transpose_a, transpose_b=transpose_b) + op_export_test('batch_dot2', M2, [x2, y2], tmp_path) From 5a81f50d1ce363de32ecebaffc37ad53fe40bafe Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Fri, 22 Jan 2021 14:00:58 -0800 Subject: [PATCH 2/9] remove print --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index c05810efddc2..9ee9525bc368 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3559,7 +3559,7 @@ def convert_broadcast_like(node, **kwargs): @mx_op.register("batch_dot") def convert_batch_dot(node, **kwargs): - """Map MXNet's gather_ND operator attributes to onnx's operator. + """Map MXNet's batch_dot operator attributes to onnx's operator. """ from onnx.helper import make_node name, input_nodes, attrs = get_inputs(node, kwargs) @@ -3581,7 +3581,7 @@ def convert_batch_dot(node, **kwargs): create_tensor([0], name+'_0f', kwargs['initializer'], dtype=dtype), ] - if transpose_a in ['0', 'False']: + if transpose_a == 'False': nodes += [ make_node('Add', [lhs, name+'_0f'], [name+'_lhs']), ] @@ -3589,7 +3589,6 @@ def convert_batch_dot(node, **kwargs): nodes += [ make_node('Shape', [lhs], [name+'_lhs_shape']), make_node('Shape', [name+'_lhs_shape'], [name+'_lhs_dim']), - # make_node('Sub', [name+'_lhs_dim', name+'_2'], [name+'_lhs_sub']), make_node('Slice', [name+'_lhs_shape', name+'_0', name+'_-2'], [name+'_lhs_slice0']), make_node('Slice', [name+'_lhs_shape', name+'_-2', name+'_100'], @@ -3604,7 +3603,7 @@ def convert_batch_dot(node, **kwargs): make_node('Reshape', [name+'_lhs_3d_transpose', name+'_lhs_concat2'], [name+'_lhs']), ] - if transpose_b in ['0', 'False']: + if transpose_b =='False': nodes += [ make_node('Add', [rhs, name+'_0f'], [name+'_rhs']), ] From 2904abe9ce497210f755e2392a38a0fca4b4b5e4 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Fri, 22 Jan 2021 14:02:54 -0800 Subject: [PATCH 3/9] update default values --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 9ee9525bc368..9aff700dc40e 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3568,8 +3568,8 @@ def convert_batch_dot(node, **kwargs): rhs = input_nodes[1] input_type = kwargs['in_type'] dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type] - transpose_a = str(attrs.get('transpose_a', '0')) - transpose_b = str(attrs.get('transpose_b', '0')) + transpose_a = str(attrs.get('transpose_a', 'False')) + transpose_b = str(attrs.get('transpose_b', 'False')) perm = [0, 2, 1] nodes = [ From b78db15cf604374a9b0310c0907d54ccdf685ce8 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Fri, 22 Jan 2021 14:40:25 -0800 Subject: [PATCH 4/9] fix sanity --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 9aff700dc40e..b5a825d6dc3f 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3603,7 +3603,7 @@ def convert_batch_dot(node, **kwargs): make_node('Reshape', [name+'_lhs_3d_transpose', name+'_lhs_concat2'], [name+'_lhs']), ] - if transpose_b =='False': + if transpose_b == 'False': nodes += [ make_node('Add', [rhs, name+'_0f'], [name+'_rhs']), ] From dec8fcd7e09d31d648847c25a3208fe03afd5f6c Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 26 Jan 2021 11:49:46 -0800 Subject: [PATCH 5/9] split usecases --- .../contrib/onnx/mx2onnx/_op_translations.py | 47 +++++++++++++++---- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index b5a825d6dc3f..84447d734f4a 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3581,11 +3581,12 @@ def convert_batch_dot(node, **kwargs): create_tensor([0], name+'_0f', kwargs['initializer'], dtype=dtype), ] - if transpose_a == 'False': + if transpose_a == 'False' and transpose_b == 'False': nodes += [ - make_node('Add', [lhs, name+'_0f'], [name+'_lhs']), + make_node('MatMul', [lhs, rhs], [name]), ] - else: + + elif transpose_a != 'False' and transpose_a == 'False': nodes += [ make_node('Shape', [lhs], [name+'_lhs_shape']), make_node('Shape', [name+'_lhs_shape'], [name+'_lhs_dim']), @@ -3601,14 +3602,45 @@ def convert_batch_dot(node, **kwargs): [name+'_lhs_slice2']), make_node('Concat', [name+'_lhs_slice0', name+'_lhs_slice2'], [name+'_lhs_concat2'], axis=0), make_node('Reshape', [name+'_lhs_3d_transpose', name+'_lhs_concat2'], [name+'_lhs']), + make_node('MatMul', [name+'_lhs', rhs], [name]), ] - if transpose_b == 'False': + elif transpose_a == 'False' and transpose_b != 'False': nodes += [ - make_node('Add', [rhs, name+'_0f'], [name+'_rhs']), + make_node('Shape', [rhs], [name+'_rhs_shape']), + make_node('Shape', [name+'_rhs_shape'], [name+'_rhs_dim']), + make_node('Slice', [name+'_rhs_shape', name+'_0', name+'_-2'], + [name+'_rhs_slice0']), + make_node('Slice', [name+'_rhs_shape', name+'_-2', name+'_100'], + [name+'_rhs_slice1']), + make_node('Concat', [name+'_-1', name+'_rhs_slice1'], [name+'_rhs_concat1'], axis=0), + make_node('Reshape', [rhs, name+'_rhs_concat1'], [name+'_rhs_3d']), + make_node('Transpose', [name+'_rhs_3d'], [name+'_rhs_3d_transpose'], perm=perm), + make_node('Shape', [name+'_rhs_3d_transpose'], [name+'_rhs_shape_3d']), + make_node('Slice', [name+'_rhs_shape_3d', name+'_-2', name+'_100'], + [name+'_rhs_slice2']), + make_node('Concat', [name+'_rhs_slice0', name+'_rhs_slice2'], [name+'_rhs_concat2'], axis=0), + make_node('Reshape', [name+'_rhs_3d_transpose', name+'_rhs_concat2'], [name+'_rhs']), + make_node('MatMul', [lhs, name+'_rhs'], [name]), ] + else: nodes += [ + make_node('Shape', [lhs], [name+'_lhs_shape']), + make_node('Shape', [name+'_lhs_shape'], [name+'_lhs_dim']), + make_node('Slice', [name+'_lhs_shape', name+'_0', name+'_-2'], + [name+'_lhs_slice0']), + make_node('Slice', [name+'_lhs_shape', name+'_-2', name+'_100'], + [name+'_lhs_slice1']), + make_node('Concat', [name+'_-1', name+'_lhs_slice1'], [name+'_lhs_concat1'], axis=0), + make_node('Reshape', [lhs, name+'_lhs_concat1'], [name+'_lhs_3d']), + make_node('Transpose', [name+'_lhs_3d'], [name+'_lhs_3d_transpose'], perm=perm), + make_node('Shape', [name+'_lhs_3d_transpose'], [name+'_lhs_shape_3d']), + make_node('Slice', [name+'_lhs_shape_3d', name+'_-2', name+'_100'], + [name+'_lhs_slice2']), + make_node('Concat', [name+'_lhs_slice0', name+'_lhs_slice2'], [name+'_lhs_concat2'], axis=0), + make_node('Reshape', [name+'_lhs_3d_transpose', name+'_lhs_concat2'], [name+'_lhs']), + make_node('Shape', [rhs], [name+'_rhs_shape']), make_node('Shape', [name+'_rhs_shape'], [name+'_rhs_dim']), make_node('Slice', [name+'_rhs_shape', name+'_0', name+'_-2'], @@ -3623,9 +3655,8 @@ def convert_batch_dot(node, **kwargs): [name+'_rhs_slice2']), make_node('Concat', [name+'_rhs_slice0', name+'_rhs_slice2'], [name+'_rhs_concat2'], axis=0), make_node('Reshape', [name+'_rhs_3d_transpose', name+'_rhs_concat2'], [name+'_rhs']), + make_node('MatMul', [name+'_lhs', name+'_rhs'], [name]), ] - nodes += [ - make_node('MatMul', [name+'_lhs', name+'_rhs'], [name]), - ] return nodes + From b2117592e863c390fd4db9e25902cb8bc964ecc0 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Fri, 29 Jan 2021 20:09:01 -0800 Subject: [PATCH 6/9] fix sanity --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 84447d734f4a..07fa7aa016ef 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3659,4 +3659,3 @@ def convert_batch_dot(node, **kwargs): ] return nodes - From f22c0ee61e44b0cf47706940e52eab0bbc7b509b Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Sun, 31 Jan 2021 13:50:23 -0800 Subject: [PATCH 7/9] refactor constant nodes --- .../contrib/onnx/mx2onnx/_op_translations.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 07fa7aa016ef..fdc19fde0079 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3572,21 +3572,20 @@ def convert_batch_dot(node, **kwargs): transpose_b = str(attrs.get('transpose_b', 'False')) perm = [0, 2, 1] + if transpose_a == 'False' and transpose_b == 'False': + nodes = [ + make_node('MatMul', [lhs, rhs], [name]), + ] + return nodes + nodes = [ create_tensor([-2], name+'_-2', kwargs['initializer']), create_tensor([-1], name+'_-1', kwargs['initializer']), create_tensor([0], name+'_0', kwargs['initializer']), - create_tensor([2], name+'_2', kwargs['initializer']), create_tensor([100], name+'_100', kwargs['initializer']), - create_tensor([0], name+'_0f', kwargs['initializer'], dtype=dtype), ] - if transpose_a == 'False' and transpose_b == 'False': - nodes += [ - make_node('MatMul', [lhs, rhs], [name]), - ] - - elif transpose_a != 'False' and transpose_a == 'False': + if transpose_a != 'False' and transpose_b == 'False': nodes += [ make_node('Shape', [lhs], [name+'_lhs_shape']), make_node('Shape', [name+'_lhs_shape'], [name+'_lhs_dim']), From 6ac079183c0b4925d4ad64a0bc5c4496fb705c3c Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Mon, 1 Feb 2021 09:53:56 -0800 Subject: [PATCH 8/9] remove dtype --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index fdc19fde0079..78080a9e6a1d 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3567,7 +3567,6 @@ def convert_batch_dot(node, **kwargs): lhs = input_nodes[0] rhs = input_nodes[1] input_type = kwargs['in_type'] - dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type] transpose_a = str(attrs.get('transpose_a', 'False')) transpose_b = str(attrs.get('transpose_b', 'False')) perm = [0, 2, 1] From 29149b33c81dcf6cace911f8f5794110a2c68a6c Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Mon, 1 Feb 2021 12:04:32 -0800 Subject: [PATCH 9/9] remove in_type --- python/mxnet/contrib/onnx/mx2onnx/_op_translations.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 78080a9e6a1d..3eb8f48be52a 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3566,7 +3566,6 @@ def convert_batch_dot(node, **kwargs): lhs = input_nodes[0] rhs = input_nodes[1] - input_type = kwargs['in_type'] transpose_a = str(attrs.get('transpose_a', 'False')) transpose_b = str(attrs.get('transpose_b', 'False')) perm = [0, 2, 1]