Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix tests for ONNX version 1.5.0 bump #18054

Merged
merged 3 commits into from
Jun 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 155 additions & 18 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,18 +493,129 @@ def convert_pad(node, **kwargs):

return [node]

def create_helper_tensor_node(input_vals, output_name, kwargs):
"""create extra tensor node from numpy values"""
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[input_vals.dtype]

tensor_node = onnx.helper.make_tensor_value_info(
name=output_name,
elem_type=data_type,
shape=input_vals.shape
)
kwargs["initializer"].append(
onnx.helper.make_tensor(
name=output_name,
data_type=data_type,
dims=input_vals.shape,
vals=input_vals.flatten(),
raw=False,
)
)

return [tensor_node]

def create_helper_reshape_node(input_name, output_name, shape, kwargs):
"""create extra reshape node with static shape"""
shape_tensor_node, = create_helper_tensor_node(
np.asarray(shape, dtype=np.int64), output_name + "__shape", kwargs
)
reshape_node = onnx.helper.make_node(
"Reshape",
inputs=[input_name, shape_tensor_node.name],
outputs=[output_name],
name=output_name
)

return [shape_tensor_node, reshape_node]

def create_helper_trans_node(op_name, input_node, node_name):
"""create extra transpose node for dot operator"""
node_name = op_name + "_" + node_name
def create_helper_trans_node(input_name, output_name, perm=None):
"""create extra transpose node"""
attrs = {}
if perm is not None:
attrs['perm'] = perm
trans_node = onnx.helper.make_node(
'Transpose',
inputs=[input_node],
outputs=[node_name],
name=node_name
inputs=[input_name],
outputs=[output_name],
name=output_name,
**attrs
)
return trans_node
return [trans_node]

def create_helper_concat_node(inputs, output_name, axis=0):
"""create extra concat node"""
concat_node = onnx.helper.make_node(
"Concat",
inputs=inputs,
outputs=[output_name],
name=output_name,
axis=axis,
)
return [concat_node]

def create_helper_expand_node(input_name, output_name, expand_shape):
"""create extra expand node"""
expand_node = onnx.helper.make_node(
"Expand",
inputs=[input_name, expand_shape],
outputs=[output_name],
name=output_name,
)
return [expand_node]

def create_helper_gather_node(
input_name, output_name,
indices, kwargs,
axis=None
):
"""create extra gather node with static indices"""
attrs = {}
if axis is not None:
attrs['axis'] = axis
gather_tensor_node, = create_helper_tensor_node(
np.asarray(indices, np.int64), output_name + "__indices", kwargs
)
gather_node = onnx.helper.make_node(
"Gather",
inputs=[input_name, gather_tensor_node.name],
outputs=[output_name],
name=output_name,
**attrs
)
return [gather_tensor_node, gather_node]

def create_helper_build_values_node(
inputs, output_name,
dtype, kwargs, axis=0
):
"""create extra node, with specified values

(allows mixing node names and static values)
"""
values = []
tensor_nodes = []
for idx, inp in enumerate(inputs):
if not isinstance(inp, (str, bytes)):
inp, = create_helper_tensor_node(
np.array([inp], dtype=dtype),
output_name + "__value" + str(idx),
kwargs
)
tensor_nodes.append(inp)
inp = inp.name
values.append(inp)
concat_node, = create_helper_concat_node(values, output_name, axis=axis)
return tensor_nodes + [concat_node,]

def create_helper_shape_node(input_name, output_name):
"""create extra shape node for specified input node"""
shape_node = onnx.helper.make_node(
"Shape",
inputs=[input_name],
outputs=[output_name],
name=output_name,
)
return [shape_node]

@mx_op.register("dot")
def convert_dot(node, **kwargs):
Expand All @@ -524,11 +635,11 @@ def convert_dot(node, **kwargs):
op_name = "transpose" + str(kwargs["idx"])

if trans_a:
trans_a_node = create_helper_trans_node(op_name, input_nodes[0], 'a')
input_node_a = op_name+"_a"
input_node_a = op_name + "_a"
trans_a_node, = create_helper_trans_node(input_nodes[0], input_node_a)
if trans_b:
trans_b_node = create_helper_trans_node(op_name, input_nodes[1], 'b')
input_node_b = op_name+"_b"
input_node_b = op_name + "_b"
trans_b_node, = create_helper_trans_node(input_nodes[1], input_node_b)

