Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX import: Remove extra flatten from Linalg_gemm
Browse files Browse the repository at this point in the history
ONNX Gemm spec expects the input to be 2D. Flatten is not required
while importing Gemm, it will only lead to redundant flatten.
  • Loading branch information
vandanavk committed Dec 7, 2018
1 parent 039548c commit 2975ab8
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,7 @@ def linalg_gemm(attrs, inputs, proto_obj):
alpha = attrs['alpha']
if 'beta' in attrs:
beta = attrs['beta']
flatten_a = symbol.flatten(inputs[0])
matmul_op = symbol.linalg_gemm2(A=flatten_a, B=inputs[1],
matmul_op = symbol.linalg_gemm2(A=inputs[0], B=inputs[1],
transpose_a=trans_a, transpose_b=trans_b,
alpha=alpha)
gemm_op = symbol.broadcast_add(matmul_op, beta*inputs[2])
Expand Down

0 comments on commit 2975ab8

Please sign in to comment.