Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix tests for ONNX version 1.5.0 bump (#18054)
Browse files Browse the repository at this point in the history
* implement onnx translation helpers

* bump onnx version to 1.5

* add export only test cases for topk and slice_axis
  • Loading branch information
RuRo committed Jun 5, 2020
1 parent 4be0955 commit deae9b8
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 35 deletions.
173 changes: 155 additions & 18 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,18 +493,129 @@ def convert_pad(node, **kwargs):

return [node]

def create_helper_tensor_node(input_vals, output_name, kwargs):
"""create extra tensor node from numpy values"""
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[input_vals.dtype]

tensor_node = onnx.helper.make_tensor_value_info(
name=output_name,
elem_type=data_type,
shape=input_vals.shape
)
kwargs["initializer"].append(
onnx.helper.make_tensor(
name=output_name,
data_type=data_type,
dims=input_vals.shape,
vals=input_vals.flatten(),
raw=False,
)
)

return [tensor_node]

def create_helper_reshape_node(input_name, output_name, shape, kwargs):
"""create extra reshape node with static shape"""
shape_tensor_node, = create_helper_tensor_node(
np.asarray(shape, dtype=np.int64), output_name + "__shape", kwargs
)
reshape_node = onnx.helper.make_node(
"Reshape",
inputs=[input_name, shape_tensor_node.name],
outputs=[output_name],
name=output_name
)

return [shape_tensor_node, reshape_node]

def create_helper_trans_node(op_name, input_node, node_name):
"""create extra transpose node for dot operator"""
node_name = op_name + "_" + node_name
def create_helper_trans_node(input_name, output_name, perm=None):
"""create extra transpose node"""
attrs = {}
if perm is not None:
attrs['perm'] = perm
trans_node = onnx.helper.make_node(
'Transpose',
inputs=[input_node],
outputs=[node_name],
name=node_name
inputs=[input_name],
outputs=[output_name],
name=output_name,
**attrs
)
return trans_node
return [trans_node]

def create_helper_concat_node(inputs, output_name, axis=0):
"""create extra concat node"""
concat_node = onnx.helper.make_node(
"Concat",
inputs=inputs,
outputs=[output_name],
name=output_name,
axis=axis,
)
return [concat_node]

def create_helper_expand_node(input_name, output_name, expand_shape):
"""create extra expand node"""
expand_node = onnx.helper.make_node(
"Expand",
inputs=[input_name, expand_shape],
outputs=[output_name],
name=output_name,
)
return [expand_node]

def create_helper_gather_node(
input_name, output_name,
indices, kwargs,
axis=None
):
"""create extra gather node with static indices"""
attrs = {}
if axis is not None:
attrs['axis'] = axis
gather_tensor_node, = create_helper_tensor_node(
np.asarray(indices, np.int64), output_name + "__indices", kwargs
)
gather_node = onnx.helper.make_node(
"Gather",
inputs=[input_name, gather_tensor_node.name],
outputs=[output_name],
name=output_name,
**attrs
)
return [gather_tensor_node, gather_node]

def create_helper_build_values_node(
inputs, output_name,
dtype, kwargs, axis=0
):
"""create extra node, with specified values
(allows mixing node names and static values)
"""
values = []
tensor_nodes = []
for idx, inp in enumerate(inputs):
if not isinstance(inp, (str, bytes)):
inp, = create_helper_tensor_node(
np.array([inp], dtype=dtype),
output_name + "__value" + str(idx),
kwargs
)
tensor_nodes.append(inp)
inp = inp.name
values.append(inp)
concat_node, = create_helper_concat_node(values, output_name, axis=axis)
return tensor_nodes + [concat_node,]

def create_helper_shape_node(input_name, output_name):
"""create extra shape node for specified input node"""
shape_node = onnx.helper.make_node(
"Shape",
inputs=[input_name],
outputs=[output_name],
name=output_name,
)
return [shape_node]

@mx_op.register("dot")
def convert_dot(node, **kwargs):
Expand All @@ -524,11 +635,11 @@ def convert_dot(node, **kwargs):
op_name = "transpose" + str(kwargs["idx"])

if trans_a:
trans_a_node = create_helper_trans_node(op_name, input_nodes[0], 'a')
input_node_a = op_name+"_a"
input_node_a = op_name + "_a"
trans_a_node, = create_helper_trans_node(input_nodes[0], input_node_a)
if trans_b:
trans_b_node = create_helper_trans_node(op_name, input_nodes[1], 'b')
input_node_b = op_name+"_b"
input_node_b = op_name + "_b"
trans_b_node, = create_helper_trans_node(input_nodes[1], input_node_b)

matmul_node = onnx.helper.make_node(
'MatMul',
Expand Down Expand Up @@ -1503,16 +1614,34 @@ def convert_slice_axis(node, **kwargs):
in_shape = kwargs['in_shape'][0]
ends = in_shape[axes]

export_nodes = []

starts = np.atleast_1d(np.asarray(starts, dtype=np.int))
ends = np.atleast_1d(np.asarray(ends, dtype=np.int))
axes = np.atleast_1d(np.asarray(axes, dtype=np.int))

starts_node = create_helper_tensor_node(starts, name + '__starts', kwargs)
export_nodes.extend(starts_node)
starts_node = starts_node[-1].name

ends_node = create_helper_tensor_node(ends, name + '__ends', kwargs)
export_nodes.extend(ends_node)
ends_node = ends_node[-1].name

axes_node = create_helper_tensor_node(axes, name + '__axes', kwargs)
export_nodes.extend(axes_node)
axes_node = axes_node[-1].name

input_node = input_nodes[0]
node = onnx.helper.make_node(
"Slice",
input_nodes,
[input_node, starts_node, ends_node, axes_node],
[name],
axes=[axes],
starts=[starts],
ends=[int(ends)],
name=name,
)
return [node]
export_nodes.extend([node])

return export_nodes


@mx_op.register("SliceChannel")
Expand Down Expand Up @@ -2070,14 +2199,22 @@ def convert_topk(node, **kwargs):
else:
raise NotImplementedError("ONNX expects both value and indices as output")

export_nodes = []

k = np.asarray([k], dtype=np.int)
k_node = create_helper_tensor_node(k, name + '__k', kwargs)
export_nodes.extend(k_node)
k_node = k_node[-1].name

input_node = input_nodes[0]
topk_node = onnx.helper.make_node(
"TopK",
input_nodes,
[input_node, k_node],
outputs,
axis=axis,
k=k,
name=name
)
export_nodes.extend([topk_node])

return [topk_node]

Expand Down
20 changes: 8 additions & 12 deletions tests/python/unittest/onnx/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@
'test_transpose',
'test_globalmaxpool',
'test_globalaveragepool',
'test_slice_cpu',
'test_slice_neg',
'test_slice_end',
'test_reciprocal',
'test_sqrt',
'test_pow',
Expand All @@ -54,19 +51,19 @@
'test_operator_maxpool',
'test_operator_params',
'test_operator_permute2',
'test_cos',
'test_sin',
'test_cos[^h]',
'test_sin[^h]',
'test_tan',
'test_acos',
'test_asin',
'test_atan',
'test_acos[^h]',
'test_asin[^h]',
'test_atan[^h]',
'test_squeeze',
'test_matmul',
'test_matmul_',
'test_depthtospace',
'test_hardsigmoid',
'test_instancenorm',
'test_shape',
'test_cast',
'test_cast((?!STRING).)*$',
'test_clip',
'test_size',
'test_dropout',
Expand All @@ -80,7 +77,6 @@
'test_softplus',
'test_reduce_',
'test_split_equal',
'test_top_k',
'test_gather'
],
'import': ['test_softsign',
Expand Down Expand Up @@ -116,7 +112,7 @@
'test_softmax_functional',
'test_softmax_lastdim',
],
'export': ['test_ConvTranspose2d']
'export': []
}

