Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[WIP] Onnx export API optional args and additional operator support, Fixes #12682 #12946

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 70 additions & 17 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ def convert_convolution(node, **kwargs):

return [conv_node]


@mx_op.register("FullyConnected")
def convert_fully_connected(node, **kwargs):
"""Map MXNet's FullyConnected operator attributes to onnx's Gemm operator
Expand Down Expand Up @@ -443,20 +442,19 @@ def convert_dot(node, **kwargs):
transpose_a, transpose_b attributes."""
name, input_nodes, attrs = get_inputs(node, kwargs)

trans_a_node = None
trans_b_node = None
trans_a_node = trans_b_node = None

trans_a = get_boolean_attribute_value(attrs, "transpose_a")
trans_b = get_boolean_attribute_value(attrs, "transpose_b")

op_name = "transpose" + str(kwargs["idx"])
input_node_a = op_name + "_a"
input_node_b = op_name + "_b"

if trans_a:
trans_a_node = create_helper_trans_node(op_name, input_nodes[0], 'a')
input_node_a = op_name+"_a"
if trans_b:
trans_b_node = create_helper_trans_node(op_name, input_nodes[1], 'b')
input_node_b = op_name+"_b"

matmul_node = onnx.helper.make_node(
'MatMul',
Expand Down Expand Up @@ -617,13 +615,21 @@ def convert_exp(node, **kwargs):
return create_basic_op_node('Exp', node, kwargs)

@mx_op.register("_copy")
def convert_identity(node, **kwargs):
def convert_copy(node, **kwargs):
"""Map MXNet's _copy operator attributes to onnx's Identity operator
and return the created node.
"""
return create_basic_op_node('Identity', node, kwargs)


@mx_op.register("identity")
def convert_identity(node, **kwargs):
"""Map MXNet's identity operator attributes to onnx's ConstantFill operator
and return the created node.
"""
return create_basic_op_node('ConstantFill', node, kwargs)


@mx_op.register("LeakyReLU")
def convert_leakyrelu(node, **kwargs):
"""Map MXNet's LeakyReLU operator attributes to onnx's Elu/LeakyRelu/PRelu operators
Expand Down Expand Up @@ -681,7 +687,7 @@ def convert_softmax_output(node, **kwargs):
"""Map MXNet's SoftmaxOutput operator attributes to onnx's Softmax operator
and return the created node.
"""
name, _, _ = get_inputs(node, kwargs)
name = node["name"]

input1_idx = kwargs["index_lookup"][node["inputs"][0][0]]
input1 = kwargs["proc_nodes"][input1_idx]
Expand All @@ -693,10 +699,38 @@ def convert_softmax_output(node, **kwargs):
axis=1,
name=name
)

return [softmax_node]


@mx_op.register("LogisticRegressionOutput")
def convert_logistic_regression_output(node, **kwargs):
"""Map MXNet's SoftmaxOutput operator attributes to onnx's Softmax operator
and return the created node.
"""
name = node["name"]
input1_idx = kwargs["index_lookup"][node["inputs"][0][0]]
input1 = kwargs["proc_nodes"][input1_idx]

sigmoid_node = onnx.helper.make_node(
"Sigmoid",
[input1.output[0]],
[name],
name=name
)
return [sigmoid_node]

@mx_op.register("BlockGrad")
def convert_blockgrad(node, **kwargs):
""" Skip operator """
return create_basic_op_node('ConstantFill', node, kwargs)


@mx_op.register("make_loss")
def convert_makeloss(node, **kwargs):
""" Skip operator """
return create_basic_op_node('ConstantFill', node, kwargs)


@mx_op.register("Concat")
def convert_concat(node, **kwargs):
"""Map MXNet's Concat operator attributes to onnx's Concat operator
Expand Down Expand Up @@ -843,6 +877,8 @@ def scalar_op_helper(node, op_name, **kwargs):
"""Helper function for scalar arithmetic operations"""
name, input_nodes, attrs = get_inputs(node, kwargs)

from onnx import numpy_helper

scalar_value = [float(attrs.get("scalar", 1))]

initializer = kwargs["initializer"]
Expand All @@ -852,13 +888,19 @@ def scalar_op_helper(node, op_name, **kwargs):
for i in initializer:
if i.name == input_nodes[0]:
if op_name == 'Mul':
new_initializer = onnx.numpy_helper.to_array(i) * scalar_value[0]
new_initializer = numpy_helper.to_array(i) * scalar_value[0]
elif op_name == 'Sub':
new_initializer = onnx.numpy_helper.to_array(i) - scalar_value[0]
if name.startswith("_rminusscalar"):
new_initializer = scalar_value[0] - numpy_helper.to_array(i)
else:
new_initializer = numpy_helper.to_array(i) - scalar_value[0]
elif op_name == 'Add':
new_initializer = onnx.numpy_helper.to_array(i) + scalar_value[0]
new_initializer = numpy_helper.to_array(i) + scalar_value[0]
elif op_name == 'Div':
new_initializer = onnx.numpy_helper.to_array(i) / scalar_value[0]
if name.startswith("_rdivscalar"):
new_initializer = scalar_value[0] / numpy_helper.to_array(i)
else:
new_initializer = numpy_helper.to_array(i) / scalar_value[0]
flag = False
break

Expand All @@ -869,6 +911,7 @@ def scalar_op_helper(node, op_name, **kwargs):
dims = np.shape(np_arr)

scalar_op_name = "scalar_op" + str(kwargs["idx"])
# Convert scalar value into node
tensor_node = onnx.helper.make_tensor_value_info(scalar_op_name, data_type, dims)

initializer.append(
Expand Down Expand Up @@ -907,7 +950,6 @@ def scalar_op_helper(node, op_name, **kwargs):
)
return [tensor_node]

# Convert scalar value into node and pass it as input to mul_node
@mx_op.register("_mul_scalar")
def convert_mul_scalar(node, **kwargs):
"""Map MXNet's _mul_scalar operator attributes to onnx's Mul operator.
Expand All @@ -916,8 +958,6 @@ def convert_mul_scalar(node, **kwargs):
"""
return scalar_op_helper(node, 'Mul', **kwargs)


# Convert scalar value into node and pass it as input to mul_node
@mx_op.register("_minus_scalar")
def convert_minus_scalar(node, **kwargs):
"""Map MXNet's _minus_scalar operator attributes to onnx's Minus operator.
Expand All @@ -926,8 +966,14 @@ def convert_minus_scalar(node, **kwargs):
"""
return scalar_op_helper(node, 'Sub', **kwargs)

@mx_op.register("_rminus_scalar")
def convert_rminus_scalar(node, **kwargs):
"""Map MXNet's _rminus_scalar operator attributes to onnx's Minus operator.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: onnx's Sub operator

Creates a new node for the input scalar value, adds it to the initializer
and return multiple created nodes.
"""
return scalar_op_helper(node, 'Sub', **kwargs)

# Convert scalar value into node and pass it as input to mul_node
@mx_op.register("_plus_scalar")
def convert_add_scalar(node, **kwargs):
"""Map MXNet's _plus_scalar operator attributes to onnx's Add operator.
Expand All @@ -936,7 +982,6 @@ def convert_add_scalar(node, **kwargs):
"""
return scalar_op_helper(node, 'Add', **kwargs)

# Convert scalar value into node and pass it as input to mul_node
@mx_op.register("_div_scalar")
def convert_div_scalar(node, **kwargs):
"""Map MXNet's _div_scalar operator attributes to onnx's Div operator.
Expand All @@ -945,6 +990,14 @@ def convert_div_scalar(node, **kwargs):
"""
return scalar_op_helper(node, 'Div', **kwargs)

@mx_op.register("_rdiv_scalar")
def convert_rdiv_scalar(node, **kwargs):
"""Map MXNet's _rdiv_scalar operator attributes to onnx's Div operator.
Creates a new node for the input scalar value, adds it to the initializer
and return multiple created nodes.
"""
return scalar_op_helper(node, 'Div', **kwargs)


# Sorting and Searching
@mx_op.register("argmax")
Expand Down
10 changes: 7 additions & 3 deletions python/mxnet/contrib/onnx/mx2onnx/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ._export_helper import load_module


def export_model(sym, params, input_shape, input_type=np.float32,
def export_model(sym, params, input_shape, input_type=np.float32, label_names=None, label_shapes=None,
onnx_file_path='model.onnx', verbose=False):
"""Exports the MXNet model file, passed as a parameter, into ONNX model.
Accepts both symbol,parameter objects as well as json and params filepaths as input.
Expand All @@ -49,6 +49,10 @@ def export_model(sym, params, input_shape, input_type=np.float32,
Input shape of the model e.g [(1,3,224,224)]
input_type : data type
Input data type e.g. np.float32
label_names : List of str
Optional list of label e.g. ['regression_label']
label_shapes : List of tuple
Optional a list of (name, shape) pairs e.g [('regression_label', (1,3,224,224))]
onnx_file_path : str
Path where to save the generated onnx file
verbose : Boolean
Expand All @@ -75,11 +79,11 @@ def export_model(sym, params, input_shape, input_type=np.float32,
sym_obj, params_obj = load_module(sym, params)
onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, input_shape,
mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
verbose=verbose)
label_names, label_shapes, verbose=verbose)
elif isinstance(sym, symbol.Symbol) and isinstance(params, dict):
onnx_graph = converter.create_onnx_graph_proto(sym, params, input_shape,
mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
verbose=verbose)
label_names, label_shapes, verbose=verbose)
else:
raise ValueError("Input sym and params should either be files or objects")

Expand Down
58 changes: 42 additions & 16 deletions python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def convert_layer(node, **kwargs):
return convert_func(node, **kwargs)

@staticmethod
def forward_pass(inputs, sym, arg_params, aux_params, output_label):
def forward_pass(inputs, sym, arg_params, aux_params, label_name):
"""Do a forward pass based on the sym and params to get the shape
of the output using dummy data

Expand All @@ -120,7 +120,7 @@ def forward_pass(inputs, sym, arg_params, aux_params, output_label):
# while running load_checkpoint which is not actually a graph input. So ignoring it here
data_names = [graph_input for graph_input in sym.list_inputs()
if graph_input not in arg_params and graph_input not in aux_params
and graph_input != output_label]
and graph_input not in label_name]

data_shapes = []
# Adding extra dimension of batch_size 1 if the batch_size is different for multiple inputs.
Expand All @@ -144,9 +144,13 @@ def forward_pass(inputs, sym, arg_params, aux_params, output_label):
data_forward.append(nd.array(val))

test_mod.forward(io.DataBatch(data_forward))
result = test_mod.get_outputs()[0].asnumpy()
result = [i.asnumpy().shape for i in test_mod.get_outputs()]

return result.shape
result_shape = []
for idx, label in enumerate(label_name):
result_shape.append((label, result[idx]))

return result_shape


@staticmethod
Expand Down Expand Up @@ -179,12 +183,12 @@ def split_params(sym, params):


@staticmethod
def infer_output_shape(sym, params, in_shape, output_label):
def infer_output_shape(sym, params, in_shape, label_name):
"""Infer output shape by doing a forward pass using dummy inputs """
# create dummy input
inputs = [np.random.randn(*input_shape) for input_shape in in_shape]
arg, aux = MXNetGraph.split_params(sym, params)
return MXNetGraph.forward_pass(inputs, sym, arg, aux, output_label)
return MXNetGraph.forward_pass(inputs, sym, arg, aux, label_name)


@staticmethod
Expand All @@ -193,7 +197,18 @@ def convert_weights_to_numpy(weights_dict):
return dict([(k.replace("arg:", "").replace("aux:", ""), v.asnumpy())
for k, v in weights_dict.items()])

def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False):
@staticmethod
def verify_provided_labels(data_names, data_shapes, name, throw):
"""Check that input labels matches input data shape."""
actual = [x[0] for x in data_shapes]
if sorted(data_names) != sorted(actual):
msg = "Data provided by %s_shapes don't match names specified by %s_names (%s vs. %s)" % (
name, name, str(data_shapes), str(data_names))
if throw:
raise ValueError(msg)

