Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fully connected operator with no bias
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Sep 22, 2018
1 parent 504d24c commit 7ba3f95
Showing 1 changed file with 32 additions and 6 deletions.
38 changes: 32 additions & 6 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,20 +212,46 @@ def convert_fully_connected(node, **kwargs):
"""Map MXNet's FullyConnected operator attributes to onnx's Gemm operator
and return the created node.
"""
helper, _, _ = import_onnx_modules()
helper, _, mapping = import_onnx_modules()
name = node["name"]
inputs = node["inputs"]
attrs = node["attrs"]
initializer = kwargs["initializer"]

no_bias = attrs.get("no_bias", "False")

input_node_id = kwargs["index_lookup"][inputs[0][0]]
weight_node_id = kwargs["index_lookup"][inputs[1][0]]
bias_node_id = kwargs["index_lookup"][inputs[2][0]]

proc_nodes = kwargs["proc_nodes"]
input_node = proc_nodes[input_node_id]
weights_node = proc_nodes[weight_node_id]
bias_node = proc_nodes[bias_node_id]

input_node = proc_nodes[input_node_id]
input_name = input_node.name

weights_node = proc_nodes[weight_node_id]
weights_name = weights_node.name
bias_name = bias_node.name

if no_bias == "False":
bias_node_id = kwargs["index_lookup"][inputs[2][0]]
bias_node = proc_nodes[bias_node_id]
bias_name = bias_node.name
else:
default_bias = [0]
np_arr = np.array(default_bias)
data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np_arr.dtype]
dims = np.shape(np_arr)

bias_name = "bias" + str(kwargs["idx"])

initializer.append(
helper.make_tensor(
name=bias_name,
data_type=data_type,
dims=dims,
vals=default_bias,
raw=False,
)
)

node = helper.make_node(
"Gemm",
Expand Down

0 comments on commit 7ba3f95

Please sign in to comment.