STANDARD_MODEL = {
Expand Down
11 changes: 6 additions & 5 deletions tests/python/unittest/onnx/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,8 @@ def test_imports(self):
npt.assert_almost_equal(np_out, mxnet_out, decimal=4)

def test_exports(self):
input_shape = (2,1,3,1)
for test in export_test_cases:
test_name, onnx_name, mx_op, attrs = test
test_name, onnx_name, mx_op, input_shape, attrs = test
input_sym = mx.sym.var('data')
outsym = mx_op(input_sym, **attrs)
converted_model = onnx_mxnet.export_model(outsym, {}, [input_shape], np.float32,
Expand Down Expand Up @@ -287,10 +286,12 @@ def test_exports(self):
("test_lpnormalization_ord2", "LpNormalization", [get_rnd([5, 3, 3, 2])], np.linalg.norm, {'ord':2, 'axis':1})
]

# test_case = ("test_case_name", "ONNX_op_name", mxnet_op, attribute map)
# test_case = ("test_case_name", "ONNX_op_name", mxnet_op, input_shape, attribute map)
export_test_cases = [
("test_expand", "Expand", mx.sym.broadcast_to, {'shape': (2,1,3,1)}),
("test_tile", "Tile", mx.sym.tile, {'reps': (2,3)})
("test_expand", "Expand", mx.sym.broadcast_to, (2,1,3,1), {'shape': (2,1,3,1)}),
("test_tile", "Tile", mx.sym.tile, (2,1,3,1), {'reps': (2,3)}),
("test_topk", "TopK", mx.sym.topk, (2, 10, 2), {'k': 3, 'axis': 1, 'ret_typ': 'both', 'dtype': np.int64}),
("test_slice_axis", "Slice", mx.sym.slice_axis, (2, 10, 2), {'begin': 3, 'end': 7, 'axis': 1}),
]

if __name__ == '__main__':
Expand Down

0 comments on commit deae9b8

Please sign in to comment.