def create_onnx_graph_proto(self, sym, params, in_shape, in_type,
label_names=None, label_shapes=None, verbose=False):
"""Convert MXNet graph to ONNX graph

Parameters
Expand All @@ -206,6 +221,10 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
Input shape of the model e.g [(1,3,224,224)]
in_type : data type
Input data type e.g. np.float32
out_label : List of str
Optional list of output label names
out_shape : List of tuple
Optional output shape of the model e.g [(1,3,224,224)]
verbose : Boolean
If true will print logs of the model conversion

Expand All @@ -226,10 +245,17 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
# name is "Softmax", this node will have a name "Softmax_label". Also, the new node
# will always be second last node in the json graph.
# Deriving the output_label name.
output_label = sym.get_internals()[len(sym.get_internals()) - 1].name + "_label"
output_suffix = '_output'
output_names = [o[:-len(output_suffix)] for o in sym.list_outputs() if o.endswith(output_suffix)]

if not label_names:
label_names = [output_name + '_label' for output_name in output_names]

# Determine output shape
output_shape = MXNetGraph.infer_output_shape(sym, params, in_shape, output_label)
if not label_shapes:
label_shapes = MXNetGraph.infer_output_shape(sym, params, in_shape, label_names)
else:
MXNetGraph.verify_provided_labels(label_names, label_shapes, 'label', True)

weights = MXNetGraph.convert_weights_to_numpy(params)

Expand All @@ -253,10 +279,9 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
# in params dict
if op == "null" and name not in params:
# Handling graph input

# Skipping output_label node, as this node is not part of graph
# Refer "output_label" assignment above for more details.
if name == output_label:
if name in label_names:
continue
converted = MXNetGraph.convert_layer(
node,
Expand Down Expand Up @@ -294,22 +319,23 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
# If converted node is NodeProto, add it in processed nodes list
elif isinstance(converted_node, NodeProto):
onnx_processed_nodes.append(converted_node)
if idx == (len(mx_graph) - 1):
if converted_node.name in output_names:
label_shape = [i[1] for i in label_shapes if converted_node.name + "_label" == i[0]]
# If converted node doesnt have name, use it from output field
if not converted_node.name:
onnx_processed_outputs.append(
make_tensor_value_info(
name=converted_node.output[0],
elem_type=in_type,
shape=output_shape
shape=label_shape[0]
)
)
else:
onnx_processed_outputs.append(
make_tensor_value_info(
name=converted_node.name,
elem_type=in_type,
shape=output_shape
shape=label_shape[0]
)
)
if verbose:
Expand All @@ -327,12 +353,12 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False)
# refer "output_label" initialization above for more details.
# if extra node was added then prev_index to the last node is adjusted.
if idx == (len(mx_graph) - 1) and \
mx_graph[len(mx_graph)-2]["name"] == output_label:
mx_graph[len(mx_graph) - 2]["name"] in label_names:
prev_index = index_lookup[idx - 2]
else:
prev_index = index_lookup[idx - 1]

index_lookup.append(prev_index+len(converted))
index_lookup.append(prev_index + len(converted))
else:
index_lookup.append(len(converted) - 1)
else:
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/contrib/onnx/onnx2mx/_import_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
_convert_map = {
# Generator Functions
'Constant' : identity,
'ConstantFill' : identity,
'RandomUniform' : random_uniform,
'RandomNormal' : random_normal,
'RandomUniformLike' : random_uniform,
Expand Down
Loading