From 013b526751d0e86c2a070e02f49851511b2ecd33 Mon Sep 17 00:00:00 2001 From: vandanavk Date: Fri, 28 Sep 2018 15:24:38 -0700 Subject: [PATCH] ONNX export: Fully connected operator with no bias --- .../contrib/onnx/mx2onnx/_op_translations.py | 41 ++++++++++++++++--- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 67793251127f..3996694c780b 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -222,17 +222,44 @@ def convert_fully_connected(node, **kwargs): onnx = import_onnx_modules() name = node["name"] inputs = node["inputs"] + attrs = node["attrs"] + initializer = kwargs["initializer"] + + no_bias = convert_bool_to_int(attrs, "no_bias") + input_node_id = kwargs["index_lookup"][inputs[0][0]] weight_node_id = kwargs["index_lookup"][inputs[1][0]] - bias_node_id = kwargs["index_lookup"][inputs[2][0]] + proc_nodes = kwargs["proc_nodes"] - input_node = proc_nodes[input_node_id] - weights_node = proc_nodes[weight_node_id] - bias_node = proc_nodes[bias_node_id] + input_node = proc_nodes[input_node_id] input_name = input_node.name + + weights_node = proc_nodes[weight_node_id] weights_name = weights_node.name - bias_name = bias_node.name + + fcnode = [] + + if no_bias == 0: + bias_node_id = kwargs["index_lookup"][inputs[2][0]] + bias_node = proc_nodes[bias_node_id] + bias_name = bias_node.name + else: + np_arr = np.array([0]) + data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np_arr.dtype] + dims = np.shape(np_arr) + bias_name = "bias" + str(kwargs["idx"]) + tensor_node = onnx.helper.make_tensor_value_info(bias_name, data_type, dims) + initializer.append( + onnx.helper.make_tensor( + name=bias_name, + data_type=data_type, + dims=dims, + vals=[0], + raw=False, + ) + ) + fcnode.append(tensor_node) node = onnx.helper.make_node( "Gemm", @@ -245,7 +272,9 @@ def convert_fully_connected(node, **kwargs): name=name ) - return [node] + fcnode.append(node) + + return fcnode @mx_op.register("BatchNorm")