Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] ONNX contrib_box_nms (#19755)
Browse files Browse the repository at this point in the history
* fix multiple output bug

* basic implementation of nms conversion

* tweak

* Update _op_translations.py

* Update _op_translations.py

* fix
  • Loading branch information
Zha0q1 authored Jan 29, 2021
1 parent 05de69b commit 5c2de1d
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 0 deletions.
92 changes: 92 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2971,6 +2971,98 @@ def convert_repeat(node, **kwargs):

return nodes


@mx_op.register('_contrib_box_nms')
def convert_contrib_box_nms(node, **kwargs):
"""Map MXNet's _contrib_box_nms operator to ONNX
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

opset_version = kwargs['opset_version']
if opset_version < 11:
raise AttributeError('ONNX opset 11 or greater is required to export this operator')

input_type = kwargs['in_type']
dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[input_type]

overlap_thresh = float(attrs.get('overlap_thresh', '0.5'))
valid_thresh = float(attrs.get('valid_thresh', '0'))
topk = int(attrs.get('topk', '-1'))
coord_start = int(attrs.get('coord_start', '2'))
score_index = int(attrs.get('score_index', '1'))
id_index = int(attrs.get('id_index', '-1'))
background_id = int(attrs.get('background_id', '-1'))
in_format = attrs.get('in_format', 'corner')
out_format = attrs.get('out_format', 'corner')

center_point_box = 0 if in_format == 'corner' else 1

if in_format != out_format:
raise NotImplementedError('box_nms does not currently support in_fomat != out_format')

if background_id != -1:
raise NotImplementedError('box_nms does not currently support background_id != -1')

if id_index != -1:
raise NotImplementedError('box_nms does not currently support id_index != -1')

nodes = [
create_tensor([coord_start], name+'_cs', kwargs['initializer']),
create_tensor([coord_start+4], name+'_cs_p4', kwargs['initializer']),
create_tensor([score_index], name+'_si', kwargs['initializer']),
create_tensor([score_index+1], name+'_si_p1', kwargs['initializer']),
create_tensor([topk], name+'_topk', kwargs['initializer']),
create_tensor([overlap_thresh], name+'_ot', kwargs['initializer'], dtype=np.float32),
create_tensor([valid_thresh], name+'_vt', kwargs['initializer'], dtype=np.float32),
create_tensor([-1], name+'_m1', kwargs['initializer']),
create_tensor([-1], name+'_m1_f', kwargs['initializer'], dtype=dtype),
create_tensor([0], name+'_0', kwargs['initializer']),
create_tensor([1], name+'_1', kwargs['initializer']),
create_tensor([2], name+'_2', kwargs['initializer']),
create_tensor([3], name+'_3', kwargs['initializer']),
create_tensor([], name+'_void', kwargs['initializer']),
create_tensor([0, 1, -1], name+'_scores_shape', kwargs['initializer']),
create_tensor([0, 0, 1, 0], name+'_pad', kwargs['initializer']),
create_tensor([0, -1], name+'_bat_spat_helper', kwargs['initializer']),
make_node('Shape', [input_nodes[0]], [name+'_shape']),
make_node('Shape', [name+'_shape'], [name+'_dim']),
make_node('Sub', [name+'_dim', name+'_2'], [name+'_dim_m2']),
make_node('Slice', [name+'_shape', name+'_dim_m2', name+'_dim'], [name+'_shape_last2']),
make_node('Concat', [name+'_m1', name+'_shape_last2'], [name+'_shape_3d'], axis=0),
make_node('Reshape', [input_nodes[0], name+'_shape_3d'], [name+'_data_3d']),
make_node('Slice', [name+'_data_3d', name+'_cs', name+'_cs_p4', name+'_m1'],
[name+'_boxes']),
make_node('Slice', [name+'_data_3d', name+'_si', name+'_si_p1', name+'_m1'],
[name+'_scores_raw']),
make_node('Reshape', [name+'_scores_raw', name+'_scores_shape'], [name+'_scores']),
make_node('Shape', [name+'_scores'], [name+'_scores_shape_actual']),
make_node('NonMaxSuppression',
[name+'_boxes', name+'_scores', name+'_topk', name+'_ot', name+'_vt'],
[name+'_nms'], center_point_box=center_point_box),
make_node('Slice', [name+'_nms', name+'_0', name+'_3', name+'_m1', name+'_2'],
[name+'_nms_sliced']),
make_node('GatherND', [name+'_data_3d', name+'_nms_sliced'], [name+'_candidates']),
make_node('Pad', [name+'_candidates', name+'_pad', name+'_m1_f'], [name+'_cand_padded']),
make_node('Shape', [name+'_nms'], [name+'_nms_shape']),
make_node('Slice', [name+'_nms_shape', name+'_0', name+'_1'], [name+'_cand_cnt']),
make_node('Reshape', [name+'_cand_cnt', name+'_void'], [name+'_cc_s']),
make_node('Range', [name+'_0', name+'_cc_s', name+'_1'], [name+'_cand_indices']),
make_node('Slice', [name+'_scores_shape_actual', name+'_0', name+'_3', name+'_m1',
name+'_2'], [name+'_shape_bat_spat']),
make_node('Slice', [name+'_shape_bat_spat', name+'_1', name+'_2'], [name+'_spat_dim']),
make_node('Expand', [name+'_cand_cnt', name+'_shape_bat_spat'], [name+'_base_indices']),
make_node('ScatterND', [name+'_base_indices', name+'_nms_sliced', name+'_cand_indices'],
[name+'_indices']),
make_node('TopK', [name+'_indices', name+'_spat_dim'], [name+'_indices_sorted', name+'__'],
largest=0, axis=-1, sorted=1),
make_node('Gather', [name+'_cand_padded', name+'_indices_sorted'], [name+'_gather']),
make_node('Reshape', [name+'_gather', name+'_shape'], [name+'0'])
]

return nodes


@mx_op.register("_greater_scalar")
def convert_greater_scalar(node, **kwargs):
"""Map MXNet's greater_scalar operator attributes to onnx's Greater
Expand Down
55 changes: 55 additions & 0 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,61 @@ def test_onnx_export_contrib_BilinearResize2D(tmp_path, dtype, params):
op_export_test('contrib_BilinearResize2D', M, [x], tmp_path)


@pytest.mark.parametrize('topk', [2, 3, 4])
@pytest.mark.parametrize('valid_thresh', [0.3, 0.4, 0.8])
@pytest.mark.parametrize('overlap_thresh', [0.4, 0.7, 1.0])
def test_onnx_export_contrib_box_nms_manual(tmp_path, topk, valid_thresh, overlap_thresh):
# Note that ONNX NMS op only supports float32

# Also note that onnxruntime's nms has slightly different implementation in handling
# overlaps and score ordering when certain boxes are suppressed than that of mxnet
# the following test tensors are manually tweaked to avoid such diferences
# The purpose of theses tests cases are to show that the high level conversion logic is
# laid out correctly

A = mx.nd.array([[
[[[[0.5, 0.1, 0.1, 0.2, 0.2],
[0.4, 0.1, 0.1, 0.2, 0.2],
[0.7, 0.5, 0.5, 0.9, 0.9],
[0.8, 0.1, 0.9, 0.11, 0.91],
[0.001, 0.01, 0.01, 0.02, 0.02]]]],

[[[[0.5, 0.1, 0.1, 0.2, 0.2],
[0.4, 0.1, 0.1, 0.2, 0.2],
[0.7, 0.5, 0.5, 0.9, 0.9],
[0.8, 0.1, 0.9, 0.11, 0.91],
[0.001, 0.01, 0.01, 0.02, 0.02]]]],

[[[[0.4, 0.1, 0.1, 0.2, 0.2],
[0.3, 0.1, 0.1, 0.2, 0.2],
[0.7, 0.5, 0.5, 0.9, 0.9],
[0.8, 0.1, 0.9, 0.11, 0.91],
[0.001, 0.01, 0.01, 0.02, 0.02]]]],
]])
M = def_model('contrib.box_nms', coord_start=1, force_suppress=True,
overlap_thresh=overlap_thresh, valid_thresh=valid_thresh, score_index=0,
topk=topk, in_format='corner', out_format='corner')
op_export_test('contrib_nms_manual_coner', M, [A], tmp_path)

B = mx.nd.array([
[[[[0.7, 0.5, 0.5, 0.2, 0.2],
[0.6, 0.48, 0.48, 0.2, 0.2],
[0.8, 0.76, 0.76, 0.2, 0.2],
[0.9, 0.7, 0.7, 0.2, 0.2],
[0.001, 0.5, 0.1, 0.02, 0.02]]]],

[[[[0.5, 0.2, 0.2, 0.2, 0.2],
[0.6, 0.4, 0.4, 0.21, 0.21],
[0.7, 0.5, 0.5, 0.9, 0.9],
[0.8, 0.1, 0.9, 0.01, 0.01],
[0.001, 0.6, 0.1, 0.02, 0.02]]]],
])
M = def_model('contrib.box_nms', coord_start=1, force_suppress=True,
overlap_thresh=overlap_thresh, valid_thresh=valid_thresh, score_index=0,
topk=topk, in_format='center', out_format='center')
op_export_test('contrib_nms_manual_center', M, [B], tmp_path)


@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
@pytest.mark.parametrize("scalar", [0., 0.1, 0.5, 1., 5, 555.])
def test_onnx_export_greater_scalar(tmp_path, dtype, scalar):
Expand Down

0 comments on commit 5c2de1d

Please sign in to comment.