Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX export: Fully connected operator with no bias
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Oct 5, 2018
1 parent 83a8831 commit 22ea71c
Showing 1 changed file with 29 additions and 5 deletions.
34 changes: 29 additions & 5 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,17 +222,41 @@ def convert_fully_connected(node, **kwargs):
onnx = import_onnx_modules()
name = node["name"]
inputs = node["inputs"]
attrs = node["attrs"]
initializer = kwargs["initializer"]

no_bias = convert_bool_to_int(attrs, "no_bias")

input_node_id = kwargs["index_lookup"][inputs[0][0]]
weight_node_id = kwargs["index_lookup"][inputs[1][0]]
bias_node_id = kwargs["index_lookup"][inputs[2][0]]

proc_nodes = kwargs["proc_nodes"]
input_node = proc_nodes[input_node_id]
weights_node = proc_nodes[weight_node_id]
bias_node = proc_nodes[bias_node_id]

input_node = proc_nodes[input_node_id]
input_name = input_node.name

weights_node = proc_nodes[weight_node_id]
weights_name = weights_node.name
bias_name = bias_node.name

if no_bias == 0:
bias_node_id = kwargs["index_lookup"][inputs[2][0]]
bias_node = proc_nodes[bias_node_id]
bias_name = bias_node.name
else:
default_bias = [0]
np_arr = np.array(default_bias)
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np_arr.dtype]
dims = np.shape(np_arr)
bias_name = "bias" + str(kwargs["idx"])
initializer.append(
onnx.helper.make_tensor(
name=bias_name,
data_type=data_type,
dims=dims,
vals=default_bias,
raw=False,
)
)

node = onnx.helper.make_node(
"Gemm",
Expand Down

0 comments on commit 22ea71c

Please sign in to comment.