Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX export: Fully connected operator with no bias
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Oct 8, 2018
1 parent f5a84e8 commit 00892b4
Showing 1 changed file with 36 additions and 6 deletions.
42 changes: 36 additions & 6 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,17 +222,45 @@ def convert_fully_connected(node, **kwargs):
onnx = import_onnx_modules()
name = node["name"]
inputs = node["inputs"]
attrs = node["attrs"]
initializer = kwargs["initializer"]

no_bias = convert_bool_to_int(attrs, "no_bias")

input_node_id = kwargs["index_lookup"][inputs[0][0]]
weight_node_id = kwargs["index_lookup"][inputs[1][0]]
bias_node_id = kwargs["index_lookup"][inputs[2][0]]

proc_nodes = kwargs["proc_nodes"]
input_node = proc_nodes[input_node_id]
weights_node = proc_nodes[weight_node_id]
bias_node = proc_nodes[bias_node_id]

input_node = proc_nodes[input_node_id]
input_name = input_node.name

weights_node = proc_nodes[weight_node_id]
weights_name = weights_node.name
bias_name = bias_node.name

fcnode = []

if no_bias == 0:
bias_node_id = kwargs["index_lookup"][inputs[2][0]]
bias_node = proc_nodes[bias_node_id]
bias_name = bias_node.name
else:
default_bias = [0]
np_arr = np.array(default_bias)
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np_arr.dtype]
dims = np.shape(np_arr)
bias_name = "bias" + str(kwargs["idx"])
tensor_node = onnx.helper.make_tensor_value_info(bias_name, data_type, dims)
initializer.append(
onnx.helper.make_tensor(
name=bias_name,
data_type=data_type,
dims=dims,
vals=default_bias,
raw=False,
)
)
fcnode.append(tensor_node)

node = onnx.helper.make_node(
"Gemm",
Expand All @@ -245,7 +273,9 @@ def convert_fully_connected(node, **kwargs):
name=name
)

return [node]
fcnode.append(node)

return fcnode


@mx_op.register("BatchNorm")
Expand Down

0 comments on commit 00892b4

Please sign in to comment.