Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX import/export: Add missing tests, ONNX export: LogSoftMax (#13654)
Browse files Browse the repository at this point in the history
* Logsoftmax, missing tests

* Support multiple outputs in Gluon backendrep

* Remove repeated unsqueeze test

* Allow multiple output support
  • Loading branch information
vandanavk authored and sandeep-krishnamurthy committed Dec 28, 2018
1 parent 116d01e commit be3d945
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 24 deletions.
29 changes: 26 additions & 3 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ def convert_l2normalization(node, **kwargs):
mode = attrs.get("mode", "instance")

if mode != "channel":
raise AttributeError("ONNX currently supports channel mode only")
raise AttributeError("L2Normalization: ONNX currently supports channel mode only")

l2norm_node = onnx.helper.make_node(
"LpNormalization",
Expand Down Expand Up @@ -1302,7 +1302,7 @@ def convert_reshape(node, **kwargs):

for val in output_shape_list:
if val in not_supported_shape:
raise AttributeError("Shape value not supported in ONNX", val)
raise AttributeError("Reshape: Shape value not supported in ONNX", val)

reshape_node = onnx.helper.make_node(
"Reshape",
Expand Down Expand Up @@ -1428,7 +1428,7 @@ def convert_squeeze(node, **kwargs):

axis = attrs.get("axis", None)
if not axis:
raise AttributeError("Missing axis attribute: ONNX currently requires axis to "
raise AttributeError("Squeeze: Missing axis attribute: ONNX currently requires axis to "
"be specified for squeeze operator")
axis = convert_string_to_list(axis)

Expand Down Expand Up @@ -1666,3 +1666,26 @@ def convert_size(node, **kwargs):
and return the created node.
"""
return create_basic_op_node('Size', node, kwargs)


@mx_op.register("log_softmax")
def convert_logsoftmax(node, **kwargs):
"""Map MXNet's log_softmax operator attributes to onnx's LogSoftMax operator
and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

# Converting to int
axis = int(attrs.get("axis", -1))
temp = attrs.get("temperature", 'None')
if temp != 'None':
raise AttributeError("LogSoftMax: ONNX supports only temperature=None")

node = onnx.helper.make_node(
'LogSoftmax',
input_nodes,
[name],
axis=axis,
name=name
)
return [node]
2 changes: 1 addition & 1 deletion python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def get_outputs(sym, params, in_shape, in_label):
if name.endswith('_output'):
out_names.append(name[:-len('_output')])
else:
logging.warning("output '%s' does not end with '_output'", name)
logging.info("output '%s' does not end with '_output'", name)
out_names.append(name)

assert len(out_shapes) == len(out_names)
Expand Down
17 changes: 12 additions & 5 deletions tests/python-pytest/onnx/backend_rep.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,10 @@ def run(self, inputs, **kwargs):
args = dict(zip(data_names, data_forward))
exe = self.symbol.bind(ctx, args=args, aux_states=self.aux_params)
exe.forward(is_train=False)
result = exe.outputs[0].asnumpy()
return [result]
result = []
for output in exe.outputs:
result.append(output.asnumpy())
return result


# GluonBackendRep object will be returned by GluonBackend's prepare method which is used to
Expand Down Expand Up @@ -124,7 +126,12 @@ def run(self, inputs, **kwargs):
net_inputs = [nd.array(input_data, ctx=ctx) for input_data in inputs]
net_outputs = self.net(*net_inputs)
results = []
results.extend([o for o in net_outputs.asnumpy()])
result = np.array(results)
if isinstance(net_outputs, list):
for output in net_outputs:
results.append(output.asnumpy())
result = results
else:
results.extend([o for o in net_outputs.asnumpy()])
result = [np.array(results)]

return [result]
return result
30 changes: 15 additions & 15 deletions tests/python-pytest/onnx/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
'test_argmin',
'test_min',
# pytorch operator tests
'test_operator_exp',
'test_exp_',
'test_operator_maxpool',
'test_operator_params',
'test_operator_permute2',
Expand All @@ -60,21 +60,28 @@
'test_asin',
'test_atan',
'test_squeeze',
'test_matmul_3d',
'test_matmul_4d',
'test_matmul',
'test_depthtospace',
'test_hardsigmoid',
'test_instancenorm',
'test_shape',
'test_cast',
'test_clip',
'test_size'
'test_size',
'test_dropout',
'test_unsqueeze',
'test_log_',
'test_flatten_default_axis',
'test_leakyrelu',
'test_selu_default',
'test_elu',
'test_max_',
'test_softplus'
],
'import': ['test_unsqueeze',
'import': ['test_gather',
'test_global_lppooling',
'test_softsign',
'test_reduce_',
'test_softplus',
'test_mean',
'test_averagepool_1d',
'test_averagepool_2d_pads_count_include_pad',
Expand All @@ -84,18 +91,16 @@
'test_averagepool_3d',
'test_LpPool_',
'test_split_equal'
'test_random_',
],
'export': ['test_random_uniform',
'test_random_normal',
'test_reduce_min',
'test_reduce_max',
'test_squeeze',
'test_reduce_mean',
'test_reduce_prod',
'test_reduce_sum_d',
'test_reduce_sum_keepdims_random',
'test_max_',
'test_lrn'
]
}

Expand All @@ -104,17 +109,12 @@
'test_BatchNorm',
'test_ConstantPad2d'
'test_Conv2d',
'test_ELU',
'test_LeakyReLU',
'test_MaxPool',
'test_PReLU',
'test_ReLU',
'test_selu_default',
'test_Sigmoid',
'test_Softmax',
'test_softmax_functional',
'test_softmax_lastdim',
'test_Tanh']
]
}

STANDARD_MODEL = {
Expand Down

0 comments on commit be3d945

Please sign in to comment.