Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI, Relay] A new NMS op variant for ONNX NMS / TF Combined NMS #7796

Merged
merged 36 commits into from
Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
64f5d50
initial import
masahi Apr 3, 2021
d980761
add c++ boilarplate
masahi Apr 3, 2021
a1c3bf6
add python boilarpolate
masahi Apr 3, 2021
8af8079
update onnx frontend
masahi Apr 3, 2021
d26d5b9
fixing
masahi Apr 3, 2021
0c71339
update onnx frontend
masahi Apr 3, 2021
a40337a
fix shape
masahi Apr 3, 2021
71370cf
minor update
masahi Apr 3, 2021
15d3bd0
fix
masahi Apr 3, 2021
837ce76
fix shape func
masahi Apr 3, 2021
e26bb4d
fix for no box
masahi Apr 3, 2021
65b5bba
more fix
masahi Apr 3, 2021
253629a
made things 64 bit
Apr 3, 2021
9cb2505
int64 tweak
Apr 3, 2021
ac5d79b
max_output_size doesn't need to be a callback
masahi Apr 4, 2021
fd868a1
remove all_class_nms schedule
masahi Apr 4, 2021
adaaf50
minor simplify
masahi Apr 4, 2021
83aa4c2
remove expand_dim
masahi Apr 4, 2021
0be35e6
refactoring
masahi Apr 4, 2021
a46bd03
simplify nms loop
masahi Apr 4, 2021
8699a98
cpu all_class_nms stub
masahi Apr 4, 2021
ef8d3c9
updating ir for cpu
masahi Apr 4, 2021
8400bbf
working with cpu
masahi Apr 4, 2021
dc437ff
update cpu strategy, relay op also working
masahi Apr 4, 2021
ee9c4d5
fix cpplint
masahi Apr 4, 2021
e67eae7
fixing pylint
masahi Apr 4, 2021
b4bd995
enable gpu test for onnx nms
masahi Apr 4, 2021
ed7f6ae
tweak parallel
masahi Apr 5, 2021
0b26341
pyformat and lint
masahi Apr 5, 2021
2361321
fix relay nms test
masahi Apr 5, 2021
004145a
doc update for cpp relay
masahi Apr 8, 2021
d207c4d
updating tests
masahi Apr 8, 2021
05fa415
updated tests
masahi Apr 8, 2021
6d314de
fix converting score_threshold to Expr
masahi Apr 8, 2021
56531f7
update doc
masahi Apr 9, 2021
b174927
doc fix
masahi Apr 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ struct GetValidCountsAttrs : public tvm::AttrsNode<GetValidCountsAttrs> {

/*! \brief Attributes used in non_maximum_suppression operator */
struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionAttrs> {
Optional<Integer> max_output_size;
Optional<FloatImm> iou_threshold;
bool force_suppress;
int top_k;
int coord_start;
Expand All @@ -97,8 +95,6 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
bool invalid_to_bottom;

TVM_DECLARE_ATTRS(NonMaximumSuppressionAttrs, "relay.attrs.NonMaximumSuppressionAttrs") {
TVM_ATTR_FIELD(max_output_size).describe("Max number of output valid boxes for each instance.");
TVM_ATTR_FIELD(iou_threshold).describe("Non-maximum suppression iou threshold.");
TVM_ATTR_FIELD(force_suppress)
.set_default(false)
.describe("Suppress all detections regardless of class_id.");
Expand All @@ -118,6 +114,13 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
}
};

/*! \brief Attributes used in non_maximum_suppression operator */
struct AllClassNonMaximumSuppressionAttrs
: public tvm::AttrsNode<AllClassNonMaximumSuppressionAttrs> {
TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs,
"relay.attrs.AllClassNonMaximumSuppressionAttrs") {}
};

