From 91d986ee54af9816a832cff885107ff80480d015 Mon Sep 17 00:00:00 2001 From: Wei Chu Date: Tue, 26 Jan 2021 11:49:46 -0800 Subject: [PATCH] split usecases --- .../contrib/onnx/mx2onnx/_op_translations.py | 47 +++++++++++++++---- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index ec82895e59b5..8bdfa913b785 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -3301,11 +3301,12 @@ def convert_batch_dot(node, **kwargs): create_tensor([0], name+'_0f', kwargs['initializer'], dtype=dtype), ] - if transpose_a == 'False': + if transpose_a == 'False' and transpose_b == 'False': nodes += [ - make_node('Add', [lhs, name+'_0f'], [name+'_lhs']), + make_node('MatMul', [lhs, rhs], [name]), ] - else: + + elif transpose_a != 'False' and transpose_a == 'False': nodes += [ make_node('Shape', [lhs], [name+'_lhs_shape']), make_node('Shape', [name+'_lhs_shape'], [name+'_lhs_dim']), @@ -3321,14 +3322,45 @@ def convert_batch_dot(node, **kwargs): [name+'_lhs_slice2']), make_node('Concat', [name+'_lhs_slice0', name+'_lhs_slice2'], [name+'_lhs_concat2'], axis=0), make_node('Reshape', [name+'_lhs_3d_transpose', name+'_lhs_concat2'], [name+'_lhs']), + make_node('MatMul', [name+'_lhs', rhs], [name]), ] - if transpose_b == 'False': + elif transpose_a == 'False' and transpose_b != 'False': nodes += [ - make_node('Add', [rhs, name+'_0f'], [name+'_rhs']), + make_node('Shape', [rhs], [name+'_rhs_shape']), + make_node('Shape', [name+'_rhs_shape'], [name+'_rhs_dim']), + make_node('Slice', [name+'_rhs_shape', name+'_0', name+'_-2'], + [name+'_rhs_slice0']), + make_node('Slice', [name+'_rhs_shape', name+'_-2', name+'_100'], + [name+'_rhs_slice1']), + make_node('Concat', [name+'_-1', name+'_rhs_slice1'], [name+'_rhs_concat1'], axis=0), + make_node('Reshape', [rhs, name+'_rhs_concat1'], [name+'_rhs_3d']), + make_node('Transpose', [name+'_rhs_3d'], [name+'_rhs_3d_transpose'], perm=perm), + make_node('Shape', [name+'_rhs_3d_transpose'], [name+'_rhs_shape_3d']), + make_node('Slice', [name+'_rhs_shape_3d', name+'_-2', name+'_100'], + [name+'_rhs_slice2']), + make_node('Concat', [name+'_rhs_slice0', name+'_rhs_slice2'], [name+'_rhs_concat2'], axis=0), + make_node('Reshape', [name+'_rhs_3d_transpose', name+'_rhs_concat2'], [name+'_rhs']), + make_node('MatMul', [lhs, name+'_rhs'], [name]), ] + else: nodes += [ + make_node('Shape', [lhs], [name+'_lhs_shape']), + make_node('Shape', [name+'_lhs_shape'], [name+'_lhs_dim']), + make_node('Slice', [name+'_lhs_shape', name+'_0', name+'_-2'], + [name+'_lhs_slice0']), + make_node('Slice', [name+'_lhs_shape', name+'_-2', name+'_100'], + [name+'_lhs_slice1']), + make_node('Concat', [name+'_-1', name+'_lhs_slice1'], [name+'_lhs_concat1'], axis=0), + make_node('Reshape', [lhs, name+'_lhs_concat1'], [name+'_lhs_3d']), + make_node('Transpose', [name+'_lhs_3d'], [name+'_lhs_3d_transpose'], perm=perm), + make_node('Shape', [name+'_lhs_3d_transpose'], [name+'_lhs_shape_3d']), + make_node('Slice', [name+'_lhs_shape_3d', name+'_-2', name+'_100'], + [name+'_lhs_slice2']), + make_node('Concat', [name+'_lhs_slice0', name+'_lhs_slice2'], [name+'_lhs_concat2'], axis=0), + make_node('Reshape', [name+'_lhs_3d_transpose', name+'_lhs_concat2'], [name+'_lhs']), + make_node('Shape', [rhs], [name+'_rhs_shape']), make_node('Shape', [name+'_rhs_shape'], [name+'_rhs_dim']), make_node('Slice', [name+'_rhs_shape', name+'_0', name+'_-2'], @@ -3343,9 +3375,8 @@ def convert_batch_dot(node, **kwargs): [name+'_rhs_slice2']), make_node('Concat', [name+'_rhs_slice0', name+'_rhs_slice2'], [name+'_rhs_concat2'], axis=0), make_node('Reshape', [name+'_rhs_3d_transpose', name+'_rhs_concat2'], [name+'_rhs']), + make_node('MatMul', [name+'_lhs', name+'_rhs'], [name]), ] - nodes += [ - make_node('MatMul', [name+'_lhs', name+'_rhs'], [name]), - ] return nodes +