matmul_node = onnx.helper.make_node(
'MatMul',
Expand Down Expand Up @@ -1503,16 +1614,34 @@ def convert_slice_axis(node, **kwargs):
in_shape = kwargs['in_shape'][0]
ends = in_shape[axes]

export_nodes = []

starts = np.atleast_1d(np.asarray(starts, dtype=np.int))
ends = np.atleast_1d(np.asarray(ends, dtype=np.int))
axes = np.atleast_1d(np.asarray(axes, dtype=np.int))

starts_node = create_helper_tensor_node(starts, name + '__starts', kwargs)
export_nodes.extend(starts_node)
starts_node = starts_node[-1].name

ends_node = create_helper_tensor_node(ends, name + '__ends', kwargs)
export_nodes.extend(ends_node)
ends_node = ends_node[-1].name

axes_node = create_helper_tensor_node(axes, name + '__axes', kwargs)
export_nodes.extend(axes_node)
axes_node = axes_node[-1].name

input_node = input_nodes[0]
node = onnx.helper.make_node(
"Slice",
input_nodes,
[input_node, starts_node, ends_node, axes_node],
[name],
axes=[axes],
starts=[starts],
ends=[int(ends)],
name=name,
)
return [node]
export_nodes.extend([node])

return export_nodes


@mx_op.register("SliceChannel")
Expand Down Expand Up @@ -2070,14 +2199,22 @@ def convert_topk(node, **kwargs):
else:
raise NotImplementedError("ONNX expects both value and indices as output")

export_nodes = []

k = np.asarray([k], dtype=np.int)
k_node = create_helper_tensor_node(k, name + '__k', kwargs)
export_nodes.extend(k_node)
k_node = k_node[-1].name

input_node = input_nodes[0]
topk_node = onnx.helper.make_node(
"TopK",
input_nodes,
[input_node, k_node],
outputs,
axis=axis,
k=k,
name=name
)
export_nodes.extend([topk_node])

return [topk_node]

Expand Down
20 changes: 8 additions & 12 deletions tests/python/unittest/onnx/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@
'test_transpose',
'test_globalmaxpool',
'test_globalaveragepool',
'test_slice_cpu',
'test_slice_neg',
'test_slice_end',
'test_reciprocal',
'test_sqrt',
'test_pow',
Expand All @@ -54,19 +51,19 @@
'test_operator_maxpool',
'test_operator_params',
'test_operator_permute2',
'test_cos',
'test_sin',
'test_cos[^h]',
'test_sin[^h]',
'test_tan',
'test_acos',
'test_asin',
'test_atan',
'test_acos[^h]',
'test_asin[^h]',
'test_atan[^h]',
'test_squeeze',
'test_matmul',
'test_matmul_',
'test_depthtospace',
'test_hardsigmoid',
'test_instancenorm',
'test_shape',
'test_cast',
'test_cast((?!STRING).)*$',
'test_clip',
'test_size',
'test_dropout',
Expand All @@ -80,7 +77,6 @@
'test_softplus',
'test_reduce_',
'test_split_equal',
'test_top_k',
'test_gather'
],
'import': ['test_softsign',
Expand Down Expand Up @@ -116,7 +112,7 @@
'test_softmax_functional',
'test_softmax_lastdim',
],
'export': ['test_ConvTranspose2d']
'export': []
}

STANDARD_MODEL = {
Expand Down
11 changes: 6 additions & 5 deletions tests/python/unittest/onnx/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,8 @@ def test_imports(self):
npt.assert_almost_equal(np_out, mxnet_out, decimal=4)

def test_exports(self):
input_shape = (2,1,3,1)
for test in export_test_cases:
test_name, onnx_name, mx_op, attrs = test
test_name, onnx_name, mx_op, input_shape, attrs = test
input_sym = mx.sym.var('data')
outsym = mx_op(input_sym, **attrs)
converted_model = onnx_mxnet.export_model(outsym, {}, [input_shape], np.float32,
Expand Down Expand Up @@ -287,10 +286,12 @@ def test_exports(self):
("test_lpnormalization_ord2", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':2, 'axis':1})
]

# test_case = ("test_case_name", "ONNX_op_name", mxnet_op, attribute map)
# test_case = ("test_case_name", "ONNX_op_name", mxnet_op, input_shape, attribute map)
export_test_cases = [
("test_expand", "Expand", mx.sym.broadcast_to, {'shape': (2,1,3,1)}),
("test_tile", "Tile", mx.sym.tile, {'reps': (2,3)})
("test_expand", "Expand", mx.sym.broadcast_to, (2,1,3,1), {'shape': (2,1,3,1)}),
("test_tile", "Tile", mx.sym.tile, (2,1,3,1), {'reps': (2,3)}),
("test_topk", "TopK", mx.sym.topk, (2, 10, 2), {'k': 3, 'axis': 1, 'ret_typ': 'both', 'dtype': np.int64}),
("test_slice_axis", "Slice", mx.sym.slice_axis, (2, 10, 2), {'begin': 3, 'end': 7, 'axis': 1}),
]

if __name__ == '__main__':
Expand Down