/*! \brief Attributes used in roi_align operators */
struct ROIAlignAttrs : public tvm::AttrsNode<ROIAlignAttrs> {
Array<IndexExpr> pooled_size;
Expand Down
238 changes: 7 additions & 231 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2456,26 +2456,13 @@ class NonMaxSuppression(OnnxOpConverter):

@classmethod
def _impl_v10(cls, inputs, attr, params):
"""
High level note: ONNX implements what TF calls combined_non_max_suppression
It passes in scores for each box for every class in the output and expects boxes to be
analyzed for each class independently

It also asks for the data to be returned in a particular format.

To support these, we implement a series of lops:
The first loop splits over class number, performs NMS, and collects the outputs.
The second (nested) loop takes the outputs and transforms them into the format ONNX wants
"""
# Get parameter values
boxes = inputs[0]
scores = inputs[1]
max_output_boxes_per_class = inputs[2]
iou_threshold = inputs[3]
score_threshold = inputs[4]

dtype = infer_type(boxes).checked_type.dtype

if "center_point_box" in attr:
if attr["center_point_box"] != 0:
raise NotImplementedError(
Expand All @@ -2498,226 +2485,15 @@ def conditionally_squeeze_scalar(x):
iou_threshold = conditionally_squeeze_scalar(iou_threshold)
score_threshold = conditionally_squeeze_scalar(score_threshold)

## prepare utility constants
zero = _op.const(np.array([0]), dtype="int64")
one = _op.const(np.array([1]), dtype="int64")
two = _op.const(np.array([2]), dtype="int64")
three = _op.const(np.array([3]), dtype="int64")
three_ones = _op.const(np.array([1, 1, 1]), dtype="int64")
four_ones = _op.const(np.array([1, 1, 1, 1]), dtype="int64")

## First loop: split by class and perform NMS
# Create Loop Vars
i = _expr.var("i", shape=(1,), dtype="int64")
scores_var = _expr.var("scores_var", shape=(_ty.Any(), _ty.Any(), _ty.Any()), dtype=dtype)
boxes_var = _expr.var("boxes_var", shape=(_ty.Any(), _ty.Any(), 4), dtype=dtype)
max_output_boxes_per_class_var = _expr.var(
"max_output_boxes_per_class_var", shape=(), dtype="int64"
)
iou_threshold_var = _expr.var("iou_threshold_var", shape=(), dtype="float32")
score_threshold_var = _expr.var("score_threshold_var", shape=(), dtype="float32")
B = _expr.var("B", shape=(1,), dtype="int64")
C = _expr.var("C", shape=(1,), dtype="int64")
S = _expr.var("S", shape=(1,), dtype="int64")
# Outputs of first loop should be padded nms values shape (B, C, S, 3)
onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64")
# and sizes of valid outputs, shape (B, C, 1)
nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64")

def _first_cond(
i,
scores,
boxes,
B,
C,
S,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
onnx_out,
nms_size_out,
):
# Loop over classes, end when i == C
return _op.take(_op.less(i, C), _expr.const(0))

def _first_body(
i,
scores,
boxes,
B,
C,
S,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
onnx_out,
nms_size_out,
):
# slice to get current class
begin = _op.concatenate([zero, i, zero], axis=0)
end = _op.concatenate([B, i + one, S], axis=0)
class_scores = _op.strided_slice(scores, begin, end, three_ones)
class_scores = _op.expand_dims(_op.squeeze(class_scores, [1]), -1, 1)
# combine scores and boxes
data = _op.concatenate([class_scores, boxes], axis=-1)

# get valid counts
ct, data, indices = _op.vision.get_valid_counts(
data, score_threshold=score_threshold, id_index=-1, score_index=0
)
# reason why using get_valid_counts is for inference performance
# ONNX NMS doesn't have parameter top_k
top_k = -1
# ONNX doesn't have class id for nms input
score_index = 0
# perform nms on current class
nms_ret = _op.vision.non_max_suppression(
data=data,
valid_count=ct,
indices=indices,
max_output_size=max_output_boxes_per_class,
iou_threshold=iou_threshold,
force_suppress=True,
top_k=top_k,
coord_start=1,
score_index=score_index,
id_index=-1,
return_indices=True,
invalid_to_bottom=False,
)
# partially prepare ONNX output format by labeling batch_num, class_id
nms_padded_out = _op.expand_dims(nms_ret[0], -1, 1)
batch_num = _op.expand_dims(_op.arange(_op.squeeze(B, [0]), dtype="int64"), -1, 1)
batch_num = _op.broadcast_to(batch_num, shape_of(nms_ret[0], dtype="int64"))
batch_num = _op.expand_dims(batch_num, -1, 1)
class_num = _op.broadcast_to(i, shape_of(nms_padded_out, dtype="int64"))
new_onnx_out = _op.concatenate(
[batch_num, class_num, _op.cast(nms_padded_out, "int64")], -1
)
new_onnx_out = _op.expand_dims(new_onnx_out, 1, 1)
# store valid nms outputs for this class
nms_size = _op.cast(nms_ret[1], "int64")
nms_size = _op.expand_dims(nms_size, 1, 1)
return [
i + one,
scores,
boxes,
B,
C,
S,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
_op.concatenate([onnx_out, new_onnx_out], axis=1),
_op.concatenate([nms_size_out, nms_size], axis=1),
]

# create the first loop
first_loop = _loops.while_loop(
_first_cond,
[
i,
scores_var,
boxes_var,
B,
C,
S,
max_output_boxes_per_class_var,
iou_threshold_var,
score_threshold_var,
onnx_out,
nms_size_out,
],
_first_body,
)

## Second loop slices outputs of the first loop for valid boxes and
## concats in the order ONNX wants
# Second inner Loop Vars
i = _expr.var("i", shape=(1,), dtype="int64")
j = _expr.var("j", shape=(1,), dtype="int64")
B = _expr.var("B", shape=(1,), dtype="int64")
C = _expr.var("C", shape=(1,), dtype="int64")
# Outputs of first loop should be padded nms values shape (B, C, 3)
onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64")
# and sizes of valid outputs, shape (B, C, 1)
nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64")
out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64")

def _inner_cond(i, j, C, onnx_out, nms_size, out):
# inner loop over number of classes
return _op.take(_op.less(j, C), _expr.const(0))

def _inner_body(i, j, C, onnx_out, nms_size, out):
# slice to get current batch and class for valid box indicator
start = _op.concatenate([i, j + one, zero], axis=0)
end = _op.concatenate([i + one, j + two, one], axis=0)
num_valid_boxes = _op.reshape(_op.strided_slice(nms_size, start, end, three_ones), [1])
# slice to get current batch, class, and valid outputs
start = _op.concatenate([i, j + one, zero, zero], axis=0)
end = _op.concatenate([i + one, j + two, num_valid_boxes, three], axis=0)
new_out = _op.squeeze(_op.strided_slice(onnx_out, start, end, four_ones), [0, 1])
return i, j + one, C, onnx_out, nms_size, _op.concatenate([out, new_out], axis=0)

inner_loop = _loops.while_loop(
_inner_cond, [i, j, C, onnx_out, nms_size_out, out], _inner_body
nms_out = _op.vision.all_class_non_max_suppression(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold
)

# Second Outer Loop Vars
i = _expr.var("i", shape=(1,), dtype="int64")
j = _expr.var("j", shape=(1,), dtype="int64")
B = _expr.var("B", shape=(1,), dtype="int64")
C = _expr.var("C", shape=(1,), dtype="int64")
# Outputs of first loop should be padded nms values shape (B, C, 3)
onnx_out = _expr.var("onnx_out", shape=(_ty.Any(), _ty.Any(), _ty.Any(), 3), dtype="int64")
# and sizes of valid outputs, shape (B, C, 1)
nms_size_out = _expr.var("nms_size_out", shape=(_ty.Any(), _ty.Any(), 1), dtype="int64")
out = _expr.var("out", shape=(_ty.Any(), 3), dtype="int64")

def _outer_cond(i, B, C, onnx_out, nms_size_out, out):
# Outer loop is over batch size
return _op.take(_op.less(i, B), _expr.const(0))

def _outer_body(i, B, C, onnx_out, nms_size_out, out):
# Outer loop just calls inner loop
init_count = _op.const(np.array([0]), dtype="int64")
inner_loop_vals = inner_loop(i, init_count, C, onnx_out, nms_size_out, out)
return i + one, B, C, onnx_out, nms_size_out, _expr.TupleGetItem(inner_loop_vals, 5)

# Create the second loop
outer_loop = _loops.while_loop(
_outer_cond, [i, B, C, onnx_out, nms_size_out, out], _outer_body
)

# Call the first loop, perform NMS
B, C, S = _op.split(shape_of(scores, dtype="int64"), 3)
init_count = _op.const(np.array([0]), dtype="int64")
init_onnx_out = _op.const([1], dtype="int64")
init_onnx_out = _op.broadcast_to(init_onnx_out, _op.concatenate([B, one, S, three], 0))
init_nms_size_out = _op.const([1], dtype="int64")
init_nms_size_out = _op.broadcast_to(init_nms_size_out, _op.concatenate([B, one, one], 0))
loop_vals = first_loop(
init_count,
scores,
boxes,
B,
C,
S,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
init_onnx_out,
init_nms_size_out,
)
onnx_output = _expr.TupleGetItem(loop_vals, 9)
nms_size_output = _expr.TupleGetItem(loop_vals, 10)

# Call the second loop, rework outputs into correct form
init_count = _op.const(np.array([0]).astype("int64"), dtype="int64")
init_out = _op.const(np.array([1, 1, 1]).reshape([1, 3]).astype("int64"), dtype="int64")
loop_vals = outer_loop(init_count, B, C, onnx_output, nms_size_output, init_out)
loop_out = _expr.TupleGetItem(loop_vals, 5)
return _op.strided_slice(loop_out, [1, 0], shape_of(loop_out), [1, 1])
three = _op.const(np.array([3]), dtype="int64")
begin = _op.const(np.array([0, 0]), dtype="int64")
end = _op.concatenate([nms_out[1], three], axis=0)
strides = _op.const(np.array([1, 1]), dtype="int64")
return _op.strided_slice(nms_out[0], begin, end, strides)


class ATen(OnnxOpConverter):
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ class NonMaximumSuppressionAttrs(Attrs):
"""Attributes for vision.non_maximum_suppression"""


@tvm._ffi.register_object("relay.attrs.AllClassNonMaximumSuppressionAttrs")
class AllClassNonMaximumSuppressionAttrs(Attrs):
"""Attributes for vision.all_classnon_maximum_suppression"""


@tvm._ffi.register_object("relay.attrs.ROIAlignAttrs")
class ROIAlignAttrs(Attrs):
"""Attributes for vision.roi_align"""
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,18 @@ def nms_strategy_cuda(attrs, inputs, out_type, target):
return strategy


@all_class_nms_strategy.register(["cuda", "gpu"])
def all_class_nms_strategy_cuda(attrs, inputs, out_type, target):
"""all class nms cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_all_class_nms(topi.cuda.all_class_non_max_suppression),
wrap_topi_schedule(topi.cuda.schedule_nms),
name="all_class_nms.cuda",
)
return strategy


@roi_align_strategy.register(["cuda", "gpu"])
def roi_align_strategy_cuda(attrs, inputs, out_type, target):
"""roi_align cuda strategy"""
Expand Down
28 changes: 24 additions & 4 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,10 +1000,6 @@ def wrap_compute_nms(topi_compute):
def _compute_nms(attrs, inputs, out_type):
max_output_size = inputs[3]
iou_threshold = inputs[4]
if attrs.max_output_size is not None:
max_output_size = attrs.max_output_size
if attrs.iou_threshold is not None:
iou_threshold = get_const_float(attrs.iou_threshold)
return_indices = bool(get_const_int(attrs.return_indices))
force_suppress = bool(get_const_int(attrs.force_suppress))
top_k = get_const_int(attrs.top_k)
Expand Down Expand Up @@ -1058,6 +1054,30 @@ def nms_strategy(attrs, inputs, out_type, target):
return strategy


def wrap_compute_all_class_nms(topi_compute):
"""wrap all class nms topi compute"""

def _compute_nms(attrs, inputs, out_type):
max_output_size = inputs[2]
iou_threshold = inputs[3]
score_threshold = inputs[4]
return topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold, score_threshold)

return _compute_nms


@override_native_generic_func("all_class_non_max_suppression_strategy")
def all_class_nms_strategy(attrs, inputs, out_type, target):
"""all class nms generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_all_class_nms(topi.vision.all_class_non_max_suppression),
wrap_topi_schedule(topi.generic.schedule_nms),
name="all_class_nms.generic",
)
return strategy


# roi_align
def wrap_compute_roi_align(topi_compute):
"""wrap roi_align topi compute"""
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/relay/op/vision/_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
reg.register_strategy("vision.non_max_suppression", strategy.nms_strategy)
reg.register_pattern("vision.non_max_suppression", OpPattern.OPAQUE)

reg.register_strategy("vision.all_class_non_max_suppression", strategy.all_class_nms_strategy)
reg.register_pattern("vision.all_class_non_max_suppression", OpPattern.OPAQUE)


@script
def _get_valid_counts_shape_func(data_shape):
Expand Down Expand Up @@ -85,6 +88,22 @@ def nms_shape_func(attrs, inputs, _):
return [topi.math.identity(inputs[0])]


@script
def _all_class_nms_shape_func(boxes_shape, scores_shape):
out_shape = output_tensor((2,), "int64")
count_shape = output_tensor((1,), "int64")

out_shape[0] = boxes_shape[0] * scores_shape[1] * boxes_shape[1]
out_shape[1] = 3
count_shape[0] = int64(1)
return out_shape, count_shape


@reg.register_shape_func("vision.all_class_non_max_suppression", False)
def all_class_nms_shape_func(attrs, inputs, _):
return _all_class_nms_shape_func(inputs[0], inputs[1])


@script
def _roi_align_shape_func_nchw(data_shape, rois_shape, pooled_size):
out = output_tensor((4,), "int64")
Expand Down
Loading