Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
split usecases
Browse files Browse the repository at this point in the history
  • Loading branch information
Wei Chu committed Jan 26, 2021
1 parent 993ea85 commit 91d986e
Showing 1 changed file with 39 additions and 8 deletions.
47 changes: 39 additions & 8 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3301,11 +3301,12 @@ def convert_batch_dot(node, **kwargs):
create_tensor([0], name+'_0f', kwargs['initializer'], dtype=dtype),
]

if transpose_a == 'False':
if transpose_a == 'False' and transpose_b == 'False':
nodes += [
make_node('Add', [lhs, name+'_0f'], [name+'_lhs']),
make_node('MatMul', [lhs, rhs], [name]),
]
else:

elif transpose_a != 'False' and transpose_a == 'False':
nodes += [
make_node('Shape', [lhs], [name+'_lhs_shape']),
make_node('Shape', [name+'_lhs_shape'], [name+'_lhs_dim']),
Expand All @@ -3321,14 +3322,45 @@ def convert_batch_dot(node, **kwargs):
[name+'_lhs_slice2']),
make_node('Concat', [name+'_lhs_slice0', name+'_lhs_slice2'], [name+'_lhs_concat2'], axis=0),
make_node('Reshape', [name+'_lhs_3d_transpose', name+'_lhs_concat2'], [name+'_lhs']),
make_node('MatMul', [name+'_lhs', rhs], [name]),
]

if transpose_b == 'False':
elif transpose_a == 'False' and transpose_b != 'False':
nodes += [
make_node('Add', [rhs, name+'_0f'], [name+'_rhs']),
make_node('Shape', [rhs], [name+'_rhs_shape']),
make_node('Shape', [name+'_rhs_shape'], [name+'_rhs_dim']),
make_node('Slice', [name+'_rhs_shape', name+'_0', name+'_-2'],
[name+'_rhs_slice0']),
make_node('Slice', [name+'_rhs_shape', name+'_-2', name+'_100'],
[name+'_rhs_slice1']),
make_node('Concat', [name+'_-1', name+'_rhs_slice1'], [name+'_rhs_concat1'], axis=0),
make_node('Reshape', [rhs, name+'_rhs_concat1'], [name+'_rhs_3d']),
make_node('Transpose', [name+'_rhs_3d'], [name+'_rhs_3d_transpose'], perm=perm),
make_node('Shape', [name+'_rhs_3d_transpose'], [name+'_rhs_shape_3d']),
make_node('Slice', [name+'_rhs_shape_3d', name+'_-2', name+'_100'],
[name+'_rhs_slice2']),
make_node('Concat', [name+'_rhs_slice0', name+'_rhs_slice2'], [name+'_rhs_concat2'], axis=0),
make_node('Reshape', [name+'_rhs_3d_transpose', name+'_rhs_concat2'], [name+'_rhs']),
make_node('MatMul', [lhs, name+'_rhs'], [name]),
]

else:
nodes += [
make_node('Shape', [lhs], [name+'_lhs_shape']),
make_node('Shape', [name+'_lhs_shape'], [name+'_lhs_dim']),
make_node('Slice', [name+'_lhs_shape', name+'_0', name+'_-2'],
[name+'_lhs_slice0']),
make_node('Slice', [name+'_lhs_shape', name+'_-2', name+'_100'],
[name+'_lhs_slice1']),
make_node('Concat', [name+'_-1', name+'_lhs_slice1'], [name+'_lhs_concat1'], axis=0),
make_node('Reshape', [lhs, name+'_lhs_concat1'], [name+'_lhs_3d']),
make_node('Transpose', [name+'_lhs_3d'], [name+'_lhs_3d_transpose'], perm=perm),
make_node('Shape', [name+'_lhs_3d_transpose'], [name+'_lhs_shape_3d']),
make_node('Slice', [name+'_lhs_shape_3d', name+'_-2', name+'_100'],
[name+'_lhs_slice2']),
make_node('Concat', [name+'_lhs_slice0', name+'_lhs_slice2'], [name+'_lhs_concat2'], axis=0),
make_node('Reshape', [name+'_lhs_3d_transpose', name+'_lhs_concat2'], [name+'_lhs']),

make_node('Shape', [rhs], [name+'_rhs_shape']),
make_node('Shape', [name+'_rhs_shape'], [name+'_rhs_dim']),
make_node('Slice', [name+'_rhs_shape', name+'_0', name+'_-2'],
Expand All @@ -3343,9 +3375,8 @@ def convert_batch_dot(node, **kwargs):
[name+'_rhs_slice2']),
make_node('Concat', [name+'_rhs_slice0', name+'_rhs_slice2'], [name+'_rhs_concat2'], axis=0),
make_node('Reshape', [name+'_rhs_3d_transpose', name+'_rhs_concat2'], [name+'_rhs']),
make_node('MatMul', [name+'_lhs', name+'_rhs'], [name]),
]

nodes += [
make_node('MatMul', [name+'_lhs', name+'_rhs'], [name]),
]
return nodes

0 comments on commit 91d986e

Please sign in to comment.