From 74bc18fbfb9201d341119c7bad5833ef4aa8024d Mon Sep 17 00:00:00 2001 From: vandanavk Date: Fri, 28 Sep 2018 15:24:38 -0700 Subject: [PATCH] ONNX export: Fully connected operator with no bias --- .../contrib/onnx/mx2onnx/_op_translations.py | 34 ++++++++++++++++--- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index bc7d17e3c927..eceee0e5b388 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -222,17 +222,41 @@ def convert_fully_connected(node, **kwargs): onnx = import_onnx_modules() name = node["name"] inputs = node["inputs"] + attrs = node["attrs"] + initializer = kwargs["initializer"] + + no_bias = convert_bool_to_int(attrs, "no_bias") + input_node_id = kwargs["index_lookup"][inputs[0][0]] weight_node_id = kwargs["index_lookup"][inputs[1][0]] - bias_node_id = kwargs["index_lookup"][inputs[2][0]] + proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_id] - weights_node = proc_nodes[weight_node_id] - bias_node = proc_nodes[bias_node_id] + input_node = proc_nodes[input_node_id] input_name = input_node.name + + weights_node = proc_nodes[weight_node_id] weights_name = weights_node.name - bias_name = bias_node.name + + if no_bias == 0: + bias_node_id = kwargs["index_lookup"][inputs[2][0]] + bias_node = proc_nodes[bias_node_id] + bias_name = bias_node.name + else: + default_bias = [0] + np_arr = np.array(default_bias) + data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np_arr.dtype] + dims = np.shape(np_arr) + bias_name = "bias" + str(kwargs["idx"]) + initializer.append( + onnx.helper.make_tensor( + name=bias_name, + data_type=data_type, + dims=dims, + vals=default_bias, + raw=False, + ) + ) node = onnx.helper.make_node( "Gemm",