Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX export: Add Flatten before Gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Nov 21, 2018
1 parent f838bb5 commit 372cb0a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 18 deletions.
11 changes: 11 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ def convert_fully_connected(node, **kwargs):

fcnode = []

op_name = "flatten_" + str(kwargs["idx"])
flatten_node = onnx.helper.make_node(
'Flatten',
inputs=[input_nodes[0]],
outputs=[op_name],
name=op_name
)

input_nodes[0] = op_name
fcnode.append(flatten_node)

if no_bias:
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('int64')]
bias_name = "bias" + str(kwargs["idx"])
Expand Down
39 changes: 21 additions & 18 deletions python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
# Determine output shape
output_shape = MXNetGraph.infer_output_shape(sym, params, in_shape, output_label)

output_suffix = '_output'
output_names = [
o[:-len(output_suffix)] for o in sym.list_outputs() if o.endswith(output_suffix)]

weights = MXNetGraph.convert_weights_to_numpy(params)

mx_graph = json.loads(sym.tojson())["nodes"]
Expand Down Expand Up @@ -294,26 +298,25 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
# If converted node is NodeProto, add it in processed nodes list
elif isinstance(converted_node, NodeProto):
onnx_processed_nodes.append(converted_node)
if idx == (len(mx_graph) - 1):
# If converted node doesnt have name, use it from output field
if not converted_node.name:
onnx_processed_outputs.append(
make_tensor_value_info(
name=converted_node.output[0],
elem_type=in_type,
shape=output_shape
)
# If converted node doesnt have name, use it from output field
if not converted_node.name and idx == (len(mx_graph) - 1):
onnx_processed_outputs.append(
make_tensor_value_info(
name=converted_node.output[0],
elem_type=in_type,
shape=output_shape
)
else:
onnx_processed_outputs.append(
make_tensor_value_info(
name=converted_node.name,
elem_type=in_type,
shape=output_shape
)
)
elif converted_node.name in output_names:
onnx_processed_outputs.append(
make_tensor_value_info(
name=converted_node.name,
elem_type=in_type,
shape=output_shape
)
if verbose:
logging.info("Output node is: %s", converted_node.name)
)
if verbose:
logging.info("Output node is: %s", converted_node.name)
elif isinstance(converted_node, TensorProto):
raise ValueError("Did not expect TensorProto")
else:
Expand Down

0 comments on commit 372cb0a

Please sign in to comment.