Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX export: Fully connected operator with no bias
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Oct 11, 2018
1 parent ad1e75d commit c80b64f
Showing 1 changed file with 33 additions and 6 deletions.
39 changes: 33 additions & 6 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,17 +222,42 @@ def convert_fully_connected(node, **kwargs):
onnx = import_onnx_modules()
name = node["name"]
inputs = node["inputs"]
attrs = node["attrs"]
initializer = kwargs["initializer"]

no_bias = get_boolean_attribute_value(attrs, "no_bias")

input_node_id = kwargs["index_lookup"][inputs[0][0]]
weight_node_id = kwargs["index_lookup"][inputs[1][0]]
bias_node_id = kwargs["index_lookup"][inputs[2][0]]

proc_nodes = kwargs["proc_nodes"]
input_node = proc_nodes[input_node_id]
weights_node = proc_nodes[weight_node_id]
bias_node = proc_nodes[bias_node_id]

input_node = proc_nodes[input_node_id]
input_name = input_node.name

weights_node = proc_nodes[weight_node_id]
weights_name = weights_node.name
bias_name = bias_node.name

fcnode = []

if no_bias == 0:
bias_node_id = kwargs["index_lookup"][inputs[2][0]]
bias_node = proc_nodes[bias_node_id]
bias_name = bias_node.name
else:
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')]
bias_name = "bias" + str(kwargs["idx"])
tensor_node = onnx.helper.make_tensor_value_info(bias_name, data_type, (1,))
initializer.append(
onnx.helper.make_tensor(
name=bias_name,
data_type=data_type,
dims=(1,),
vals=[0],
raw=False,
)
)
fcnode.append(tensor_node)

node = onnx.helper.make_node(
"Gemm",
Expand All @@ -245,7 +270,9 @@ def convert_fully_connected(node, **kwargs):
name=name
)

return [node]
fcnode.append(node)

return fcnode


@mx_op.register("BatchNorm")
Expand Down

0 comments on commit c80b64f

Please sign in to comment.