From 7ba3f953bd6aad1030781890efe3f9d82f234c30 Mon Sep 17 00:00:00 2001 From: vandanavk Date: Tue, 18 Sep 2018 16:44:37 -0700 Subject: [PATCH] Fully connected operator with no bias --- .../contrib/onnx/mx2onnx/_op_translations.py | 38 ++++++++++++++++--- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 3ffac96a14e1..5adade8487b7 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -212,20 +212,46 @@ def convert_fully_connected(node, **kwargs): """Map MXNet's FullyConnected operator attributes to onnx's Gemm operator and return the created node. """ - helper, _, _ = import_onnx_modules() + helper, _, mapping = import_onnx_modules() name = node["name"] inputs = node["inputs"] + attrs = node["attrs"] + initializer = kwargs["initializer"] + + no_bias = attrs.get("no_bias", "False") + input_node_id = kwargs["index_lookup"][inputs[0][0]] weight_node_id = kwargs["index_lookup"][inputs[1][0]] - bias_node_id = kwargs["index_lookup"][inputs[2][0]] + proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_id] - weights_node = proc_nodes[weight_node_id] - bias_node = proc_nodes[bias_node_id] + input_node = proc_nodes[input_node_id] input_name = input_node.name + + weights_node = proc_nodes[weight_node_id] weights_name = weights_node.name - bias_name = bias_node.name + + if no_bias == "False": + bias_node_id = kwargs["index_lookup"][inputs[2][0]] + bias_node = proc_nodes[bias_node_id] + bias_name = bias_node.name + else: + default_bias = [0] + np_arr = np.array(default_bias) + data_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np_arr.dtype] + dims = np.shape(np_arr) + + bias_name = "bias" + str(kwargs["idx"]) + + initializer.append( + helper.make_tensor( + name=bias_name, + data_type=data_type, + dims=dims, + vals=default_bias, + raw=False, + ) + ) node = helper.make_node( "Gemm",